Lesson 7 of 23

LLM Application Architecture

Caching Strategies

5 min read

Caching is one of the most effective ways to reduce LLM costs and latency. A well-designed cache can cut costs by 40-60% while improving response times.

Types of AI Caching

1. Exact Match Cache

The simplest approach—cache identical queries.

import hashlib
import redis

class ExactMatchCache:
    def __init__(self, redis_client, ttl_seconds=3600):
        self.redis = redis_client
        self.ttl = ttl_seconds

    def _hash_key(self, prompt: str) -> str:
        return hashlib.sha256(prompt.encode()).hexdigest()

    async def get(self, prompt: str) -> Optional[str]:
        key = self._hash_key(prompt)
        cached = await self.redis.get(f"llm_cache:{key}")
        return cached.decode() if cached else None

    async def set(self, prompt: str, response: str):
        key = self._hash_key(prompt)
        await self.redis.setex(f"llm_cache:{key}", self.ttl, response)

# Usage
cache = ExactMatchCache(redis_client)
cached_response = await cache.get(prompt)
if cached_response:
    return cached_response

response = await llm.complete(prompt)
await cache.set(prompt, response)

Hit rate: 20-40% in typical applications.

2. Semantic Cache

Cache similar queries, not just identical ones.

from numpy import dot
from numpy.linalg import norm

class SemanticCache:
    def __init__(self, embedding_model, vector_store, threshold=0.95):
        self.embedder = embedding_model
        self.store = vector_store
        self.threshold = threshold

    async def get(self, query: str) -> Optional[str]:
        query_embedding = await self.embedder.embed(query)

        # Search for similar cached queries
        results = await self.store.search(
            embedding=query_embedding,
            top_k=1
        )

        if results and results[0].score > self.threshold:
            return results[0].metadata["response"]
        return None

    async def set(self, query: str, response: str):
        embedding = await self.embedder.embed(query)
        await self.store.insert(
            embedding=embedding,
            metadata={"query": query, "response": response}
        )

Hit rate: 40-60%, but with trade-offs:

Factor Exact Match Semantic Cache
Hit rate Lower Higher
Latency overhead Minimal Embedding cost
False positives None Possible
Storage cost Lower Higher

3. Contextual Cache

Cache responses based on context similarity, not just query.

class ContextualCache:
    def __init__(self, semantic_cache):
        self.cache = semantic_cache

    async def get(self, query: str, context: str) -> Optional[str]:
        # Combine query and context for cache lookup
        combined = f"Query: {query}\nContext: {context[:500]}"
        return await self.cache.get(combined)

    async def set(self, query: str, context: str, response: str):
        combined = f"Query: {query}\nContext: {context[:500]}"
        await self.cache.set(combined, response)

Cache Invalidation

The hardest problem in computer science—now with AI complexity.

class SmartCacheManager:
    def __init__(self, cache, ttl_config):
        self.cache = cache
        self.ttl_config = ttl_config

    async def set_with_smart_ttl(self, key: str, value: str, query_type: str):
        # Different TTLs based on content type
        ttl = self.ttl_config.get(query_type, 3600)
        await self.cache.set(key, value, ttl=ttl)

    async def invalidate_by_topic(self, topic: str):
        # Invalidate all cached responses related to a topic
        # Useful when underlying data changes
        keys = await self.cache.search_keys(f"*{topic}*")
        for key in keys:
            await self.cache.delete(key)

# TTL Configuration
TTL_CONFIG = {
    "factual": 86400,      # 24 hours - facts don't change often
    "time_sensitive": 300,  # 5 minutes - current events
    "personalized": 3600,   # 1 hour - user-specific
    "creative": 0           # No cache - each response should be unique
}

Multi-Layer Caching

┌─────────────────────────────────────────────┐
│              Request Flow                    │
├─────────────────────────────────────────────┤
│                                             │
│   Query ──▶ L1: In-Memory ──▶ Hit? Return  │
│                   │                         │
│                   ▼ Miss                    │
│            L2: Redis ──▶ Hit? Return       │
│                   │                         │
│                   ▼ Miss                    │
│            L3: Semantic ──▶ Hit? Return    │
│                   │                         │
│                   ▼ Miss                    │
│               LLM Call                      │
│                   │                         │
│                   ▼                         │
│            Update All Layers               │
│                                             │
└─────────────────────────────────────────────┘
class MultiLayerCache:
    def __init__(self, l1_memory, l2_redis, l3_semantic):
        self.layers = [l1_memory, l2_redis, l3_semantic]

    async def get(self, query: str) -> tuple[Optional[str], str]:
        for i, layer in enumerate(self.layers):
            result = await layer.get(query)
            if result:
                # Backfill faster layers
                for faster_layer in self.layers[:i]:
                    await faster_layer.set(query, result)
                return result, f"L{i+1}"
        return None, "MISS"

    async def set(self, query: str, response: str):
        for layer in self.layers:
            await layer.set(query, response)

Cache Metrics

Track these to optimize your cache:

class CacheMetrics:
    def __init__(self):
        self.hits = 0
        self.misses = 0
        self.hit_latency = []
        self.miss_latency = []

    def record_hit(self, latency_ms: float, layer: str):
        self.hits += 1
        self.hit_latency.append(latency_ms)
        # Log to monitoring
        metrics.increment("cache_hit", tags={"layer": layer})

    def record_miss(self, latency_ms: float):
        self.misses += 1
        self.miss_latency.append(latency_ms)
        metrics.increment("cache_miss")

    @property
    def hit_rate(self) -> float:
        total = self.hits + self.misses
        return self.hits / total if total > 0 else 0

Interview Tip

When discussing caching in interviews, always mention:

  1. What you're caching (queries, embeddings, full responses)
  2. How you handle invalidation (TTL, event-driven, manual)
  3. Trade-offs (hit rate vs. freshness, memory vs. latency)

Next, we'll explore cost optimization strategies beyond caching. :::

Quiz

Module 2: LLM Application Architecture

Take Quiz
FREE WEEKLY NEWSLETTER

Stay on the Nerd Track

One email per week — courses, deep dives, tools, and AI experiments.

No spam. Unsubscribe anytime.