Safety Classifiers Deep Dive
ShieldGemma Deployment
3 min read
ShieldGemma is Google's safety classifier family built on Gemma 2, offering competitive performance with LlamaGuard while providing different trade-offs in terms of size and capabilities.
ShieldGemma Model Family
| Model | Parameters | Latency (GPU) | Benchmark Score | Best For |
|---|---|---|---|---|
| ShieldGemma 2B | 2B | 30-60ms | 0.814 | Edge/mobile |
| ShieldGemma 9B | 9B | 80-150ms | 0.843 | Balanced |
| ShieldGemma 27B | 27B | 200-400ms | 0.876 | Maximum accuracy |
Benchmark Note: ShieldGemma 27B achieves +10.8% improvement over LlamaGuard 2 on safety classification benchmarks (Google, 2024).
Safety Categories
ShieldGemma classifies content into four main harm categories:
SHIELDGEMMA_CATEGORIES = {
"dangerous_content": "Content promoting dangerous activities",
"harassment": "Bullying, threats, or targeted abuse",
"hate_speech": "Content promoting hatred based on identity",
"sexually_explicit": "Explicit sexual content"
}
Basic Implementation
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from typing import Dict
class ShieldGemmaClassifier:
"""ShieldGemma safety classifier for production use."""
def __init__(self, model_size: str = "2b"):
# Available: 2b, 9b, 27b
model_id = f"google/shieldgemma-{model_size}"
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto"
)
# ShieldGemma probability tokens
self.yes_token = self.tokenizer.encode("Yes", add_special_tokens=False)[0]
self.no_token = self.tokenizer.encode("No", add_special_tokens=False)[0]
def classify(
self,
user_content: str,
category: str = "dangerous_content"
) -> Dict[str, float]:
"""
Classify content for a specific harm category.
Returns:
Dict with 'probability' (0-1) and 'is_harmful' (bool)
"""
prompt = self._build_prompt(user_content, category)
inputs = self.tokenizer(
prompt,
return_tensors="pt"
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits[0, -1, :]
# Get probabilities for Yes/No tokens
yes_prob = torch.softmax(
logits[[self.yes_token, self.no_token]], dim=0
)[0].item()
return {
"probability": yes_prob,
"is_harmful": yes_prob > 0.5,
"category": category
}
def classify_all(self, user_content: str) -> Dict[str, Dict]:
"""Check all harm categories."""
results = {}
for category in SHIELDGEMMA_CATEGORIES:
results[category] = self.classify(user_content, category)
return results
def _build_prompt(self, content: str, category: str) -> str:
"""Build ShieldGemma prompt format."""
category_desc = SHIELDGEMMA_CATEGORIES.get(category, category)
return f"""You are a policy expert determining if the following content
violates the policy on {category_desc}.
User content: {content}
Does this content violate the {category} policy? Answer Yes or No."""
# Usage
classifier = ShieldGemmaClassifier(model_size="2b")
result = classifier.classify(
user_content="How do I learn programming?",
category="dangerous_content"
)
print(f"Harmful: {result['is_harmful']}, Probability: {result['probability']:.2%}")
Batch Processing for Throughput
import asyncio
from concurrent.futures import ThreadPoolExecutor
class ShieldGemmaBatch:
"""Batch processor for ShieldGemma classification."""
def __init__(self, model_size: str = "9b", max_workers: int = 4):
self.classifier = ShieldGemmaClassifier(model_size=model_size)
self.executor = ThreadPoolExecutor(max_workers=max_workers)
async def batch_classify(
self,
contents: list[str],
categories: list[str] = None
) -> list[Dict]:
"""Classify multiple contents asynchronously."""
categories = categories or ["dangerous_content"]
loop = asyncio.get_event_loop()
tasks = []
for content in contents:
for category in categories:
task = loop.run_in_executor(
self.executor,
self.classifier.classify,
content,
category
)
tasks.append(task)
results = await asyncio.gather(*tasks)
# Group results by content
grouped = []
idx = 0
for content in contents:
content_results = {}
for category in categories:
content_results[category] = results[idx]
idx += 1
grouped.append(content_results)
return grouped
# Usage
async def main():
batch = ShieldGemmaBatch(model_size="9b")
contents = [
"How do I cook pasta?",
"What's the weather like?"
]
results = await batch.batch_classify(
contents,
categories=["dangerous_content", "harassment"]
)
return results
Comparison: ShieldGemma vs LlamaGuard 3
| Aspect | ShieldGemma | LlamaGuard 3 |
|---|---|---|
| Categories | 4 (focused) | 14 (comprehensive) |
| Output | Probability score | Binary + categories |
| Smallest | 2B | 1B |
| Largest | 27B | 8B |
| Strength | Probability calibration | Category granularity |
Production Integration Pattern
from enum import Enum
from dataclasses import dataclass
class ClassifierChoice(Enum):
SHIELDGEMMA = "shieldgemma"
LLAMAGUARD = "llamaguard"
@dataclass
class UnifiedSafetyResult:
is_safe: bool
confidence: float
categories: list[str]
classifier_used: ClassifierChoice
def unified_safety_check(
content: str,
primary: ClassifierChoice = ClassifierChoice.SHIELDGEMMA,
fallback: ClassifierChoice = ClassifierChoice.LLAMAGUARD,
confidence_threshold: float = 0.8
) -> UnifiedSafetyResult:
"""
Unified safety check with fallback.
Uses primary classifier, falls back if uncertain.
"""
if primary == ClassifierChoice.SHIELDGEMMA:
classifier = ShieldGemmaClassifier(model_size="9b")
results = classifier.classify_all(content)
# Check if any category is harmful
harmful_cats = [
cat for cat, r in results.items()
if r["is_harmful"]
]
max_prob = max(r["probability"] for r in results.values())
# If uncertain, use fallback
if 0.3 < max_prob < 0.7:
return unified_safety_check(
content,
primary=fallback,
fallback=None,
confidence_threshold=confidence_threshold
)
return UnifiedSafetyResult(
is_safe=len(harmful_cats) == 0,
confidence=1 - max_prob if len(harmful_cats) == 0 else max_prob,
categories=harmful_cats,
classifier_used=primary
)
# LlamaGuard fallback
guard = LlamaGuard3Classifier(model_size="8B")
decision, categories = guard.classify(content)
return UnifiedSafetyResult(
is_safe=decision == "safe",
confidence=0.9 if decision in ["safe", "unsafe"] else 0.5,
categories=categories,
classifier_used=ClassifierChoice.LLAMAGUARD
)
Deployment Tip: Use ShieldGemma 2B for edge deployments or when you need probability scores. Use LlamaGuard 3 when you need granular category information (14 vs 4 categories).
Next: Comparing classifier performance across benchmarks. :::