Safety Classifiers Deep Dive
Custom Taxonomy Adaptation
3 min read
Pre-trained safety classifiers use generic taxonomies that may not fit your specific domain. This lesson covers adapting and extending safety categories for domain-specific requirements.
When to Customize
Customize your taxonomy when:
- Industry-specific content: Financial advice, medical information, legal guidance
- Cultural context: Regional content norms, language-specific nuances
- Business rules: Brand safety, competitive mentions, topic restrictions
- Use case: Customer support vs. creative writing vs. code generation
Extending LlamaGuard Taxonomy
from typing import Dict, List
import json
# Original LlamaGuard 3 taxonomy
LLAMAGUARD_BASE = {
"S1": "Violent Crimes",
"S2": "Non-Violent Crimes",
# ... S3-S14
}
# Custom extensions for financial services
FINANCIAL_TAXONOMY = {
**LLAMAGUARD_BASE,
"F1": "Investment Advice",
"F2": "Tax Guidance",
"F3": "Insurance Recommendations",
"F4": "Regulatory Violations",
"F5": "Competitor Mentions"
}
class CustomTaxonomyClassifier:
"""Classifier with custom taxonomy extensions."""
def __init__(self, base_classifier, custom_rules: Dict[str, callable]):
self.base = base_classifier
self.custom_rules = custom_rules
self.taxonomy = FINANCIAL_TAXONOMY
def classify(self, content: str) -> Dict:
"""Classify with base + custom rules."""
# Run base classification
base_decision, base_categories = self.base.classify(content)
# Run custom rule checks
custom_violations = []
for code, rule_fn in self.custom_rules.items():
if rule_fn(content):
custom_violations.append(code)
# Combine results
all_categories = base_categories + custom_violations
final_decision = "unsafe" if all_categories else "safe"
return {
"decision": final_decision,
"categories": all_categories,
"base_result": base_decision,
"custom_violations": custom_violations
}
# Define custom rules
def check_investment_advice(text: str) -> bool:
"""Check for unauthorized investment advice."""
investment_patterns = [
r"you should (buy|sell|invest)",
r"guaranteed returns",
r"risk-free investment",
r"I recommend (buying|selling)"
]
import re
return any(re.search(p, text.lower()) for p in investment_patterns)
def check_competitor_mention(text: str) -> bool:
"""Check for competitor brand mentions."""
competitors = ["competitor_a", "competitor_b", "other_bank"]
return any(comp in text.lower() for comp in competitors)
# Usage
custom_rules = {
"F1": check_investment_advice,
"F5": check_competitor_mention
}
classifier = CustomTaxonomyClassifier(
base_classifier=LlamaGuard3Classifier(model_size="1B"),
custom_rules=custom_rules
)
Fine-tuning Safety Classifiers
For persistent custom categories, fine-tune a classifier:
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
TrainingArguments,
Trainer
)
from datasets import Dataset
import torch
class SafetyClassifierFineTuner:
"""Fine-tune safety classifiers for custom taxonomy."""
def __init__(self, base_model: str = "martin-ha/toxic-comment-model"):
self.tokenizer = AutoTokenizer.from_pretrained(base_model)
self.model = AutoModelForSequenceClassification.from_pretrained(
base_model,
num_labels=2 # safe/unsafe
)
def prepare_dataset(
self,
examples: List[Dict],
custom_category: str
) -> Dataset:
"""
Prepare training data for custom category.
examples format:
[{"text": "...", "label": 0/1}, ...]
"""
def tokenize(batch):
return self.tokenizer(
batch["text"],
truncation=True,
max_length=512,
padding="max_length"
)
dataset = Dataset.from_list(examples)
dataset = dataset.map(tokenize, batched=True)
return dataset
def fine_tune(
self,
train_dataset: Dataset,
eval_dataset: Dataset = None,
output_dir: str = "./custom_classifier",
epochs: int = 3
):
"""Fine-tune the classifier."""
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=epochs,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
warmup_steps=100,
weight_decay=0.01,
logging_steps=50,
evaluation_strategy="epoch" if eval_dataset else "no",
save_strategy="epoch",
load_best_model_at_end=True if eval_dataset else False,
)
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
trainer.save_model(output_dir)
self.tokenizer.save_pretrained(output_dir)
return trainer
# Example: Fine-tune for financial advice detection
training_data = [
{"text": "You should definitely buy AAPL stock", "label": 1},
{"text": "Here's how stock markets generally work", "label": 0},
{"text": "I guarantee 20% returns on this investment", "label": 1},
{"text": "Investments carry risk of loss", "label": 0},
# ... more examples
]
finetuner = SafetyClassifierFineTuner()
train_ds = finetuner.prepare_dataset(training_data, "investment_advice")
finetuner.fine_tune(train_ds, output_dir="./investment_advice_classifier")
Prompt-based Taxonomy with LLMs
For flexible taxonomy without fine-tuning:
from openai import OpenAI
class PromptBasedClassifier:
"""LLM-based classifier with dynamic taxonomy."""
def __init__(self, api_key: str):
self.client = OpenAI(api_key=api_key)
self.taxonomy = {}
def add_category(
self,
code: str,
name: str,
description: str,
examples: List[str] = None
):
"""Add a custom category to the taxonomy."""
self.taxonomy[code] = {
"name": name,
"description": description,
"examples": examples or []
}
def classify(self, content: str) -> Dict:
"""Classify content against custom taxonomy."""
taxonomy_prompt = self._build_taxonomy_prompt()
response = self.client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{
"role": "system",
"content": f"""You are a content safety classifier.
Classify the user content against these categories:
{taxonomy_prompt}
Respond in JSON format:
{{"decision": "safe" or "unsafe", "categories": ["code1", "code2"], "reasoning": "..."}}"""
},
{"role": "user", "content": content}
],
response_format={"type": "json_object"}
)
return json.loads(response.choices[0].message.content)
def _build_taxonomy_prompt(self) -> str:
"""Build taxonomy description for prompt."""
lines = []
for code, info in self.taxonomy.items():
line = f"- {code}: {info['name']}\n Description: {info['description']}"
if info['examples']:
line += f"\n Examples: {', '.join(info['examples'][:3])}"
lines.append(line)
return "\n".join(lines)
# Usage for healthcare chatbot
classifier = PromptBasedClassifier(api_key="...")
classifier.add_category(
code="H1",
name="Medical Diagnosis",
description="Content that provides specific medical diagnoses or suggests conditions",
examples=["You have diabetes", "This sounds like cancer"]
)
classifier.add_category(
code="H2",
name="Medication Advice",
description="Recommendations for specific medications or dosages",
examples=["Take 500mg ibuprofen", "You should try this medication"]
)
classifier.add_category(
code="H3",
name="Emergency Redirection",
description="Situations requiring immediate medical attention",
examples=["chest pain", "difficulty breathing", "severe bleeding"]
)
result = classifier.classify("I think you might have appendicitis based on your symptoms")
# Returns: {"decision": "unsafe", "categories": ["H1"], "reasoning": "..."}
Ensemble Approach for Production
from dataclasses import dataclass
from typing import List, Dict, Optional
from enum import Enum
class ViolationType(Enum):
BASE_SAFETY = "base_safety"
CUSTOM_RULE = "custom_rule"
DOMAIN_SPECIFIC = "domain_specific"
@dataclass
class EnsembleResult:
is_safe: bool
violations: List[Dict]
confidence: float
class EnsembleClassifier:
"""Production ensemble with multiple taxonomy layers."""
def __init__(self):
self.classifiers = []
def add_classifier(
self,
name: str,
classifier: callable,
priority: int = 1,
veto_power: bool = False
):
"""Add a classifier to the ensemble."""
self.classifiers.append({
"name": name,
"fn": classifier,
"priority": priority,
"veto_power": veto_power
})
self.classifiers.sort(key=lambda x: x["priority"], reverse=True)
def classify(self, content: str) -> EnsembleResult:
"""Run all classifiers and aggregate results."""
all_violations = []
veto_triggered = False
for clf in self.classifiers:
result = clf["fn"](content)
if result.get("violations"):
for v in result["violations"]:
all_violations.append({
"source": clf["name"],
"category": v,
"priority": clf["priority"]
})
if clf["veto_power"]:
veto_triggered = True
# Calculate confidence based on agreement
if not all_violations:
confidence = 0.95
elif veto_triggered:
confidence = 0.99
else:
confidence = min(0.5 + (len(all_violations) * 0.15), 0.95)
return EnsembleResult(
is_safe=len(all_violations) == 0,
violations=all_violations,
confidence=confidence
)
# Production setup
ensemble = EnsembleClassifier()
# Base safety (highest priority, veto power)
ensemble.add_classifier(
name="llamaguard",
classifier=lambda x: {"violations": LlamaGuard3Classifier().classify(x)[1]},
priority=10,
veto_power=True
)
# Domain rules
ensemble.add_classifier(
name="financial_rules",
classifier=lambda x: {
"violations": ["F1"] if check_investment_advice(x) else []
},
priority=5
)
# Brand safety
ensemble.add_classifier(
name="brand_safety",
classifier=lambda x: {
"violations": ["B1"] if check_competitor_mention(x) else []
},
priority=3
)
Best Practice: Start with a base safety classifier (LlamaGuard/ShieldGemma) and layer domain-specific rules on top. This ensures you don't miss generic safety issues while catching domain-specific violations.
Next: Deep dive into NVIDIA NeMo Guardrails framework. :::