- Improve module descriptions and learning objectives - Standardize documentation format and structure - Add clearer guidance for students - Enhance module-specific context and examples
26 KiB
title, description, difficulty, time_estimate, prerequisites, next_steps, learning_objectives
| title | description | difficulty | time_estimate | prerequisites | next_steps | learning_objectives | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Memoization - Computational Reuse for Inference | Transform O(n²) transformer generation into O(n) through KV caching, achieving 10-15x speedup | ⭐⭐⭐ (3/4) | 4-5 hours |
|
|
|
17. Memoization - Computational Reuse for Inference
OPTIMIZATION TIER | Difficulty: ⭐⭐⭐ (3/4) | Time: 4-5 hours
Overview
Memoization is a fundamental optimization pattern: cache computational results to avoid redundant work. You'll apply this pattern to transformers through KV (Key-Value) caching, transforming O(n²) autoregressive generation into O(n) complexity and achieving 10-15x speedup. This optimization makes production language model serving economically viable.
This is inference-only optimization - you'll implement caching patterns used in every production LLM from ChatGPT to Claude to GitHub Copilot.
Learning Objectives
By the end of this module, you will be able to:
- Understand Memoization Pattern: Recognize when computational reuse through caching applies to ML problems and understand the memory-speed trade-off
- Implement KVCache Structure: Build efficient cache data structures with O(1) updates, proper memory management, and multi-layer support
- Apply Caching to Transformers: Integrate KV caching into attention layers without modifying existing transformer code (non-invasive enhancement)
- Measure Performance Gains: Profile latency improvements, measure O(n²) → O(n) complexity reduction, and understand speedup characteristics
- Analyze Production Trade-offs: Calculate cache memory costs, understand cache invalidation policies, and recognize when caching justifies its overhead
Build → Use → Optimize
This module follows TinyTorch's Build → Use → Optimize framework:
- Build: Implement KVCache data structure with efficient updates, cached attention integration, and multi-layer cache management
- Use: Apply caching to GPT text generation, measure 10-15x speedup over naive generation, and validate output correctness
- Optimize: Profile memory bandwidth bottlenecks, measure cache hit rates, and understand when memory cost exceeds latency benefit
Why This Matters
KV Cache Optimization Flow
Caching stores computed keys and values, avoiding recomputation for each new token:
graph LR
A[Token i<br/>Compute K_i, V_i] --> B[Cache<br/>Store K_i, V_i]
B --> C[Token i+1<br/>New computation]
C --> D[Reuse<br/>K_i, V_i from cache]
D --> E[Only compute<br/>K_{i+1}, V_{i+1}]
E --> F[10-15× speedup]
style A fill:#e3f2fd
style C fill:#e3f2fd
style B fill:#f3e5f5
style D fill:#fff3e0
style E fill:#ffe0b2
style F fill:#f0fdf4
Optimization: Compute K,V once → Cache → Reuse for all future tokens → O(n²) → O(n) complexity
The Autoregressive Generation Problem
Without caching, transformer generation has quadratic complexity:
Naive Generation (O(n²) complexity):
Step 1: Generate token 1 → Compute attention for [t₀] (1 computation)
Step 2: Generate token 2 → Compute attention for [t₀, t₁] (2 computations, t₀ RECOMPUTED!)
Step 3: Generate token 3 → Compute attention for [t₀, t₁, t₂] (3 computations, t₀,t₁ RECOMPUTED!)
...
Step n: Generate token n → Compute attention for [t₀, ..., tₙ] (n computations, ALL RECOMPUTED!)
Total: 1 + 2 + 3 + ... + n = n(n+1)/2 = O(n²) complexity!
For 100 tokens: ~5,050 redundant K,V computations
The Key Insight: K and V matrices for previous tokens NEVER change, yet we recompute them every step!
The Caching Solution
Cached Generation (O(n) complexity):
Step 1: Compute K₁, V₁ → Cache them → Attention with cached[K₁, V₁]
Step 2: Compute K₂, V₂ → Cache them → Attention with cached[K₁, K₂, V₁, V₂] (reuse K₁, V₁!)
Step 3: Compute K₃, V₃ → Cache them → Attention with cached[K₁, K₂, K₃, V₁, V₂, V₃] (reuse all!)
Total: 1 + 1 + 1 + ... + 1 = n computations (50x reduction for n=100!)
Production Impact
KV caching is mandatory for all production LLM serving:
- ChatGPT/GPT-4: Would be 50-100x slower without caching, making conversational AI economically infeasible
- Claude: Caches up to 100K tokens of context, enabling long document processing
- GitHub Copilot: Real-time code completion requires sub-100ms latency - impossible without caching
- Google Gemini: Multi-level caching (KV + intermediate layers) serves billions of requests daily
Without KV caching, the computational cost would make these services prohibitively expensive.
Memory-Speed Trade-off
Traditional Approach (No Cache):
Memory: O(1) Cost: Negligible
Compute: O(n²) Cost: Prohibitive for long sequences
Cached Approach (KV Cache):
Memory: O(n × d_k) Cost: ~18MB per batch for GPT-2
Compute: O(n) Cost: 10-15x faster than naive
Trade-off Winner: Memory is cheap, compute is expensive!
Use O(n) memory to save O(n²) compute.
Implementation Guide
Core Components
You'll implement three main components:
1. KVCache Data Structure
class KVCache:
"""
Efficient key-value cache for autoregressive generation.
Memory Layout:
keys: (num_layers, batch, num_heads, seq_len, d_k)
values: (num_layers, batch, num_heads, seq_len, d_v)
For GPT-2 (12 layers, 12 heads, 1024 seq, 64 dims):
12 layers × 12 heads × 1024 seq × 64 dims = ~9M values
At FP32 (4 bytes): ~36MB per batch item
At FP16 (2 bytes): ~18MB per batch item
Operations:
update(layer_idx, key, value) -> None # O(1) append
get(layer_idx) -> (cached_k, cached_v) # O(1) retrieval
advance() -> None # Increment position
reset() -> None # Clear for new sequence
"""
Key Design Decisions:
- Pre-allocate cache tensors to avoid dynamic resizing overhead
- Use position counter for O(1) indexed updates (no copying)
- Store per-layer caches to support multi-layer transformers
- Track sequence position externally for clean separation
2. Non-Invasive Cache Integration
def enable_kv_cache(model):
"""
Enable KV caching WITHOUT modifying Module 12/13 code.
This demonstrates non-invasive optimization - adding capabilities
to existing systems without breaking them. Similar to how Module 05
uses enable_autograd() to add gradient tracking to Tensors.
Approach:
1. Create KVCache sized for model architecture
2. Store cache on model as model._kv_cache
3. Wrap each attention layer's forward method with caching logic
4. Intercept attention calls to manage cache automatically
This is composition + monkey-patching - a critical ML systems pattern!
"""
Why Non-Invasive?
- Modules 12-13 (Attention, Transformers) work unchanged
- Module 17 ADDS optimization, doesn't BREAK old code
- Teaches "forward-only" systems engineering: never modify earlier modules
- Matches how production systems layer optimizations (vLLM, HuggingFace)
3. Cached Attention Logic
def cached_forward(x, mask=None):
"""
Cache-aware attention with three paths:
PATH 1: Training (seq_len > 1)
→ Use original attention (preserve gradients)
→ O(n²) but needed for backpropagation
PATH 2: First Token (seq_len == 1, cache empty)
→ Use original attention (initialize cache)
→ O(1) - just one token
PATH 3: Cached Generation (seq_len == 1, cache populated)
→ Compute K,V for NEW token only
→ Retrieve ALL cached K,V (includes history)
→ Attention with cached context
→ O(n) - only compute new, reuse cache
→ THIS IS WHERE THE SPEEDUP HAPPENS!
"""
Implementation Steps
Step 1: Design KVCache Structure
- Initialize cache storage for all layers
- Pre-allocate tensors with maximum sequence length
- Track current sequence position (write pointer)
- Implement update() for O(1) append operations
- Implement get() for O(1) retrieval of valid cache portion
Step 2: Implement Cache Updates
- Validate layer index and sequence position
- Write new K,V to current position (indexed assignment)
- Advance position counter after all layers processed
- Handle batch dimension and multi-head structure
Step 3: Enable Non-Invasive Integration
- Validate model has required attributes (embed_dim, num_layers, etc.)
- Calculate head_dim from embed_dim and num_heads
- Create KVCache instance sized for model
- Store cache on model with model._kv_cache flag
- Wrap each block's attention.forward with caching logic
Step 4: Implement Cached Attention Forward
- Detect path: training (seq_len > 1), first token (cache empty), or cached generation
- For cached path: Compute Q,K,V projections for new token only
- Reshape to multi-head format (batch, num_heads, 1, head_dim)
- Update cache with new K,V pairs
- Retrieve ALL cached K,V (history + new)
- Compute attention: softmax(Q @ K^T / √d_k) @ V using NumPy (.data)
- Apply output projection and return
Step 5: Validate Correctness
- Test cache initialization and memory calculation
- Verify single-token and multi-token updates
- Validate multi-layer cache synchronization
- Test reset functionality
- Measure speedup vs non-cached generation
Why .data Instead of Tensor Operations?
In cached attention, we use NumPy via .data for three reasons:
-
Explicit Intent: Makes it crystal clear this is inference-only
- Training: Uses Tensor operations → gradients tracked
- Inference: Uses .data → no gradient overhead
-
Performance: Avoids any autograd bookkeeping
- Even small overhead matters in generation hotpath
- Production LLMs (vLLM, llama.cpp) use similar patterns
-
Educational Clarity: Shows students the distinction
- "When do I need gradients?" (training)
- "When can I skip them?" (inference)
We COULD use Tensor operations with requires_grad=False, but .data is more explicit and follows industry patterns.
Getting Started
Prerequisites
Ensure you understand transformers and profiling:
# Activate TinyTorch environment
source bin/activate-tinytorch.sh
# Verify prerequisite modules
tito test --module transformers
tito test --module profiling
Required Understanding:
- Multi-head attention mechanism (Module 12)
- Transformer architecture (Module 13)
- Latency profiling techniques (Module 14)
- O(n²) complexity of attention computation
Development Workflow
- Open the development file:
modules/17_memoization/memoization_dev.ipynb - Profile naive generation: Measure O(n²) growth in latency as sequence lengthens
- Implement KVCache class: Build data structure with update(), get(), advance(), reset()
- Test cache operations: Verify single-token, multi-token, and multi-layer caching
- Implement enable_kv_cache(): Non-invasively patch model attention layers
- Build cached attention forward: Three-path logic (training, first token, cached generation)
- Measure speedup: Profile cached vs non-cached generation, validate O(n) complexity
- Export and verify:
tito module complete 17 && tito test --module memoization
Testing
Comprehensive Test Suite
Run the full test suite to verify memoization functionality:
# TinyTorch CLI (recommended)
tito test --module memoization
# Direct pytest execution
python -m pytest tests/ -k memoization -v
Test Coverage Areas
- ✅ KVCache Initialization: Validate cache creation, memory calculation, and initial state
- ✅ Cache Updates: Test single-token append, multi-token sequences, and O(1) update performance
- ✅ Multi-Layer Synchronization: Verify independent per-layer caches with correct indexing
- ✅ Cache Retrieval: Test get() returns only valid cached portion (up to seq_pos)
- ✅ Non-Invasive Integration: Validate enable_kv_cache() works without breaking model
- ✅ Correctness Validation: Compare cached vs non-cached outputs (should be identical)
- ✅ Performance Measurement: Measure speedup at different sequence lengths
- ✅ Memory Tracking: Calculate cache size and validate memory usage
Inline Testing & Profiling
The module includes comprehensive validation with performance measurement:
# Unit Test: KVCache Implementation
🔬 Unit Test: KVCache Implementation...
Cache initialized: 0.59 MB
✅ Cache initialization successful
✅ Append and retrieval work correctly
✅ Multi-layer caching validated
✅ Reset functionality verified
📈 Progress: KVCache ✓
# Integration Test: Performance Measurement
🔬 Profiling Transformer Generation (Without Caching):
Seq Len | Latency (ms) | Growth
---------|----------------|----------
10 | 2.34 | baseline
20 | 4.89 | 2.09×
40 | 10.12 | 2.07×
80 | 21.45 | 2.12×
160 | 45.67 | 2.13×
💡 Key Observations:
• Latency grows QUADRATICALLY with sequence length
• Each new token forces recomputation of ALL previous K,V pairs
• For 160 tokens: ~4× time vs 80 tokens (2² growth)
🎯 The Solution: CACHE the K,V values! (That's memoization)
✅ Speedup: 10-15× for typical generation
Manual Testing Examples
from tinytorch.generation.kv_cache import KVCache, enable_kv_cache
# Test cache with small transformer
cache = KVCache(
batch_size=1,
max_seq_len=128,
num_layers=4,
num_heads=8,
head_dim=64
)
# Simulate generation loop
import numpy as np
from tinytorch.core.tensor import Tensor
for step in range(10):
for layer_idx in range(4):
# New key-value pairs for this step
new_k = Tensor(np.random.randn(1, 8, 1, 64))
new_v = Tensor(np.random.randn(1, 8, 1, 64))
# Update cache (O(1) operation)
cache.update(layer_idx, new_k, new_v)
# Advance position after all layers
cache.advance()
# Retrieve cached values
cached_k, cached_v = cache.get(layer_idx=0)
print(f"Cached 10 tokens: {cached_k.shape}") # (1, 8, 10, 64)
# Calculate memory usage
mem_info = cache.get_memory_usage()
print(f"Cache memory: {mem_info['total_mb']:.2f} MB")
Systems Thinking Questions
Real-World Production Challenges
Memory-Speed Trade-off Analysis:
- KV cache uses ~18MB per batch for GPT-2 (FP16). For batch=32, that's 576MB.
- On an 8GB GPU, how many concurrent users can you serve?
- What's the trade-off between batch size and cache size?
- When does memory bandwidth (cache access) become the bottleneck instead of compute?
Cache Invalidation Policies:
- In multi-turn chat, when should you clear the cache?
- What happens when context exceeds max_seq_len?
- How do production systems like ChatGPT handle context window limits?
- Compare eviction policies: LRU, FIFO, sliding window, importance-based
Distributed Caching for Large Models:
- For models too large for one GPU, you need tensor parallelism
- How do you partition the KV cache across GPUs?
- Which dimension should you shard: layers, heads, or sequence?
- What's the communication overhead for cache synchronization?
Quantized Caching:
- Storing cache in INT8 instead of FP16 saves 50% memory
- What's the accuracy impact of quantized KV cache?
- When is this trade-off worth it?
- How does quantization error accumulate over long sequences?
Production Optimization Patterns
Multi-Level Caching:
- What if you cache not just K,V but intermediate layer activations?
- How does HuggingFace's
DynamicCachediffer from static pre-allocation? - When should you use persistent caching (save to disk) for very long conversations?
Speculation and Prefetching:
- What if you predict the next query and pre-compute KV cache?
- How would speculative caching improve throughput?
- What's the risk if speculation is wrong?
- When does prefetching justify its overhead?
Mathematical Foundations
Complexity Reduction:
- Why does KV caching transform O(n²) into O(n)?
- Calculate total operations for naive vs cached generation (n=100)
- What's the crossover point where caching overhead exceeds savings?
Memory Layout Optimization:
- Why pre-allocate cache instead of dynamic appending?
- How does cache contiguity affect memory bandwidth?
- Compare row-major vs column-major cache layouts for performance
Attention Computation Analysis:
- Why can we cache K,V but not Q (query)?
- What property of autoregressive generation makes caching valid?
- How would bidirectional attention (BERT) change caching strategy?
HuggingFace Cache Patterns Comparison
Static vs Dynamic Cache:
# TinyTorch (Module 17): Static pre-allocation
cache = KVCache(max_seq_len=1024) # Fixed size, O(1) updates
# HuggingFace: Dynamic cache (DynamicCache class)
cache = DynamicCache() # Grows as needed, more flexible but slower
When to Use Each:
- Static (TinyTorch): Known max length, maximum performance, inference serving
- Dynamic (HuggingFace): Variable lengths, exploration, research
Production Systems (vLLM, TGI):
- Use PagedAttention for virtual memory management of KV cache
- Enables efficient memory sharing across requests
- Reduces memory fragmentation for variable-length sequences
Performance Characteristics
Expected Speedup by Sequence Length
Speedup Characteristics (GPT-2 on CPU):
┌─────────────┬──────────────┬──────────────┬──────────┐
│ Seq Length │ No Cache │ With Cache │ Speedup │
├─────────────┼──────────────┼──────────────┼──────────┤
│ 10 tokens │ ~80 tok/s │ ~600 tok/s │ 7.5x │
│ 25 tokens │ ~40 tok/s │ ~500 tok/s │ 12.5x │
│ 50 tokens │ ~25 tok/s │ ~400 tok/s │ 16.0x │
│ 100 tokens │ ~12 tok/s │ ~200 tok/s │ 16.7x │
└─────────────┴──────────────┴──────────────┴──────────┘
Key Insight: Speedup increases with sequence length!
Why? Longer sequences = more redundant computation without cache.
Memory Usage by Model Size
Cache Memory Requirements (FP16, batch_size=1):
┌──────────────┬────────┬────────┬─────────┬──────────────┐
│ Model │ Layers │ Heads │ Seq Len │ Cache Memory │
├──────────────┼────────┼────────┼─────────┼──────────────┤
│ TinyGPT │ 4 │ 4 │ 128 │ 0.5 MB │
│ GPT-2 (124M) │ 12 │ 12 │ 1024 │ 18.0 MB │
│ GPT-3 (175B) │ 96 │ 96 │ 2048 │ 4.7 GB │
└──────────────┴────────┴────────┴─────────┴──────────────┘
Formula: memory = num_layers × num_heads × max_seq_len × head_dim × 2 × 2 bytes
(2× for K and V, 2 bytes for FP16)
Throughput Impact
Single Sequence Generation:
- Without cache: Throughput decreases as sequence grows (O(n²) bottleneck)
- With cache: Throughput stays relatively constant (O(n) scales well)
Batch Inference:
- Cache memory scales linearly with batch size
- Throughput increases with batching (amortize model loading)
- Memory becomes limiting factor before compute
Where This Code Lives in the Final Package
Package Export: Code exports to tinytorch.generation.kv_cache
# When students install tinytorch, they import your work like this:
from tinytorch.generation.kv_cache import KVCache, enable_kv_cache, disable_kv_cache
from tinytorch.nn import MultiHeadAttention # Base class from Module 12
from tinytorch.models.transformer import GPT # Architecture from Module 13
# Usage in generation:
model = GPT(vocab_size=1000, embed_dim=128, num_layers=4, num_heads=4)
cache = enable_kv_cache(model) # Non-invasively add caching
# Generate with caching enabled (10-15x faster!)
output = generate_text(model, prompt="Hello", max_new_tokens=100)
# Disable caching if needed
disable_kv_cache(model)
Your KV caching implementation becomes the foundation for efficient inference in the TinyTorch package, used by subsequent modules for text generation, chat applications, and deployment scenarios.
Common Challenges and Solutions
Challenge 1: Cache Synchronization Across Layers
Problem: Keeping cache consistent when different layers process at different speeds or batch items have variable lengths.
Solution:
- Use layer indexing to maintain independent per-layer caches
- Advance sequence position only after ALL layers have processed current token
- Handle variable sequence lengths with padding and attention masks
Code Pattern:
# Process all layers before advancing
for layer_idx in range(num_layers):
cache.update(layer_idx, new_k, new_v)
# Now advance position (all layers synchronized)
cache.advance()
Challenge 2: Memory Overhead for Large Models
Problem: Cache memory grows with sequence length and batch size, potentially exceeding GPU memory.
Solution:
- Implement cache size limits with eviction policies (LRU, FIFO)
- Use FP16 or INT8 quantization for cache storage (50% memory reduction)
- Consider PagedAttention for virtual memory management
- Tune max_seq_len to expected generation length
Memory Optimization:
# FP16 caching (2 bytes per element)
cache = KVCache(...).to(dtype=np.float16) # 50% memory savings
# INT8 caching (1 byte per element)
cache = KVCache(...).to(dtype=np.int8) # 75% memory savings, accuracy trade-off
Challenge 3: Correctness Validation
Problem: Cached generation must produce identical outputs to non-cached generation.
Solution:
- Compare cached vs non-cached outputs token-by-token
- Use deterministic sampling (temperature=0) for testing
- Validate cache retrieval returns correct sequence positions
- Test edge cases: first token, cache full, reset
Validation Pattern:
# Generate without cache (ground truth)
output_nocache = generate(model, prompt, max_new_tokens=50)
# Generate with cache (optimized)
cache = enable_kv_cache(model)
output_cached = generate(model, prompt, max_new_tokens=50)
# Validate identical outputs
assert np.allclose(output_nocache, output_cached), "Cached output must match!"
Challenge 4: Integration Without Breaking Existing Code
Problem: Adding caching shouldn't require modifying Modules 12-13 (attention, transformer).
Solution:
- Use composition + monkey-patching (wrap, don't modify)
- Store original forward methods before patching
- Provide disable_kv_cache() to restore original behavior
- Use feature flags (model._cache_enabled) for path selection
Non-Invasive Pattern:
# Save original before patching
block._original_attention_forward = block.attention.forward
# Patch with cached version
block.attention.forward = cached_forward
# Restore later if needed
block.attention.forward = block._original_attention_forward
Ready to Build?
You're about to implement the optimization that makes production language models economically viable! KV caching is THE technique that transformed LLMs from research toys into products used by millions daily.
This is where theory meets practice in ML systems engineering. You'll see firsthand how a simple idea - "don't recompute what never changes" - can deliver 10-15x speedup and make the impossible possible.
What makes this module special: Unlike many optimizations that require deep algorithmic changes, KV caching is conceptually simple but profoundly impactful. You'll implement it from scratch, measure the dramatic speedup, and understand the memory-speed trade-offs that guide production deployments.
Understanding this optimization from first principles - implementing it yourself, profiling the speedup, analyzing the trade-offs - will give you deep insight into how production ML systems work. This is the optimization that makes ChatGPT, Claude, and GitHub Copilot possible.
Take your time, measure thoroughly, and enjoy building production-ready ML systems!
Choose your preferred way to engage with this module:
```{grid-item-card} 🚀 Launch Binder
:link: https://mybinder.org/v2/gh/mlsysbook/TinyTorch/main?filepath=modules/17_memoization/memoization_dev.ipynb
:class-header: bg-light
Run this module interactively in your browser. No installation required!
```
```{grid-item-card} ⚡ Open in Colab
:link: https://colab.research.google.com/github/mlsysbook/TinyTorch/blob/main/modules/17_memoization/memoization_dev.ipynb
:class-header: bg-light
Use Google Colab for GPU access and cloud compute power.
```
```{grid-item-card} 📖 View Source
:link: https://github.com/mlsysbook/TinyTorch/blob/main/modules/17_memoization/memoization_dev.ipynb
:class-header: bg-light
Browse the Jupyter notebook source and understand the implementation.
```
:class: tip
**Binder sessions are temporary!** Download your completed notebook when done, or switch to local development for persistent work.