Files

Module 17: Memoization/KV Caching - Inference Optimization

Time: 2-3 hours Difficulty: ☆ (Advanced)

🎯 What You'll Build

Implement KV caching - the critical optimization that makes production LLM inference economically viable. Transform O(n²) naive generation into O(n) optimized generation through computational reuse.

📋 Prerequisites

Required Modules:

  • Module 01-14 (Foundation through Profiling)
  • Module 12 (Multi-Head Attention) - What we'll optimize
  • Module 13 (Transformer) - Architecture we'll accelerate
  • Module 14 (Profiling) - How we measure speedup

Before Starting:

# Verify transformer implementation works
pytest modules/13_transformer/test_transformer.py

# Verify profiling tools work
pytest modules/14_profiling/test_profiling.py

🧠 Core Concept

The Problem: O(n²) Generation

When generating text token-by-token, naive transformers recompute ALL previous key-value pairs at EVERY step:

Step 1: Generate "Hello"  → Compute K₁, V₁             (1 computation)
Step 2: Generate "world"  → Compute K₁, V₁, K₂, V₂     (2 computations, K₁,V₁ WASTED!)
Step 3: Generate "!"      → Compute K₁, V₁, K₂, V₂, K₃, V₃  (3 computations, K₁,V₁,K₂,V₂ WASTED!)

Total: 1 + 2 + 3 + ... + n = O(n²) complexity!

For 100 tokens: 5,050 redundant computations! 😱

The Solution: Cache & Reuse

Key insight: K and V for previous tokens NEVER change!

Step 1: Compute K₁, V₁ → CACHE them
Step 2: Compute K₂, V₂ → Append to cache, retrieve [K₁,V₁,K₂,V₂]
Step 3: Compute K₃, V₃ → Append to cache, retrieve [K₁,V₁,K₂,V₂,K₃,V₃]

Total: 1 + 1 + 1 + ... + 1 = O(n) complexity!

Result: 10-15× speedup for typical generation! 🚀

🏗️ What You'll Implement

1. KVCache Class

class KVCache:
    """Efficient storage for key-value pairs across transformer layers."""

    def __init__(self, batch_size, max_seq_len, num_layers, num_heads, head_dim):
        # Pre-allocate cache tensors for all layers
        pass

    def update(self, layer_idx, key, value):
        # O(1) append new K,V to cache (no copying!)
        pass

    def get(self, layer_idx):
        # O(1) retrieve cached K,V for attention
        pass

2. Non-Invasive Integration

def enable_kv_cache(model):
    """Add caching to existing transformer WITHOUT modifying Module 12/13!"""
    # Create cache sized for model
    # Wrap attention layers with caching logic
    # Return cache for manual control
    pass

3. Performance Analysis

  • Measure speedup: O(n²) → O(n) transformation
  • Analyze memory trade-off: 2× memory enables 10× speed
  • Profile scaling: Longer generation = better ROI

📊 Focus: Memory-Compute Trade-offs

This module teaches THE fundamental systems trade-off:

WITHOUT Cache:
Memory:  O(1)      (no storage)
Compute: O(n²)     (recompute everything)
Speed:   ~40 tok/s (slow!)

WITH Cache:
Memory:  O(n)      (store all K,V pairs)
Compute: O(n)      (compute new K,V only)
Speed:   ~500 tok/s (10-15× faster!)

Trade-off Winner: Memory is cheap, compute is expensive! Accept O(n) memory for O(n²)→O(n) speedup.

🚀 Production Technique for Real LLM Inference

This isn't a toy optimization - it's THE technique that makes production serving possible:

Real-World Impact

ChatGPT, Claude, GPT-4, LLaMA: ALL use KV caching

  • Without caching: 100-token response = ~17 seconds
  • With caching: 100-token response = ~0.1 seconds

Production Systems:

  • vLLM (Serving framework): KV cache is the core optimization
  • llama.cpp (Inference engine): Implements KV caching for efficiency
  • HuggingFace Transformers: use_cache=True in generation

Memory Requirements

GPT-2 (12 layers, 12 heads, seq_len=1024, head_dim=64):
Cache size = 12 × 12 × 1024 × 64 × 2 (K+V) × 4 bytes (float32)
          = ~37 MB per sequence

GPT-3 (96 layers, 96 heads, seq_len=2048, head_dim=128):
Cache size = 96 × 96 × 2048 × 128 × 2 × 4 bytes
          = ~4.7 GB per sequence

Trade-off: <1% of model memory enables 10× speedup!

🎓 Learning Outcomes

By completing this module, you will:

  1. Understand memoization as a general optimization pattern (cache results, avoid recomputation)
  2. Implement KVCache with efficient O(1) updates and O(n) memory scaling
  3. Build cache-aware attention that reuses previously computed keys and values
  4. Measure dramatic speedup gains (10-15×) through systems profiling
  5. Analyze memory-compute trade-offs in production inference systems
  6. Learn non-invasive optimization - add capabilities without breaking old code

🔗 Connections to Other Modules

Builds On:

  • Module 12 (Attention): What we're optimizing
  • Module 13 (Transformer): Architecture we're accelerating
  • Module 14 (Profiling): How we validate speedup

Enables:

  • Module 18 (Acceleration): Combine caching with parallelization
  • Milestone 05 (Chatbot): Real-time generation with caching

Systems Pattern:

Module 05 (Autograd):     enable_autograd()  → Add gradients to Tensors
Module 17 (Memoization):  enable_kv_cache()  → Add caching to Attention
                          ↓
        Critical Pattern: ENHANCE, don't MODIFY existing code!

📈 Expected Performance

┌─────────────┬────────────┬─────────────┬──────────┐
│ Seq Length  │ No Cache   │ With Cache  │ Speedup  │
├─────────────┼────────────┼─────────────┼──────────┤
│  10 tokens  │  ~80 tok/s │  ~600 tok/s │   7.5×   │
│  25 tokens  │  ~40 tok/s │  ~500 tok/s │  12.5×   │
│  50 tokens  │  ~25 tok/s │  ~400 tok/s │  16.0×   │
│ 100 tokens  │  ~12 tok/s │  ~200 tok/s │  16.7×   │
└─────────────┴────────────┴─────────────┴──────────┘

Key Insight: Speedup INCREASES with sequence length!
Why? Longer sequences = more redundant computation without cache.

🧪 Testing Strategy

  1. Unit Tests: Test KVCache in isolation (storage, retrieval, memory tracking)
  2. Integration Tests: Test cache with mock transformer models
  3. Performance Tests: Measure O(n²)→O(n) speedup via profiling
  4. Systems Analysis: Analyze memory usage and scaling behavior

💡 Key Insights You'll Discover

  1. Recomputation is Expensive: O(n²) growth makes naive generation impractical
  2. Memory is Cheap: Spending O(n) memory saves O(n²) compute
  3. Scaling Matters: 100-token generation = 170× fewer operations with cache!
  4. Production Critical: This single optimization enables ChatGPT-scale inference
  5. Non-Invasive Design: Best optimizations ADD capabilities, don't BREAK old code

🎯 Success Criteria

  • KVCache correctly stores and retrieves K,V pairs for all layers
  • Cache updates are O(1) (no data copying)
  • Memory usage matches theoretical predictions
  • enable_kv_cache() works without modifying Module 12/13
  • All unit tests pass
  • Integration test validates complete workflow
  • Performance analysis shows 10-15× speedup

🚀 Next Steps

After completing this module:

  1. Try it yourself: Run chatbot milestone with/without caching

    python milestones/05_2017_transformer/vaswani_chatgpt.py --use-cache
    
  2. Experiment: Profile speedup on different sequence lengths

  3. Compare: Measure memory overhead vs model parameters

  4. Move forward: Module 18 (Acceleration) teaches parallelization!


Ready to build the optimization that powers ChatGPT? 🚀

Start with: modules/17_memoization/memoization_dev.py