Module 17: Memoization/KV Caching - Inference Optimization
Time: 2-3 hours Difficulty: ⭐⭐⭐⭐☆ (Advanced)
🎯 What You'll Build
Implement KV caching - the critical optimization that makes production LLM inference economically viable. Transform O(n²) naive generation into O(n) optimized generation through computational reuse.
📋 Prerequisites
Required Modules:
- ✅ Module 01-14 (Foundation through Profiling)
- ✅ Module 12 (Multi-Head Attention) - What we'll optimize
- ✅ Module 13 (Transformer) - Architecture we'll accelerate
- ✅ Module 14 (Profiling) - How we measure speedup
Before Starting:
# Verify transformer implementation works
pytest modules/13_transformer/test_transformer.py
# Verify profiling tools work
pytest modules/14_profiling/test_profiling.py
🧠 Core Concept
The Problem: O(n²) Generation
When generating text token-by-token, naive transformers recompute ALL previous key-value pairs at EVERY step:
Step 1: Generate "Hello" → Compute K₁, V₁ (1 computation)
Step 2: Generate "world" → Compute K₁, V₁, K₂, V₂ (2 computations, K₁,V₁ WASTED!)
Step 3: Generate "!" → Compute K₁, V₁, K₂, V₂, K₃, V₃ (3 computations, K₁,V₁,K₂,V₂ WASTED!)
Total: 1 + 2 + 3 + ... + n = O(n²) complexity!
For 100 tokens: 5,050 redundant computations! 😱
The Solution: Cache & Reuse
Key insight: K and V for previous tokens NEVER change!
Step 1: Compute K₁, V₁ → CACHE them
Step 2: Compute K₂, V₂ → Append to cache, retrieve [K₁,V₁,K₂,V₂]
Step 3: Compute K₃, V₃ → Append to cache, retrieve [K₁,V₁,K₂,V₂,K₃,V₃]
Total: 1 + 1 + 1 + ... + 1 = O(n) complexity!
Result: 10-15× speedup for typical generation! 🚀
🏗️ What You'll Implement
1. KVCache Class
class KVCache:
"""Efficient storage for key-value pairs across transformer layers."""
def __init__(self, batch_size, max_seq_len, num_layers, num_heads, head_dim):
# Pre-allocate cache tensors for all layers
pass
def update(self, layer_idx, key, value):
# O(1) append new K,V to cache (no copying!)
pass
def get(self, layer_idx):
# O(1) retrieve cached K,V for attention
pass
2. Non-Invasive Integration
def enable_kv_cache(model):
"""Add caching to existing transformer WITHOUT modifying Module 12/13!"""
# Create cache sized for model
# Wrap attention layers with caching logic
# Return cache for manual control
pass
3. Performance Analysis
- Measure speedup: O(n²) → O(n) transformation
- Analyze memory trade-off: 2× memory enables 10× speed
- Profile scaling: Longer generation = better ROI
📊 Focus: Memory-Compute Trade-offs
This module teaches THE fundamental systems trade-off:
WITHOUT Cache:
Memory: O(1) (no storage)
Compute: O(n²) (recompute everything)
Speed: ~40 tok/s (slow!)
WITH Cache:
Memory: O(n) (store all K,V pairs)
Compute: O(n) (compute new K,V only)
Speed: ~500 tok/s (10-15× faster!)
Trade-off Winner: Memory is cheap, compute is expensive! Accept O(n) memory for O(n²)→O(n) speedup.
🚀 Production Technique for Real LLM Inference
This isn't a toy optimization - it's THE technique that makes production serving possible:
Real-World Impact
ChatGPT, Claude, GPT-4, LLaMA: ALL use KV caching
- Without caching: 100-token response = ~17 seconds ❌
- With caching: 100-token response = ~0.1 seconds ✅
Production Systems:
- vLLM (Serving framework): KV cache is the core optimization
- llama.cpp (Inference engine): Implements KV caching for efficiency
- HuggingFace Transformers:
use_cache=Truein generation
Memory Requirements
GPT-2 (12 layers, 12 heads, seq_len=1024, head_dim=64):
Cache size = 12 × 12 × 1024 × 64 × 2 (K+V) × 4 bytes (float32)
= ~37 MB per sequence
GPT-3 (96 layers, 96 heads, seq_len=2048, head_dim=128):
Cache size = 96 × 96 × 2048 × 128 × 2 × 4 bytes
= ~4.7 GB per sequence
Trade-off: <1% of model memory enables 10× speedup!
🎓 Learning Outcomes
By completing this module, you will:
- Understand memoization as a general optimization pattern (cache results, avoid recomputation)
- Implement KVCache with efficient O(1) updates and O(n) memory scaling
- Build cache-aware attention that reuses previously computed keys and values
- Measure dramatic speedup gains (10-15×) through systems profiling
- Analyze memory-compute trade-offs in production inference systems
- Learn non-invasive optimization - add capabilities without breaking old code
🔗 Connections to Other Modules
Builds On:
- Module 12 (Attention): What we're optimizing
- Module 13 (Transformer): Architecture we're accelerating
- Module 14 (Profiling): How we validate speedup
Enables:
- Module 18 (Acceleration): Combine caching with parallelization
- Milestone 05 (Chatbot): Real-time generation with caching
Systems Pattern:
Module 05 (Autograd): enable_autograd() → Add gradients to Tensors
Module 17 (Memoization): enable_kv_cache() → Add caching to Attention
↓
Critical Pattern: ENHANCE, don't MODIFY existing code!
📈 Expected Performance
┌─────────────┬────────────┬─────────────┬──────────┐
│ Seq Length │ No Cache │ With Cache │ Speedup │
├─────────────┼────────────┼─────────────┼──────────┤
│ 10 tokens │ ~80 tok/s │ ~600 tok/s │ 7.5× │
│ 25 tokens │ ~40 tok/s │ ~500 tok/s │ 12.5× │
│ 50 tokens │ ~25 tok/s │ ~400 tok/s │ 16.0× │
│ 100 tokens │ ~12 tok/s │ ~200 tok/s │ 16.7× │
└─────────────┴────────────┴─────────────┴──────────┘
Key Insight: Speedup INCREASES with sequence length!
Why? Longer sequences = more redundant computation without cache.
🧪 Testing Strategy
- Unit Tests: Test KVCache in isolation (storage, retrieval, memory tracking)
- Integration Tests: Test cache with mock transformer models
- Performance Tests: Measure O(n²)→O(n) speedup via profiling
- Systems Analysis: Analyze memory usage and scaling behavior
💡 Key Insights You'll Discover
- Recomputation is Expensive: O(n²) growth makes naive generation impractical
- Memory is Cheap: Spending O(n) memory saves O(n²) compute
- Scaling Matters: 100-token generation = 170× fewer operations with cache!
- Production Critical: This single optimization enables ChatGPT-scale inference
- Non-Invasive Design: Best optimizations ADD capabilities, don't BREAK old code
🎯 Success Criteria
- KVCache correctly stores and retrieves K,V pairs for all layers
- Cache updates are O(1) (no data copying)
- Memory usage matches theoretical predictions
- enable_kv_cache() works without modifying Module 12/13
- All unit tests pass
- Integration test validates complete workflow
- Performance analysis shows 10-15× speedup
🚀 Next Steps
After completing this module:
-
Try it yourself: Run chatbot milestone with/without caching
python milestones/05_2017_transformer/vaswani_chatgpt.py --use-cache -
Experiment: Profile speedup on different sequence lengths
-
Compare: Measure memory overhead vs model parameters
-
Move forward: Module 18 (Acceleration) teaches parallelization!
Ready to build the optimization that powers ChatGPT? 🚀
Start with: modules/17_memoization/memoization_dev.py