RAG System Design
Scaling RAG Systems
4 min read
Scaling RAG from a prototype to production requires careful consideration of indexing, retrieval, and infrastructure. This lesson covers patterns for handling millions of documents.
Scaling Challenges
| Challenge | At 10K docs | At 10M docs |
|---|---|---|
| Index time | Minutes | Days |
| Query latency | < 100ms | Seconds without optimization |
| Storage | GBs | TBs |
| Cost | $10/month | $1000+/month |
Sharding Strategies
By Collection/Namespace
class ShardedVectorStore:
def __init__(self, base_client):
self.client = base_client
self.shards = {}
def get_shard(self, category: str):
"""Route to appropriate shard based on category."""
if category not in self.shards:
self.shards[category] = self.client.collection(f"docs_{category}")
return self.shards[category]
async def insert(self, doc: dict):
category = doc["metadata"]["category"]
shard = self.get_shard(category)
await shard.insert(doc)
async def search(self, query_embedding, category: str = None, top_k: int = 10):
if category:
# Search single shard
shard = self.get_shard(category)
return await shard.search(query_embedding, top_k=top_k)
else:
# Search all shards and merge
all_results = []
for shard in self.shards.values():
results = await shard.search(query_embedding, top_k=top_k)
all_results.extend(results)
# Sort and return top-k across all shards
all_results.sort(key=lambda x: x["score"], reverse=True)
return all_results[:top_k]
By Time Period
class TimeBasedSharding:
def __init__(self, client):
self.client = client
def get_shard_name(self, timestamp: datetime) -> str:
"""Shard by month."""
return f"docs_{timestamp.year}_{timestamp.month:02d}"
async def search_date_range(
self,
query_embedding,
start_date: datetime,
end_date: datetime,
top_k: int = 10
):
# Determine which shards to query
shards = self._get_shards_in_range(start_date, end_date)
# Query relevant shards in parallel
tasks = [
self.client.collection(shard).search(query_embedding, top_k=top_k)
for shard in shards
]
results = await asyncio.gather(*tasks)
# Merge and return
merged = [r for batch in results for r in batch]
merged.sort(key=lambda x: x["score"], reverse=True)
return merged[:top_k]
Index Optimization
Approximate Nearest Neighbor (ANN) Indexes
# pgvector index types
CREATE INDEX ON documents
USING ivfflat (embedding vector_cosine_ops)
WITH (lists = 100); # Good for < 1M vectors
# For larger datasets
CREATE INDEX ON documents
USING hnsw (embedding vector_cosine_ops)
WITH (m = 16, ef_construction = 64); # Better for > 1M vectors
Index comparison:
| Index Type | Build Time | Query Speed | Memory | Best For |
|---|---|---|---|---|
| Flat | Fast | Slow (O(n)) | Low | < 100K |
| IVF | Medium | Fast | Medium | 100K-10M |
| HNSW | Slow | Fastest | High | > 1M |
Incremental Indexing
Don't reindex everything when documents change:
class IncrementalIndexer:
def __init__(self, vector_store, embedding_model):
self.store = vector_store
self.embedder = embedding_model
self.pending_updates = []
self.batch_size = 100
async def add_document(self, doc: dict):
"""Queue document for batch processing."""
embedding = await self.embedder.embed(doc["content"])
self.pending_updates.append({
"id": doc["id"],
"embedding": embedding,
"metadata": doc["metadata"]
})
if len(self.pending_updates) >= self.batch_size:
await self._flush()
async def update_document(self, doc_id: str, new_content: str):
"""Update existing document."""
# Delete old version
await self.store.delete(doc_id)
# Add new version
embedding = await self.embedder.embed(new_content)
await self.store.upsert({
"id": doc_id,
"embedding": embedding
})
async def _flush(self):
"""Batch insert pending updates."""
if self.pending_updates:
await self.store.upsert_batch(self.pending_updates)
self.pending_updates = []
Query Routing
Route queries to optimal retrieval paths:
class SmartQueryRouter:
def __init__(self, retrievers: dict):
self.retrievers = retrievers
# "fast" - small index, cached
# "full" - complete index
# "archive" - historical data
async def route_and_search(self, query: str, metadata_filter: dict = None):
# Determine optimal route
if self._is_recent_query(metadata_filter):
retriever = self.retrievers["fast"]
elif self._needs_archive(metadata_filter):
retriever = self.retrievers["archive"]
else:
retriever = self.retrievers["full"]
return await retriever.search(query, filter=metadata_filter)
def _is_recent_query(self, filter: dict) -> bool:
if filter and "date" in filter:
return filter["date"] > datetime.now() - timedelta(days=30)
return False
Caching for RAG
class RAGCache:
def __init__(self, redis_client, ttl=3600):
self.redis = redis_client
self.ttl = ttl
async def get_or_compute(
self,
query: str,
retriever,
generator
) -> dict:
# Check retrieval cache
cache_key = f"rag:{hash(query)}"
cached = await self.redis.get(cache_key)
if cached:
return json.loads(cached)
# Compute fresh result
retrieved = await retriever.search(query)
response = await generator.generate(query, retrieved)
result = {
"answer": response,
"sources": [r["metadata"] for r in retrieved]
}
# Cache result
await self.redis.setex(cache_key, self.ttl, json.dumps(result))
return result
Performance Benchmarks
Target metrics for production RAG:
| Metric | Prototype | Production | Enterprise |
|---|---|---|---|
| Index size | 10K docs | 1M docs | 100M docs |
| Query latency (p50) | 500ms | 200ms | 100ms |
| Query latency (p99) | 2s | 500ms | 300ms |
| Throughput | 10 QPS | 100 QPS | 1000 QPS |
| Availability | 95% | 99.9% | 99.99% |
Now let's move to multi-agent system design—another crucial interview topic. :::