Safety Classifiers Deep Dive

LlamaGuard 3 Implementation

4 min read

LlamaGuard 3 is Meta's latest safety classifier built on Llama 3.1, designed specifically for content moderation in LLM applications. This lesson covers practical implementation patterns for both the 1B and 8B variants.

LlamaGuard 3 Model Variants

Model Parameters Latency Memory Use Case
LlamaGuard 3 1B 1B 50-100ms 2GB High-throughput filtering
LlamaGuard 3 8B 8B 200-400ms 16GB High-accuracy decisions

Safety Taxonomy (14 Categories)

LlamaGuard 3 classifies content across these harm categories:

LLAMAGUARD_TAXONOMY = {
    "S1": "Violent Crimes",
    "S2": "Non-Violent Crimes",
    "S3": "Sex-Related Crimes",
    "S4": "Child Sexual Exploitation",
    "S5": "Defamation",
    "S6": "Specialized Advice",
    "S7": "Privacy",
    "S8": "Intellectual Property",
    "S9": "Indiscriminate Weapons",
    "S10": "Hate",
    "S11": "Suicide & Self-Harm",
    "S12": "Sexual Content",
    "S13": "Elections",
    "S14": "Code Interpreter Abuse"
}

Basic Implementation with Transformers

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from typing import Tuple, List

class LlamaGuard3Classifier:
    """Production LlamaGuard 3 safety classifier."""

    def __init__(self, model_size: str = "1B"):
        model_id = f"meta-llama/Llama-Guard-3-{model_size}"
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )

    def classify(
        self,
        user_message: str,
        assistant_response: str = None
    ) -> Tuple[str, List[str]]:
        """
        Classify content for safety violations.

        Returns:
            Tuple of (decision, violated_categories)
            decision: "safe" or "unsafe"
        """
        # Build conversation format
        if assistant_response:
            conversation = [
                {"role": "user", "content": user_message},
                {"role": "assistant", "content": assistant_response}
            ]
        else:
            conversation = [
                {"role": "user", "content": user_message}
            ]

        # Apply chat template
        input_ids = self.tokenizer.apply_chat_template(
            conversation,
            return_tensors="pt"
        ).to(self.device)

        # Generate classification
        with torch.no_grad():
            output = self.model.generate(
                input_ids,
                max_new_tokens=100,
                pad_token_id=self.tokenizer.eos_token_id
            )

        # Decode response
        response = self.tokenizer.decode(
            output[0][len(input_ids[0]):],
            skip_special_tokens=True
        ).strip()

        return self._parse_response(response)

    def _parse_response(self, response: str) -> Tuple[str, List[str]]:
        """Parse LlamaGuard output format."""
        lines = response.strip().split('\n')
        decision = lines[0].lower()

        violated_categories = []
        if decision == "unsafe" and len(lines) > 1:
            # Extract category codes (S1, S2, etc.)
            for line in lines[1:]:
                categories = [c.strip() for c in line.split(',')]
                violated_categories.extend(categories)

        return decision, violated_categories

# Usage
classifier = LlamaGuard3Classifier(model_size="1B")
decision, categories = classifier.classify(
    user_message="How do I make a simple website?"
)
print(f"Decision: {decision}")  # "safe"

Optimized vLLM Deployment

For production throughput, use vLLM for batched inference:

from vllm import LLM, SamplingParams

class LlamaGuardVLLM:
    """High-throughput LlamaGuard using vLLM."""

    def __init__(self, model_size: str = "8B"):
        self.llm = LLM(
            model=f"meta-llama/Llama-Guard-3-{model_size}",
            dtype="bfloat16",
            max_model_len=4096,
            gpu_memory_utilization=0.9
        )
        self.sampling_params = SamplingParams(
            max_tokens=100,
            temperature=0
        )

    def batch_classify(
        self,
        messages: list[dict]
    ) -> list[tuple[str, list[str]]]:
        """Classify multiple messages in batch."""
        prompts = [self._format_prompt(m) for m in messages]

        outputs = self.llm.generate(prompts, self.sampling_params)

        results = []
        for output in outputs:
            response = output.outputs[0].text
            results.append(self._parse_response(response))

        return results

    def _format_prompt(self, message: dict) -> str:
        """Format message for LlamaGuard."""
        return f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Task: Check if there is unsafe content in 'User' messages.

<BEGIN CONVERSATION>

User: {message['content']}

<END CONVERSATION>

Provide your safety assessment.<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""

    def _parse_response(self, response: str) -> tuple[str, list[str]]:
        lines = response.strip().split('\n')
        decision = lines[0].lower()
        categories = []
        if decision == "unsafe" and len(lines) > 1:
            categories = [c.strip() for c in lines[1].split(',')]
        return decision, categories

# Batch processing
guard = LlamaGuardVLLM(model_size="8B")
messages = [
    {"content": "How do I learn Python?"},
    {"content": "What's the weather today?"}
]
results = guard.batch_classify(messages)

Integration with Guardrails Pipeline

from dataclasses import dataclass
from enum import Enum
from typing import Optional

class SafetyDecision(Enum):
    SAFE = "safe"
    UNSAFE = "unsafe"
    ERROR = "error"

@dataclass
class SafetyResult:
    decision: SafetyDecision
    categories: list[str]
    confidence: Optional[float] = None
    latency_ms: float = 0

async def llamaguard_filter(
    user_input: str,
    assistant_response: str = None,
    model_size: str = "1B"
) -> SafetyResult:
    """Production LlamaGuard filter layer."""
    import time

    start = time.time()
    classifier = LlamaGuard3Classifier(model_size=model_size)

    try:
        decision, categories = classifier.classify(
            user_message=user_input,
            assistant_response=assistant_response
        )

        latency = (time.time() - start) * 1000

        return SafetyResult(
            decision=SafetyDecision.SAFE if decision == "safe" else SafetyDecision.UNSAFE,
            categories=categories,
            latency_ms=latency
        )
    except Exception as e:
        return SafetyResult(
            decision=SafetyDecision.ERROR,
            categories=[],
            latency_ms=(time.time() - start) * 1000
        )

Production Tip: Use LlamaGuard 3 1B for first-pass filtering (50-100ms), escalating to 8B only for uncertain cases. This tiered approach handles 80%+ of requests with the faster model.

Next: Deploying ShieldGemma for alternative safety classification. :::

Quiz

Module 3: Safety Classifiers Deep Dive

Take Quiz