Production RAG Systems
Reliability & Error Handling
4 min read
Production systems fail. This lesson covers patterns to handle failures gracefully and maintain service availability.
Common Failure Points
┌────────────────────────────────────────────────────────────────┐
│ RAG Failure Points │
├────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ Embedding │────▶│ Vector │────▶│ LLM │ │
│ │ API │ │ Store │ │ API │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ • Rate limits • Connection • Rate limits │
│ • Timeouts • Query timeout • Timeouts │
│ • API errors • Index issues • Context overflow │
│ • Model changes • Capacity • Content filtering │
│ │
└────────────────────────────────────────────────────────────────┘
Retry with Exponential Backoff
import asyncio
import random
from functools import wraps
from typing import Type, Tuple
def retry_with_backoff(
max_retries: int = 3,
base_delay: float = 1.0,
max_delay: float = 60.0,
exponential_base: float = 2.0,
jitter: bool = True,
retryable_exceptions: Tuple[Type[Exception], ...] = (Exception,),
):
"""Decorator for retry with exponential backoff."""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
last_exception = None
for attempt in range(max_retries + 1):
try:
return await func(*args, **kwargs)
except retryable_exceptions as e:
last_exception = e
if attempt == max_retries:
raise
# Calculate delay
delay = min(
base_delay * (exponential_base ** attempt),
max_delay
)
# Add jitter to prevent thundering herd
if jitter:
delay = delay * (0.5 + random.random())
print(f"Attempt {attempt + 1} failed: {e}. Retrying in {delay:.1f}s")
await asyncio.sleep(delay)
raise last_exception
return wrapper
return decorator
# Usage
class RAGPipeline:
@retry_with_backoff(
max_retries=3,
retryable_exceptions=(RateLimitError, TimeoutError, ConnectionError)
)
async def embed(self, text: str) -> list[float]:
return await self.embedding_client.embed(text)
@retry_with_backoff(
max_retries=2,
base_delay=2.0,
retryable_exceptions=(RateLimitError, TimeoutError)
)
async def generate(self, prompt: str) -> str:
return await self.llm_client.generate(prompt)
Circuit Breaker Pattern
import time
from enum import Enum
from dataclasses import dataclass
class CircuitState(Enum):
CLOSED = "closed" # Normal operation
OPEN = "open" # Failing, reject requests
HALF_OPEN = "half_open" # Testing if recovered
@dataclass
class CircuitBreaker:
"""Circuit breaker for external service calls."""
failure_threshold: int = 5
recovery_timeout: float = 30.0
half_open_max_calls: int = 3
def __post_init__(self):
self.state = CircuitState.CLOSED
self.failure_count = 0
self.last_failure_time = 0
self.half_open_calls = 0
def can_execute(self) -> bool:
"""Check if request should proceed."""
if self.state == CircuitState.CLOSED:
return True
if self.state == CircuitState.OPEN:
# Check if recovery timeout passed
if time.time() - self.last_failure_time >= self.recovery_timeout:
self.state = CircuitState.HALF_OPEN
self.half_open_calls = 0
return True
return False
if self.state == CircuitState.HALF_OPEN:
return self.half_open_calls < self.half_open_max_calls
return False
def record_success(self):
"""Record successful call."""
if self.state == CircuitState.HALF_OPEN:
self.half_open_calls += 1
if self.half_open_calls >= self.half_open_max_calls:
# Recovered
self.state = CircuitState.CLOSED
self.failure_count = 0
elif self.state == CircuitState.CLOSED:
self.failure_count = 0
def record_failure(self):
"""Record failed call."""
self.failure_count += 1
self.last_failure_time = time.time()
if self.state == CircuitState.HALF_OPEN:
# Failed during recovery test
self.state = CircuitState.OPEN
elif self.failure_count >= self.failure_threshold:
self.state = CircuitState.OPEN
# Usage
class ResilientRAG:
def __init__(self):
self.embedding_breaker = CircuitBreaker(failure_threshold=5)
self.llm_breaker = CircuitBreaker(failure_threshold=3)
async def query(self, question: str) -> str:
# Check circuit breakers
if not self.embedding_breaker.can_execute():
return self._fallback_response("Embedding service unavailable")
if not self.llm_breaker.can_execute():
return self._fallback_response("LLM service unavailable")
try:
embedding = await self.embed(question)
self.embedding_breaker.record_success()
contexts = await self.retrieve(embedding)
response = await self.generate(question, contexts)
self.llm_breaker.record_success()
return response
except EmbeddingError as e:
self.embedding_breaker.record_failure()
return self._fallback_response(str(e))
except LLMError as e:
self.llm_breaker.record_failure()
return self._fallback_response(str(e))
Fallback Strategies
from typing import Optional, Callable
class FallbackRAG:
"""RAG with multiple fallback strategies."""
def __init__(
self,
primary_llm,
fallback_llm,
cached_responses: dict = None,
):
self.primary_llm = primary_llm
self.fallback_llm = fallback_llm
self.cached_responses = cached_responses or {}
async def query(self, question: str) -> dict:
"""Query with fallback chain."""
# Try primary path
try:
contexts = await self._retrieve(question)
response = await self._generate_with_primary(question, contexts)
return {
"answer": response,
"source": "primary",
"contexts": contexts,
}
except PrimaryLLMError:
pass # Fall through to fallback
# Fallback 1: Use backup LLM
try:
response = await self._generate_with_fallback(question, contexts)
return {
"answer": response,
"source": "fallback_llm",
"contexts": contexts,
}
except FallbackLLMError:
pass # Fall through to cache
# Fallback 2: Check semantic cache for similar answers
cached = self._find_similar_cached(question)
if cached:
return {
"answer": cached["answer"],
"source": "cache",
"contexts": cached.get("contexts", []),
"warning": "Response from cache - may not be current",
}
# Fallback 3: Return helpful error
return {
"answer": self._graceful_error_message(question),
"source": "error",
"contexts": [],
}
def _graceful_error_message(self, question: str) -> str:
"""Generate helpful error message."""
return (
"I'm currently unable to process your question due to a "
"temporary service issue. Please try again in a few moments. "
"If you need immediate assistance, please contact support."
)
async def _generate_with_primary(self, question: str, contexts: list) -> str:
"""Generate with primary LLM (e.g., GPT-4)."""
return await self.primary_llm.generate(
self._format_prompt(question, contexts),
timeout=30,
)
async def _generate_with_fallback(self, question: str, contexts: list) -> str:
"""Generate with fallback LLM (e.g., GPT-3.5 or local model)."""
return await self.fallback_llm.generate(
self._format_prompt(question, contexts),
timeout=60, # More generous timeout for fallback
)
Graceful Degradation
class DegradableRAG:
"""RAG that degrades gracefully under load."""
def __init__(self):
self.load_level = "normal" # normal, high, critical
async def query(self, question: str, priority: str = "normal") -> dict:
"""Query with load-aware degradation."""
# Check current load level
self.load_level = await self._check_load()
if self.load_level == "critical":
# Only process high-priority requests
if priority != "high":
return self._queue_response(question)
# Minimal processing
return await self._minimal_query(question)
if self.load_level == "high":
# Skip reranking, use fewer contexts
return await self._reduced_query(question)
# Normal processing
return await self._full_query(question)
async def _full_query(self, question: str) -> dict:
"""Full pipeline with all features."""
embedding = await self.embed(question)
docs = await self.retrieve(embedding, k=20)
reranked = await self.rerank(question, docs)
response = await self.generate(question, reranked[:5])
return {"answer": response, "mode": "full"}
async def _reduced_query(self, question: str) -> dict:
"""Reduced pipeline - skip reranking."""
embedding = await self.embed(question)
docs = await self.retrieve(embedding, k=5) # Fewer docs
response = await self.generate(question, docs)
return {"answer": response, "mode": "reduced"}
async def _minimal_query(self, question: str) -> dict:
"""Minimal pipeline - semantic cache or simple retrieval."""
# Try cache first
cached = self.cache.get(question)
if cached:
return {"answer": cached, "mode": "cached"}
# Simple keyword search + fast model
docs = await self.keyword_search(question, k=3)
response = await self.fast_llm.generate(question, docs)
return {"answer": response, "mode": "minimal"}
def _queue_response(self, question: str) -> dict:
"""Queue request for later processing."""
ticket_id = self.queue.add(question)
return {
"answer": f"System is under heavy load. Your request has been queued. Ticket: {ticket_id}",
"mode": "queued",
"ticket_id": ticket_id,
}
Health Checks
from fastapi import FastAPI, Response
from datetime import datetime
app = FastAPI()
class HealthChecker:
"""Health check for RAG components."""
async def check_all(self) -> dict:
"""Run all health checks."""
checks = {
"embedding_api": await self._check_embedding(),
"vector_store": await self._check_vectorstore(),
"llm_api": await self._check_llm(),
"cache": await self._check_cache(),
}
overall = all(c["healthy"] for c in checks.values())
return {
"healthy": overall,
"timestamp": datetime.utcnow().isoformat(),
"checks": checks,
}
async def _check_embedding(self) -> dict:
try:
start = time.time()
await self.embedding_client.embed("health check")
latency = (time.time() - start) * 1000
return {"healthy": True, "latency_ms": latency}
except Exception as e:
return {"healthy": False, "error": str(e)}
async def _check_vectorstore(self) -> dict:
try:
start = time.time()
await self.vectorstore.search([0.0] * 1536, k=1)
latency = (time.time() - start) * 1000
return {"healthy": True, "latency_ms": latency}
except Exception as e:
return {"healthy": False, "error": str(e)}
async def _check_llm(self) -> dict:
try:
start = time.time()
await self.llm.generate("Say 'ok'", max_tokens=5)
latency = (time.time() - start) * 1000
return {"healthy": True, "latency_ms": latency}
except Exception as e:
return {"healthy": False, "error": str(e)}
health_checker = HealthChecker()
@app.get("/health")
async def health_endpoint():
result = await health_checker.check_all()
status_code = 200 if result["healthy"] else 503
return Response(content=json.dumps(result), status_code=status_code)
Key Insight: Design for failure from the start. Users prefer a degraded response with transparency ("I'm using cached data") over a cryptic error or timeout.
Next, let's cover monitoring and cost management. :::