Production RAG Systems
Performance Optimization
4 min read
Production RAG systems need sub-second latency. This lesson covers techniques to optimize every stage of the pipeline.
Latency Breakdown
┌────────────────────────────────────────────────────────────────┐
│ Typical RAG Latency │
├────────────────────────────────────────────────────────────────┤
│ │
│ Query Embedding: 50-150ms ████░░░░░░ │
│ Vector Search: 20-100ms ███░░░░░░░ │
│ Reranking: 100-300ms ██████░░░░ │
│ LLM Generation: 500-2000ms ██████████████████████ │
│ │
│ Total: 670-2550ms │
│ │
│ Target for production: < 1000ms (without streaming) │
│ Target with streaming: < 300ms to first token │
│ │
└────────────────────────────────────────────────────────────────┘
Semantic Caching
Cache responses for semantically similar queries:
import hashlib
from typing import Optional
import numpy as np
class SemanticCache:
"""Cache RAG responses based on semantic similarity."""
def __init__(
self,
embedding_model,
similarity_threshold: float = 0.95,
max_entries: int = 10000,
ttl_seconds: int = 3600,
):
self.embedding_model = embedding_model
self.similarity_threshold = similarity_threshold
self.max_entries = max_entries
self.ttl_seconds = ttl_seconds
# In production, use Redis or similar
self.cache = {} # {query_hash: {embedding, response, timestamp}}
def get(self, query: str) -> Optional[str]:
"""Check cache for semantically similar query."""
query_embedding = self.embedding_model.embed(query)
for cached in self.cache.values():
# Check TTL
if time.time() - cached["timestamp"] > self.ttl_seconds:
continue
# Check semantic similarity
similarity = self._cosine_similarity(
query_embedding,
cached["embedding"]
)
if similarity >= self.similarity_threshold:
return cached["response"]
return None
def set(self, query: str, response: str):
"""Cache a response."""
if len(self.cache) >= self.max_entries:
self._evict_oldest()
query_hash = hashlib.md5(query.encode()).hexdigest()
self.cache[query_hash] = {
"embedding": self.embedding_model.embed(query),
"response": response,
"timestamp": time.time(),
}
def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
def _evict_oldest(self):
oldest = min(self.cache.items(), key=lambda x: x[1]["timestamp"])
del self.cache[oldest[0]]
# Usage
cache = SemanticCache(embedding_model, similarity_threshold=0.95)
def query_rag(question: str) -> str:
# Check cache first
cached = cache.get(question)
if cached:
return cached # ~5ms vs ~1500ms
# Run full pipeline
response = rag_pipeline.query(question)
# Cache for future
cache.set(question, response)
return response
Async and Parallel Processing
import asyncio
from concurrent.futures import ThreadPoolExecutor
class AsyncRAGPipeline:
"""RAG pipeline with async operations."""
def __init__(self, vectorstore, reranker, llm):
self.vectorstore = vectorstore
self.reranker = reranker
self.llm = llm
self.executor = ThreadPoolExecutor(max_workers=4)
async def query(self, question: str) -> str:
# Parallel: embedding + BM25 search
embedding_task = asyncio.create_task(
self._async_embed(question)
)
bm25_task = asyncio.create_task(
self._async_bm25_search(question)
)
# Wait for both retrieval methods
query_embedding, bm25_results = await asyncio.gather(
embedding_task, bm25_task
)
# Vector search with embedding
vector_results = await self._async_vector_search(query_embedding)
# Combine and rerank
combined = self._merge_results(vector_results, bm25_results)
reranked = await self._async_rerank(question, combined)
# Generate response
response = await self._async_generate(question, reranked[:5])
return response
async def _async_embed(self, text: str):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
self.embedding_model.embed,
text
)
async def _async_vector_search(self, embedding):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
lambda: self.vectorstore.search(embedding, k=20)
)
async def _async_rerank(self, query: str, docs: list):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
lambda: self.reranker.rerank(query, docs)
)
async def _async_generate(self, question: str, contexts: list):
# Most LLM clients support async natively
return await self.llm.agenerate(
prompt=self._format_prompt(question, contexts)
)
# Usage
pipeline = AsyncRAGPipeline(vectorstore, reranker, llm)
response = await pipeline.query("How do I reset my password?")
Streaming Responses
from typing import AsyncGenerator
class StreamingRAG:
"""RAG with streaming for faster time-to-first-token."""
async def query_stream(
self,
question: str
) -> AsyncGenerator[str, None]:
"""Stream response tokens as they're generated."""
# Retrieval (can't stream, do it first)
contexts = await self._retrieve(question)
# Stream generation
prompt = self._format_prompt(question, contexts)
async for token in self.llm.astream(prompt):
yield token
async def _retrieve(self, question: str) -> list[str]:
# Fast retrieval path
docs = await self.vectorstore.asearch(question, k=5)
return [doc.page_content for doc in docs]
# FastAPI endpoint
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
app = FastAPI()
rag = StreamingRAG()
@app.get("/query")
async def query_endpoint(question: str):
async def generate():
async for token in rag.query_stream(question):
yield f"data: {token}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
generate(),
media_type="text/event-stream"
)
Batching for Throughput
import asyncio
from dataclasses import dataclass
from typing import List
@dataclass
class BatchRequest:
question: str
future: asyncio.Future
class BatchingRAG:
"""Batch multiple queries for efficient processing."""
def __init__(self, batch_size: int = 8, max_wait_ms: int = 50):
self.batch_size = batch_size
self.max_wait_ms = max_wait_ms
self.queue: List[BatchRequest] = []
self.lock = asyncio.Lock()
self._processing = False
async def query(self, question: str) -> str:
"""Add query to batch and wait for result."""
future = asyncio.Future()
request = BatchRequest(question=question, future=future)
async with self.lock:
self.queue.append(request)
if len(self.queue) >= self.batch_size:
await self._process_batch()
elif not self._processing:
# Start timer for partial batch
asyncio.create_task(self._wait_and_process())
return await future
async def _wait_and_process(self):
"""Wait for more requests or timeout."""
self._processing = True
await asyncio.sleep(self.max_wait_ms / 1000)
async with self.lock:
if self.queue:
await self._process_batch()
self._processing = False
async def _process_batch(self):
"""Process all queued requests in batch."""
batch = self.queue
self.queue = []
# Batch embed all questions
questions = [r.question for r in batch]
embeddings = self.embedding_model.embed_batch(questions)
# Batch vector search
all_results = self.vectorstore.batch_search(embeddings, k=5)
# Generate responses (can batch with some LLMs)
for request, results in zip(batch, all_results):
response = await self._generate(request.question, results)
request.future.set_result(response)
Optimization Checklist
| Optimization | Impact | Effort | When to Use |
|---|---|---|---|
| Semantic caching | High | Medium | Repeated similar queries |
| Async retrieval | Medium | Low | Always |
| Streaming | High UX | Low | User-facing apps |
| Batching | High throughput | Medium | High QPS |
| Quantized embeddings | Medium | Low | Large indexes |
| Smaller context | Medium | Low | When LLM is bottleneck |
Benchmark Your Pipeline
import time
import statistics
def benchmark_rag(pipeline, questions: list, runs: int = 3) -> dict:
"""Benchmark RAG pipeline performance."""
latencies = {
"total": [],
"embedding": [],
"retrieval": [],
"reranking": [],
"generation": [],
}
for _ in range(runs):
for q in questions:
times = {}
start = time.perf_counter()
embedding = pipeline.embed(q)
times["embedding"] = time.perf_counter() - start
start = time.perf_counter()
docs = pipeline.retrieve(embedding)
times["retrieval"] = time.perf_counter() - start
start = time.perf_counter()
reranked = pipeline.rerank(q, docs)
times["reranking"] = time.perf_counter() - start
start = time.perf_counter()
response = pipeline.generate(q, reranked)
times["generation"] = time.perf_counter() - start
times["total"] = sum(times.values())
for key, value in times.items():
latencies[key].append(value * 1000) # Convert to ms
return {
key: {
"mean": statistics.mean(values),
"p50": statistics.median(values),
"p95": sorted(values)[int(len(values) * 0.95)],
"p99": sorted(values)[int(len(values) * 0.99)],
}
for key, values in latencies.items()
}
# Run benchmark
results = benchmark_rag(pipeline, test_questions, runs=3)
print(f"Total P95: {results['total']['p95']:.0f}ms")
print(f"Generation P95: {results['generation']['p95']:.0f}ms")
Key Insight: LLM generation is typically 70-80% of total latency. Streaming doesn't reduce total time but dramatically improves perceived latency by showing progress immediately.
Next, let's implement reliability patterns for production. :::