Lesson 12 of 23

RAG System Design

Scaling RAG Systems

4 min read

Scaling RAG from a prototype to production requires careful consideration of indexing, retrieval, and infrastructure. This lesson covers patterns for handling millions of documents.

Scaling Challenges

Challenge At 10K docs At 10M docs
Index time Minutes Days
Query latency < 100ms Seconds without optimization
Storage GBs TBs
Cost $10/month $1000+/month

Sharding Strategies

By Collection/Namespace

class ShardedVectorStore:
    def __init__(self, base_client):
        self.client = base_client
        self.shards = {}

    def get_shard(self, category: str):
        """Route to appropriate shard based on category."""
        if category not in self.shards:
            self.shards[category] = self.client.collection(f"docs_{category}")
        return self.shards[category]

    async def insert(self, doc: dict):
        category = doc["metadata"]["category"]
        shard = self.get_shard(category)
        await shard.insert(doc)

    async def search(self, query_embedding, category: str = None, top_k: int = 10):
        if category:
            # Search single shard
            shard = self.get_shard(category)
            return await shard.search(query_embedding, top_k=top_k)
        else:
            # Search all shards and merge
            all_results = []
            for shard in self.shards.values():
                results = await shard.search(query_embedding, top_k=top_k)
                all_results.extend(results)

            # Sort and return top-k across all shards
            all_results.sort(key=lambda x: x["score"], reverse=True)
            return all_results[:top_k]

By Time Period

class TimeBasedSharding:
    def __init__(self, client):
        self.client = client

    def get_shard_name(self, timestamp: datetime) -> str:
        """Shard by month."""
        return f"docs_{timestamp.year}_{timestamp.month:02d}"

    async def search_date_range(
        self,
        query_embedding,
        start_date: datetime,
        end_date: datetime,
        top_k: int = 10
    ):
        # Determine which shards to query
        shards = self._get_shards_in_range(start_date, end_date)

        # Query relevant shards in parallel
        tasks = [
            self.client.collection(shard).search(query_embedding, top_k=top_k)
            for shard in shards
        ]
        results = await asyncio.gather(*tasks)

        # Merge and return
        merged = [r for batch in results for r in batch]
        merged.sort(key=lambda x: x["score"], reverse=True)
        return merged[:top_k]

Index Optimization

Approximate Nearest Neighbor (ANN) Indexes

# pgvector index types
CREATE INDEX ON documents
USING ivfflat (embedding vector_cosine_ops)
WITH (lists = 100);  # Good for < 1M vectors

# For larger datasets
CREATE INDEX ON documents
USING hnsw (embedding vector_cosine_ops)
WITH (m = 16, ef_construction = 64);  # Better for > 1M vectors

Index comparison:

Index Type Build Time Query Speed Memory Best For
Flat Fast Slow (O(n)) Low < 100K
IVF Medium Fast Medium 100K-10M
HNSW Slow Fastest High > 1M

Incremental Indexing

Don't reindex everything when documents change:

class IncrementalIndexer:
    def __init__(self, vector_store, embedding_model):
        self.store = vector_store
        self.embedder = embedding_model
        self.pending_updates = []
        self.batch_size = 100

    async def add_document(self, doc: dict):
        """Queue document for batch processing."""
        embedding = await self.embedder.embed(doc["content"])
        self.pending_updates.append({
            "id": doc["id"],
            "embedding": embedding,
            "metadata": doc["metadata"]
        })

        if len(self.pending_updates) >= self.batch_size:
            await self._flush()

    async def update_document(self, doc_id: str, new_content: str):
        """Update existing document."""
        # Delete old version
        await self.store.delete(doc_id)

        # Add new version
        embedding = await self.embedder.embed(new_content)
        await self.store.upsert({
            "id": doc_id,
            "embedding": embedding
        })

    async def _flush(self):
        """Batch insert pending updates."""
        if self.pending_updates:
            await self.store.upsert_batch(self.pending_updates)
            self.pending_updates = []

Query Routing

Route queries to optimal retrieval paths:

class SmartQueryRouter:
    def __init__(self, retrievers: dict):
        self.retrievers = retrievers
        # "fast" - small index, cached
        # "full" - complete index
        # "archive" - historical data

    async def route_and_search(self, query: str, metadata_filter: dict = None):
        # Determine optimal route
        if self._is_recent_query(metadata_filter):
            retriever = self.retrievers["fast"]
        elif self._needs_archive(metadata_filter):
            retriever = self.retrievers["archive"]
        else:
            retriever = self.retrievers["full"]

        return await retriever.search(query, filter=metadata_filter)

    def _is_recent_query(self, filter: dict) -> bool:
        if filter and "date" in filter:
            return filter["date"] > datetime.now() - timedelta(days=30)
        return False

Caching for RAG

class RAGCache:
    def __init__(self, redis_client, ttl=3600):
        self.redis = redis_client
        self.ttl = ttl

    async def get_or_compute(
        self,
        query: str,
        retriever,
        generator
    ) -> dict:
        # Check retrieval cache
        cache_key = f"rag:{hash(query)}"
        cached = await self.redis.get(cache_key)

        if cached:
            return json.loads(cached)

        # Compute fresh result
        retrieved = await retriever.search(query)
        response = await generator.generate(query, retrieved)

        result = {
            "answer": response,
            "sources": [r["metadata"] for r in retrieved]
        }

        # Cache result
        await self.redis.setex(cache_key, self.ttl, json.dumps(result))

        return result

Performance Benchmarks

Target metrics for production RAG:

Metric Prototype Production Enterprise
Index size 10K docs 1M docs 100M docs
Query latency (p50) 500ms 200ms 100ms
Query latency (p99) 2s 500ms 300ms
Throughput 10 QPS 100 QPS 1000 QPS
Availability 95% 99.9% 99.99%

Now let's move to multi-agent system design—another crucial interview topic. :::

Quiz

Module 3: RAG System Design

Take Quiz