Safety Classifiers Deep Dive

Custom Taxonomy Adaptation

3 min read

Pre-trained safety classifiers use generic taxonomies that may not fit your specific domain. This lesson covers adapting and extending safety categories for domain-specific requirements.

When to Customize

Customize your taxonomy when:

  • Industry-specific content: Financial advice, medical information, legal guidance
  • Cultural context: Regional content norms, language-specific nuances
  • Business rules: Brand safety, competitive mentions, topic restrictions
  • Use case: Customer support vs. creative writing vs. code generation

Extending LlamaGuard Taxonomy

from typing import Dict, List
import json

# Original LlamaGuard 3 taxonomy
LLAMAGUARD_BASE = {
    "S1": "Violent Crimes",
    "S2": "Non-Violent Crimes",
    # ... S3-S14
}

# Custom extensions for financial services
FINANCIAL_TAXONOMY = {
    **LLAMAGUARD_BASE,
    "F1": "Investment Advice",
    "F2": "Tax Guidance",
    "F3": "Insurance Recommendations",
    "F4": "Regulatory Violations",
    "F5": "Competitor Mentions"
}

class CustomTaxonomyClassifier:
    """Classifier with custom taxonomy extensions."""

    def __init__(self, base_classifier, custom_rules: Dict[str, callable]):
        self.base = base_classifier
        self.custom_rules = custom_rules
        self.taxonomy = FINANCIAL_TAXONOMY

    def classify(self, content: str) -> Dict:
        """Classify with base + custom rules."""
        # Run base classification
        base_decision, base_categories = self.base.classify(content)

        # Run custom rule checks
        custom_violations = []
        for code, rule_fn in self.custom_rules.items():
            if rule_fn(content):
                custom_violations.append(code)

        # Combine results
        all_categories = base_categories + custom_violations
        final_decision = "unsafe" if all_categories else "safe"

        return {
            "decision": final_decision,
            "categories": all_categories,
            "base_result": base_decision,
            "custom_violations": custom_violations
        }

# Define custom rules
def check_investment_advice(text: str) -> bool:
    """Check for unauthorized investment advice."""
    investment_patterns = [
        r"you should (buy|sell|invest)",
        r"guaranteed returns",
        r"risk-free investment",
        r"I recommend (buying|selling)"
    ]
    import re
    return any(re.search(p, text.lower()) for p in investment_patterns)

def check_competitor_mention(text: str) -> bool:
    """Check for competitor brand mentions."""
    competitors = ["competitor_a", "competitor_b", "other_bank"]
    return any(comp in text.lower() for comp in competitors)

# Usage
custom_rules = {
    "F1": check_investment_advice,
    "F5": check_competitor_mention
}

classifier = CustomTaxonomyClassifier(
    base_classifier=LlamaGuard3Classifier(model_size="1B"),
    custom_rules=custom_rules
)

Fine-tuning Safety Classifiers

For persistent custom categories, fine-tune a classifier:

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer
)
from datasets import Dataset
import torch

class SafetyClassifierFineTuner:
    """Fine-tune safety classifiers for custom taxonomy."""

    def __init__(self, base_model: str = "martin-ha/toxic-comment-model"):
        self.tokenizer = AutoTokenizer.from_pretrained(base_model)
        self.model = AutoModelForSequenceClassification.from_pretrained(
            base_model,
            num_labels=2  # safe/unsafe
        )

    def prepare_dataset(
        self,
        examples: List[Dict],
        custom_category: str
    ) -> Dataset:
        """
        Prepare training data for custom category.

        examples format:
        [{"text": "...", "label": 0/1}, ...]
        """
        def tokenize(batch):
            return self.tokenizer(
                batch["text"],
                truncation=True,
                max_length=512,
                padding="max_length"
            )

        dataset = Dataset.from_list(examples)
        dataset = dataset.map(tokenize, batched=True)
        return dataset

    def fine_tune(
        self,
        train_dataset: Dataset,
        eval_dataset: Dataset = None,
        output_dir: str = "./custom_classifier",
        epochs: int = 3
    ):
        """Fine-tune the classifier."""
        training_args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=epochs,
            per_device_train_batch_size=16,
            per_device_eval_batch_size=16,
            warmup_steps=100,
            weight_decay=0.01,
            logging_steps=50,
            evaluation_strategy="epoch" if eval_dataset else "no",
            save_strategy="epoch",
            load_best_model_at_end=True if eval_dataset else False,
        )

        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
        )

        trainer.train()
        trainer.save_model(output_dir)
        self.tokenizer.save_pretrained(output_dir)

        return trainer

# Example: Fine-tune for financial advice detection
training_data = [
    {"text": "You should definitely buy AAPL stock", "label": 1},
    {"text": "Here's how stock markets generally work", "label": 0},
    {"text": "I guarantee 20% returns on this investment", "label": 1},
    {"text": "Investments carry risk of loss", "label": 0},
    # ... more examples
]

finetuner = SafetyClassifierFineTuner()
train_ds = finetuner.prepare_dataset(training_data, "investment_advice")
finetuner.fine_tune(train_ds, output_dir="./investment_advice_classifier")

Prompt-based Taxonomy with LLMs

For flexible taxonomy without fine-tuning:

from openai import OpenAI

class PromptBasedClassifier:
    """LLM-based classifier with dynamic taxonomy."""

    def __init__(self, api_key: str):
        self.client = OpenAI(api_key=api_key)
        self.taxonomy = {}

    def add_category(
        self,
        code: str,
        name: str,
        description: str,
        examples: List[str] = None
    ):
        """Add a custom category to the taxonomy."""
        self.taxonomy[code] = {
            "name": name,
            "description": description,
            "examples": examples or []
        }

    def classify(self, content: str) -> Dict:
        """Classify content against custom taxonomy."""
        taxonomy_prompt = self._build_taxonomy_prompt()

        response = self.client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {
                    "role": "system",
                    "content": f"""You are a content safety classifier.

Classify the user content against these categories:

{taxonomy_prompt}

Respond in JSON format:
{{"decision": "safe" or "unsafe", "categories": ["code1", "code2"], "reasoning": "..."}}"""
                },
                {"role": "user", "content": content}
            ],
            response_format={"type": "json_object"}
        )

        return json.loads(response.choices[0].message.content)

    def _build_taxonomy_prompt(self) -> str:
        """Build taxonomy description for prompt."""
        lines = []
        for code, info in self.taxonomy.items():
            line = f"- {code}: {info['name']}\n  Description: {info['description']}"
            if info['examples']:
                line += f"\n  Examples: {', '.join(info['examples'][:3])}"
            lines.append(line)
        return "\n".join(lines)

# Usage for healthcare chatbot
classifier = PromptBasedClassifier(api_key="...")

classifier.add_category(
    code="H1",
    name="Medical Diagnosis",
    description="Content that provides specific medical diagnoses or suggests conditions",
    examples=["You have diabetes", "This sounds like cancer"]
)

classifier.add_category(
    code="H2",
    name="Medication Advice",
    description="Recommendations for specific medications or dosages",
    examples=["Take 500mg ibuprofen", "You should try this medication"]
)

classifier.add_category(
    code="H3",
    name="Emergency Redirection",
    description="Situations requiring immediate medical attention",
    examples=["chest pain", "difficulty breathing", "severe bleeding"]
)

result = classifier.classify("I think you might have appendicitis based on your symptoms")
# Returns: {"decision": "unsafe", "categories": ["H1"], "reasoning": "..."}

Ensemble Approach for Production

from dataclasses import dataclass
from typing import List, Dict, Optional
from enum import Enum

class ViolationType(Enum):
    BASE_SAFETY = "base_safety"
    CUSTOM_RULE = "custom_rule"
    DOMAIN_SPECIFIC = "domain_specific"

@dataclass
class EnsembleResult:
    is_safe: bool
    violations: List[Dict]
    confidence: float

class EnsembleClassifier:
    """Production ensemble with multiple taxonomy layers."""

    def __init__(self):
        self.classifiers = []

    def add_classifier(
        self,
        name: str,
        classifier: callable,
        priority: int = 1,
        veto_power: bool = False
    ):
        """Add a classifier to the ensemble."""
        self.classifiers.append({
            "name": name,
            "fn": classifier,
            "priority": priority,
            "veto_power": veto_power
        })
        self.classifiers.sort(key=lambda x: x["priority"], reverse=True)

    def classify(self, content: str) -> EnsembleResult:
        """Run all classifiers and aggregate results."""
        all_violations = []
        veto_triggered = False

        for clf in self.classifiers:
            result = clf["fn"](content)

            if result.get("violations"):
                for v in result["violations"]:
                    all_violations.append({
                        "source": clf["name"],
                        "category": v,
                        "priority": clf["priority"]
                    })

                if clf["veto_power"]:
                    veto_triggered = True

        # Calculate confidence based on agreement
        if not all_violations:
            confidence = 0.95
        elif veto_triggered:
            confidence = 0.99
        else:
            confidence = min(0.5 + (len(all_violations) * 0.15), 0.95)

        return EnsembleResult(
            is_safe=len(all_violations) == 0,
            violations=all_violations,
            confidence=confidence
        )

# Production setup
ensemble = EnsembleClassifier()

# Base safety (highest priority, veto power)
ensemble.add_classifier(
    name="llamaguard",
    classifier=lambda x: {"violations": LlamaGuard3Classifier().classify(x)[1]},
    priority=10,
    veto_power=True
)

# Domain rules
ensemble.add_classifier(
    name="financial_rules",
    classifier=lambda x: {
        "violations": ["F1"] if check_investment_advice(x) else []
    },
    priority=5
)

# Brand safety
ensemble.add_classifier(
    name="brand_safety",
    classifier=lambda x: {
        "violations": ["B1"] if check_competitor_mention(x) else []
    },
    priority=3
)

Best Practice: Start with a base safety classifier (LlamaGuard/ShieldGemma) and layer domain-specific rules on top. This ensures you don't miss generic safety issues while catching domain-specific violations.

Next: Deep dive into NVIDIA NeMo Guardrails framework. :::

Quiz

Module 3: Safety Classifiers Deep Dive

Take Quiz