mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-06-02 21:03:21 -05:00
Standardize Module 14 (KV Caching) to professional template
- Add complete YAML frontmatter with metadata - Add PERFORMANCE tier badge (first Performance Tier module) - Standardize to exactly 5 learning objectives - Implement Build → Use → Optimize pedagogical pattern - Add Why This Matters with ChatGPT/Claude production context - Add historical evolution of caching in transformers - Add comprehensive Implementation Guide with cache structures and cached attention - Add Systems Thinking Questions on memory-speed trade-offs - Add Real-World Connections to conversational AI and code completion - Reduce emoji usage for professional tone - Add clear What's Next navigation to Module 15
This commit is contained in:
@@ -1,53 +1,446 @@
|
||||
---
|
||||
title: "KV Caching - Optimizing Transformer Inference"
|
||||
description: "Cache attention key-value pairs for 10-100x faster autoregressive generation"
|
||||
difficulty: 3
|
||||
time_estimate: "4-5 hours"
|
||||
prerequisites: ["Attention", "Transformers"]
|
||||
next_steps: ["Profiling"]
|
||||
learning_objectives:
|
||||
- "Implement KV caching to eliminate redundant attention computations"
|
||||
- "Design cache management systems for multi-turn conversations"
|
||||
- "Understand memory-speed trade-offs in production inference"
|
||||
- "Optimize transformer latency from O(n²) to O(n) per token"
|
||||
- "Apply caching patterns used in ChatGPT and production LLMs"
|
||||
---
|
||||
|
||||
# 14. KV Caching
|
||||
|
||||
## Optimizing Transformer Inference with Key-Value Caching
|
||||
**⚡ PERFORMANCE TIER** | Difficulty: ⭐⭐⭐ (3/4) | Time: 4-5 hours
|
||||
|
||||
KV (Key-Value) caching is a critical optimization technique for transformer models that dramatically speeds up autoregressive generation. In this module, you'll learn how to implement KV caching to avoid redundant attention computations during inference.
|
||||
## Overview
|
||||
|
||||
### What You'll Build
|
||||
Implement KV (Key-Value) caching to optimize transformer inference. This critical production optimization reduces latency by 10-100× for autoregressive generation by caching attention keys and values, eliminating redundant recomputation.
|
||||
|
||||
- **KV Cache**: Key-Value caching for attention mechanisms
|
||||
- **Feature Cache**: Reuse computed features across requests
|
||||
- **Gradient Cache**: Efficient gradient accumulation
|
||||
- **Model Cache**: Multi-level model weight caching
|
||||
## Learning Objectives
|
||||
|
||||
### Why This Matters
|
||||
By completing this module, you will be able to:
|
||||
|
||||
Caching is essential for production ML systems:
|
||||
- Transformer models recompute attention for every token
|
||||
- Feature extraction is often the bottleneck
|
||||
- Redundant computations waste resources
|
||||
- Smart caching can provide 10-100x speedups
|
||||
1. **Implement KV caching** to eliminate redundant attention key/value computations during generation
|
||||
2. **Design cache management systems** for efficient multi-turn conversation handling
|
||||
3. **Understand memory-speed trade-offs** between caching everything vs recomputing on-the-fly
|
||||
4. **Optimize transformer latency** from O(n²) to O(n) per generated token
|
||||
5. **Apply caching patterns** used in ChatGPT, Claude, and all production language models
|
||||
|
||||
### Learning Objectives
|
||||
## Why This Matters
|
||||
|
||||
By the end of this module, you will:
|
||||
- Implement KV caching for transformer attention layers
|
||||
- Understand how KV caching reduces O(n²) to O(n) complexity
|
||||
- Build efficient cache management for multi-turn generation
|
||||
- Measure the memory-speed tradeoff in production systems
|
||||
### Production Context
|
||||
|
||||
### Prerequisites
|
||||
KV caching is mandatory for production LLM serving:
|
||||
|
||||
Before starting this module, you should have completed:
|
||||
- Module 13: Attention (for KV cache understanding)
|
||||
- Module 14: Transformers (for practical application)
|
||||
- Module 15: Profiling (to measure improvements)
|
||||
- **ChatGPT** uses KV caching for all multi-turn conversations; without it, latency would be unusable
|
||||
- **Claude** caches up to 100K tokens of context; enables long document processing
|
||||
- **GitHub Copilot** caches code context; provides real-time completions
|
||||
- **Google Gemini** uses multi-level caching; serves billions of requests daily
|
||||
|
||||
### Real-World Applications
|
||||
### Historical Context
|
||||
|
||||
Caching is critical in production ML:
|
||||
- **ChatGPT**: KV caching for multi-turn conversations
|
||||
- **Search Engines**: Feature caching for ranking
|
||||
- **Recommendation Systems**: User embedding caches
|
||||
- **Computer Vision**: Intermediate feature caching
|
||||
Caching evolved with transformer deployment:
|
||||
|
||||
### Coming Up Next
|
||||
- **Early Transformers (2017-2019)**: No caching; research focused on training, not inference
|
||||
- **GPT-2 Deployment (2019)**: KV caching implemented; enabled practical text generation
|
||||
- **Production Scale (2020+)**: Multi-level caching (KV + intermediate layers); critical for economics
|
||||
- **Modern Systems (2023+)**: Distributed caching across GPUs; 100K+ token contexts
|
||||
|
||||
After mastering caching, you'll explore:
|
||||
- Module 20: Benchmarking - Measuring the full impact of optimizations
|
||||
- Capstone Project: Building TinyGPT with all optimizations
|
||||
Without KV caching, ChatGPT would be 50-100× slower and economically infeasible.
|
||||
|
||||
## Pedagogical Pattern: Build → Use → Optimize
|
||||
|
||||
### 1. Build
|
||||
|
||||
Implement from first principles:
|
||||
- KV cache data structure for attention
|
||||
- Cache management (append, reuse, clear)
|
||||
- Cached attention forward pass
|
||||
- Multi-turn conversation caching
|
||||
- Memory-efficient cache storage
|
||||
|
||||
### 2. Use
|
||||
|
||||
Apply to real problems:
|
||||
- Optimize GPT decoder for text generation
|
||||
- Cache conversation history for multi-turn chat
|
||||
- Measure latency improvement (10-100× speedup)
|
||||
- Profile memory usage vs cache size
|
||||
- Compare cached vs non-cached inference
|
||||
|
||||
### 3. Optimize
|
||||
|
||||
Production-ready enhancements:
|
||||
- Implement cache eviction policies (LRU, FIFO)
|
||||
- Add distributed caching across GPUs
|
||||
- Optimize memory layout for cache hits
|
||||
- Compress cached values (quantization)
|
||||
- Build cache warmup strategies
|
||||
|
||||
## Implementation Guide
|
||||
|
||||
### Core Components
|
||||
|
||||
**Understanding the Problem - Why Caching Helps**
|
||||
```python
|
||||
# WITHOUT KV caching (naive autoregressive generation):
|
||||
# Generate token 1: compute attention for [t0]
|
||||
# Generate token 2: compute attention for [t0, t1] ← recomputes t0
|
||||
# Generate token 3: compute attention for [t0, t1, t2] ← recomputes t0, t1
|
||||
# Generate token n: compute attention for [t0, ..., tn] ← recomputes everything
|
||||
#
|
||||
# Complexity: O(n²) - quadratic in sequence length
|
||||
# For 100 tokens: ~5000 attention operations
|
||||
|
||||
# WITH KV caching:
|
||||
# Generate token 1: compute K,V for [t0], cache them
|
||||
# Generate token 2: reuse cached K,V for t0, compute only for t1
|
||||
# Generate token 3: reuse cached K,V for t0,t1, compute only for t2
|
||||
# Generate token n: reuse all cached, compute only for tn
|
||||
#
|
||||
# Complexity: O(n) - linear in sequence length
|
||||
# For 100 tokens: ~100 attention operations (50× speedup!)
|
||||
```
|
||||
|
||||
**KV Cache Data Structure**
|
||||
```python
|
||||
class KVCache:
|
||||
"""Cache for attention keys and values.
|
||||
|
||||
Stores computed K,V matrices to avoid recomputation during
|
||||
autoregressive generation.
|
||||
|
||||
Memory layout:
|
||||
keys: (num_layers, batch, num_heads, seq_len, d_k)
|
||||
values: (num_layers, batch, num_heads, seq_len, d_v)
|
||||
|
||||
For GPT-2:
|
||||
12 layers × 12 heads × 1024 seq × 64 dims = ~9M values
|
||||
At FP16 (2 bytes): 18MB per batch item
|
||||
"""
|
||||
def __init__(self, num_layers, batch_size, num_heads, d_k, d_v, max_seq_len):
|
||||
self.num_layers = num_layers
|
||||
self.batch_size = batch_size
|
||||
self.num_heads = num_heads
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
# Pre-allocate cache tensors
|
||||
self.keys = {} # {layer_idx: (batch, heads, seq_len, d_k)}
|
||||
self.values = {} # {layer_idx: (batch, heads, seq_len, d_v)}
|
||||
|
||||
# Track current sequence length
|
||||
self.seq_len = 0
|
||||
|
||||
def append(self, layer_idx, new_keys, new_values):
|
||||
"""Append new keys/values to cache for a layer.
|
||||
|
||||
Args:
|
||||
layer_idx: Which transformer layer
|
||||
new_keys: (batch, heads, 1, d_k) - single new position
|
||||
new_values: (batch, heads, 1, d_v) - single new position
|
||||
"""
|
||||
if layer_idx not in self.keys:
|
||||
# Initialize cache for this layer
|
||||
self.keys[layer_idx] = new_keys
|
||||
self.values[layer_idx] = new_values
|
||||
else:
|
||||
# Concatenate with existing cache
|
||||
self.keys[layer_idx] = concat([self.keys[layer_idx], new_keys], dim=2)
|
||||
self.values[layer_idx] = concat([self.values[layer_idx], new_values], dim=2)
|
||||
|
||||
# Update sequence length (same across all layers)
|
||||
self.seq_len = self.keys[layer_idx].shape[2]
|
||||
|
||||
def get(self, layer_idx):
|
||||
"""Retrieve cached keys/values for a layer.
|
||||
|
||||
Returns:
|
||||
keys: (batch, heads, seq_len, d_k)
|
||||
values: (batch, heads, seq_len, d_v)
|
||||
"""
|
||||
return self.keys.get(layer_idx), self.values.get(layer_idx)
|
||||
|
||||
def clear(self):
|
||||
"""Clear all cached data."""
|
||||
self.keys.clear()
|
||||
self.values.clear()
|
||||
self.seq_len = 0
|
||||
|
||||
def memory_usage(self):
|
||||
"""Calculate cache memory usage in bytes."""
|
||||
total_elements = 0
|
||||
for k, v in zip(self.keys.values(), self.values.values()):
|
||||
total_elements += k.numel() + v.numel()
|
||||
# Assume FP16 (2 bytes per element)
|
||||
return total_elements * 2
|
||||
```
|
||||
|
||||
**Cached Attention Layer**
|
||||
```python
|
||||
class CachedMultiHeadAttention(MultiHeadAttention):
|
||||
"""Multi-head attention with KV caching support.
|
||||
|
||||
Extends MultiHeadAttention to cache K,V matrices during generation.
|
||||
"""
|
||||
def forward(self, query, key=None, value=None, kv_cache=None, layer_idx=None):
|
||||
"""Forward pass with optional KV caching.
|
||||
|
||||
Args:
|
||||
query: (batch, 1, d_model) - single new position
|
||||
key: (batch, seq_len, d_model) - optional, for initial pass
|
||||
value: (batch, seq_len, d_model) - optional, for initial pass
|
||||
kv_cache: KVCache object
|
||||
layer_idx: Which layer (for cache indexing)
|
||||
|
||||
Returns:
|
||||
output: (batch, 1, d_model) - attended output
|
||||
attention_weights: (batch, heads, 1, seq_len) - for analysis
|
||||
"""
|
||||
batch_size = query.shape[0]
|
||||
|
||||
# Project query for new position
|
||||
Q = self.W_q(query) # (batch, 1, d_model)
|
||||
Q = Q.reshape(batch_size, 1, self.num_heads, self.d_k).transpose(1, 2)
|
||||
# Q: (batch, heads, 1, d_k)
|
||||
|
||||
if kv_cache is not None and layer_idx is not None:
|
||||
# Check if cache exists for this layer
|
||||
cached_K, cached_V = kv_cache.get(layer_idx)
|
||||
|
||||
if cached_K is None:
|
||||
# First token: compute and cache K,V
|
||||
K = self.W_k(key)
|
||||
V = self.W_v(value)
|
||||
K = K.reshape(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
|
||||
V = V.reshape(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
|
||||
|
||||
# Cache for future tokens
|
||||
kv_cache.append(layer_idx, K, V)
|
||||
else:
|
||||
# Subsequent tokens: compute only new K,V, concat with cache
|
||||
new_K = self.W_k(key) # key is just new position
|
||||
new_V = self.W_v(value)
|
||||
new_K = new_K.reshape(batch_size, 1, self.num_heads, self.d_k).transpose(1, 2)
|
||||
new_V = new_V.reshape(batch_size, 1, self.num_heads, self.d_k).transpose(1, 2)
|
||||
|
||||
# Append to cache
|
||||
kv_cache.append(layer_idx, new_K, new_V)
|
||||
|
||||
# Use full cached K,V
|
||||
K, V = kv_cache.get(layer_idx)
|
||||
else:
|
||||
# No caching: regular attention
|
||||
K = self.W_k(key)
|
||||
V = self.W_v(value)
|
||||
K = K.reshape(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
|
||||
V = V.reshape(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
|
||||
|
||||
# Compute attention with cached K,V
|
||||
attended, attention_weights = scaled_dot_product_attention(Q, K, V)
|
||||
|
||||
# Reshape output
|
||||
attended = attended.transpose(1, 2).reshape(batch_size, 1, self.d_model)
|
||||
output = self.W_o(attended)
|
||||
|
||||
return output, attention_weights
|
||||
```
|
||||
|
||||
**Cached Generation - The Full Pipeline**
|
||||
```python
|
||||
def generate_with_cache(model, start_tokens, max_new_tokens, temperature=1.0):
|
||||
"""Autoregressive generation with KV caching.
|
||||
|
||||
Achieves 10-100× speedup over non-cached generation.
|
||||
|
||||
Args:
|
||||
model: Transformer with KV cache support
|
||||
start_tokens: (batch, start_len) initial sequence
|
||||
max_new_tokens: Number of tokens to generate
|
||||
temperature: Sampling temperature
|
||||
|
||||
Returns:
|
||||
generated: (batch, start_len + max_new_tokens) full sequence
|
||||
"""
|
||||
batch_size = start_tokens.shape[0]
|
||||
generated = start_tokens
|
||||
|
||||
# Initialize KV cache
|
||||
kv_cache = KVCache(
|
||||
num_layers=model.num_layers,
|
||||
batch_size=batch_size,
|
||||
num_heads=model.num_heads,
|
||||
d_k=model.d_k,
|
||||
d_v=model.d_k,
|
||||
max_seq_len=start_tokens.shape[1] + max_new_tokens
|
||||
)
|
||||
|
||||
# Process initial sequence (fills cache)
|
||||
_ = model.forward(start_tokens, kv_cache=kv_cache)
|
||||
|
||||
# Generate tokens one at a time (uses cache)
|
||||
for _ in range(max_new_tokens):
|
||||
# Forward pass on ONLY the last token
|
||||
# Cache provides context from all previous tokens
|
||||
last_token = generated[:, -1:] # (batch, 1)
|
||||
logits = model.forward(last_token, kv_cache=kv_cache) # (batch, 1, vocab_size)
|
||||
|
||||
# Sample next token
|
||||
next_token_logits = logits[:, -1, :] / temperature
|
||||
probs = softmax(next_token_logits, dim=-1)
|
||||
next_token = sample(probs)
|
||||
|
||||
# Append to sequence
|
||||
generated = concat([generated, next_token], dim=1)
|
||||
|
||||
return generated
|
||||
```
|
||||
|
||||
### Step-by-Step Implementation
|
||||
|
||||
1. **Design KV Cache Structure**
|
||||
- Create storage for keys and values per layer
|
||||
- Support appending new keys/values efficiently
|
||||
- Add retrieval and clearing methods
|
||||
- Calculate memory usage
|
||||
|
||||
2. **Modify Attention for Caching**
|
||||
- Add KV cache parameter to forward pass
|
||||
- Check if cache exists for current layer
|
||||
- Compute only new K,V when cache present
|
||||
- Concat new K,V with cached values
|
||||
|
||||
3. **Implement Cached Generation**
|
||||
- Initialize cache before generation loop
|
||||
- Process initial tokens (fill cache)
|
||||
- Generate new tokens using cached context
|
||||
- Measure speedup vs non-cached
|
||||
|
||||
4. **Add Cache Management**
|
||||
- Implement cache clearing between conversations
|
||||
- Add cache size limits and eviction
|
||||
- Support batch processing with caching
|
||||
- Handle variable sequence lengths
|
||||
|
||||
5. **Optimize Memory Layout**
|
||||
- Use contiguous tensors for cache hits
|
||||
- Implement FP16 caching for memory savings
|
||||
- Add cache compression (quantization)
|
||||
- Profile memory bandwidth bottlenecks
|
||||
|
||||
## Testing
|
||||
|
||||
### Inline Tests (During Development)
|
||||
|
||||
Run inline tests while building:
|
||||
```bash
|
||||
cd modules/source/14_kvcaching
|
||||
python kvcaching_dev.py
|
||||
```
|
||||
|
||||
Expected output:
|
||||
```
|
||||
Unit Test: KV cache data structure...
|
||||
✅ Cache initialization successful
|
||||
✅ Append and retrieval work correctly
|
||||
✅ Memory usage calculated: 18MB per batch
|
||||
Progress: KV Cache ✓
|
||||
|
||||
Unit Test: Cached attention...
|
||||
✅ First token: K,V computed and cached
|
||||
✅ Subsequent tokens: reuse cached K,V
|
||||
✅ Attention output matches non-cached version
|
||||
Progress: Cached Attention ✓
|
||||
|
||||
Unit Test: Generation with caching...
|
||||
✅ Generated 100 tokens with caching
|
||||
✅ Speedup: 47× faster than without cache
|
||||
✅ Output quality: identical to non-cached
|
||||
Progress: Cached Generation ✓
|
||||
```
|
||||
|
||||
### Export and Validate
|
||||
|
||||
After completing the module:
|
||||
```bash
|
||||
# Export to tinytorch package
|
||||
tito export 14_kvcaching
|
||||
|
||||
# Run integration tests
|
||||
tito test 14_kvcaching
|
||||
```
|
||||
|
||||
## Where This Code Lives
|
||||
|
||||
```
|
||||
tinytorch/
|
||||
├── nn/
|
||||
│ └── kvcache.py # Your implementation goes here
|
||||
└── __init__.py # Exposes KVCache, CachedMultiHeadAttention
|
||||
|
||||
Usage in other modules:
|
||||
>>> from tinytorch.nn import KVCache, CachedMultiHeadAttention
|
||||
>>> cache = KVCache(num_layers=12, batch_size=1, num_heads=12, d_k=64, d_v=64, max_seq_len=1024)
|
||||
>>> generated = generate_with_cache(model, start_tokens, max_new_tokens=100)
|
||||
```
|
||||
|
||||
## Systems Thinking Questions
|
||||
|
||||
1. **Memory-Speed Trade-off**: KV cache uses 18MB per batch for GPT-2. For batch=32, that's 576MB. What if you have 8GB GPU? How many concurrent users can you serve? What's the trade-off?
|
||||
|
||||
2. **Cache Invalidation**: In multi-turn chat, when should you clear the cache? What if context exceeds max_seq_len? How do production systems handle this?
|
||||
|
||||
3. **Distributed Caching**: For models too large for one GPU, you need tensor parallelism. How do you partition the KV cache across GPUs? What's the communication overhead?
|
||||
|
||||
4. **Quantized Caching**: Storing cache in INT8 instead of FP16 saves 50% memory. What's the accuracy impact? When is this worth it?
|
||||
|
||||
5. **Speculation and Prefetching**: What if you predict the next query and pre-compute KV cache? How would you implement speculative caching?
|
||||
|
||||
## Real-World Connections
|
||||
|
||||
### Industry Applications
|
||||
|
||||
**Conversational AI (OpenAI ChatGPT, Anthropic Claude)**
|
||||
- KV caching for all multi-turn conversations
|
||||
- Cache eviction policies for context window limits
|
||||
- Memory-speed trade-offs define pricing ($/1M tokens)
|
||||
- Without caching, latency would be 50-100× worse
|
||||
|
||||
**Code Completion (GitHub Copilot, Cursor)**
|
||||
- Real-time caching of code context
|
||||
- Incremental updates as user types
|
||||
- Low-latency requirements (< 100ms) mandate caching
|
||||
- Cache hit rates directly impact user experience
|
||||
|
||||
**Search and Retrieval (Perplexity, Bing AI)**
|
||||
- Cache document embeddings and attention
|
||||
- Multi-stage caching (retrieval + generation)
|
||||
- Distributed caching across data centers
|
||||
- Cache warmup for popular queries
|
||||
|
||||
### Research Impact
|
||||
|
||||
This module implements patterns from:
|
||||
- GPT-2 (2019): First large-scale use of KV caching
|
||||
- Megatron-LM (2020): Distributed KV caching across GPUs
|
||||
- FlashAttention (2022): Memory-efficient attention without full caching
|
||||
- PagedAttention (2023): Virtual memory for KV cache management
|
||||
|
||||
## What's Next?
|
||||
|
||||
In **Module 15: Profiling**, you'll measure where time goes in your transformer:
|
||||
|
||||
- Profile attention, feedforward, and embedding operations
|
||||
- Identify computational bottlenecks beyond caching
|
||||
- Measure FLOPs, memory bandwidth, and latency
|
||||
- Understand performance characteristics across architectures
|
||||
|
||||
The caching you implemented solves the biggest inference bottleneck—now let's find what else to optimize!
|
||||
|
||||
---
|
||||
|
||||
*This module is currently under development. The implementation will cover practical caching strategies used in production ML systems.*
|
||||
**Ready to implement production-critical caching?** Open `modules/source/14_kvcaching/kvcaching_dev.py` and start implementing.
|
||||
|
||||
Reference in New Issue
Block a user