Safety Classifiers Deep Dive

ShieldGemma Deployment

3 min read

ShieldGemma is Google's safety classifier family built on Gemma 2, offering competitive performance with LlamaGuard while providing different trade-offs in terms of size and capabilities.

ShieldGemma Model Family

Model Parameters Latency (GPU) Benchmark Score Best For
ShieldGemma 2B 2B 30-60ms 0.814 Edge/mobile
ShieldGemma 9B 9B 80-150ms 0.843 Balanced
ShieldGemma 27B 27B 200-400ms 0.876 Maximum accuracy

Benchmark Note: ShieldGemma 27B achieves +10.8% improvement over LlamaGuard 2 on safety classification benchmarks (Google, 2024).

Safety Categories

ShieldGemma classifies content into four main harm categories:

SHIELDGEMMA_CATEGORIES = {
    "dangerous_content": "Content promoting dangerous activities",
    "harassment": "Bullying, threats, or targeted abuse",
    "hate_speech": "Content promoting hatred based on identity",
    "sexually_explicit": "Explicit sexual content"
}

Basic Implementation

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from typing import Dict

class ShieldGemmaClassifier:
    """ShieldGemma safety classifier for production use."""

    def __init__(self, model_size: str = "2b"):
        # Available: 2b, 9b, 27b
        model_id = f"google/shieldgemma-{model_size}"
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )

        # ShieldGemma probability tokens
        self.yes_token = self.tokenizer.encode("Yes", add_special_tokens=False)[0]
        self.no_token = self.tokenizer.encode("No", add_special_tokens=False)[0]

    def classify(
        self,
        user_content: str,
        category: str = "dangerous_content"
    ) -> Dict[str, float]:
        """
        Classify content for a specific harm category.

        Returns:
            Dict with 'probability' (0-1) and 'is_harmful' (bool)
        """
        prompt = self._build_prompt(user_content, category)

        inputs = self.tokenizer(
            prompt,
            return_tensors="pt"
        ).to(self.device)

        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits[0, -1, :]

            # Get probabilities for Yes/No tokens
            yes_prob = torch.softmax(
                logits[[self.yes_token, self.no_token]], dim=0
            )[0].item()

        return {
            "probability": yes_prob,
            "is_harmful": yes_prob > 0.5,
            "category": category
        }

    def classify_all(self, user_content: str) -> Dict[str, Dict]:
        """Check all harm categories."""
        results = {}
        for category in SHIELDGEMMA_CATEGORIES:
            results[category] = self.classify(user_content, category)
        return results

    def _build_prompt(self, content: str, category: str) -> str:
        """Build ShieldGemma prompt format."""
        category_desc = SHIELDGEMMA_CATEGORIES.get(category, category)

        return f"""You are a policy expert determining if the following content
violates the policy on {category_desc}.

User content: {content}

Does this content violate the {category} policy? Answer Yes or No."""

# Usage
classifier = ShieldGemmaClassifier(model_size="2b")
result = classifier.classify(
    user_content="How do I learn programming?",
    category="dangerous_content"
)
print(f"Harmful: {result['is_harmful']}, Probability: {result['probability']:.2%}")

Batch Processing for Throughput

import asyncio
from concurrent.futures import ThreadPoolExecutor

class ShieldGemmaBatch:
    """Batch processor for ShieldGemma classification."""

    def __init__(self, model_size: str = "9b", max_workers: int = 4):
        self.classifier = ShieldGemmaClassifier(model_size=model_size)
        self.executor = ThreadPoolExecutor(max_workers=max_workers)

    async def batch_classify(
        self,
        contents: list[str],
        categories: list[str] = None
    ) -> list[Dict]:
        """Classify multiple contents asynchronously."""
        categories = categories or ["dangerous_content"]

        loop = asyncio.get_event_loop()
        tasks = []

        for content in contents:
            for category in categories:
                task = loop.run_in_executor(
                    self.executor,
                    self.classifier.classify,
                    content,
                    category
                )
                tasks.append(task)

        results = await asyncio.gather(*tasks)

        # Group results by content
        grouped = []
        idx = 0
        for content in contents:
            content_results = {}
            for category in categories:
                content_results[category] = results[idx]
                idx += 1
            grouped.append(content_results)

        return grouped

# Usage
async def main():
    batch = ShieldGemmaBatch(model_size="9b")
    contents = [
        "How do I cook pasta?",
        "What's the weather like?"
    ]
    results = await batch.batch_classify(
        contents,
        categories=["dangerous_content", "harassment"]
    )
    return results

Comparison: ShieldGemma vs LlamaGuard 3

Aspect ShieldGemma LlamaGuard 3
Categories 4 (focused) 14 (comprehensive)
Output Probability score Binary + categories
Smallest 2B 1B
Largest 27B 8B
Strength Probability calibration Category granularity

Production Integration Pattern

from enum import Enum
from dataclasses import dataclass

class ClassifierChoice(Enum):
    SHIELDGEMMA = "shieldgemma"
    LLAMAGUARD = "llamaguard"

@dataclass
class UnifiedSafetyResult:
    is_safe: bool
    confidence: float
    categories: list[str]
    classifier_used: ClassifierChoice

def unified_safety_check(
    content: str,
    primary: ClassifierChoice = ClassifierChoice.SHIELDGEMMA,
    fallback: ClassifierChoice = ClassifierChoice.LLAMAGUARD,
    confidence_threshold: float = 0.8
) -> UnifiedSafetyResult:
    """
    Unified safety check with fallback.
    Uses primary classifier, falls back if uncertain.
    """
    if primary == ClassifierChoice.SHIELDGEMMA:
        classifier = ShieldGemmaClassifier(model_size="9b")
        results = classifier.classify_all(content)

        # Check if any category is harmful
        harmful_cats = [
            cat for cat, r in results.items()
            if r["is_harmful"]
        ]
        max_prob = max(r["probability"] for r in results.values())

        # If uncertain, use fallback
        if 0.3 < max_prob < 0.7:
            return unified_safety_check(
                content,
                primary=fallback,
                fallback=None,
                confidence_threshold=confidence_threshold
            )

        return UnifiedSafetyResult(
            is_safe=len(harmful_cats) == 0,
            confidence=1 - max_prob if len(harmful_cats) == 0 else max_prob,
            categories=harmful_cats,
            classifier_used=primary
        )

    # LlamaGuard fallback
    guard = LlamaGuard3Classifier(model_size="8B")
    decision, categories = guard.classify(content)

    return UnifiedSafetyResult(
        is_safe=decision == "safe",
        confidence=0.9 if decision in ["safe", "unsafe"] else 0.5,
        categories=categories,
        classifier_used=ClassifierChoice.LLAMAGUARD
    )

Deployment Tip: Use ShieldGemma 2B for edge deployments or when you need probability scores. Use LlamaGuard 3 when you need granular category information (14 vs 4 categories).

Next: Comparing classifier performance across benchmarks. :::

Quiz

Module 3: Safety Classifiers Deep Dive

Take Quiz