mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-05-05 20:29:06 -05:00
Documentation updates across the codebase: Root documentation: - README.md: Updated references from book/ to site/ - CONTRIBUTING.md: Updated build and workflow instructions - .shared-ai-rules.md: Updated AI assistant rules for new structure GitHub configuration: - Issue templates updated for new module locations - Workflow references updated from book/ to site/ docs/ updates: - STUDENT_QUICKSTART.md: New paths and structure - module-rules.md: Updated module development guidelines - NBGrader documentation: Updated for module restructuring - Archive documentation: Updated references Module documentation: - modules/17_memoization/README.md: Updated after reordering All documentation now correctly references: - site/ instead of book/ - modules/XX_name/ instead of modules/source/
230 lines
7.8 KiB
Markdown
230 lines
7.8 KiB
Markdown
# Module 17: Memoization/KV Caching - Inference Optimization
|
||
|
||
**Time**: 2-3 hours
|
||
**Difficulty**: ⭐⭐⭐⭐☆ (Advanced)
|
||
|
||
## 🎯 What You'll Build
|
||
|
||
Implement **KV caching** - the critical optimization that makes production LLM inference economically viable. Transform O(n²) naive generation into O(n) optimized generation through computational reuse.
|
||
|
||
## 📋 Prerequisites
|
||
|
||
**Required Modules**:
|
||
- ✅ Module 01-14 (Foundation through Profiling)
|
||
- ✅ Module 12 (Multi-Head Attention) - What we'll optimize
|
||
- ✅ Module 13 (Transformer) - Architecture we'll accelerate
|
||
- ✅ Module 14 (Profiling) - How we measure speedup
|
||
|
||
**Before Starting**:
|
||
```bash
|
||
# Verify transformer implementation works
|
||
pytest modules/13_transformer/test_transformer.py
|
||
|
||
# Verify profiling tools work
|
||
pytest modules/14_profiling/test_profiling.py
|
||
```
|
||
|
||
## 🧠 Core Concept
|
||
|
||
### The Problem: O(n²) Generation
|
||
|
||
When generating text token-by-token, naive transformers recompute ALL previous key-value pairs at EVERY step:
|
||
|
||
```
|
||
Step 1: Generate "Hello" → Compute K₁, V₁ (1 computation)
|
||
Step 2: Generate "world" → Compute K₁, V₁, K₂, V₂ (2 computations, K₁,V₁ WASTED!)
|
||
Step 3: Generate "!" → Compute K₁, V₁, K₂, V₂, K₃, V₃ (3 computations, K₁,V₁,K₂,V₂ WASTED!)
|
||
|
||
Total: 1 + 2 + 3 + ... + n = O(n²) complexity!
|
||
```
|
||
|
||
**For 100 tokens**: 5,050 redundant computations! 😱
|
||
|
||
### The Solution: Cache & Reuse
|
||
|
||
**Key insight**: K and V for previous tokens NEVER change!
|
||
|
||
```
|
||
Step 1: Compute K₁, V₁ → CACHE them
|
||
Step 2: Compute K₂, V₂ → Append to cache, retrieve [K₁,V₁,K₂,V₂]
|
||
Step 3: Compute K₃, V₃ → Append to cache, retrieve [K₁,V₁,K₂,V₂,K₃,V₃]
|
||
|
||
Total: 1 + 1 + 1 + ... + 1 = O(n) complexity!
|
||
```
|
||
|
||
**Result**: 10-15× speedup for typical generation! 🚀
|
||
|
||
## 🏗️ What You'll Implement
|
||
|
||
### 1. KVCache Class
|
||
```python
|
||
class KVCache:
|
||
"""Efficient storage for key-value pairs across transformer layers."""
|
||
|
||
def __init__(self, batch_size, max_seq_len, num_layers, num_heads, head_dim):
|
||
# Pre-allocate cache tensors for all layers
|
||
pass
|
||
|
||
def update(self, layer_idx, key, value):
|
||
# O(1) append new K,V to cache (no copying!)
|
||
pass
|
||
|
||
def get(self, layer_idx):
|
||
# O(1) retrieve cached K,V for attention
|
||
pass
|
||
```
|
||
|
||
### 2. Non-Invasive Integration
|
||
```python
|
||
def enable_kv_cache(model):
|
||
"""Add caching to existing transformer WITHOUT modifying Module 12/13!"""
|
||
# Create cache sized for model
|
||
# Wrap attention layers with caching logic
|
||
# Return cache for manual control
|
||
pass
|
||
```
|
||
|
||
### 3. Performance Analysis
|
||
- Measure speedup: O(n²) → O(n) transformation
|
||
- Analyze memory trade-off: 2× memory enables 10× speed
|
||
- Profile scaling: Longer generation = better ROI
|
||
|
||
## 📊 Focus: Memory-Compute Trade-offs
|
||
|
||
This module teaches THE fundamental systems trade-off:
|
||
|
||
```
|
||
WITHOUT Cache:
|
||
Memory: O(1) (no storage)
|
||
Compute: O(n²) (recompute everything)
|
||
Speed: ~40 tok/s (slow!)
|
||
|
||
WITH Cache:
|
||
Memory: O(n) (store all K,V pairs)
|
||
Compute: O(n) (compute new K,V only)
|
||
Speed: ~500 tok/s (10-15× faster!)
|
||
```
|
||
|
||
**Trade-off Winner**: Memory is cheap, compute is expensive! Accept O(n) memory for O(n²)→O(n) speedup.
|
||
|
||
## 🚀 Production Technique for Real LLM Inference
|
||
|
||
This isn't a toy optimization - it's **THE** technique that makes production serving possible:
|
||
|
||
### Real-World Impact
|
||
|
||
**ChatGPT, Claude, GPT-4, LLaMA**: ALL use KV caching
|
||
- Without caching: 100-token response = ~17 seconds ❌
|
||
- With caching: 100-token response = ~0.1 seconds ✅
|
||
|
||
**Production Systems**:
|
||
- vLLM (Serving framework): KV cache is the core optimization
|
||
- llama.cpp (Inference engine): Implements KV caching for efficiency
|
||
- HuggingFace Transformers: `use_cache=True` in generation
|
||
|
||
### Memory Requirements
|
||
|
||
```
|
||
GPT-2 (12 layers, 12 heads, seq_len=1024, head_dim=64):
|
||
Cache size = 12 × 12 × 1024 × 64 × 2 (K+V) × 4 bytes (float32)
|
||
= ~37 MB per sequence
|
||
|
||
GPT-3 (96 layers, 96 heads, seq_len=2048, head_dim=128):
|
||
Cache size = 96 × 96 × 2048 × 128 × 2 × 4 bytes
|
||
= ~4.7 GB per sequence
|
||
|
||
Trade-off: <1% of model memory enables 10× speedup!
|
||
```
|
||
|
||
## 🎓 Learning Outcomes
|
||
|
||
By completing this module, you will:
|
||
|
||
1. **Understand memoization** as a general optimization pattern (cache results, avoid recomputation)
|
||
2. **Implement KVCache** with efficient O(1) updates and O(n) memory scaling
|
||
3. **Build cache-aware attention** that reuses previously computed keys and values
|
||
4. **Measure dramatic speedup gains** (10-15×) through systems profiling
|
||
5. **Analyze memory-compute trade-offs** in production inference systems
|
||
6. **Learn non-invasive optimization** - add capabilities without breaking old code
|
||
|
||
## 🔗 Connections to Other Modules
|
||
|
||
**Builds On**:
|
||
- Module 12 (Attention): What we're optimizing
|
||
- Module 13 (Transformer): Architecture we're accelerating
|
||
- Module 14 (Profiling): How we validate speedup
|
||
|
||
**Enables**:
|
||
- Module 18 (Acceleration): Combine caching with parallelization
|
||
- Milestone 05 (Chatbot): Real-time generation with caching
|
||
|
||
**Systems Pattern**:
|
||
```
|
||
Module 05 (Autograd): enable_autograd() → Add gradients to Tensors
|
||
Module 17 (Memoization): enable_kv_cache() → Add caching to Attention
|
||
↓
|
||
Critical Pattern: ENHANCE, don't MODIFY existing code!
|
||
```
|
||
|
||
## 📈 Expected Performance
|
||
|
||
```
|
||
┌─────────────┬────────────┬─────────────┬──────────┐
|
||
│ Seq Length │ No Cache │ With Cache │ Speedup │
|
||
├─────────────┼────────────┼─────────────┼──────────┤
|
||
│ 10 tokens │ ~80 tok/s │ ~600 tok/s │ 7.5× │
|
||
│ 25 tokens │ ~40 tok/s │ ~500 tok/s │ 12.5× │
|
||
│ 50 tokens │ ~25 tok/s │ ~400 tok/s │ 16.0× │
|
||
│ 100 tokens │ ~12 tok/s │ ~200 tok/s │ 16.7× │
|
||
└─────────────┴────────────┴─────────────┴──────────┘
|
||
|
||
Key Insight: Speedup INCREASES with sequence length!
|
||
Why? Longer sequences = more redundant computation without cache.
|
||
```
|
||
|
||
## 🧪 Testing Strategy
|
||
|
||
1. **Unit Tests**: Test KVCache in isolation (storage, retrieval, memory tracking)
|
||
2. **Integration Tests**: Test cache with mock transformer models
|
||
3. **Performance Tests**: Measure O(n²)→O(n) speedup via profiling
|
||
4. **Systems Analysis**: Analyze memory usage and scaling behavior
|
||
|
||
## 💡 Key Insights You'll Discover
|
||
|
||
1. **Recomputation is Expensive**: O(n²) growth makes naive generation impractical
|
||
2. **Memory is Cheap**: Spending O(n) memory saves O(n²) compute
|
||
3. **Scaling Matters**: 100-token generation = 170× fewer operations with cache!
|
||
4. **Production Critical**: This single optimization enables ChatGPT-scale inference
|
||
5. **Non-Invasive Design**: Best optimizations ADD capabilities, don't BREAK old code
|
||
|
||
## 🎯 Success Criteria
|
||
|
||
- [ ] KVCache correctly stores and retrieves K,V pairs for all layers
|
||
- [ ] Cache updates are O(1) (no data copying)
|
||
- [ ] Memory usage matches theoretical predictions
|
||
- [ ] enable_kv_cache() works without modifying Module 12/13
|
||
- [ ] All unit tests pass
|
||
- [ ] Integration test validates complete workflow
|
||
- [ ] Performance analysis shows 10-15× speedup
|
||
|
||
## 🚀 Next Steps
|
||
|
||
After completing this module:
|
||
|
||
1. **Try it yourself**: Run chatbot milestone with/without caching
|
||
```bash
|
||
python milestones/05_2017_transformer/vaswani_chatgpt.py --use-cache
|
||
```
|
||
|
||
2. **Experiment**: Profile speedup on different sequence lengths
|
||
|
||
3. **Compare**: Measure memory overhead vs model parameters
|
||
|
||
4. **Move forward**: Module 18 (Acceleration) teaches parallelization!
|
||
|
||
---
|
||
|
||
**Ready to build the optimization that powers ChatGPT?** 🚀
|
||
|
||
Start with: `modules/17_memoization/memoization_dev.py`
|