Safety Classifiers Deep Dive
LlamaGuard 3 Implementation
4 min read
LlamaGuard 3 is Meta's latest safety classifier built on Llama 3.1, designed specifically for content moderation in LLM applications. This lesson covers practical implementation patterns for both the 1B and 8B variants.
LlamaGuard 3 Model Variants
| Model | Parameters | Latency | Memory | Use Case |
|---|---|---|---|---|
| LlamaGuard 3 1B | 1B | 50-100ms | 2GB | High-throughput filtering |
| LlamaGuard 3 8B | 8B | 200-400ms | 16GB | High-accuracy decisions |
Safety Taxonomy (14 Categories)
LlamaGuard 3 classifies content across these harm categories:
LLAMAGUARD_TAXONOMY = {
"S1": "Violent Crimes",
"S2": "Non-Violent Crimes",
"S3": "Sex-Related Crimes",
"S4": "Child Sexual Exploitation",
"S5": "Defamation",
"S6": "Specialized Advice",
"S7": "Privacy",
"S8": "Intellectual Property",
"S9": "Indiscriminate Weapons",
"S10": "Hate",
"S11": "Suicide & Self-Harm",
"S12": "Sexual Content",
"S13": "Elections",
"S14": "Code Interpreter Abuse"
}
Basic Implementation with Transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from typing import Tuple, List
class LlamaGuard3Classifier:
"""Production LlamaGuard 3 safety classifier."""
def __init__(self, model_size: str = "1B"):
model_id = f"meta-llama/Llama-Guard-3-{model_size}"
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto"
)
def classify(
self,
user_message: str,
assistant_response: str = None
) -> Tuple[str, List[str]]:
"""
Classify content for safety violations.
Returns:
Tuple of (decision, violated_categories)
decision: "safe" or "unsafe"
"""
# Build conversation format
if assistant_response:
conversation = [
{"role": "user", "content": user_message},
{"role": "assistant", "content": assistant_response}
]
else:
conversation = [
{"role": "user", "content": user_message}
]
# Apply chat template
input_ids = self.tokenizer.apply_chat_template(
conversation,
return_tensors="pt"
).to(self.device)
# Generate classification
with torch.no_grad():
output = self.model.generate(
input_ids,
max_new_tokens=100,
pad_token_id=self.tokenizer.eos_token_id
)
# Decode response
response = self.tokenizer.decode(
output[0][len(input_ids[0]):],
skip_special_tokens=True
).strip()
return self._parse_response(response)
def _parse_response(self, response: str) -> Tuple[str, List[str]]:
"""Parse LlamaGuard output format."""
lines = response.strip().split('\n')
decision = lines[0].lower()
violated_categories = []
if decision == "unsafe" and len(lines) > 1:
# Extract category codes (S1, S2, etc.)
for line in lines[1:]:
categories = [c.strip() for c in line.split(',')]
violated_categories.extend(categories)
return decision, violated_categories
# Usage
classifier = LlamaGuard3Classifier(model_size="1B")
decision, categories = classifier.classify(
user_message="How do I make a simple website?"
)
print(f"Decision: {decision}") # "safe"
Optimized vLLM Deployment
For production throughput, use vLLM for batched inference:
from vllm import LLM, SamplingParams
class LlamaGuardVLLM:
"""High-throughput LlamaGuard using vLLM."""
def __init__(self, model_size: str = "8B"):
self.llm = LLM(
model=f"meta-llama/Llama-Guard-3-{model_size}",
dtype="bfloat16",
max_model_len=4096,
gpu_memory_utilization=0.9
)
self.sampling_params = SamplingParams(
max_tokens=100,
temperature=0
)
def batch_classify(
self,
messages: list[dict]
) -> list[tuple[str, list[str]]]:
"""Classify multiple messages in batch."""
prompts = [self._format_prompt(m) for m in messages]
outputs = self.llm.generate(prompts, self.sampling_params)
results = []
for output in outputs:
response = output.outputs[0].text
results.append(self._parse_response(response))
return results
def _format_prompt(self, message: dict) -> str:
"""Format message for LlamaGuard."""
return f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
Task: Check if there is unsafe content in 'User' messages.
<BEGIN CONVERSATION>
User: {message['content']}
<END CONVERSATION>
Provide your safety assessment.<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
def _parse_response(self, response: str) -> tuple[str, list[str]]:
lines = response.strip().split('\n')
decision = lines[0].lower()
categories = []
if decision == "unsafe" and len(lines) > 1:
categories = [c.strip() for c in lines[1].split(',')]
return decision, categories
# Batch processing
guard = LlamaGuardVLLM(model_size="8B")
messages = [
{"content": "How do I learn Python?"},
{"content": "What's the weather today?"}
]
results = guard.batch_classify(messages)
Integration with Guardrails Pipeline
from dataclasses import dataclass
from enum import Enum
from typing import Optional
class SafetyDecision(Enum):
SAFE = "safe"
UNSAFE = "unsafe"
ERROR = "error"
@dataclass
class SafetyResult:
decision: SafetyDecision
categories: list[str]
confidence: Optional[float] = None
latency_ms: float = 0
async def llamaguard_filter(
user_input: str,
assistant_response: str = None,
model_size: str = "1B"
) -> SafetyResult:
"""Production LlamaGuard filter layer."""
import time
start = time.time()
classifier = LlamaGuard3Classifier(model_size=model_size)
try:
decision, categories = classifier.classify(
user_message=user_input,
assistant_response=assistant_response
)
latency = (time.time() - start) * 1000
return SafetyResult(
decision=SafetyDecision.SAFE if decision == "safe" else SafetyDecision.UNSAFE,
categories=categories,
latency_ms=latency
)
except Exception as e:
return SafetyResult(
decision=SafetyDecision.ERROR,
categories=[],
latency_ms=(time.time() - start) * 1000
)
Production Tip: Use LlamaGuard 3 1B for first-pass filtering (50-100ms), escalating to 8B only for uncertain cases. This tiered approach handles 80%+ of requests with the faster model.
Next: Deploying ShieldGemma for alternative safety classification. :::