Lesson 21 of 23

Production RAG Systems

Reliability & Error Handling

4 min read

Production systems fail. This lesson covers patterns to handle failures gracefully and maintain service availability.

Common Failure Points

┌────────────────────────────────────────────────────────────────┐
│                    RAG Failure Points                           │
├────────────────────────────────────────────────────────────────┤
│                                                                 │
│  ┌─────────────┐     ┌─────────────┐     ┌─────────────┐       │
│  │  Embedding  │────▶│   Vector    │────▶│     LLM     │       │
│  │    API      │     │    Store    │     │     API     │       │
│  └─────────────┘     └─────────────┘     └─────────────┘       │
│         │                   │                   │               │
│         ▼                   ▼                   ▼               │
│    • Rate limits      • Connection       • Rate limits         │
│    • Timeouts         • Query timeout    • Timeouts            │
│    • API errors       • Index issues     • Context overflow    │
│    • Model changes    • Capacity         • Content filtering   │
│                                                                 │
└────────────────────────────────────────────────────────────────┘

Retry with Exponential Backoff

import asyncio
import random
from functools import wraps
from typing import Type, Tuple

def retry_with_backoff(
    max_retries: int = 3,
    base_delay: float = 1.0,
    max_delay: float = 60.0,
    exponential_base: float = 2.0,
    jitter: bool = True,
    retryable_exceptions: Tuple[Type[Exception], ...] = (Exception,),
):
    """Decorator for retry with exponential backoff."""

    def decorator(func):
        @wraps(func)
        async def wrapper(*args, **kwargs):
            last_exception = None

            for attempt in range(max_retries + 1):
                try:
                    return await func(*args, **kwargs)

                except retryable_exceptions as e:
                    last_exception = e

                    if attempt == max_retries:
                        raise

                    # Calculate delay
                    delay = min(
                        base_delay * (exponential_base ** attempt),
                        max_delay
                    )

                    # Add jitter to prevent thundering herd
                    if jitter:
                        delay = delay * (0.5 + random.random())

                    print(f"Attempt {attempt + 1} failed: {e}. Retrying in {delay:.1f}s")
                    await asyncio.sleep(delay)

            raise last_exception

        return wrapper
    return decorator

# Usage
class RAGPipeline:

    @retry_with_backoff(
        max_retries=3,
        retryable_exceptions=(RateLimitError, TimeoutError, ConnectionError)
    )
    async def embed(self, text: str) -> list[float]:
        return await self.embedding_client.embed(text)

    @retry_with_backoff(
        max_retries=2,
        base_delay=2.0,
        retryable_exceptions=(RateLimitError, TimeoutError)
    )
    async def generate(self, prompt: str) -> str:
        return await self.llm_client.generate(prompt)

Circuit Breaker Pattern

import time
from enum import Enum
from dataclasses import dataclass

class CircuitState(Enum):
    CLOSED = "closed"      # Normal operation
    OPEN = "open"          # Failing, reject requests
    HALF_OPEN = "half_open"  # Testing if recovered

@dataclass
class CircuitBreaker:
    """Circuit breaker for external service calls."""

    failure_threshold: int = 5
    recovery_timeout: float = 30.0
    half_open_max_calls: int = 3

    def __post_init__(self):
        self.state = CircuitState.CLOSED
        self.failure_count = 0
        self.last_failure_time = 0
        self.half_open_calls = 0

    def can_execute(self) -> bool:
        """Check if request should proceed."""

        if self.state == CircuitState.CLOSED:
            return True

        if self.state == CircuitState.OPEN:
            # Check if recovery timeout passed
            if time.time() - self.last_failure_time >= self.recovery_timeout:
                self.state = CircuitState.HALF_OPEN
                self.half_open_calls = 0
                return True
            return False

        if self.state == CircuitState.HALF_OPEN:
            return self.half_open_calls < self.half_open_max_calls

        return False

    def record_success(self):
        """Record successful call."""
        if self.state == CircuitState.HALF_OPEN:
            self.half_open_calls += 1
            if self.half_open_calls >= self.half_open_max_calls:
                # Recovered
                self.state = CircuitState.CLOSED
                self.failure_count = 0

        elif self.state == CircuitState.CLOSED:
            self.failure_count = 0

    def record_failure(self):
        """Record failed call."""
        self.failure_count += 1
        self.last_failure_time = time.time()

        if self.state == CircuitState.HALF_OPEN:
            # Failed during recovery test
            self.state = CircuitState.OPEN

        elif self.failure_count >= self.failure_threshold:
            self.state = CircuitState.OPEN

# Usage
class ResilientRAG:

    def __init__(self):
        self.embedding_breaker = CircuitBreaker(failure_threshold=5)
        self.llm_breaker = CircuitBreaker(failure_threshold=3)

    async def query(self, question: str) -> str:
        # Check circuit breakers
        if not self.embedding_breaker.can_execute():
            return self._fallback_response("Embedding service unavailable")

        if not self.llm_breaker.can_execute():
            return self._fallback_response("LLM service unavailable")

        try:
            embedding = await self.embed(question)
            self.embedding_breaker.record_success()

            contexts = await self.retrieve(embedding)

            response = await self.generate(question, contexts)
            self.llm_breaker.record_success()

            return response

        except EmbeddingError as e:
            self.embedding_breaker.record_failure()
            return self._fallback_response(str(e))

        except LLMError as e:
            self.llm_breaker.record_failure()
            return self._fallback_response(str(e))

Fallback Strategies

from typing import Optional, Callable

class FallbackRAG:
    """RAG with multiple fallback strategies."""

    def __init__(
        self,
        primary_llm,
        fallback_llm,
        cached_responses: dict = None,
    ):
        self.primary_llm = primary_llm
        self.fallback_llm = fallback_llm
        self.cached_responses = cached_responses or {}

    async def query(self, question: str) -> dict:
        """Query with fallback chain."""

        # Try primary path
        try:
            contexts = await self._retrieve(question)
            response = await self._generate_with_primary(question, contexts)

            return {
                "answer": response,
                "source": "primary",
                "contexts": contexts,
            }

        except PrimaryLLMError:
            pass  # Fall through to fallback

        # Fallback 1: Use backup LLM
        try:
            response = await self._generate_with_fallback(question, contexts)

            return {
                "answer": response,
                "source": "fallback_llm",
                "contexts": contexts,
            }

        except FallbackLLMError:
            pass  # Fall through to cache

        # Fallback 2: Check semantic cache for similar answers
        cached = self._find_similar_cached(question)
        if cached:
            return {
                "answer": cached["answer"],
                "source": "cache",
                "contexts": cached.get("contexts", []),
                "warning": "Response from cache - may not be current",
            }

        # Fallback 3: Return helpful error
        return {
            "answer": self._graceful_error_message(question),
            "source": "error",
            "contexts": [],
        }

    def _graceful_error_message(self, question: str) -> str:
        """Generate helpful error message."""
        return (
            "I'm currently unable to process your question due to a "
            "temporary service issue. Please try again in a few moments. "
            "If you need immediate assistance, please contact support."
        )

    async def _generate_with_primary(self, question: str, contexts: list) -> str:
        """Generate with primary LLM (e.g., GPT-4)."""
        return await self.primary_llm.generate(
            self._format_prompt(question, contexts),
            timeout=30,
        )

    async def _generate_with_fallback(self, question: str, contexts: list) -> str:
        """Generate with fallback LLM (e.g., GPT-3.5 or local model)."""
        return await self.fallback_llm.generate(
            self._format_prompt(question, contexts),
            timeout=60,  # More generous timeout for fallback
        )

Graceful Degradation

class DegradableRAG:
    """RAG that degrades gracefully under load."""

    def __init__(self):
        self.load_level = "normal"  # normal, high, critical

    async def query(self, question: str, priority: str = "normal") -> dict:
        """Query with load-aware degradation."""

        # Check current load level
        self.load_level = await self._check_load()

        if self.load_level == "critical":
            # Only process high-priority requests
            if priority != "high":
                return self._queue_response(question)

            # Minimal processing
            return await self._minimal_query(question)

        if self.load_level == "high":
            # Skip reranking, use fewer contexts
            return await self._reduced_query(question)

        # Normal processing
        return await self._full_query(question)

    async def _full_query(self, question: str) -> dict:
        """Full pipeline with all features."""
        embedding = await self.embed(question)
        docs = await self.retrieve(embedding, k=20)
        reranked = await self.rerank(question, docs)
        response = await self.generate(question, reranked[:5])

        return {"answer": response, "mode": "full"}

    async def _reduced_query(self, question: str) -> dict:
        """Reduced pipeline - skip reranking."""
        embedding = await self.embed(question)
        docs = await self.retrieve(embedding, k=5)  # Fewer docs
        response = await self.generate(question, docs)

        return {"answer": response, "mode": "reduced"}

    async def _minimal_query(self, question: str) -> dict:
        """Minimal pipeline - semantic cache or simple retrieval."""
        # Try cache first
        cached = self.cache.get(question)
        if cached:
            return {"answer": cached, "mode": "cached"}

        # Simple keyword search + fast model
        docs = await self.keyword_search(question, k=3)
        response = await self.fast_llm.generate(question, docs)

        return {"answer": response, "mode": "minimal"}

    def _queue_response(self, question: str) -> dict:
        """Queue request for later processing."""
        ticket_id = self.queue.add(question)

        return {
            "answer": f"System is under heavy load. Your request has been queued. Ticket: {ticket_id}",
            "mode": "queued",
            "ticket_id": ticket_id,
        }

Health Checks

from fastapi import FastAPI, Response
from datetime import datetime

app = FastAPI()

class HealthChecker:
    """Health check for RAG components."""

    async def check_all(self) -> dict:
        """Run all health checks."""

        checks = {
            "embedding_api": await self._check_embedding(),
            "vector_store": await self._check_vectorstore(),
            "llm_api": await self._check_llm(),
            "cache": await self._check_cache(),
        }

        overall = all(c["healthy"] for c in checks.values())

        return {
            "healthy": overall,
            "timestamp": datetime.utcnow().isoformat(),
            "checks": checks,
        }

    async def _check_embedding(self) -> dict:
        try:
            start = time.time()
            await self.embedding_client.embed("health check")
            latency = (time.time() - start) * 1000

            return {"healthy": True, "latency_ms": latency}
        except Exception as e:
            return {"healthy": False, "error": str(e)}

    async def _check_vectorstore(self) -> dict:
        try:
            start = time.time()
            await self.vectorstore.search([0.0] * 1536, k=1)
            latency = (time.time() - start) * 1000

            return {"healthy": True, "latency_ms": latency}
        except Exception as e:
            return {"healthy": False, "error": str(e)}

    async def _check_llm(self) -> dict:
        try:
            start = time.time()
            await self.llm.generate("Say 'ok'", max_tokens=5)
            latency = (time.time() - start) * 1000

            return {"healthy": True, "latency_ms": latency}
        except Exception as e:
            return {"healthy": False, "error": str(e)}

health_checker = HealthChecker()

@app.get("/health")
async def health_endpoint():
    result = await health_checker.check_all()
    status_code = 200 if result["healthy"] else 503
    return Response(content=json.dumps(result), status_code=status_code)

Key Insight: Design for failure from the start. Users prefer a degraded response with transparency ("I'm using cached data") over a cryptic error or timeout.

Next, let's cover monitoring and cost management. :::

Quiz

Module 6: Production RAG Systems

Take Quiz