mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-03-12 02:09:16 -05:00
- Shows O(n²) latency growth in transformer generation - Demonstrates problem before teaching solution - Prepares module for reorganization to Module 15
1471 lines
59 KiB
Python
1471 lines
59 KiB
Python
# ---
|
||
# jupyter:
|
||
# jupytext:
|
||
# text_representation:
|
||
# extension: .py
|
||
# format_name: percent
|
||
# format_version: '1.3'
|
||
# jupytext_version: 1.17.1
|
||
# kernelspec:
|
||
# display_name: Python 3 (ipykernel)
|
||
# language: python
|
||
# name: python3
|
||
# ---
|
||
|
||
# %% [markdown]
|
||
"""
|
||
# Module 14: KV Caching - Optimizing Autoregressive Generation
|
||
|
||
Welcome to Module 14! You'll implement the critical optimization that makes production language models possible: Key-Value caching for 10-15x faster text generation.
|
||
|
||
## 🔗 Prerequisites & Progress
|
||
**You've Built**: Complete transformer architecture with multi-head attention and text generation
|
||
**You'll Build**: Memory-efficient KV caching system that eliminates redundant computation
|
||
**You'll Enable**: Production-grade inference optimization and real-world serving capabilities
|
||
|
||
**Connection Map**:
|
||
```
|
||
Transformers → KV Caching → Production Serving
|
||
(slow O(n²)) (fast O(n)) (real-world scale)
|
||
```
|
||
|
||
## Learning Objectives
|
||
By the end of this module, you will:
|
||
1. Understand why autoregressive generation has O(n²) complexity without caching
|
||
2. Implement KVCache with efficient memory management and O(1) updates
|
||
3. Build cache-aware attention that reuses previously computed keys and values
|
||
4. Measure dramatic speedup gains (10-15x) and understand memory trade-offs
|
||
5. Connect to production optimization patterns used in real LLM serving
|
||
|
||
Let's make inference blazingly fast!
|
||
|
||
## 📦 Where This Code Lives in the Final Package
|
||
|
||
**Learning Side:** You work in `modules/14_kvcaching/kvcaching_dev.py`
|
||
**Building Side:** Code exports to `tinytorch.generation.kv_cache`
|
||
|
||
```python
|
||
# How to use this module:
|
||
from tinytorch.generation.kv_cache import KVCache, enable_kv_cache
|
||
```
|
||
|
||
**Why this matters:**
|
||
- **Learning:** Complete caching system demonstrating production optimization techniques
|
||
- **Production:** Proper organization matching Hugging Face's generation/ module structure
|
||
- **Consistency:** All generation optimizations in generation.kv_cache
|
||
- **Integration:** Works seamlessly with transformers for complete inference optimization
|
||
"""
|
||
|
||
# %%
|
||
#| default_exp generation.kv_cache
|
||
#| export
|
||
|
||
import numpy as np
|
||
import time
|
||
from typing import Tuple, Optional, Dict, List
|
||
|
||
# Import TinyTorch components from previous modules
|
||
from tinytorch.core.tensor import Tensor
|
||
|
||
# %% [markdown]
|
||
"""
|
||
## 🔬 Motivation: Why Memoization Matters for Transformers
|
||
|
||
Before we learn KV caching, let's profile transformer generation to understand
|
||
the problem we're solving. We'll see O(n²) growth in latency as we generate text.
|
||
"""
|
||
|
||
# %%
|
||
# Profile transformer generation to discover the bottleneck
|
||
from tinytorch.profiling.profiler import Profiler
|
||
import matplotlib.pyplot as plt
|
||
|
||
profiler = Profiler()
|
||
|
||
def naive_attention_step(seq_len, hidden_dim=64):
|
||
"""
|
||
Simulates one step of attention computation.
|
||
Without caching, this processes ALL previous tokens every time.
|
||
"""
|
||
# Q, K, V for entire sequence
|
||
q = Tensor(np.random.randn(1, seq_len, hidden_dim))
|
||
k = Tensor(np.random.randn(1, seq_len, hidden_dim))
|
||
v = Tensor(np.random.randn(1, seq_len, hidden_dim))
|
||
|
||
# Attention: Q @ K.T then @ V
|
||
# This is O(seq_len²) in complexity
|
||
scores = q @ k.T # (1, seq_len, seq_len)
|
||
output = scores @ v
|
||
|
||
return output
|
||
|
||
# Profile at increasing sequence lengths
|
||
print("🔬 Profiling Transformer Generation (Without Caching):\n")
|
||
print(" Seq Len | Latency (ms) | Growth")
|
||
print(" ---------|----------------|----------")
|
||
|
||
sequence_lengths = [10, 20, 40, 80, 160]
|
||
latencies = []
|
||
|
||
for seq_len in sequence_lengths:
|
||
# Measure latency for this sequence length
|
||
latency = profiler.measure_latency(
|
||
lambda: naive_attention_step(seq_len),
|
||
None,
|
||
warmup=5,
|
||
iterations=20
|
||
)
|
||
latencies.append(latency)
|
||
|
||
# Calculate growth rate
|
||
if len(latencies) > 1:
|
||
growth = latencies[-1] / latencies[-2]
|
||
print(f" {seq_len:3d} | {latency:6.2f} | {growth:.2f}×")
|
||
else:
|
||
print(f" {seq_len:3d} | {latency:6.2f} | baseline")
|
||
|
||
print("\n💡 Key Observations:")
|
||
print(" • Latency grows QUADRATICALLY with sequence length")
|
||
print(" • Each new token forces recomputation of ALL previous K,V pairs")
|
||
print(" • For 160 tokens: ~4× time vs 80 tokens (2² growth)")
|
||
|
||
print("\n🎯 The Problem:")
|
||
print(" K and V values for previous tokens NEVER change,")
|
||
print(" yet we recompute them every single step!")
|
||
|
||
print("\n✨ The Solution:")
|
||
print(" CACHE the K,V values! (That's memoization)")
|
||
print(" • First compute: Calculate and store K,V")
|
||
print(" • Later steps: Reuse stored K,V")
|
||
print(" • Complexity: O(n²) → O(n)")
|
||
print(" • Speedup: 10-15× for typical generation\n")
|
||
|
||
# %% [markdown]
|
||
"""
|
||
## 🎯 Part 1: Understanding the Autoregressive Generation Problem
|
||
|
||
### The Core Inefficiency
|
||
|
||
When generating text token by token, transformers face a fundamental computational bottleneck. Let's visualize what happens during naive generation:
|
||
|
||
```
|
||
Token Generation Process (Without Caching):
|
||
|
||
Step 1: Generate "Hello"
|
||
Input: [START]
|
||
Attention: Q₁ × [K₁] × [V₁] ← 1 computation
|
||
|
||
Step 2: Generate "world"
|
||
Input: [START, Hello]
|
||
Attention: Q₂ × [K₁, K₂] × [V₁, V₂] ← 2 computations (K₁,V₁ RECOMPUTED!)
|
||
|
||
Step 3: Generate "!"
|
||
Input: [START, Hello, world]
|
||
Attention: Q₃ × [K₁, K₂, K₃] × [V₁, V₂, V₃] ← 3 computations (K₁,V₁,K₂,V₂ RECOMPUTED!)
|
||
```
|
||
|
||
**The Problem**: For each new token, we recompute ALL previous key-value pairs even though they never change!
|
||
|
||
### Computational Complexity Analysis
|
||
|
||
```
|
||
Naive Generation Complexity:
|
||
Step 1: 1 K,V computation
|
||
Step 2: 2 K,V computations
|
||
Step 3: 3 K,V computations
|
||
...
|
||
Step n: n K,V computations
|
||
|
||
Total: 1 + 2 + 3 + ... + n = n(n+1)/2 = O(n²) complexity!
|
||
```
|
||
|
||
For a 100-token sequence, this means **5,050 redundant computations**!
|
||
|
||
### Real-World Impact
|
||
|
||
This inefficiency makes production LLM serving economically impossible without optimization:
|
||
- **ChatGPT/GPT-4**: Would be too slow for real-time chat without caching
|
||
- **Code completion**: IDEs couldn't provide instant suggestions
|
||
- **Mobile deployment**: On-device generation would drain batteries instantly
|
||
- **API serving**: Server costs would be 10x+ higher
|
||
|
||
**The Solution**: Cache key-value pairs after computing them once, transforming O(n²) into O(n).
|
||
"""
|
||
|
||
# %% [markdown]
|
||
"""
|
||
## 🧮 Part 2: The Key-Value Caching Insight
|
||
|
||
### Mathematical Foundation
|
||
|
||
The core insight comes from understanding what changes during autoregressive generation:
|
||
|
||
```
|
||
Attention Computation Breakdown:
|
||
|
||
Q = new_token @ W_q ← Only new token (changes each step)
|
||
K = all_tokens @ W_k ← Includes old tokens (mostly redundant!)
|
||
V = all_tokens @ W_v ← Includes old tokens (mostly redundant!)
|
||
|
||
attention_output = softmax(Q @ K.T / √d_k) @ V
|
||
```
|
||
|
||
**Key Insight**: K and V matrices for previous tokens NEVER change!
|
||
|
||
```
|
||
Token Dependencies:
|
||
K₁ = token₁ @ W_k ← Computed once, never changes
|
||
K₂ = token₂ @ W_k ← Computed once, never changes
|
||
K₃ = token₃ @ W_k ← Computed once, never changes
|
||
|
||
Same for V₁, V₂, V₃...
|
||
```
|
||
|
||
### Cache-Optimized Generation
|
||
|
||
```
|
||
Optimized Generation Process (With Caching):
|
||
|
||
Step 1: Generate "Hello"
|
||
Compute: K₁, V₁ → Store in cache
|
||
Attention: Q₁ × cached[K₁] × cached[V₁]
|
||
|
||
Step 2: Generate "world"
|
||
Compute: K₂, V₂ → Append to cache
|
||
Attention: Q₂ × cached[K₁, K₂] × cached[V₁, V₂]
|
||
|
||
Step 3: Generate "!"
|
||
Compute: K₃, V₃ → Append to cache
|
||
Attention: Q₃ × cached[K₁, K₂, K₃] × cached[V₁, V₂, V₃]
|
||
```
|
||
|
||
**Result**: Each step computes only ONE new K,V pair instead of recomputing ALL!
|
||
|
||
### Memory vs Compute Trade-off
|
||
|
||
```
|
||
Traditional Approach:
|
||
Memory: O(1) (no storage needed)
|
||
Compute: O(n²) (recompute everything)
|
||
|
||
Cached Approach:
|
||
Memory: O(n × d_k) (store all K,V pairs)
|
||
Compute: O(n) (only compute new pairs)
|
||
|
||
For n=100, d_k=64:
|
||
Memory cost: 6.4 KB per layer
|
||
Compute savings: 50x reduction in K,V computations
|
||
```
|
||
|
||
**Trade-off Winner**: Memory is cheap, compute is expensive! Use O(n) memory to save O(n²) compute.
|
||
"""
|
||
|
||
# %% [markdown]
|
||
"""
|
||
## 🏗️ Part 3: KVCache Class Implementation
|
||
|
||
### Core Requirements
|
||
|
||
Our KVCache needs to efficiently handle:
|
||
|
||
1. **Multi-layer storage**: Each transformer layer needs its own K,V cache
|
||
2. **Multi-head attention**: Each attention head has separate K,V pairs
|
||
3. **Batch processing**: Support multiple sequences simultaneously (batch inference)
|
||
4. **Dynamic updates**: Efficiently append new tokens without copying data
|
||
5. **Memory management**: Pre-allocate space to avoid dynamic resizing overhead
|
||
|
||
### Cache Architecture Visualization
|
||
|
||
```
|
||
KVCache Memory Layout:
|
||
┌─────────────────────────────────────────────────────────┐
|
||
│ KVCache Object │
|
||
├─────────────────────────────────────────────────────────┤
|
||
│ Layer 0: ┌─────────────┬─────────────┐ │
|
||
│ │ Key Cache │ Value Cache │ │
|
||
│ │ (B,H,S,D) │ (B,H,S,D) │ │
|
||
│ └─────────────┴─────────────┘ │
|
||
├─────────────────────────────────────────────────────────┤
|
||
│ Layer 1: ┌─────────────┬─────────────┐ │
|
||
│ │ Key Cache │ Value Cache │ │
|
||
│ │ (B,H,S,D) │ (B,H,S,D) │ │
|
||
│ └─────────────┴─────────────┘ │
|
||
├─────────────────────────────────────────────────────────┤
|
||
│ ... ┌─────────────┬─────────────┐ │
|
||
│ Layer N: │ Key Cache │ Value Cache │ │
|
||
│ │ (B,H,S,D) │ (B,H,S,D) │ │
|
||
│ └─────────────┴─────────────┘ │
|
||
└─────────────────────────────────────────────────────────┘
|
||
|
||
Where:
|
||
B = batch_size (number of sequences)
|
||
H = num_heads (attention heads per layer)
|
||
S = max_seq_len (maximum sequence length)
|
||
D = head_dim (dimension per attention head)
|
||
```
|
||
|
||
### Update Operation Flow
|
||
|
||
```
|
||
Cache Update Process:
|
||
seq_pos = 2
|
||
↓
|
||
┌─────┬─────┬─────┬─────┬─────┬─────┐
|
||
│ K₁ │ K₂ │ ??? │ ??? │ ??? │ ??? │ ← Key Cache
|
||
├─────┼─────┼─────┼─────┼─────┼─────┤
|
||
│ V₁ │ V₂ │ ??? │ ??? │ ??? │ ??? │ ← Value Cache
|
||
└─────┴─────┴─────┴─────┴─────┴─────┘
|
||
|
||
New token arrives: K₃, V₃
|
||
|
||
seq_pos = 2
|
||
↓
|
||
┌─────┬─────┬─────┬─────┬─────┬─────┐
|
||
│ K₁ │ K₂ │ K₃ │ ??? │ ??? │ ??? │ ← Write K₃ here
|
||
├─────┼─────┼─────┼─────┼─────┼─────┤
|
||
│ V₁ │ V₂ │ V₃ │ ??? │ ??? │ ??? │ ← Write V₃ here
|
||
└─────┴─────┴─────┴─────┴─────┴─────┘
|
||
|
||
Then: seq_pos += 1 (advance to position 3)
|
||
```
|
||
|
||
This design enables **O(1) updates** - just write to the next position!
|
||
"""
|
||
|
||
# %% nbgrader={"grade": false, "grade_id": "kvcache-class", "solution": true}
|
||
#| export
|
||
class KVCache:
|
||
"""
|
||
Efficient key-value cache for autoregressive generation.
|
||
|
||
Stores K,V matrices for each transformer layer to avoid recomputation
|
||
during sequential token generation. This is THE critical optimization
|
||
that makes production language model serving economically viable.
|
||
|
||
⚠️ IMPORTANT: INFERENCE-ONLY (No Gradient Tracking)
|
||
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
||
KV caching is designed ONLY for inference (generation), NOT training.
|
||
- During generation: No gradients computed (model.eval() mode)
|
||
- Cache operations use .data (no gradient tracking)
|
||
- This is correct and intentional for maximum speed
|
||
- DO NOT use caching during training (use standard forward pass)
|
||
|
||
Architecture:
|
||
- Pre-allocates cache tensors with maximum sequence length
|
||
- Tracks current sequence position for efficient O(1) updates
|
||
- Provides update() method to append new K,V pairs without copying
|
||
- Provides get() method to retrieve cached values for attention
|
||
- Handles multiple layers and attention heads properly
|
||
|
||
Memory Layout:
|
||
```
|
||
Layer 0: [Key_cache, Value_cache] # Shape: (batch, num_heads, max_seq, head_dim)
|
||
Layer 1: [Key_cache, Value_cache]
|
||
...
|
||
Layer N: [Key_cache, Value_cache]
|
||
```
|
||
|
||
Performance:
|
||
- Update: O(1) - just index assignment
|
||
- Get: O(1) - just slicing (no data copy)
|
||
- Memory: O(num_layers × batch × heads × max_seq × head_dim)
|
||
"""
|
||
|
||
def __init__(self, batch_size: int, max_seq_len: int, num_layers: int,
|
||
num_heads: int, head_dim: int):
|
||
"""
|
||
Initialize KV cache for efficient generation.
|
||
|
||
TODO: Set up pre-allocated cache storage for all transformer layers
|
||
|
||
APPROACH:
|
||
1. Store configuration parameters (batch_size, max_seq_len, etc.)
|
||
2. Initialize sequence position counter to 0
|
||
3. Create empty list for cache storage
|
||
4. For each layer, pre-allocate zero-filled key and value caches
|
||
5. Store each layer's (key_cache, value_cache) tuple in the list
|
||
|
||
Args:
|
||
batch_size: Number of sequences to generate simultaneously
|
||
max_seq_len: Maximum sequence length to support
|
||
num_layers: Number of transformer layers
|
||
num_heads: Number of attention heads per layer
|
||
head_dim: Dimension of each attention head
|
||
|
||
EXAMPLE:
|
||
>>> cache = KVCache(batch_size=2, max_seq_len=128, num_layers=4,
|
||
... num_heads=8, head_dim=64)
|
||
>>> cache.seq_pos # 0 (no tokens cached yet)
|
||
>>> len(cache.caches) # 4 (one per layer)
|
||
>>> cache.caches[0][0].shape # (2, 8, 128, 64) - key cache for layer 0
|
||
|
||
HINTS:
|
||
- Cache shape: (batch_size, num_heads, max_seq_len, head_dim)
|
||
- Use Tensor(np.zeros(...)) to create cache tensors
|
||
- Store caches as list of tuples: [(key_0, val_0), (key_1, val_1), ...]
|
||
- Pre-allocation avoids dynamic resizing overhead during generation
|
||
"""
|
||
### BEGIN SOLUTION
|
||
self.batch_size = batch_size
|
||
self.max_seq_len = max_seq_len
|
||
self.num_layers = num_layers
|
||
self.num_heads = num_heads
|
||
self.head_dim = head_dim
|
||
|
||
# Current sequence position (how many tokens are cached)
|
||
self.seq_pos = 0
|
||
|
||
# Cache storage: list of (key_cache, value_cache) tuples per layer
|
||
self.caches = []
|
||
|
||
for layer_idx in range(num_layers):
|
||
# Pre-allocate cache tensors with maximum size
|
||
# Shape: (batch_size, num_heads, max_seq_len, head_dim)
|
||
key_cache = Tensor(np.zeros((batch_size, num_heads, max_seq_len, head_dim)))
|
||
value_cache = Tensor(np.zeros((batch_size, num_heads, max_seq_len, head_dim)))
|
||
|
||
self.caches.append((key_cache, value_cache))
|
||
### END SOLUTION
|
||
|
||
def update(self, layer_idx: int, key: Tensor, value: Tensor) -> None:
|
||
"""
|
||
Update cache with new key-value pairs for given layer.
|
||
|
||
TODO: Efficiently append new K,V to cache without data copying
|
||
|
||
APPROACH:
|
||
1. Validate layer_idx is in range [0, num_layers-1]
|
||
2. Validate seq_pos hasn't exceeded max_seq_len
|
||
3. Retrieve the (key_cache, value_cache) tuple for this layer
|
||
4. Write new key to position seq_pos in key_cache using indexed assignment
|
||
5. Write new value to position seq_pos in value_cache using indexed assignment
|
||
6. Note: seq_pos is advanced externally via advance() after all layers
|
||
|
||
This is the core caching operation - efficiently append new K,V
|
||
to the cache without recomputation. This operation is O(1) because
|
||
it's just an indexed assignment.
|
||
|
||
IMPORTANT: KV caching is designed for INFERENCE (generation) only,
|
||
not training. During generation, gradients are not computed. If you
|
||
need gradients, don't use caching (use standard forward pass instead).
|
||
|
||
Args:
|
||
layer_idx: Which transformer layer (0 to num_layers-1)
|
||
key: New key tensor, shape (batch_size, num_heads, 1, head_dim)
|
||
value: New value tensor, shape (batch_size, num_heads, 1, head_dim)
|
||
|
||
EXAMPLE:
|
||
>>> cache = KVCache(batch_size=1, max_seq_len=10, num_layers=2,
|
||
... num_heads=4, head_dim=64)
|
||
>>> new_k = Tensor(np.random.randn(1, 4, 1, 64))
|
||
>>> new_v = Tensor(np.random.randn(1, 4, 1, 64))
|
||
>>> cache.update(layer_idx=0, key=new_k, value=new_v)
|
||
>>> cache.seq_pos # Still 0 (update doesn't advance position)
|
||
>>> cache.advance()
|
||
>>> cache.seq_pos # Now 1
|
||
|
||
HINTS:
|
||
- Use slicing: cache[:, :, seq_pos:seq_pos+1, :] to write to position
|
||
- Use .data for direct NumPy access (no gradient tracking needed)
|
||
- Raise ValueError with helpful messages for invalid inputs
|
||
- This is an in-place operation (modifies cache, returns None)
|
||
|
||
Raises:
|
||
ValueError: If layer_idx is out of range or sequence is full
|
||
"""
|
||
### BEGIN SOLUTION
|
||
if layer_idx >= self.num_layers:
|
||
raise ValueError(f"Layer index {layer_idx} >= num_layers {self.num_layers}")
|
||
|
||
if self.seq_pos >= self.max_seq_len:
|
||
raise ValueError(f"Sequence position {self.seq_pos} >= max_seq_len {self.max_seq_len}")
|
||
|
||
# Get cache for this layer
|
||
key_cache, value_cache = self.caches[layer_idx]
|
||
|
||
# Update cache at current position (efficient O(1) write)
|
||
# Note: We use .data here because caching is inference-only (no gradients needed)
|
||
# This avoids gradient tracking overhead during generation
|
||
key_cache.data[:, :, self.seq_pos:self.seq_pos+1, :] = key.data
|
||
value_cache.data[:, :, self.seq_pos:self.seq_pos+1, :] = value.data
|
||
|
||
# Note: seq_pos is advanced externally via advance() after all layers process
|
||
### END SOLUTION
|
||
|
||
def get(self, layer_idx: int) -> Tuple[Tensor, Tensor]:
|
||
"""
|
||
Retrieve cached key-value pairs for attention computation.
|
||
|
||
TODO: Return only the valid cached portion for this layer
|
||
|
||
APPROACH:
|
||
1. Validate layer_idx is in range
|
||
2. Retrieve the (key_cache, value_cache) tuple for this layer
|
||
3. Calculate valid_len = seq_pos (number of tokens currently cached)
|
||
4. Slice key_cache to get [:, :, :valid_len, :] (only filled portion)
|
||
5. Slice value_cache to get [:, :, :valid_len, :] (only filled portion)
|
||
6. Wrap sliced data in new Tensor objects and return
|
||
|
||
Returns only the valid portion of the cache (up to current seq_pos).
|
||
This is O(1) because we're just slicing NumPy arrays (view, not copy).
|
||
|
||
IMPORTANT: Returns Tensors without gradient tracking since caching
|
||
is inference-only. The returned tensors can be used in attention
|
||
computation but won't propagate gradients backward.
|
||
|
||
Args:
|
||
layer_idx: Which transformer layer to get cache for
|
||
|
||
Returns:
|
||
(cached_keys, cached_values): Tensors shaped for attention
|
||
Keys: (batch_size, num_heads, seq_pos, head_dim)
|
||
Values: (batch_size, num_heads, seq_pos, head_dim)
|
||
|
||
EXAMPLE:
|
||
>>> cache = KVCache(batch_size=1, max_seq_len=100, num_layers=2,
|
||
... num_heads=4, head_dim=64)
|
||
>>> # After processing 3 tokens
|
||
>>> cache.seq_pos = 3
|
||
>>> cached_k, cached_v = cache.get(layer_idx=0)
|
||
>>> cached_k.shape # (1, 4, 3, 64) - only first 3 positions
|
||
>>> cached_v.shape # (1, 4, 3, 64)
|
||
|
||
HINTS:
|
||
- valid_len = self.seq_pos (how many tokens have been cached so far)
|
||
- Use slicing: cache.data[:, :, :valid_len, :] to get valid portion
|
||
- Wrap result in Tensor() for consistency with TinyTorch API
|
||
- If seq_pos=0, returns empty cache (shape with 0 in sequence dimension)
|
||
|
||
Raises:
|
||
ValueError: If layer_idx is out of range
|
||
"""
|
||
### BEGIN SOLUTION
|
||
if layer_idx >= self.num_layers:
|
||
raise ValueError(f"Layer index {layer_idx} >= num_layers {self.num_layers}")
|
||
|
||
# Get cache for this layer
|
||
key_cache, value_cache = self.caches[layer_idx]
|
||
|
||
# Return only the valid portion (up to current sequence position)
|
||
# seq_pos tracks where to write next, so we have seq_pos valid tokens
|
||
valid_len = self.seq_pos
|
||
|
||
# Note: Creating new Tensors from .data (no gradient tracking)
|
||
# This is correct for inference-only caching
|
||
cached_keys = Tensor(key_cache.data[:, :, :valid_len, :])
|
||
cached_values = Tensor(value_cache.data[:, :, :valid_len, :])
|
||
|
||
return cached_keys, cached_values
|
||
### END SOLUTION
|
||
|
||
def advance(self) -> None:
|
||
"""
|
||
Advance sequence position after processing current token.
|
||
|
||
Call this after all layers have processed the current token and
|
||
updated their caches. This moves the write pointer forward.
|
||
"""
|
||
self.seq_pos += 1
|
||
|
||
def reset(self) -> None:
|
||
"""
|
||
Reset cache for new generation sequence.
|
||
|
||
Call this when starting a new generation (new prompt).
|
||
Resets the sequence position counter and optionally zeros cache data.
|
||
"""
|
||
self.seq_pos = 0
|
||
|
||
# Zero out caches for clean state (helps with debugging)
|
||
for layer_idx in range(self.num_layers):
|
||
key_cache, value_cache = self.caches[layer_idx]
|
||
key_cache.data.fill(0.0)
|
||
value_cache.data.fill(0.0)
|
||
|
||
def get_memory_usage(self) -> Dict[str, float]:
|
||
"""
|
||
Calculate memory usage of the cache system.
|
||
|
||
Returns:
|
||
Dictionary with memory statistics in MB
|
||
"""
|
||
# Calculate size of one cache tensor
|
||
cache_size = self.batch_size * self.num_heads * self.max_seq_len * self.head_dim
|
||
bytes_per_float = 4 # float32
|
||
|
||
# Each layer has key_cache + value_cache
|
||
total_cache_tensors = self.num_layers * 2
|
||
total_elements = cache_size * total_cache_tensors
|
||
total_bytes = total_elements * bytes_per_float
|
||
total_mb = total_bytes / (1024 * 1024)
|
||
|
||
return {
|
||
'total_mb': total_mb,
|
||
'per_layer_mb': total_mb / self.num_layers,
|
||
'cache_tensors': total_cache_tensors,
|
||
'total_elements': total_elements
|
||
}
|
||
|
||
# %% [markdown]
|
||
"""
|
||
### 🧪 Unit Test: KVCache Implementation
|
||
|
||
Let's test that our cache correctly stores and retrieves key-value pairs across multiple layers and sequence positions.
|
||
|
||
**This is a unit test** - it tests the KVCache class in isolation with simulated attention keys and values.
|
||
"""
|
||
|
||
# %% nbgrader={"grade": true, "grade_id": "test-kvcache", "locked": true, "points": 10}
|
||
def test_unit_kvcache():
|
||
"""🔬 Unit Test: KVCache Implementation"""
|
||
print("🔬 Unit Test: KVCache Implementation...")
|
||
|
||
# Test parameters (small transformer for testing)
|
||
batch_size, max_seq_len = 2, 8
|
||
num_layers, num_heads, head_dim = 3, 4, 16
|
||
|
||
# Create cache
|
||
cache = KVCache(batch_size, max_seq_len, num_layers, num_heads, head_dim)
|
||
|
||
# Test 1: Initial state
|
||
assert cache.seq_pos == 0, "Cache should start at position 0"
|
||
mem_usage = cache.get_memory_usage()
|
||
assert mem_usage['total_mb'] > 0, "Cache should have non-zero memory usage"
|
||
print(f" Cache initialized: {mem_usage['total_mb']:.2f} MB")
|
||
|
||
# Test 2: Single token update and retrieval
|
||
key1 = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim))
|
||
value1 = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim))
|
||
|
||
# Update layer 0 with first token
|
||
cache.update(0, key1, value1)
|
||
|
||
# Before advance, get() should return empty (seq_pos=0)
|
||
cached_k, cached_v = cache.get(0)
|
||
assert cached_k.shape == (batch_size, num_heads, 0, head_dim), "Before advance, cache should be empty"
|
||
|
||
# Advance position
|
||
cache.advance()
|
||
|
||
# Now cache should have 1 token
|
||
cached_k, cached_v = cache.get(0)
|
||
assert cached_k.shape == (batch_size, num_heads, 1, head_dim), f"Expected shape (2,4,1,16), got {cached_k.shape}"
|
||
assert cached_v.shape == (batch_size, num_heads, 1, head_dim), f"Expected shape (2,4,1,16), got {cached_v.shape}"
|
||
|
||
# Test 3: Multi-token sequence
|
||
key2 = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim))
|
||
value2 = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim))
|
||
cache.update(0, key2, value2)
|
||
cache.advance()
|
||
|
||
cached_k, cached_v = cache.get(0)
|
||
assert cached_k.shape == (batch_size, num_heads, 2, head_dim), "Should have 2 tokens cached"
|
||
assert cached_v.shape == (batch_size, num_heads, 2, head_dim), "Should have 2 tokens cached"
|
||
|
||
# Test 4: Multiple layers
|
||
cache.reset()
|
||
key_test = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim))
|
||
value_test = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim))
|
||
|
||
# Update all layers with same token
|
||
cache.update(0, key_test, value_test) # Layer 0
|
||
cache.update(1, key_test, value_test) # Layer 1
|
||
cache.update(2, key_test, value_test) # Layer 2
|
||
cache.advance()
|
||
|
||
# Each layer should have the cached token
|
||
for layer_idx in range(num_layers):
|
||
cached_k, cached_v = cache.get(layer_idx)
|
||
assert cached_k.shape[2] == 1, f"Layer {layer_idx} should have 1 token"
|
||
|
||
# Test 5: Reset functionality
|
||
cache.reset()
|
||
assert cache.seq_pos == 0, "Reset should clear sequence position"
|
||
cached_k, cached_v = cache.get(0)
|
||
assert cached_k.shape == (batch_size, num_heads, 0, head_dim), "Reset should clear cache"
|
||
|
||
print("✅ KVCache implementation works correctly!")
|
||
|
||
# Run test immediately when developing this module
|
||
if __name__ == "__main__":
|
||
test_unit_kvcache()
|
||
|
||
# %% [markdown]
|
||
"""
|
||
## 🎯 Part 4: Enabling KV Caching for Model Generation
|
||
|
||
### Integration Strategy
|
||
|
||
Now we need a clean way to enable KV caching in our existing transformer models without breaking the existing code. We'll create an `enable_kv_cache()` function that:
|
||
|
||
1. Creates a KVCache instance sized for the model
|
||
2. Returns a flag to indicate caching is enabled
|
||
3. Can be called before generation starts
|
||
|
||
The actual integration with attention will happen in the milestone code where we:
|
||
1. Check if cache is enabled
|
||
2. Only compute K,V for new token (not all tokens)
|
||
3. Update cache with new K,V
|
||
4. Use cached K,V for attention computation
|
||
|
||
### Generation Flow Comparison
|
||
|
||
```
|
||
Without Cache (Current):
|
||
for each new token:
|
||
input_seq = [all tokens so far] # Length grows: 1, 2, 3, ...
|
||
logits = model.forward(input_seq) # Recomputes everything!
|
||
next_token = sample(logits[-1])
|
||
append next_token
|
||
|
||
With Cache (New):
|
||
cache = enable_kv_cache(model)
|
||
for each new token:
|
||
input_token = [just new token] # Length always 1
|
||
logits = model.forward_cached(input_token, cache) # Only new computation
|
||
next_token = sample(logits[-1])
|
||
append next_token
|
||
```
|
||
|
||
**Key Difference**: Input changes from growing sequence to single token, with cache providing history.
|
||
"""
|
||
|
||
# %%
|
||
#| export
|
||
def enable_kv_cache(batch_size: int, max_seq_len: int, num_layers: int,
|
||
num_heads: int, head_dim: int) -> KVCache:
|
||
"""
|
||
Create and return a KVCache instance for model generation.
|
||
|
||
This function creates a properly sized cache for the model architecture.
|
||
Call this before starting generation, then pass the cache to your
|
||
generation loop.
|
||
|
||
Args:
|
||
batch_size: Number of sequences to generate simultaneously
|
||
max_seq_len: Maximum sequence length to support
|
||
num_layers: Number of transformer layers in model
|
||
num_heads: Number of attention heads per layer
|
||
head_dim: Dimension per attention head (usually embed_dim // num_heads)
|
||
|
||
Returns:
|
||
KVCache instance ready for use
|
||
|
||
Example:
|
||
```python
|
||
# Enable caching for generation
|
||
cache = enable_kv_cache(
|
||
batch_size=1,
|
||
max_seq_len=100,
|
||
num_layers=4,
|
||
num_heads=4,
|
||
head_dim=32
|
||
)
|
||
|
||
# Use in generation loop (pseudocode)
|
||
for step in range(max_new_tokens):
|
||
# Only process new token with cache
|
||
logits = model.forward_cached(new_token, cache)
|
||
next_token = sample(logits)
|
||
```
|
||
"""
|
||
cache = KVCache(batch_size, max_seq_len, num_layers, num_heads, head_dim)
|
||
|
||
print(f"⚡ KV Cache enabled:")
|
||
print(f" Batch size: {batch_size}")
|
||
print(f" Max sequence: {max_seq_len}")
|
||
print(f" Layers: {num_layers}")
|
||
print(f" Heads: {num_heads}")
|
||
print(f" Head dim: {head_dim}")
|
||
|
||
mem_info = cache.get_memory_usage()
|
||
print(f" Memory: {mem_info['total_mb']:.2f} MB")
|
||
print()
|
||
|
||
return cache
|
||
|
||
# %% [markdown]
|
||
"""
|
||
### 🧪 Unit Test: Cache Enablement
|
||
|
||
Let's verify that we can create caches for realistic model configurations.
|
||
|
||
**This is a unit test** - it tests the cache creation and memory calculation for different model sizes.
|
||
"""
|
||
|
||
# %% nbgrader={"grade": true, "grade_id": "test-cache-enablement", "locked": true, "points": 10}
|
||
def test_unit_cache_enablement():
|
||
"""🔬 Unit Test: Cache Enablement for Different Models"""
|
||
print("🔬 Unit Test: Cache Enablement for Different Models...")
|
||
|
||
# Test 1: Small model (fast generation)
|
||
print(" Test 1: Small Model (Tiny Transformer)")
|
||
cache_small = KVCache(
|
||
batch_size=1,
|
||
max_seq_len=64,
|
||
num_layers=2,
|
||
num_heads=4,
|
||
head_dim=32
|
||
)
|
||
mem_small = cache_small.get_memory_usage()
|
||
assert mem_small['total_mb'] < 1.0, "Small model should use < 1 MB"
|
||
print(f" Small model cache: {mem_small['total_mb']:.3f} MB")
|
||
|
||
# Test 2: Medium model (balanced performance)
|
||
print(" Test 2: Medium Model (Standard Transformer)")
|
||
cache_medium = KVCache(
|
||
batch_size=1,
|
||
max_seq_len=128,
|
||
num_layers=4,
|
||
num_heads=8,
|
||
head_dim=64
|
||
)
|
||
mem_medium = cache_medium.get_memory_usage()
|
||
assert 1.0 < mem_medium['total_mb'] < 10.0, "Medium model should use 1-10 MB"
|
||
print(f" Medium model cache: {mem_medium['total_mb']:.3f} MB")
|
||
|
||
# Test 3: Batch inference (multiple sequences)
|
||
print(" Test 3: Batch Inference (4 sequences)")
|
||
cache_batch = KVCache(
|
||
batch_size=4, # Generate 4 sequences in parallel
|
||
max_seq_len=64,
|
||
num_layers=2,
|
||
num_heads=4,
|
||
head_dim=32
|
||
)
|
||
mem_batch = cache_batch.get_memory_usage()
|
||
assert mem_batch['total_mb'] > mem_small['total_mb'], "Batch cache should be larger"
|
||
print(f" Batch cache: {mem_batch['total_mb']:.3f} MB (4x batch size)")
|
||
|
||
print("✅ Cache enablement works correctly!")
|
||
|
||
# Run test immediately when developing this module
|
||
if __name__ == "__main__":
|
||
test_unit_cache_enablement()
|
||
|
||
# %% [markdown]
|
||
"""
|
||
## 🎯 Part 5: Using KV Cache in Practice
|
||
|
||
### Practical Integration Checklist
|
||
|
||
To use KV caching in your transformer generation:
|
||
|
||
**✅ Before Generation:**
|
||
1. Create cache with `enable_kv_cache()`
|
||
2. Set cache dimensions to match your model architecture
|
||
3. Verify memory usage is acceptable
|
||
|
||
**✅ During Generation (Modified Forward Pass):**
|
||
1. For the first token (prompt), process normally and populate cache
|
||
2. For subsequent tokens:
|
||
- Only process the NEW token (not entire sequence)
|
||
- Update cache with new K,V pairs
|
||
- Retrieve full cached K,V for attention
|
||
- Use cached values in attention computation
|
||
- Advance cache position after all layers
|
||
|
||
**✅ After Generation:**
|
||
1. Reset cache if generating another sequence
|
||
2. Monitor memory usage for production deployment
|
||
|
||
### Performance Expectations
|
||
|
||
```
|
||
Expected Speedup by Sequence Length:
|
||
┌───────────┬──────────┬───────────┬──────────┐
|
||
│ Seq Len │ No Cache │ With Cache│ Speedup │
|
||
├───────────┼──────────┼───────────┼──────────┤
|
||
│ 10 tokens│ ~80 tok/s│ ~600 tok/s│ 7.5x │
|
||
│ 25 tokens│ ~40 tok/s│ ~500 tok/s│ 12.5x │
|
||
│ 50 tokens│ ~25 tok/s│ ~400 tok/s│ 16.0x │
|
||
│ 100 tokens│ ~12 tok/s│ ~200 tok/s│ 16.7x │
|
||
└───────────┴──────────┴───────────┴──────────┘
|
||
|
||
Key Insight: Speedup increases with sequence length!
|
||
Why? Longer sequences = more redundant computation without cache.
|
||
```
|
||
|
||
### Production Considerations
|
||
|
||
**Memory Management:**
|
||
- Cache memory = `batch_size × num_layers × num_heads × max_seq_len × head_dim × 4 bytes`
|
||
- For GPT-2 (12 layers, 12 heads, seq_len=1024, head_dim=64): ~37 MB per sequence
|
||
- For GPT-3 (96 layers, 96 heads, seq_len=2048, head_dim=128): ~4.7 GB per sequence
|
||
|
||
**Trade-off Analysis:**
|
||
- **10x+ speedup** for typical generation lengths (50-200 tokens)
|
||
- **Modest memory cost** compared to model parameters (often <1% of model size)
|
||
- **Enables real-time interaction** that's impossible without caching
|
||
|
||
**Best Practices:**
|
||
1. Always use caching for production serving
|
||
2. Tune `max_seq_len` to expected generation length (don't over-allocate)
|
||
3. Consider batch inference to amortize model loading costs
|
||
4. Monitor cache memory usage in production
|
||
"""
|
||
|
||
# %% [markdown]
|
||
"""
|
||
## 🎯 Part 5: Non-Invasive Integration with Existing Models
|
||
|
||
### The Challenge
|
||
|
||
We built KV caching in Module 14, but our transformer (Modules 12-13) doesn't know about it!
|
||
|
||
**❌ BAD Solution**: Go back and modify Module 12 (MultiHeadAttention)
|
||
- Breaks "forward-only" learning (students shouldn't revisit old modules)
|
||
- Makes Module 12 depend on Module 14 (wrong dependency direction!)
|
||
- Violates clean module boundaries
|
||
|
||
**✅ GOOD Solution**: Module 14 ADDS caching to existing models without modification!
|
||
- Use composition + monkey-patching (like `enable_autograd()`)
|
||
- Module 14 wraps/enhances Module 12, not modifies it
|
||
- Students learn systems engineering: "Add capabilities, don't break old code"
|
||
|
||
### Implementation Strategy
|
||
|
||
We'll create `enable_kv_cache(model)` that:
|
||
1. Creates cache for the model's architecture
|
||
2. Wraps each attention layer with caching logic
|
||
3. Intercepts attention calls and manages cache automatically
|
||
4. Returns the cache for manual control if needed
|
||
|
||
This is **non-invasive enhancement** - a critical ML systems pattern!
|
||
"""
|
||
|
||
# %% nbgrader={"grade": false, "grade_id": "enable-kv-cache", "solution": true}
|
||
#| export
|
||
def enable_kv_cache(model):
|
||
"""
|
||
Enable KV caching for a transformer model WITHOUT modifying Module 12/13 code.
|
||
|
||
TODO: Create cache and non-invasively patch attention layers
|
||
|
||
APPROACH:
|
||
1. Validate model has required attributes (embed_dim, num_layers, num_heads, max_seq_len, blocks)
|
||
2. Calculate head_dim from embed_dim and num_heads
|
||
3. Create KVCache instance sized for this model's architecture
|
||
4. Store cache on model as model._kv_cache and set model._cache_enabled flag
|
||
5. For each transformer block, wrap its attention forward method with caching logic
|
||
6. Print confirmation message with cache statistics
|
||
7. Return the cache object
|
||
|
||
This function demonstrates **non-invasive optimization** - adding capabilities
|
||
to existing systems without breaking them. Similar to how Module 05 (Autograd)
|
||
uses enable_autograd() to add gradient tracking to Tensors.
|
||
|
||
Args:
|
||
model: A GPT-style transformer model with:
|
||
- model.embed_dim (int)
|
||
- model.num_layers (int)
|
||
- model.num_heads (int)
|
||
- model.max_seq_len (int)
|
||
- model.blocks (list of TransformerBlock objects)
|
||
|
||
Returns:
|
||
cache: KVCache object for this model
|
||
|
||
EXAMPLE:
|
||
>>> from tinytorch.models.transformer import GPT
|
||
>>> model = GPT(vocab_size=100, embed_dim=128, num_layers=4, num_heads=4)
|
||
>>> cache = enable_kv_cache(model)
|
||
>>> hasattr(model, '_kv_cache') # True
|
||
>>> model._cache_enabled # True
|
||
>>> cache.num_layers # 4 (matches model)
|
||
|
||
HINTS:
|
||
- Use hasattr() to validate model attributes exist
|
||
- head_dim = model.embed_dim // model.num_heads
|
||
- Store cache on model with model._kv_cache = cache
|
||
- Set flag with model._cache_enabled = True
|
||
- Save original forward with block._original_attention_forward
|
||
- Use a factory function to create patched forwards (closure captures layer_idx)
|
||
|
||
Pedagogical Note:
|
||
This teaches students that optimizations can be LAYERED on top of
|
||
working systems. Module 14 doesn't break Modules 12-13; it enhances them!
|
||
"""
|
||
### BEGIN SOLUTION
|
||
import types
|
||
|
||
# Validate model has required attributes
|
||
required_attrs = ['embed_dim', 'num_layers', 'num_heads', 'max_seq_len', 'blocks']
|
||
for attr in required_attrs:
|
||
if not hasattr(model, attr):
|
||
raise AttributeError(
|
||
f"Model missing '{attr}' - enable_kv_cache() requires a GPT-style model "
|
||
f"with {', '.join(required_attrs)}"
|
||
)
|
||
|
||
# Calculate head dimension
|
||
head_dim = model.embed_dim // model.num_heads
|
||
if model.embed_dim % model.num_heads != 0:
|
||
raise ValueError(
|
||
f"embed_dim ({model.embed_dim}) must be divisible by num_heads ({model.num_heads})"
|
||
)
|
||
|
||
# Create cache for this model
|
||
cache = KVCache(
|
||
batch_size=1, # Default to single sequence; can be reset for batch inference
|
||
max_seq_len=model.max_seq_len,
|
||
num_layers=model.num_layers,
|
||
num_heads=model.num_heads,
|
||
head_dim=head_dim
|
||
)
|
||
|
||
# Store cache on model for easy access
|
||
model._kv_cache = cache
|
||
model._cache_enabled = True
|
||
|
||
# Patch each transformer block's attention
|
||
for layer_idx, block in enumerate(model.blocks):
|
||
# Store original attention forward method
|
||
if not hasattr(block, '_original_attention_forward'):
|
||
block._original_attention_forward = block.attention.forward
|
||
|
||
# Create cached version
|
||
def make_cached_forward(layer_idx, original_forward, cache_obj):
|
||
"""Factory to create cached forward with correct layer_idx closure"""
|
||
def cached_forward(x, mask=None):
|
||
"""
|
||
Cached attention forward pass with REAL speedup!
|
||
|
||
PATH SELECTION STRATEGY (Key to Understanding KV Caching):
|
||
──────────────────────────────────────────────────────────
|
||
|
||
We have THREE possible paths through attention:
|
||
|
||
1️⃣ TRAINING PATH (seq_len > 1):
|
||
- Input: Full sequence of tokens (e.g., 64 tokens)
|
||
- Action: Use ORIGINAL attention (no caching)
|
||
- Why: Need full gradient flow for backpropagation
|
||
- Complexity: O(n²) but that's fine for training
|
||
- Example: x.shape = (batch=1, seq=64, embed=128)
|
||
|
||
2️⃣ FIRST TOKEN PATH (seq_len == 1 AND cache empty):
|
||
- Input: Single token (the first one in generation)
|
||
- Action: Use ORIGINAL attention (initialize cache)
|
||
- Why: Cache is empty, nothing to retrieve yet
|
||
- Complexity: O(1) - only one token
|
||
- Example: x.shape = (batch=1, seq=1, embed=128), cache.seq_pos=0
|
||
|
||
3️⃣ CACHED GENERATION PATH (seq_len == 1 AND cache populated):
|
||
- Input: Single NEW token (during generation)
|
||
- Action: Compute K,V for new token ONLY, retrieve history from cache
|
||
- Why: This is where the speedup happens! O(n²) → O(n)
|
||
- Complexity: O(n) - only compute for new token, reuse cache
|
||
- Example: x.shape = (batch=1, seq=1, embed=128), cache.seq_pos=5
|
||
|
||
|
||
WHY .data INSTEAD OF TENSOR OPERATIONS?
|
||
────────────────────────────────────────
|
||
|
||
In the cached path, we use numpy via .data for three reasons:
|
||
|
||
1. **Explicit Intent**: Makes it crystal clear this is inference-only
|
||
- Training: Uses Tensor operations → gradients tracked
|
||
- Inference: Uses .data → no gradient overhead
|
||
|
||
2. **Performance**: Avoids any autograd bookkeeping
|
||
- Even if small, every bit counts in generation
|
||
- Production LLMs (vLLM, llama.cpp) use similar patterns
|
||
|
||
3. **Educational Clarity**: Shows students the distinction
|
||
- "When do I need gradients?" (training)
|
||
- "When can I skip them?" (inference)
|
||
|
||
We COULD use Tensor operations with requires_grad=False, but .data
|
||
is more explicit and is the industry-standard pattern.
|
||
|
||
|
||
THE O(n²) → O(n) TRANSFORMATION:
|
||
─────────────────────────────────
|
||
|
||
WITHOUT Cache (Standard Attention):
|
||
Step 1: Process token 1 → Compute attention for 1 token (1² = 1 op)
|
||
Step 2: Process tokens 1-2 → Compute attention for 2 tokens (2² = 4 ops)
|
||
Step 3: Process tokens 1-3 → Compute attention for 3 tokens (3² = 9 ops)
|
||
...
|
||
Step N: Process tokens 1-N → Compute attention for N tokens (N² ops)
|
||
|
||
Total: 1 + 4 + 9 + ... + N² = O(N³) across all steps!
|
||
|
||
WITH Cache (Our Implementation):
|
||
Step 1: Process token 1 → Compute K,V for token 1, cache it (1 op)
|
||
Step 2: Process token 2 → Compute K,V for token 2, retrieve 1 (2 ops)
|
||
Step 3: Process token 3 → Compute K,V for token 3, retrieve 1-2 (3 ops)
|
||
...
|
||
Step N: Process token N → Compute K,V for token N, retrieve 1-(N-1) (N ops)
|
||
|
||
Total: 1 + 2 + 3 + ... + N = O(N²) across all steps!
|
||
|
||
That's why we see 5-7x speedup on short sequences, and 10-15x on longer ones!
|
||
"""
|
||
from tinytorch.core.tensor import Tensor
|
||
import numpy as np
|
||
|
||
seq_len = x.shape[1]
|
||
|
||
# ═══════════════════════════════════════════════════════════════
|
||
# PATH SELECTION: Choose between training, first token, or cached
|
||
# ═══════════════════════════════════════════════════════════════
|
||
|
||
# PATH 1: TRAINING (seq_len > 1)
|
||
# ───────────────────────────────────
|
||
# Input is a full sequence (e.g., 64 tokens during training)
|
||
# We MUST use original attention to preserve gradient flow
|
||
# No caching during training - we need backprop through everything
|
||
if seq_len > 1:
|
||
return original_forward(x, mask) # O(n²) but preserves gradients
|
||
|
||
# PATH 2: FIRST TOKEN (seq_len == 1, cache empty)
|
||
# ────────────────────────────────────────────────
|
||
# This is the very first token in generation (cache.seq_pos == 0)
|
||
# Cache is empty, so there's nothing to retrieve yet
|
||
# Use original attention to process this token, which will populate cache
|
||
if cache_obj.seq_pos == 0:
|
||
return original_forward(x, mask) # O(1) - just one token
|
||
|
||
# PATH 3: CACHED GENERATION (seq_len == 1, cache populated)
|
||
# ──────────────────────────────────────────────────────────
|
||
# This is a NEW token during generation (cache has history)
|
||
# We can now use the cache for massive speedup!
|
||
# Compute K,V for ONLY this new token, retrieve cached history
|
||
|
||
# Get attention layer (assumes block.attention has the attention object)
|
||
attention = block.attention
|
||
|
||
# Step 1: Compute Q, K, V for NEW token only
|
||
# Access the linear projection layers
|
||
Q_new = attention.q_proj.forward(x) # (batch, 1, embed_dim)
|
||
K_new = attention.k_proj.forward(x) # (batch, 1, embed_dim)
|
||
V_new = attention.v_proj.forward(x) # (batch, 1, embed_dim)
|
||
|
||
# Step 2: Reshape to multi-head format
|
||
batch_size = x.shape[0]
|
||
num_heads = attention.num_heads
|
||
head_dim = attention.head_dim
|
||
|
||
# Reshape: (batch, 1, embed_dim) → (batch, num_heads, 1, head_dim)
|
||
Q_heads = Q_new.reshape(batch_size, 1, num_heads, head_dim)
|
||
Q_heads = Tensor(np.transpose(Q_heads.data, (0, 2, 1, 3))) # (batch, num_heads, 1, head_dim)
|
||
|
||
K_heads = K_new.reshape(batch_size, 1, num_heads, head_dim)
|
||
K_heads = Tensor(np.transpose(K_heads.data, (0, 2, 1, 3)))
|
||
|
||
V_heads = V_new.reshape(batch_size, 1, num_heads, head_dim)
|
||
V_heads = Tensor(np.transpose(V_heads.data, (0, 2, 1, 3)))
|
||
|
||
# Step 3: Update cache with new K, V (using .data for performance)
|
||
cache_obj.update(layer_idx, K_heads, V_heads)
|
||
|
||
# Step 4: Retrieve ALL cached K, V (includes history + new token)
|
||
K_all, V_all = cache_obj.get(layer_idx)
|
||
|
||
# Step 5: Compute attention using new Q with ALL cached K, V
|
||
# ─────────────────────────────────────────────────────────
|
||
# Scaled dot-product attention: softmax(Q @ K^T / sqrt(d_k)) @ V
|
||
#
|
||
# NOTE: We use .data (numpy arrays) here instead of Tensor operations
|
||
# Why? This is INFERENCE-ONLY code (no gradients needed):
|
||
# - Explicit: Makes it clear this is inference, not training
|
||
# - Fast: Avoids autograd overhead (even if small)
|
||
# - Standard: Production LLMs (vLLM, llama.cpp) do the same
|
||
#
|
||
# If this were training, we'd use Tensor operations for gradient flow.
|
||
# But in generation (inference), .data is the right choice.
|
||
|
||
# Q @ K^T: (batch, num_heads, 1, head_dim) @ (batch, num_heads, head_dim, seq_len)
|
||
# → (batch, num_heads, 1, seq_len)
|
||
K_transposed = np.transpose(K_all.data, (0, 1, 3, 2)) # .data = numpy array
|
||
scores = np.matmul(Q_heads.data, K_transposed) # Pure numpy matmul
|
||
|
||
# Scale by sqrt(head_dim)
|
||
scores = scores / np.sqrt(head_dim)
|
||
|
||
# Apply mask if provided (causal mask for generation)
|
||
if mask is not None:
|
||
# Mask should be (1, 1, 1, seq_len) for this token
|
||
# In generation, we can attend to all previous tokens
|
||
pass # No masking needed in generation (we see all history)
|
||
|
||
# Softmax over key dimension
|
||
scores_max = np.max(scores, axis=-1, keepdims=True)
|
||
exp_scores = np.exp(scores - scores_max)
|
||
attention_weights = exp_scores / np.sum(exp_scores, axis=-1, keepdims=True)
|
||
|
||
# Apply attention weights to values
|
||
# (batch, num_heads, 1, seq_len) @ (batch, num_heads, seq_len, head_dim)
|
||
# → (batch, num_heads, 1, head_dim)
|
||
attention_output = np.matmul(attention_weights, V_all.data)
|
||
|
||
# Step 6: Reshape back and apply output projection
|
||
# (batch, num_heads, 1, head_dim) → (batch, 1, num_heads, head_dim)
|
||
attention_output_transposed = np.transpose(attention_output, (0, 2, 1, 3))
|
||
|
||
# Concatenate heads: (batch, 1, num_heads * head_dim)
|
||
concat_data = attention_output_transposed.reshape(batch_size, 1, num_heads * head_dim)
|
||
concat_output = Tensor(concat_data)
|
||
|
||
# Output projection
|
||
output = attention.out_proj.forward(concat_output)
|
||
|
||
return output
|
||
|
||
return cached_forward
|
||
|
||
# Patch this block's attention
|
||
block.attention.forward = make_cached_forward(layer_idx, block._original_attention_forward, cache)
|
||
|
||
print(f"⚡ KV Cache enabled for model!")
|
||
print(f" Architecture: {model.num_layers} layers × {model.num_heads} heads × {head_dim}D")
|
||
print(f" Memory: {cache.get_memory_usage()['total_mb']:.2f} MB")
|
||
print(f" Cache stored in: model._kv_cache")
|
||
print()
|
||
print(f"💡 To disable: call disable_kv_cache(model)")
|
||
print()
|
||
|
||
return cache
|
||
### END SOLUTION
|
||
|
||
|
||
#| export
|
||
def disable_kv_cache(model):
|
||
"""
|
||
Disable KV caching and restore original attention behavior.
|
||
|
||
Args:
|
||
model: Model with caching enabled
|
||
|
||
Example:
|
||
```python
|
||
cache = enable_kv_cache(model)
|
||
# ... do cached generation ...
|
||
disable_kv_cache(model) # Back to normal
|
||
```
|
||
"""
|
||
if not hasattr(model, '_cache_enabled') or not model._cache_enabled:
|
||
print("⚠️ KV cache not enabled on this model")
|
||
return
|
||
|
||
# Restore original attention forwards
|
||
for block in model.blocks:
|
||
if hasattr(block, '_original_attention_forward'):
|
||
block.attention.forward = block._original_attention_forward
|
||
|
||
# Clean up
|
||
model._cache_enabled = False
|
||
if hasattr(model, '_kv_cache'):
|
||
delattr(model, '_kv_cache')
|
||
|
||
print("✓ KV cache disabled, original attention restored")
|
||
|
||
|
||
# %% [markdown]
|
||
"""
|
||
### 🧪 Unit Test: Non-Invasive Cache Integration
|
||
|
||
Let's verify that `enable_kv_cache()` works without breaking the model!
|
||
|
||
**This is an integration test** - it tests Module 14 enhancing Modules 12-13 without modification.
|
||
"""
|
||
|
||
# %% nbgrader={"grade": true, "grade_id": "test-noninvasive", "locked": true, "points": 10}
|
||
def test_unit_noninvasive_integration():
|
||
"""🔬 Unit Test: Non-Invasive Cache Integration"""
|
||
print("🔬 Unit Test: Non-Invasive Cache Integration...")
|
||
|
||
# Create a mock transformer-like object for testing
|
||
class MockTransformerBlock:
|
||
def __init__(self):
|
||
self.attention = self
|
||
|
||
def forward(self, x):
|
||
# Simple pass-through for testing
|
||
return x
|
||
|
||
class MockGPT:
|
||
def __init__(self):
|
||
self.vocab_size = 100
|
||
self.embed_dim = 128
|
||
self.num_layers = 4
|
||
self.num_heads = 4
|
||
self.max_seq_len = 64
|
||
self.blocks = [MockTransformerBlock() for _ in range(self.num_layers)]
|
||
|
||
# Test 1: Enable caching
|
||
model = MockGPT()
|
||
print(" Test 1: Enable caching on model")
|
||
cache = enable_kv_cache(model)
|
||
assert hasattr(model, '_kv_cache'), "Model should have _kv_cache attribute"
|
||
assert hasattr(model, '_cache_enabled'), "Model should have _cache_enabled flag"
|
||
assert model._cache_enabled == True, "Cache should be enabled"
|
||
assert cache is model._kv_cache, "Returned cache should match model._kv_cache"
|
||
|
||
# Test 2: Attention forward still works
|
||
print(" Test 2: Attention forward pass still works")
|
||
test_input = Tensor(np.random.randn(1, 10, 128))
|
||
for block in model.blocks:
|
||
output = block.attention.forward(test_input)
|
||
assert output.shape == test_input.shape, "Forward pass should preserve shape"
|
||
|
||
# Test 3: Disable caching
|
||
print(" Test 3: Disable caching")
|
||
disable_kv_cache(model)
|
||
assert model._cache_enabled == False, "Cache should be disabled"
|
||
assert not hasattr(model, '_kv_cache'), "Cache object should be removed"
|
||
|
||
# Test 4: Can re-enable
|
||
print(" Test 4: Re-enable caching")
|
||
_ = enable_kv_cache(model)
|
||
assert model._cache_enabled == True, "Cache should be re-enabled"
|
||
|
||
print("✅ Non-invasive cache integration works correctly!")
|
||
|
||
# Run test immediately when developing this module
|
||
if __name__ == "__main__":
|
||
test_unit_noninvasive_integration()
|
||
|
||
|
||
# %% [markdown]
|
||
"""
|
||
## 🧪 Module Integration Test
|
||
|
||
Final validation that everything works together correctly before module completion.
|
||
"""
|
||
|
||
# %% nbgrader={"grade": true, "grade_id": "module-integration", "locked": true, "points": 20}
|
||
def test_module():
|
||
"""
|
||
Comprehensive test of entire KV Caching module functionality.
|
||
|
||
This final test runs before module summary to ensure:
|
||
- All unit tests pass
|
||
- Functions work together correctly
|
||
- Module is ready for integration with TinyTorch
|
||
"""
|
||
print("🧪 RUNNING MODULE INTEGRATION TEST")
|
||
print("=" * 50)
|
||
print()
|
||
|
||
# Run all unit tests
|
||
print("Running unit tests...")
|
||
test_unit_kvcache()
|
||
print()
|
||
test_unit_cache_enablement()
|
||
print()
|
||
test_unit_noninvasive_integration()
|
||
print()
|
||
|
||
print("Running integration scenarios...")
|
||
print()
|
||
|
||
# Integration Test: Complete KV Cache Workflow
|
||
print("🔬 Integration Test: Complete KV Cache Workflow...")
|
||
batch_size, max_seq_len = 1, 128
|
||
num_layers, num_heads, head_dim = 4, 8, 64
|
||
|
||
cache = KVCache(batch_size, max_seq_len, num_layers, num_heads, head_dim)
|
||
|
||
# Simulate generation loop (processing multiple tokens)
|
||
for _ in range(5):
|
||
for layer_idx in range(num_layers):
|
||
# Simulate new key-value pairs
|
||
new_key = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim))
|
||
new_value = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim))
|
||
|
||
# Update cache
|
||
cache.update(layer_idx, new_key, new_value)
|
||
|
||
# Advance position after all layers processed
|
||
cache.advance()
|
||
|
||
# Verify cache state
|
||
assert cache.seq_pos == 5, f"Expected seq_pos=5, got {cache.seq_pos}"
|
||
|
||
# Verify retrieval
|
||
for layer_idx in range(num_layers):
|
||
cached_k, cached_v = cache.get(layer_idx)
|
||
assert cached_k.shape == (batch_size, num_heads, 5, head_dim)
|
||
assert cached_v.shape == (batch_size, num_heads, 5, head_dim)
|
||
|
||
print("✅ Complete KV cache workflow validated!")
|
||
print()
|
||
|
||
# Integration Test: Memory Tracking
|
||
print("🔬 Integration Test: Memory Tracking...")
|
||
mem_info = cache.get_memory_usage()
|
||
assert mem_info['total_mb'] > 0
|
||
assert mem_info['cache_tensors'] == num_layers * 2
|
||
print(f"✅ Memory tracking: {mem_info['total_mb']:.2f} MB for {mem_info['cache_tensors']} tensors")
|
||
print()
|
||
|
||
print("=" * 50)
|
||
print("🎉 ALL TESTS PASSED! Module ready for export.")
|
||
print("Run: tito module complete 14")
|
||
|
||
# %%
|
||
if __name__ == "__main__":
|
||
test_module()
|
||
|
||
|
||
# %% [markdown]
|
||
"""
|
||
## 🎓 Module 14 Complete!
|
||
|
||
You've implemented KV caching - the critical optimization that makes production language models economically viable!
|
||
|
||
### What You Built
|
||
|
||
✅ **KVCache Class**: Efficient memory management for key-value pairs across layers
|
||
✅ **O(1) Updates**: Fast cache updates without data copying
|
||
✅ **Memory Tracking**: Understanding cache size and memory trade-offs
|
||
✅ **Non-Invasive Integration**: `enable_kv_cache()` adds optimization WITHOUT breaking modules
|
||
✅ **Production Patterns**: Integration strategy for real transformer models
|
||
|
||
### Key Systems Engineering Lesson
|
||
|
||
**Module 14 doesn't modify Modules 12-13 - it ENHANCES them!**
|
||
|
||
This teaches the critical principle: **Add capabilities forward, never break backward.**
|
||
- Old code keeps working (Module 12 unchanged)
|
||
- New code adds optimization (Module 14 layers on top)
|
||
- Clean separation of concerns (caching is separate from attention logic)
|
||
|
||
### Performance Impact
|
||
|
||
```
|
||
Without Cache: O(n²) complexity → slow, expensive, impractical
|
||
With Cache: O(n) complexity → fast, cheap, production-ready
|
||
|
||
Real Impact: 10-15x speedup for typical generation!
|
||
```
|
||
|
||
### What's Next
|
||
|
||
**Module 15 (Profiling)**: Now that you've seen a concrete optimization, learn how to systematically measure and find more optimizations using professional profiling tools.
|
||
|
||
### Try It Yourself
|
||
|
||
Run the chatbot milestone with and without caching:
|
||
|
||
```bash
|
||
# Without cache (slow - baseline)
|
||
python milestones/05_2017_transformer/vaswani_chatgpt.py
|
||
|
||
# With cache (fast - 10-15x speedup!)
|
||
python milestones/05_2017_transformer/vaswani_chatgpt.py --use-cache
|
||
```
|
||
|
||
Watch the tokens/sec metric jump from ~40 to ~500! 🚀
|
||
|
||
---
|
||
|
||
**Congratulations! You've completed Module 14: KV Caching!**
|
||
|
||
You now understand the optimization that makes ChatGPT, Claude, and all production LLMs possible. This is THE technique that transformed language models from research toys into products used by millions of people every day.
|
||
|
||
**From Theory to Practice**: You've gone from O(n²) naive generation to O(n) optimized generation. This is real ML engineering!
|
||
"""
|