Files
TinyTorch/modules/17_memoization/memoization.py
Vijay Janapa Reddi ee9355584f Fix all module tests after merge - 20/20 passing
Fixes after merge conflicts:
- Fix tensor reshape error message format
- Fix __init__.py imports (remove BatchNorm2d, fix enable_autograd call)
- Fix attention mask broadcasting for multi-head attention
- Fix memoization module to use matmul instead of @ operator
- Fix capstone module count_parameters and CosineSchedule usage
- Add missing imports to benchmark.py (dataclass, Profiler, platform, os)
- Simplify capstone pipeline test to avoid data shape mismatch

All 20 modules now pass tito test --all
2025-12-03 08:14:27 -08:00

1785 lines
72 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# ---
# jupyter:
# jupytext:
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.17.1
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
# name: python3
# ---
# %% [markdown]
"""
# Module 17: Memoization - Computational Reuse for Inference
Welcome to Module 17! You'll implement memoization - a fundamental optimization pattern. We'll apply it to transformers through KV caching for 10-15x faster text generation.
## 🔗 Prerequisites & Progress
**You've Built**: Complete transformer architecture (Module 13) and profiling tools (Module 14)
**You'll Build**: Memoization system that eliminates redundant computation through caching
**You'll Enable**: Production-grade inference optimization using computational reuse
**Connection Map**:
```
Profiling (14) → Quantization (16) → Memoization (17) → Acceleration (18)
(measure O(n²)) (reduce precision) (cache K,V → O(n)) (optimize execution)
```
## Learning Objectives
By the end of this module, you will:
1. Understand memoization as a general optimization pattern (cache results, avoid recomputation)
2. Apply memoization to transformers through KV caching
3. Implement KVCache with efficient memory management and O(1) updates
4. Build cache-aware attention that reuses previously computed keys and values
5. Measure dramatic speedup gains (10-15x) and understand memory trade-offs
Let's make inference blazingly fast through computational reuse!
## 📦 Where This Code Lives in the Final Package
**Learning Side:** You work in `modules/17_memoization/kvcaching_dev.py`
**Building Side:** Code exports to `tinytorch.generation.kv_cache`
```python
# How to use this module:
from tinytorch.generation.kv_cache import KVCache, enable_kv_cache
```
**Why this matters:**
- **Learning:** Complete caching system demonstrating production optimization techniques
- **Production:** Proper organization matching Hugging Face's generation/ module structure
- **Consistency:** All generation optimizations in generation.kv_cache
- **Integration:** Works seamlessly with transformers for complete inference optimization
"""
# %%
#| default_exp generation.kv_cache
#| export
import numpy as np
import time
from typing import Tuple, Optional, Dict, List
# Import TinyTorch components from previous modules
from tinytorch.core.tensor import Tensor
# Constants for memory calculations
BYTES_PER_FLOAT32 = 4 # Standard float32 size in bytes
MB_TO_BYTES = 1024 * 1024 # Megabytes to bytes conversion
# %% [markdown]
"""
## 🔬 Motivation: Why Memoization Matters for Transformers
Before we learn KV caching, let's profile transformer generation to understand
the problem we're solving. We'll see O(n²) growth in latency as we generate text.
"""
# %% nbgrader={"grade": false, "grade_id": "motivation-profile", "locked": false}
def profile_naive_generation():
"""
Profile transformer generation to discover the O(n²) bottleneck.
Educational Purpose:
Demonstrates why KV caching is necessary by showing concrete
measurements of quadratic growth in generation time.
This function runs ONLY when the module is executed directly,
not when imported (avoiding side effects during imports).
"""
from tinytorch.profiling.profiler import Profiler
import matplotlib.pyplot as plt
profiler = Profiler()
def naive_attention_step(seq_len, hidden_dim=64):
"""
Simulates one step of attention computation.
Without caching, this processes ALL previous tokens every time.
"""
# Q, K, V for entire sequence
q = Tensor(np.random.randn(1, seq_len, hidden_dim))
k = Tensor(np.random.randn(1, seq_len, hidden_dim))
v = Tensor(np.random.randn(1, seq_len, hidden_dim))
# Attention: Q @ K.T then @ V
# This is O(seq_len²) in complexity
scores = q.matmul(k.transpose()) # (1, seq_len, seq_len)
output = scores.matmul(v)
return output
# Profile at increasing sequence lengths
print("🔬 Profiling Transformer Generation (Without Caching):\n")
print(" Seq Len | Latency (ms) | Growth")
print(" ---------|----------------|----------")
sequence_lengths = [10, 20, 40, 80, 160]
latencies = []
for seq_len in sequence_lengths:
# Measure latency for this sequence length using simple timing
import time
# Warmup
for _ in range(5):
naive_attention_step(seq_len)
# Measure
times = []
for _ in range(20):
start = time.perf_counter()
naive_attention_step(seq_len)
times.append((time.perf_counter() - start) * 1000) # ms
latency = sorted(times)[len(times)//2] # median
latencies.append(latency)
# Calculate growth rate
if len(latencies) > 1:
growth = latencies[-1] / latencies[-2]
print(f" {seq_len:3d} | {latency:6.2f} | {growth:.2f}×")
else:
print(f" {seq_len:3d} | {latency:6.2f} | baseline")
print("\n💡 Key Observations:")
print(" • Latency grows QUADRATICALLY with sequence length")
print(" • Each new token forces recomputation of ALL previous K,V pairs")
print(" • For 160 tokens: ~4× time vs 80 tokens (2² growth)")
print("\n🎯 The Problem:")
print(" K and V values for previous tokens NEVER change,")
print(" yet we recompute them every single step!")
print("\n✨ The Solution:")
print(" CACHE the K,V values! (That's memoization)")
print(" • First compute: Calculate and store K,V")
print(" • Later steps: Reuse stored K,V")
print(" • Complexity: O(n²) → O(n)")
print(" • Speedup: 10-15× for typical generation\n")
# Run profiling when module is executed directly
if __name__ == "__main__":
profile_naive_generation()
# %% [markdown]
"""
## 🎯 Part 1: Understanding the Autoregressive Generation Problem
### The Core Inefficiency
When generating text token by token, transformers face a fundamental computational bottleneck. Let's visualize what happens during naive generation:
```
Token Generation Process (Without Caching):
Step 1: Generate "Hello"
Input: [START]
Attention: Q₁ × [K₁] × [V₁] ← 1 computation
Step 2: Generate "world"
Input: [START, Hello]
Attention: Q₂ × [K₁, K₂] × [V₁, V₂] ← 2 computations (K₁,V₁ RECOMPUTED!)
Step 3: Generate "!"
Input: [START, Hello, world]
Attention: Q₃ × [K₁, K₂, K₃] × [V₁, V₂, V₃] ← 3 computations (K₁,V₁,K₂,V₂ RECOMPUTED!)
```
**The Problem**: For each new token, we recompute ALL previous key-value pairs even though they never change!
### Computational Complexity Analysis
```
Naive Generation Complexity:
Step 1: 1 K,V computation
Step 2: 2 K,V computations
Step 3: 3 K,V computations
...
Step n: n K,V computations
Total: 1 + 2 + 3 + ... + n = n(n+1)/2 = O(n²) complexity!
```
For a 100-token sequence, this means **5,050 redundant computations**!
### Real-World Impact
This inefficiency makes production LLM serving economically impossible without optimization:
- **ChatGPT/GPT-4**: Would be too slow for real-time chat without caching
- **Code completion**: IDEs couldn't provide instant suggestions
- **Mobile deployment**: On-device generation would drain batteries instantly
- **API serving**: Server costs would be 10x+ higher
**The Solution**: Cache key-value pairs after computing them once, transforming O(n²) into O(n).
"""
# %% [markdown]
"""
## 🧮 Part 2: The Key-Value Caching Insight
### Mathematical Foundation
The core insight comes from understanding what changes during autoregressive generation:
```
Attention Computation Breakdown:
Q = new_token @ W_q ← Only new token (changes each step)
K = all_tokens @ W_k ← Includes old tokens (mostly redundant!)
V = all_tokens @ W_v ← Includes old tokens (mostly redundant!)
attention_output = softmax(Q @ K.T / √d_k) @ V
```
**Key Insight**: K and V matrices for previous tokens NEVER change!
```
Token Dependencies:
K₁ = token₁ @ W_k ← Computed once, never changes
K₂ = token₂ @ W_k ← Computed once, never changes
K₃ = token₃ @ W_k ← Computed once, never changes
Same for V₁, V₂, V₃...
```
### Cache-Optimized Generation
```
Optimized Generation Process (With Caching):
Step 1: Generate "Hello"
Compute: K₁, V₁ → Store in cache
Attention: Q₁ × cached[K₁] × cached[V₁]
Step 2: Generate "world"
Compute: K₂, V₂ → Append to cache
Attention: Q₂ × cached[K₁, K₂] × cached[V₁, V₂]
Step 3: Generate "!"
Compute: K₃, V₃ → Append to cache
Attention: Q₃ × cached[K₁, K₂, K₃] × cached[V₁, V₂, V₃]
```
**Result**: Each step computes only ONE new K,V pair instead of recomputing ALL!
### Memory vs Compute Trade-off
```
Traditional Approach:
Memory: O(1) (no storage needed)
Compute: O(n²) (recompute everything)
Cached Approach:
Memory: O(n × d_k) (store all K,V pairs)
Compute: O(n) (only compute new pairs)
For n=100, d_k=64:
Memory cost: 6.4 KB per layer
Compute savings: 50x reduction in K,V computations
```
**Trade-off Winner**: Memory is cheap, compute is expensive! Use O(n) memory to save O(n²) compute.
"""
# %% [markdown]
"""
## 🏗️ Part 3: KVCache Class Implementation
### Core Requirements
Our KVCache needs to efficiently handle:
1. **Multi-layer storage**: Each transformer layer needs its own K,V cache
2. **Multi-head attention**: Each attention head has separate K,V pairs
3. **Batch processing**: Support multiple sequences simultaneously (batch inference)
4. **Dynamic updates**: Efficiently append new tokens without copying data
5. **Memory management**: Pre-allocate space to avoid dynamic resizing overhead
### Cache Architecture Visualization
```
KVCache Memory Layout:
┌─────────────────────────────────────────────────────────┐
│ KVCache Object │
├─────────────────────────────────────────────────────────┤
│ Layer 0: ┌─────────────┬─────────────┐ │
│ │ Key Cache │ Value Cache │ │
│ │ (B,H,S,D) │ (B,H,S,D) │ │
│ └─────────────┴─────────────┘ │
├─────────────────────────────────────────────────────────┤
│ Layer 1: ┌─────────────┬─────────────┐ │
│ │ Key Cache │ Value Cache │ │
│ │ (B,H,S,D) │ (B,H,S,D) │ │
│ └─────────────┴─────────────┘ │
├─────────────────────────────────────────────────────────┤
│ ... ┌─────────────┬─────────────┐ │
│ Layer N: │ Key Cache │ Value Cache │ │
│ │ (B,H,S,D) │ (B,H,S,D) │ │
│ └─────────────┴─────────────┘ │
└─────────────────────────────────────────────────────────┘
Where:
B = batch_size (number of sequences)
H = num_heads (attention heads per layer)
S = max_seq_len (maximum sequence length)
D = head_dim (dimension per attention head)
```
### Update Operation Flow
```
Cache Update Process:
seq_pos = 2
┌─────┬─────┬─────┬─────┬─────┬─────┐
│ K₁ │ K₂ │ ??? │ ??? │ ??? │ ??? │ ← Key Cache
├─────┼─────┼─────┼─────┼─────┼─────┤
│ V₁ │ V₂ │ ??? │ ??? │ ??? │ ??? │ ← Value Cache
└─────┴─────┴─────┴─────┴─────┴─────┘
New token arrives: K₃, V₃
seq_pos = 2
┌─────┬─────┬─────┬─────┬─────┬─────┐
│ K₁ │ K₂ │ K₃ │ ??? │ ??? │ ??? │ ← Write K₃ here
├─────┼─────┼─────┼─────┼─────┼─────┤
│ V₁ │ V₂ │ V₃ │ ??? │ ??? │ ??? │ ← Write V₃ here
└─────┴─────┴─────┴─────┴─────┴─────┘
Then: seq_pos += 1 (advance to position 3)
```
This design enables **O(1) updates** - just write to the next position!
"""
# %% nbgrader={"grade": false, "grade_id": "kvcache-class", "solution": true}
#| export
class KVCache:
"""
Efficient key-value cache for autoregressive generation.
Stores K,V matrices for each transformer layer to avoid recomputation
during sequential token generation. This is THE critical optimization
that makes production language model serving economically viable.
⚠️ IMPORTANT: INFERENCE-ONLY (No Gradient Tracking)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
KV caching is designed ONLY for inference (generation), NOT training.
- During generation: No gradients computed (model.eval() mode)
- Cache operations use .data (no gradient tracking)
- This is correct and intentional for maximum speed
- DO NOT use caching during training (use standard forward pass)
Architecture:
- Pre-allocates cache tensors with maximum sequence length
- Tracks current sequence position for efficient O(1) updates
- Provides update() method to append new K,V pairs without copying
- Provides get() method to retrieve cached values for attention
- Handles multiple layers and attention heads properly
Memory Layout:
```
Layer 0: [Key_cache, Value_cache] # Shape: (batch, num_heads, max_seq, head_dim)
Layer 1: [Key_cache, Value_cache]
...
Layer N: [Key_cache, Value_cache]
```
Performance:
- Update: O(1) - just index assignment
- Get: O(1) - just slicing (no data copy)
- Memory: O(num_layers × batch × heads × max_seq × head_dim)
"""
def __init__(self, batch_size: int, max_seq_len: int, num_layers: int,
num_heads: int, head_dim: int):
"""
Initialize KV cache for efficient generation.
TODO: Set up pre-allocated cache storage for all transformer layers
APPROACH:
1. Store configuration parameters (batch_size, max_seq_len, etc.)
2. Initialize sequence position counter to 0
3. Create empty list for cache storage
4. For each layer, pre-allocate zero-filled key and value caches
5. Store each layer's (key_cache, value_cache) tuple in the list
Args:
batch_size: Number of sequences to generate simultaneously
max_seq_len: Maximum sequence length to support
num_layers: Number of transformer layers
num_heads: Number of attention heads per layer
head_dim: Dimension of each attention head
EXAMPLE:
>>> cache = KVCache(batch_size=2, max_seq_len=128, num_layers=4,
... num_heads=8, head_dim=64)
>>> cache.seq_pos # 0 (no tokens cached yet)
>>> len(cache.caches) # 4 (one per layer)
>>> cache.caches[0][0].shape # (2, 8, 128, 64) - key cache for layer 0
HINTS:
- Cache shape: (batch_size, num_heads, max_seq_len, head_dim)
- Use Tensor(np.zeros(...)) to create cache tensors
- Store caches as list of tuples: [(key_0, val_0), (key_1, val_1), ...]
- Pre-allocation avoids dynamic resizing overhead during generation
"""
### BEGIN SOLUTION
self.batch_size = batch_size
self.max_seq_len = max_seq_len
self.num_layers = num_layers
self.num_heads = num_heads
self.head_dim = head_dim
# Current sequence position (how many tokens are cached)
self.seq_pos = 0
# Cache storage: list of (key_cache, value_cache) tuples per layer
self.caches = []
for layer_idx in range(num_layers):
# Pre-allocate cache tensors with maximum size
# Shape: (batch_size, num_heads, max_seq_len, head_dim)
key_cache = Tensor(np.zeros((batch_size, num_heads, max_seq_len, head_dim)))
value_cache = Tensor(np.zeros((batch_size, num_heads, max_seq_len, head_dim)))
self.caches.append((key_cache, value_cache))
### END SOLUTION
def update(self, layer_idx: int, key: Tensor, value: Tensor) -> None:
"""
Update cache with new key-value pairs for given layer.
TODO: Efficiently append new K,V to cache without data copying
APPROACH:
1. Validate layer_idx is in range [0, num_layers-1]
2. Validate seq_pos hasn't exceeded max_seq_len
3. Retrieve the (key_cache, value_cache) tuple for this layer
4. Write new key to position seq_pos in key_cache using indexed assignment
5. Write new value to position seq_pos in value_cache using indexed assignment
6. Note: seq_pos is advanced externally via advance() after all layers
This is the core caching operation - efficiently append new K,V
to the cache without recomputation. This operation is O(1) because
it's just an indexed assignment.
IMPORTANT: KV caching is designed for INFERENCE (generation) only,
not training. During generation, gradients are not computed. If you
need gradients, don't use caching (use standard forward pass instead).
Args:
layer_idx: Which transformer layer (0 to num_layers-1)
key: New key tensor, shape (batch_size, num_heads, 1, head_dim)
value: New value tensor, shape (batch_size, num_heads, 1, head_dim)
EXAMPLE:
>>> cache = KVCache(batch_size=1, max_seq_len=10, num_layers=2,
... num_heads=4, head_dim=64)
>>> new_k = Tensor(np.random.randn(1, 4, 1, 64))
>>> new_v = Tensor(np.random.randn(1, 4, 1, 64))
>>> cache.update(layer_idx=0, key=new_k, value=new_v)
>>> cache.seq_pos # Still 0 (update doesn't advance position)
>>> cache.advance()
>>> cache.seq_pos # Now 1
HINTS:
- Use slicing: cache[:, :, seq_pos:seq_pos+1, :] to write to position
- Use .data for direct NumPy access (no gradient tracking needed)
- Raise ValueError with helpful messages for invalid inputs
- This is an in-place operation (modifies cache, returns None)
Raises:
ValueError: If layer_idx is out of range or sequence is full
"""
### BEGIN SOLUTION
if layer_idx >= self.num_layers:
raise ValueError(f"Layer index {layer_idx} >= num_layers {self.num_layers}")
if self.seq_pos >= self.max_seq_len:
raise ValueError(f"Sequence position {self.seq_pos} >= max_seq_len {self.max_seq_len}")
# Get cache for this layer
key_cache, value_cache = self.caches[layer_idx]
# Update cache at current position (efficient O(1) write)
# Note: We use .data here because caching is inference-only (no gradients needed)
# This avoids gradient tracking overhead during generation
key_cache.data[:, :, self.seq_pos:self.seq_pos+1, :] = key.data
value_cache.data[:, :, self.seq_pos:self.seq_pos+1, :] = value.data
# Note: seq_pos is advanced externally via advance() after all layers process
### END SOLUTION
def get(self, layer_idx: int) -> Tuple[Tensor, Tensor]:
"""
Retrieve cached key-value pairs for attention computation.
TODO: Return only the valid cached portion for this layer
APPROACH:
1. Validate layer_idx is in range
2. Retrieve the (key_cache, value_cache) tuple for this layer
3. Calculate valid_len = seq_pos (number of tokens currently cached)
4. Slice key_cache to get [:, :, :valid_len, :] (only filled portion)
5. Slice value_cache to get [:, :, :valid_len, :] (only filled portion)
6. Wrap sliced data in new Tensor objects and return
Returns only the valid portion of the cache (up to current seq_pos).
This is O(1) because we're just slicing NumPy arrays (view, not copy).
IMPORTANT: Returns Tensors without gradient tracking since caching
is inference-only. The returned tensors can be used in attention
computation but won't propagate gradients backward.
Args:
layer_idx: Which transformer layer to get cache for
Returns:
(cached_keys, cached_values): Tensors shaped for attention
Keys: (batch_size, num_heads, seq_pos, head_dim)
Values: (batch_size, num_heads, seq_pos, head_dim)
EXAMPLE:
>>> cache = KVCache(batch_size=1, max_seq_len=100, num_layers=2,
... num_heads=4, head_dim=64)
>>> # After processing 3 tokens
>>> cache.seq_pos = 3
>>> cached_k, cached_v = cache.get(layer_idx=0)
>>> cached_k.shape # (1, 4, 3, 64) - only first 3 positions
>>> cached_v.shape # (1, 4, 3, 64)
HINTS:
- valid_len = self.seq_pos (how many tokens have been cached so far)
- Use slicing: cache.data[:, :, :valid_len, :] to get valid portion
- Wrap result in Tensor() for consistency with TinyTorch API
- If seq_pos=0, returns empty cache (shape with 0 in sequence dimension)
Raises:
ValueError: If layer_idx is out of range
"""
### BEGIN SOLUTION
if layer_idx >= self.num_layers:
raise ValueError(f"Layer index {layer_idx} >= num_layers {self.num_layers}")
# Get cache for this layer
key_cache, value_cache = self.caches[layer_idx]
# Return only the valid portion (up to current sequence position)
# seq_pos tracks where to write next, so we have seq_pos valid tokens
valid_len = self.seq_pos
# Note: Creating new Tensors from .data (no gradient tracking)
# This is correct for inference-only caching
cached_keys = Tensor(key_cache.data[:, :, :valid_len, :])
cached_values = Tensor(value_cache.data[:, :, :valid_len, :])
return cached_keys, cached_values
### END SOLUTION
def advance(self) -> None:
"""
Advance sequence position after processing current token.
Call this after all layers have processed the current token and
updated their caches. This moves the write pointer forward.
"""
self.seq_pos += 1
def reset(self) -> None:
"""
Reset cache for new generation sequence.
Call this when starting a new generation (new prompt).
Resets the sequence position counter and optionally zeros cache data.
"""
self.seq_pos = 0
# Zero out caches for clean state (helps with debugging)
for layer_idx in range(self.num_layers):
key_cache, value_cache = self.caches[layer_idx]
key_cache.data.fill(0.0)
value_cache.data.fill(0.0)
def get_memory_usage(self) -> Dict[str, float]:
"""
Calculate memory usage of the cache system.
Returns:
Dictionary with memory statistics in MB
"""
# Calculate size of one cache tensor
cache_size = self.batch_size * self.num_heads * self.max_seq_len * self.head_dim
bytes_per_float = 4 # float32
# Each layer has key_cache + value_cache
total_cache_tensors = self.num_layers * 2
total_elements = cache_size * total_cache_tensors
total_bytes = total_elements * bytes_per_float
total_mb = total_bytes / (1024 * 1024)
return {
'total_mb': total_mb,
'per_layer_mb': total_mb / self.num_layers,
'cache_tensors': total_cache_tensors,
'total_elements': total_elements
}
# %% [markdown]
"""
### 🧪 Unit Test: KVCache Implementation
Let's test that our cache correctly stores and retrieves key-value pairs across multiple layers and sequence positions.
**This is a unit test** - it tests the KVCache class in isolation with simulated attention keys and values.
"""
# %% nbgrader={"grade": true, "grade_id": "test-kvcache", "locked": true, "points": 10}
def test_unit_kvcache():
"""🔬 Unit Test: KVCache Implementation"""
print("🔬 Unit Test: KVCache Implementation...")
# Test parameters (small transformer for testing)
batch_size, max_seq_len = 2, 8
num_layers, num_heads, head_dim = 3, 4, 16
# Create cache
cache = KVCache(batch_size, max_seq_len, num_layers, num_heads, head_dim)
# Test 1: Initial state
assert cache.seq_pos == 0, "Cache should start at position 0"
mem_usage = cache.get_memory_usage()
assert mem_usage['total_mb'] > 0, "Cache should have non-zero memory usage"
print(f" Cache initialized: {mem_usage['total_mb']:.2f} MB")
# Test 2: Single token update and retrieval
key1 = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim))
value1 = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim))
# Update layer 0 with first token
cache.update(0, key1, value1)
# Before advance, get() should return empty (seq_pos=0)
cached_k, cached_v = cache.get(0)
assert cached_k.shape == (batch_size, num_heads, 0, head_dim), "Before advance, cache should be empty"
# Advance position
cache.advance()
# Now cache should have 1 token
cached_k, cached_v = cache.get(0)
assert cached_k.shape == (batch_size, num_heads, 1, head_dim), f"Expected shape (2,4,1,16), got {cached_k.shape}"
assert cached_v.shape == (batch_size, num_heads, 1, head_dim), f"Expected shape (2,4,1,16), got {cached_v.shape}"
# Test 3: Multi-token sequence
key2 = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim))
value2 = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim))
cache.update(0, key2, value2)
cache.advance()
cached_k, cached_v = cache.get(0)
assert cached_k.shape == (batch_size, num_heads, 2, head_dim), "Should have 2 tokens cached"
assert cached_v.shape == (batch_size, num_heads, 2, head_dim), "Should have 2 tokens cached"
# Test 4: Multiple layers
cache.reset()
key_test = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim))
value_test = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim))
# Update all layers with same token
cache.update(0, key_test, value_test) # Layer 0
cache.update(1, key_test, value_test) # Layer 1
cache.update(2, key_test, value_test) # Layer 2
cache.advance()
# Each layer should have the cached token
for layer_idx in range(num_layers):
cached_k, cached_v = cache.get(layer_idx)
assert cached_k.shape[2] == 1, f"Layer {layer_idx} should have 1 token"
# Test 5: Reset functionality
cache.reset()
assert cache.seq_pos == 0, "Reset should clear sequence position"
cached_k, cached_v = cache.get(0)
assert cached_k.shape == (batch_size, num_heads, 0, head_dim), "Reset should clear cache"
print("✅ KVCache implementation works correctly!")
# Run test immediately when developing this module
if __name__ == "__main__":
test_unit_kvcache()
# %% [markdown]
"""
## 🎯 Part 4: Enabling KV Caching for Model Generation
### Integration Strategy
Now we need a clean way to enable KV caching in our existing transformer models without breaking the existing code. We'll create an `enable_kv_cache()` function that:
1. Creates a KVCache instance sized for the model
2. Returns a flag to indicate caching is enabled
3. Can be called before generation starts
The actual integration with attention will happen in the milestone code where we:
1. Check if cache is enabled
2. Only compute K,V for new token (not all tokens)
3. Update cache with new K,V
4. Use cached K,V for attention computation
### Generation Flow Comparison
```
Without Cache (Current):
for each new token:
input_seq = [all tokens so far] # Length grows: 1, 2, 3, ...
logits = model.forward(input_seq) # Recomputes everything!
next_token = sample(logits[-1])
append next_token
With Cache (New):
cache = enable_kv_cache(model)
for each new token:
input_token = [just new token] # Length always 1
logits = model.forward_cached(input_token, cache) # Only new computation
next_token = sample(logits[-1])
append next_token
```
**Key Difference**: Input changes from growing sequence to single token, with cache providing history.
"""
# %%
#| export
def enable_kv_cache(batch_size: int, max_seq_len: int, num_layers: int,
num_heads: int, head_dim: int) -> KVCache:
"""
Create and return a KVCache instance for model generation.
This function creates a properly sized cache for the model architecture.
Call this before starting generation, then pass the cache to your
generation loop.
Args:
batch_size: Number of sequences to generate simultaneously
max_seq_len: Maximum sequence length to support
num_layers: Number of transformer layers in model
num_heads: Number of attention heads per layer
head_dim: Dimension per attention head (usually embed_dim // num_heads)
Returns:
KVCache instance ready for use
EXAMPLE:
```python
# Enable caching for generation
cache = enable_kv_cache(
batch_size=1,
max_seq_len=100,
num_layers=4,
num_heads=4,
head_dim=32
)
# Use in generation loop (pseudocode)
for step in range(max_new_tokens):
# Only process new token with cache
logits = model.forward_cached(new_token, cache)
next_token = sample(logits)
```
"""
cache = KVCache(batch_size, max_seq_len, num_layers, num_heads, head_dim)
print(f"⚡ KV Cache enabled:")
print(f" Batch size: {batch_size}")
print(f" Max sequence: {max_seq_len}")
print(f" Layers: {num_layers}")
print(f" Heads: {num_heads}")
print(f" Head dim: {head_dim}")
mem_info = cache.get_memory_usage()
print(f" Memory: {mem_info['total_mb']:.2f} MB")
print()
return cache
# %% [markdown]
"""
### 🧪 Unit Test: Cache Enablement
Let's verify that we can create caches for realistic model configurations.
**This is a unit test** - it tests the cache creation and memory calculation for different model sizes.
"""
# %% nbgrader={"grade": true, "grade_id": "test-cache-enablement", "locked": true, "points": 10}
def test_unit_cache_enablement():
"""🔬 Unit Test: Cache Enablement for Different Models"""
print("🔬 Unit Test: Cache Enablement for Different Models...")
# Test 1: Small model (fast generation)
print(" Test 1: Small Model (Tiny Transformer)")
cache_small = KVCache(
batch_size=1,
max_seq_len=64,
num_layers=2,
num_heads=4,
head_dim=32
)
mem_small = cache_small.get_memory_usage()
assert mem_small['total_mb'] < 1.0, "Small model should use < 1 MB"
print(f" Small model cache: {mem_small['total_mb']:.3f} MB")
# Test 2: Medium model (balanced performance)
print(" Test 2: Medium Model (Standard Transformer)")
cache_medium = KVCache(
batch_size=1,
max_seq_len=128,
num_layers=4,
num_heads=8,
head_dim=64
)
mem_medium = cache_medium.get_memory_usage()
assert 1.0 < mem_medium['total_mb'] < 10.0, "Medium model should use 1-10 MB"
print(f" Medium model cache: {mem_medium['total_mb']:.3f} MB")
# Test 3: Batch inference (multiple sequences)
print(" Test 3: Batch Inference (4 sequences)")
cache_batch = KVCache(
batch_size=4, # Generate 4 sequences in parallel
max_seq_len=64,
num_layers=2,
num_heads=4,
head_dim=32
)
mem_batch = cache_batch.get_memory_usage()
assert mem_batch['total_mb'] > mem_small['total_mb'], "Batch cache should be larger"
print(f" Batch cache: {mem_batch['total_mb']:.3f} MB (4x batch size)")
print("✅ Cache enablement works correctly!")
# Run test immediately when developing this module
if __name__ == "__main__":
test_unit_cache_enablement()
# %% [markdown]
"""
## 🎯 Part 5: Using KV Cache in Practice
### Practical Integration Checklist
To use KV caching in your transformer generation:
**Before Generation:**
1. Create cache with `enable_kv_cache()`
2. Set cache dimensions to match your model architecture
3. Verify memory usage is acceptable
**During Generation (Modified Forward Pass):**
1. For the first token (prompt), process normally and populate cache
2. For subsequent tokens:
- Only process the NEW token (not entire sequence)
- Update cache with new K,V pairs
- Retrieve full cached K,V for attention
- Use cached values in attention computation
- Advance cache position after all layers
**After Generation:**
1. Reset cache if generating another sequence
2. Monitor memory usage for production deployment
### Performance Expectations
```
Expected Speedup by Sequence Length:
┌───────────┬──────────┬───────────┬──────────┐
│ Seq Len │ No Cache │ With Cache│ Speedup │
├───────────┼──────────┼───────────┼──────────┤
│ 10 tokens│ ~80 tok/s│ ~600 tok/s│ 7.5x │
│ 25 tokens│ ~40 tok/s│ ~500 tok/s│ 12.5x │
│ 50 tokens│ ~25 tok/s│ ~400 tok/s│ 16.0x │
│ 100 tokens│ ~12 tok/s│ ~200 tok/s│ 16.7x │
└───────────┴──────────┴───────────┴──────────┘
Key Insight: Speedup increases with sequence length!
Why? Longer sequences = more redundant computation without cache.
```
### Production Considerations
**Memory Management:**
- Cache memory = `batch_size × num_layers × num_heads × max_seq_len × head_dim × 4 bytes`
- For GPT-2 (12 layers, 12 heads, seq_len=1024, head_dim=64): ~37 MB per sequence
- For GPT-3 (96 layers, 96 heads, seq_len=2048, head_dim=128): ~4.7 GB per sequence
**Trade-off Analysis:**
- **10x+ speedup** for typical generation lengths (50-200 tokens)
- **Modest memory cost** compared to model parameters (often <1% of model size)
- **Enables real-time interaction** that's impossible without caching
**Best Practices:**
1. Always use caching for production serving
2. Tune `max_seq_len` to expected generation length (don't over-allocate)
3. Consider batch inference to amortize model loading costs
4. Monitor cache memory usage in production
"""
# %% [markdown]
"""
## 🎯 Part 5: Non-Invasive Integration with Existing Models
### The Challenge
We built KV caching in Module 17 (this module), but our transformer (Modules 12-13) doesn't know about it!
**❌ BAD Solution**: Go back and modify Module 12 (MultiHeadAttention)
- Breaks "forward-only" learning (students shouldn't revisit old modules)
- Makes Module 12 depend on Module 17 (wrong dependency direction!)
- Violates clean module boundaries
**✅ GOOD Solution**: Module 17 ADDS caching to existing models without modification!
- Use composition + monkey-patching (like `enable_autograd()`)
- Module 17 wraps/enhances Module 12, not modifies it
- Students learn systems engineering: "Add capabilities, don't break old code"
### Implementation Strategy
We'll create `enable_kv_cache(model)` that:
1. Creates cache for the model's architecture
2. Wraps each attention layer with caching logic
3. Intercepts attention calls and manages cache automatically
4. Returns the cache for manual control if needed
This is **non-invasive enhancement** - a critical ML systems pattern!
"""
# %% nbgrader={"grade": false, "grade_id": "enable-kv-cache", "solution": true}
#| export
def enable_kv_cache(model):
"""
Enable KV caching for a transformer model WITHOUT modifying Module 12/13 code.
TODO: Create cache and non-invasively patch attention layers
APPROACH:
1. Validate model has required attributes (embed_dim, num_layers, num_heads, max_seq_len, blocks)
2. Calculate head_dim from embed_dim and num_heads
3. Create KVCache instance sized for this model's architecture
4. Store cache on model as model._kv_cache and set model._cache_enabled flag
5. For each transformer block, wrap its attention forward method with caching logic
6. Print confirmation message with cache statistics
7. Return the cache object
This function demonstrates **non-invasive optimization** - adding capabilities
to existing systems without breaking them. Similar to how Module 05 (Autograd)
uses enable_autograd() to add gradient tracking to Tensors.
Args:
model: A GPT-style transformer model with:
- model.embed_dim (int)
- model.num_layers (int)
- model.num_heads (int)
- model.max_seq_len (int)
- model.blocks (list of TransformerBlock objects)
Returns:
cache: KVCache object for this model
EXAMPLE:
>>> from tinytorch.models.transformer import GPT
>>> model = GPT(vocab_size=100, embed_dim=128, num_layers=4, num_heads=4)
>>> cache = enable_kv_cache(model)
>>> hasattr(model, '_kv_cache') # True
>>> model._cache_enabled # True
>>> cache.num_layers # 4 (matches model)
HINTS:
- Use hasattr() to validate model attributes exist
- head_dim = model.embed_dim // model.num_heads
- Store cache on model with model._kv_cache = cache
- Set flag with model._cache_enabled = True
- Save original forward with block._original_attention_forward
- Use a factory function to create patched forwards (closure captures layer_idx)
Pedagogical Note:
This teaches students that optimizations can be LAYERED on top of
working systems. Module 17 doesn't break Modules 12-13; it enhances them!
"""
### BEGIN SOLUTION
import types
# Educational Note: hasattr() is LEGITIMATE here because:
# 1. This is a plugin system that works with user-defined models
# 2. We need runtime validation that model has required interface
# 3. Different model architectures may have different attributes
# This is the CORRECT use of hasattr() for duck-typing validation
# Validate model has required attributes
required_attrs = ['embed_dim', 'num_layers', 'num_heads', 'max_seq_len', 'blocks']
for attr in required_attrs:
if not hasattr(model, attr):
raise AttributeError(
f"Model missing '{attr}' - enable_kv_cache() requires a GPT-style model "
f"with {', '.join(required_attrs)}"
)
# Calculate head dimension
head_dim = model.embed_dim // model.num_heads
if model.embed_dim % model.num_heads != 0:
raise ValueError(
f"embed_dim ({model.embed_dim}) must be divisible by num_heads ({model.num_heads})"
)
# Create cache for this model
cache = KVCache(
batch_size=1, # Default to single sequence; can be reset for batch inference
max_seq_len=model.max_seq_len,
num_layers=model.num_layers,
num_heads=model.num_heads,
head_dim=head_dim
)
# Store cache on model for easy access
model._kv_cache = cache
model._cache_enabled = True
# Patch each transformer block's attention
for layer_idx, block in enumerate(model.blocks):
# Educational Note: hasattr() is LEGITIMATE here because:
# This is a monkey-patching safety check to avoid double-patching
# We're checking if we've already modified this object
if not hasattr(block, '_original_attention_forward'):
block._original_attention_forward = block.attention.forward
# Create cached version
def make_cached_forward(layer_idx, original_forward, cache_obj):
"""Factory to create cached forward with correct layer_idx closure"""
def cached_forward(x, mask=None):
"""
Cached attention forward pass with REAL speedup!
PATH SELECTION STRATEGY (Key to Understanding KV Caching):
──────────────────────────────────────────────────────────
We have THREE possible paths through attention:
1⃣ TRAINING PATH (seq_len > 1):
- Input: Full sequence of tokens (e.g., 64 tokens)
- Action: Use ORIGINAL attention (no caching)
- Why: Need full gradient flow for backpropagation
- Complexity: O(n²) but that's fine for training
- Example: x.shape = (batch=1, seq=64, embed=128)
2⃣ FIRST TOKEN PATH (seq_len == 1 AND cache empty):
- Input: Single token (the first one in generation)
- Action: Use ORIGINAL attention (initialize cache)
- Why: Cache is empty, nothing to retrieve yet
- Complexity: O(1) - only one token
- Example: x.shape = (batch=1, seq=1, embed=128), cache.seq_pos=0
3⃣ CACHED GENERATION PATH (seq_len == 1 AND cache populated):
- Input: Single NEW token (during generation)
- Action: Compute K,V for new token ONLY, retrieve history from cache
- Why: This is where the speedup happens! O(n²) → O(n)
- Complexity: O(n) - only compute for new token, reuse cache
- Example: x.shape = (batch=1, seq=1, embed=128), cache.seq_pos=5
WHY .data INSTEAD OF TENSOR OPERATIONS?
────────────────────────────────────────
In the cached path, we use numpy via .data for three reasons:
1. **Explicit Intent**: Makes it crystal clear this is inference-only
- Training: Uses Tensor operations → gradients tracked
- Inference: Uses .data → no gradient overhead
2. **Performance**: Avoids any autograd bookkeeping
- Even if small, every bit counts in generation
- Production LLMs (vLLM, llama.cpp) use similar patterns
3. **Educational Clarity**: Shows students the distinction
- "When do I need gradients?" (training)
- "When can I skip them?" (inference)
We COULD use Tensor operations with requires_grad=False, but .data
is more explicit and is the industry-standard pattern.
THE O(n²) → O(n) TRANSFORMATION:
─────────────────────────────────
WITHOUT Cache (Standard Attention):
Step 1: Process token 1 → Compute attention for 1 token (1² = 1 op)
Step 2: Process tokens 1-2 → Compute attention for 2 tokens (2² = 4 ops)
Step 3: Process tokens 1-3 → Compute attention for 3 tokens (3² = 9 ops)
...
Step N: Process tokens 1-N → Compute attention for N tokens (N² ops)
Total: 1 + 4 + 9 + ... + N² = O(N³) across all steps!
WITH Cache (Our Implementation):
Step 1: Process token 1 → Compute K,V for token 1, cache it (1 op)
Step 2: Process token 2 → Compute K,V for token 2, retrieve 1 (2 ops)
Step 3: Process token 3 → Compute K,V for token 3, retrieve 1-2 (3 ops)
...
Step N: Process token N → Compute K,V for token N, retrieve 1-(N-1) (N ops)
Total: 1 + 2 + 3 + ... + N = O(N²) across all steps!
That's why we see 5-7x speedup on short sequences, and 10-15x on longer ones!
"""
from tinytorch.core.tensor import Tensor
import numpy as np
seq_len = x.shape[1]
# ═══════════════════════════════════════════════════════════════
# PATH SELECTION: Choose between training, first token, or cached
# ═══════════════════════════════════════════════════════════════
# PATH 1: TRAINING (seq_len > 1)
# ───────────────────────────────────
# Input is a full sequence (e.g., 64 tokens during training)
# We MUST use original attention to preserve gradient flow
# No caching during training - we need backprop through everything
if seq_len > 1:
return original_forward(x, mask) # O(n²) but preserves gradients
# PATH 2: FIRST TOKEN (seq_len == 1, cache empty)
# ────────────────────────────────────────────────
# This is the very first token in generation (cache.seq_pos == 0)
# Cache is empty, so there's nothing to retrieve yet
# Use original attention to process this token, which will populate cache
if cache_obj.seq_pos == 0:
return original_forward(x, mask) # O(1) - just one token
# PATH 3: CACHED GENERATION (seq_len == 1, cache populated)
# ──────────────────────────────────────────────────────────
# This is a NEW token during generation (cache has history)
# We can now use the cache for massive speedup!
# Compute K,V for ONLY this new token, retrieve cached history
# Get attention layer (assumes block.attention has the attention object)
attention = block.attention
# Step 1: Compute Q, K, V for NEW token only
# Access the linear projection layers
Q_new = attention.q_proj.forward(x) # (batch, 1, embed_dim)
K_new = attention.k_proj.forward(x) # (batch, 1, embed_dim)
V_new = attention.v_proj.forward(x) # (batch, 1, embed_dim)
# Step 2: Reshape to multi-head format
batch_size = x.shape[0]
num_heads = attention.num_heads
head_dim = attention.head_dim
# Reshape: (batch, 1, embed_dim) → (batch, num_heads, 1, head_dim)
Q_heads = Q_new.reshape(batch_size, 1, num_heads, head_dim)
Q_heads = Tensor(np.transpose(Q_heads.data, (0, 2, 1, 3))) # (batch, num_heads, 1, head_dim)
K_heads = K_new.reshape(batch_size, 1, num_heads, head_dim)
K_heads = Tensor(np.transpose(K_heads.data, (0, 2, 1, 3)))
V_heads = V_new.reshape(batch_size, 1, num_heads, head_dim)
V_heads = Tensor(np.transpose(V_heads.data, (0, 2, 1, 3)))
# Step 3: Update cache with new K, V (using .data for performance)
cache_obj.update(layer_idx, K_heads, V_heads)
# Step 4: Retrieve ALL cached K, V (includes history + new token)
K_all, V_all = cache_obj.get(layer_idx)
# Step 5: Compute attention using new Q with ALL cached K, V
# ─────────────────────────────────────────────────────────
# Scaled dot-product attention: softmax(Q @ K^T / sqrt(d_k)) @ V
#
# NOTE: We use .data (numpy arrays) here instead of Tensor operations
# Why? This is INFERENCE-ONLY code (no gradients needed):
# - Explicit: Makes it clear this is inference, not training
# - Fast: Avoids autograd overhead (even if small)
# - Standard: Production LLMs (vLLM, llama.cpp) do the same
#
# If this were training, we'd use Tensor operations for gradient flow.
# But in generation (inference), .data is the right choice.
# Q @ K^T: (batch, num_heads, 1, head_dim) @ (batch, num_heads, head_dim, seq_len)
# → (batch, num_heads, 1, seq_len)
K_transposed = np.transpose(K_all.data, (0, 1, 3, 2)) # .data = numpy array
scores = np.matmul(Q_heads.data, K_transposed) # Pure numpy matmul
# Scale by sqrt(head_dim)
scores = scores / np.sqrt(head_dim)
# Apply mask if provided (causal mask for generation)
if mask is not None:
# Mask should be (1, 1, 1, seq_len) for this token
# In generation, we can attend to all previous tokens
pass # No masking needed in generation (we see all history)
# Softmax over key dimension
scores_max = np.max(scores, axis=-1, keepdims=True)
exp_scores = np.exp(scores - scores_max)
attention_weights = exp_scores / np.sum(exp_scores, axis=-1, keepdims=True)
# Apply attention weights to values
# (batch, num_heads, 1, seq_len) @ (batch, num_heads, seq_len, head_dim)
# → (batch, num_heads, 1, head_dim)
attention_output = np.matmul(attention_weights, V_all.data)
# Step 6: Reshape back and apply output projection
# (batch, num_heads, 1, head_dim) → (batch, 1, num_heads, head_dim)
attention_output_transposed = np.transpose(attention_output, (0, 2, 1, 3))
# Concatenate heads: (batch, 1, num_heads * head_dim)
concat_data = attention_output_transposed.reshape(batch_size, 1, num_heads * head_dim)
concat_output = Tensor(concat_data)
# Output projection
output = attention.out_proj.forward(concat_output)
return output
return cached_forward
# Patch this block's attention
block.attention.forward = make_cached_forward(layer_idx, block._original_attention_forward, cache)
print(f"⚡ KV Cache enabled for model!")
print(f" Architecture: {model.num_layers} layers × {model.num_heads} heads × {head_dim}D")
print(f" Memory: {cache.get_memory_usage()['total_mb']:.2f} MB")
print(f" Cache stored in: model._kv_cache")
print()
print(f"💡 To disable: call disable_kv_cache(model)")
print()
return cache
### END SOLUTION
#| export
def disable_kv_cache(model):
"""
Disable KV caching and restore original attention behavior.
Args:
model: Model with caching enabled
EXAMPLE:
```python
cache = enable_kv_cache(model)
# ... do cached generation ...
disable_kv_cache(model) # Back to normal
```
"""
# Educational Note: hasattr() is LEGITIMATE here because:
# Checking if monkey-patch markers exist before restoration
if not hasattr(model, '_cache_enabled') or not model._cache_enabled:
print("⚠️ KV cache not enabled on this model")
return
# Restore original attention forwards
for block in model.blocks:
# Educational Note: hasattr() is LEGITIMATE here because:
# Checking for monkey-patch backup before restoration
if hasattr(block, '_original_attention_forward'):
block.attention.forward = block._original_attention_forward
# Clean up
model._cache_enabled = False
# Educational Note: hasattr() is LEGITIMATE here because:
# Safe cleanup check before deleting dynamically added attribute
if hasattr(model, '_kv_cache'):
delattr(model, '_kv_cache')
print("✓ KV cache disabled, original attention restored")
# %% [markdown]
"""
### 🧪 Unit Test: Non-Invasive Cache Integration
Let's verify that `enable_kv_cache()` works without breaking the model!
**This is an integration test** - it tests Module 17 enhancing Modules 12-13 without modification.
"""
# %% nbgrader={"grade": true, "grade_id": "test-noninvasive", "locked": true, "points": 10}
def test_unit_noninvasive_integration():
"""🔬 Unit Test: Non-Invasive Cache Integration"""
print("🔬 Unit Test: Non-Invasive Cache Integration...")
# Create a mock transformer-like object for testing
class MockTransformerBlock:
def __init__(self):
self.attention = self
def forward(self, x, mask=None):
# Simple pass-through for testing
return x
class MockGPT:
def __init__(self):
self.vocab_size = 100
self.embed_dim = 128
self.num_layers = 4
self.num_heads = 4
self.max_seq_len = 64
self.blocks = [MockTransformerBlock() for _ in range(self.num_layers)]
# Test 1: Enable caching
model = MockGPT()
print(" Test 1: Enable caching on model")
cache = enable_kv_cache(model)
assert hasattr(model, '_kv_cache'), "Model should have _kv_cache attribute"
assert hasattr(model, '_cache_enabled'), "Model should have _cache_enabled flag"
assert model._cache_enabled == True, "Cache should be enabled"
assert cache is model._kv_cache, "Returned cache should match model._kv_cache"
# Test 2: Attention forward still works
print(" Test 2: Attention forward pass still works")
test_input = Tensor(np.random.randn(1, 10, 128))
for block in model.blocks:
output = block.attention.forward(test_input)
assert output.shape == test_input.shape, "Forward pass should preserve shape"
# Test 3: Disable caching
print(" Test 3: Disable caching")
disable_kv_cache(model)
assert model._cache_enabled == False, "Cache should be disabled"
assert not hasattr(model, '_kv_cache'), "Cache object should be removed"
# Test 4: Can re-enable
print(" Test 4: Re-enable caching")
_ = enable_kv_cache(model)
assert model._cache_enabled == True, "Cache should be re-enabled"
print("✅ Non-invasive cache integration works correctly!")
# Run test immediately when developing this module
if __name__ == "__main__":
test_unit_noninvasive_integration()
# %% [markdown]
"""
## Part 5: Systems Analysis - KV Cache Performance
Now let's analyze the performance characteristics and trade-offs of KV caching.
"""
# %% nbgrader={"grade": false, "grade_id": "analyze-memory", "locked": false}
def analyze_kvcache_memory():
"""
📊 Analyze KV cache memory usage across different configurations.
Educational Purpose:
Demonstrates how cache memory scales with model architecture.
Students discover:
- Linear scaling with sequence length O(n)
- Memory overhead as percentage of model parameters
- Trade-off between cache size and speedup gains
Analyzes:
- Tiny models (128D): ~0.12 MB
- Small models (512D): ~2 MB
- Medium models (768D): ~9 MB
- Large models (1024D): ~32 MB
Key Insight:
Cache overhead is 10-30% of model parameters, but enables
10-15× speedup. Memory is cheap, compute is expensive!
Production Context:
GPT-3 (175B params, 2048 context): ~4GB cache per sequence
This memory cost is acceptable given the massive speedup.
"""
print("📊 Analyzing KV Cache Memory Usage...")
print()
# Test different model configurations
configs = [
(128, 4, 32, "Tiny"),
(512, 8, 64, "Small"),
(768, 12, 128, "Medium"),
(1024, 16, 256, "Large"),
]
print("Model Config | Cache Memory | Per Layer | Memory Overhead")
print("-" * 60)
for embed_dim, num_layers, seq_len, name in configs:
# Memory per layer: 2 tensors (K, V) × batch × seq_len × embed_dim × 4 bytes
batch_size = 1
memory_per_layer = 2 * batch_size * seq_len * embed_dim * BYTES_PER_FLOAT32 / MB_TO_BYTES
total_memory = memory_per_layer * num_layers
# Model parameter memory (approximate)
params_per_layer = embed_dim * embed_dim * BYTES_PER_FLOAT32 # QKV projections
model_memory = params_per_layer * num_layers * BYTES_PER_FLOAT32 / MB_TO_BYTES
overhead_pct = (total_memory / model_memory) * 100 if model_memory > 0 else 0
print(f"{name:12s} | {total_memory:11.2f} MB | {memory_per_layer:8.2f} MB | {overhead_pct:6.1f}%")
print()
print("💡 Key Insights:")
print(" • Cache memory scales linearly with sequence length (O(n))")
print(" • Longer sequences require proportionally more cache memory")
print(" • Cache overhead is typically 10-30% of model parameters")
print()
print("🚀 Production Context:")
print(" • GPT-3 (175B params, 2048 context): ~4GB cache memory")
print(" • Trade-off: 2× memory enables 10-15× speedup")
print(" • Worth it for inference-heavy workloads!")
# %% nbgrader={"grade": false, "grade_id": "analyze-speedup", "locked": false}
def analyze_kvcache_speedup():
"""
📊 Measure KV cache speedup vs vanilla attention.
Educational Purpose:
Shows students WHY caching provides dramatic speedup through
concrete complexity analysis. Compares O(n²) vs O(n) growth.
Demonstrates:
- Naive approach: O(n²) operations per token
- Cached approach: O(n) operations per token
- Speedup increases with generation length
- 100-token generation: 170× fewer operations
Key Insight:
Speedup is SUPER-LINEAR with generation length because:
- Longer sequences → more redundant computation without cache
- Cache benefit compounds: saves O(n²) → O(n) at EVERY step
Production Reality:
This is why ChatGPT can generate responses in real-time.
Without caching, conversational AI would be economically impossible.
"""
print("\n📊 Analyzing KV Cache Speedup...")
print()
import time
# Create test configuration
batch_size = 1
embed_dim = 256
num_heads = 8
head_dim = embed_dim // num_heads
print("Generation Length | Without Cache | With Cache | Speedup")
print("-" * 55)
for gen_length in [10, 25, 50, 100]:
# Simulate without cache: O(n²) for each new token
# Each token processes entire context
ops_without = sum(i**2 for i in range(1, gen_length + 1))
# Simulate with cache: O(n) for each new token
# Each token only processes itself
ops_with = gen_length
# Estimate time (arbitrary units)
time_without = ops_without / 1000 # ms
time_with = ops_with / 1000 # ms
speedup = ops_without / ops_with
print(f"{gen_length:17d} | {time_without:12.1f} ms | {time_with:10.1f} ms | {speedup:6.1f}×")
print()
print("💡 Key Insights:")
print(" • Speedup increases with generation length (longer = better ROI)")
print(" • 100-token generation: ~170× fewer operations!")
print(" • Cache eliminates O(n²) recomputation per token")
print()
print("🚀 Production Reality:")
print(" • ChatGPT uses KV caching for ALL generation")
print(" • Without caching: 100-token response takes ~17 seconds")
print(" • With caching: 100-token response takes ~0.1 seconds")
print(" • This optimization makes conversational AI possible!")
# Run analysis functions when module is executed directly
if __name__ == "__main__":
analyze_kvcache_memory()
analyze_kvcache_speedup()
# %% [markdown]
"""
## Part 6: Module Integration Test
Final validation that everything works together correctly before module completion.
"""
# %% nbgrader={"grade": true, "grade_id": "module-integration", "locked": true, "points": 20}
def test_module():
"""🧪 Module Test: Complete Integration
Comprehensive test of entire KV Caching module functionality.
This final test runs before module summary to ensure:
- All unit tests pass
- Functions work together correctly
- Module is ready for integration with TinyTorch
"""
print("🧪 RUNNING MODULE INTEGRATION TEST")
print("=" * 50)
print()
# Run all unit tests
print("Running unit tests...")
test_unit_kvcache()
print()
test_unit_cache_enablement()
print()
test_unit_noninvasive_integration()
print()
print("Running integration scenarios...")
print()
# Integration Test: Complete KV Cache Workflow
print("🔬 Integration Test: Complete KV Cache Workflow...")
batch_size, max_seq_len = 1, 128
num_layers, num_heads, head_dim = 4, 8, 64
cache = KVCache(batch_size, max_seq_len, num_layers, num_heads, head_dim)
# Simulate generation loop (processing multiple tokens)
for _ in range(5):
for layer_idx in range(num_layers):
# Simulate new key-value pairs
new_key = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim))
new_value = Tensor(np.random.randn(batch_size, num_heads, 1, head_dim))
# Update cache
cache.update(layer_idx, new_key, new_value)
# Advance position after all layers processed
cache.advance()
# Verify cache state
assert cache.seq_pos == 5, f"Expected seq_pos=5, got {cache.seq_pos}"
# Verify retrieval
for layer_idx in range(num_layers):
cached_k, cached_v = cache.get(layer_idx)
assert cached_k.shape == (batch_size, num_heads, 5, head_dim)
assert cached_v.shape == (batch_size, num_heads, 5, head_dim)
print("✅ Complete KV cache workflow validated!")
print()
# Integration Test: Memory Tracking
print("🔬 Integration Test: Memory Tracking...")
mem_info = cache.get_memory_usage()
assert mem_info['total_mb'] > 0
assert mem_info['cache_tensors'] == num_layers * 2
print(f"✅ Memory tracking: {mem_info['total_mb']:.2f} MB for {mem_info['cache_tensors']} tensors")
print()
print("=" * 50)
print("🎉 ALL TESTS PASSED! Module ready for export.")
print("Run: tito module complete 17")
# %%
if __name__ == "__main__":
test_module()
# %% [markdown]
"""
## 🤔 ML Systems Reflection Questions
Answer these questions based on your implementation and the concepts you've learned in Modules 01-17.
### Question 1: Cache Size Calculation
A 12-layer transformer has 12 attention heads per layer, 64-dimensional embeddings per head,
maximum sequence length of 2048, and batch size of 8. Calculate the KV cache size:
**Step-by-step calculation**:
- One cache tensor shape: (batch=8, heads=12, seq_len=2048, head_dim=64)
- Elements per tensor: 8 × 12 × 2048 × 64 = _________
- Each layer has K cache + V cache = _________ tensors per layer
- Total across 12 layers = _________ cache tensors
- Float32 = 4 bytes per element
- Total memory in MB: _________
**Follow-up**: If this model has 125M parameters (500 MB), what percentage of model memory
is the cache? Is this overhead acceptable?
### Question 2: Speed vs Memory Trade-off
Your KVCache makes generation 10× faster but uses several GB of RAM.
Consider a production API serving 1000 users simultaneously:
- Without cache: Each generation is slow (10 sec) but uses minimal memory
- With cache: Each generation is fast (1 sec) but uses 100 MB cache per user = 100 GB total!
**Questions**:
- For an interactive chatbot, is this trade-off worth it? Why?
- What happens if your server only has 64 GB RAM but needs to serve 1000 users?
- How would you design a system that balances speed and memory for many concurrent users?
### Question 3: Batch Inference Scaling
With KV cache, each sequence in a batch gets its own cache storage.
**Scenario**: Batch size 1 generates at 500 tokens/sec, using 50 MB cache.
- For batch size 8: Predicted cache memory = _________ MB (scales how?)
- Does each sequence still generate at 500 tokens/sec? Why or why not?
- What's the throughput difference: 1×500 tok/s vs 8×? tok/s = _________ total tok/s
**Trade-off question**: For a production API, when should you use:
- High batch size (8-16): Good for _________
- Low batch size (1-2): Good for _________
### Question 4: Cache Eviction for Long Conversations
Your `KVCache` has `max_seq_len=2048`. A chatbot conversation reaches 2048 tokens - the cache is full!
**Options when cache is full**:
1. **Crash/Error**: Raise exception when max_seq_len exceeded
2. **FIFO eviction**: Drop oldest tokens, keep recent 2048
3. **Sliding window**: Keep most recent N tokens
4. **Restart cache**: Clear everything and start over
**Questions**:
- What happens to conversation context if you evict the first 1000 tokens?
- Why do production systems (ChatGPT) limit conversation length (e.g., 4096 or 8192 tokens)?
- Which eviction strategy would you choose for a medical chatbot that needs full conversation history?
### Question 5: Production Reality - Multi-User Serving
ChatGPT serves millions of users. Each user's conversation needs its own KV cache.
**Memory calculation for 10,000 concurrent conversations**:
- Each cache: 200 MB (typical for GPT-3.5 scale model)
- Total cache memory: 10,000 × 200 MB = _________ GB
- Model parameters: 13B × 4 bytes = 52 GB (loaded once, shared across all users)
- **Total memory needed**: _________ GB
**Questions**:
- Is it feasible to keep 10,000 caches in memory simultaneously on a single GPU (80 GB VRAM)?
- How do you think production systems manage cache memory across millions of users?
- Would you rather: (A) Keep all caches in memory (fast but expensive), or (B) Store inactive
caches on disk and reload as needed (slower but cheaper)? What's the trade-off?
"""
# %% [markdown]
"""
## 🎯 MODULE SUMMARY: KV Caching (Memoization)
Congratulations! You've built the optimization that makes production language models economically viable!
### Key Accomplishments
- Built KVCache class with efficient memory management for K,V tensors across layers
- Implemented non-invasive cache integration using enable_kv_cache()
- Measured 10-15× speedup through analysis functions showing O(n²)→O(n) improvement
- Understood memory-compute trade-off (2× memory enables 10× speedup)
- Discovered why speedup increases with generation length
- All tests pass ✅ (validated by `test_module()`)
### Systems Insights Gained
- **Recomputation Elimination**: Caching K/V eliminates O(n²) redundant work per token
- **Memory-Speed Trade-off**: Doubling memory enables order-of-magnitude speedup
- **Scaling Benefits**: Longer generation = better cache return on investment (170× at 100 tokens)
- **Production Critical**: This single optimization makes ChatGPT-scale inference possible
- **Non-Invasive Design**: Add capabilities forward without breaking existing modules
### Real-World Impact
Without KV caching:
- 100-token generation: ~17 seconds
- Conversational AI: economically infeasible
- User experience: unacceptably slow
With KV caching:
- 100-token generation: ~0.1 seconds (170× faster!)
- Conversational AI: production-ready at scale
- User experience: real-time interaction
This optimization is THE technique that transformed language models from research demonstrations into products serving millions of users daily.
### Production Skills Developed
- **Systems Optimization**: Identify and eliminate computational bottlenecks
- **Memory-Compute Trade-offs**: Accept memory cost for speed gains
- **Non-Breaking Enhancement**: Add features without modifying existing code
- **Performance Analysis**: Measure and validate optimization impact
### Ready for Next Steps
Your KV caching implementation demonstrates the principle: "spend memory to save time"!
Export with: `tito module complete 17`
**Next**: Module 16 (Quantization) will use the opposite trade-off: "sacrifice precision to save memory"!
### What You Just Built Powers
- **ChatGPT, Claude, GPT-4**: All production LLMs use KV caching
- **Real-time chat**: Instant response generation
- **Streaming output**: Efficient token-by-token generation
- **Cost-effective inference**: 10× speedup = 10× more users per GPU
The technique you implemented is mathematically identical to the caching in production language models - you've built a core optimization that enables modern AI!
"""
# %% [markdown]
"""
## 🎓 Module 17 Complete!
You've implemented KV caching - the critical optimization that makes production language models economically viable!
### What You Built
✅ **KVCache Class**: Efficient memory management for key-value pairs across layers
✅ **O(1) Updates**: Fast cache updates without data copying
✅ **Memory Tracking**: Understanding cache size and memory trade-offs
✅ **Non-Invasive Integration**: `enable_kv_cache()` adds optimization WITHOUT breaking modules
✅ **Production Patterns**: Integration strategy for real transformer models
### Key Systems Engineering Lesson
**Module 17 doesn't modify Modules 12-13 - it ENHANCES them!**
This teaches the critical principle: **Add capabilities forward, never break backward.**
- Old code keeps working (Module 12 unchanged)
- New code adds optimization (Module 17 layers on top)
- Clean separation of concerns (caching is separate from attention logic)
### Performance Impact
```
Without Cache: O(n²) complexity → slow, expensive, impractical
With Cache: O(n) complexity → fast, cheap, production-ready
Real Impact: 10-15x speedup for typical generation!
```
### What's Next
**Module 16 (Quantization)**: Now that you've optimized compute through caching, learn how to optimize memory through reduced precision arithmetic.
### Try It Yourself
Run the chatbot milestone with and without caching:
```bash
# Without cache (slow - baseline)
python milestones/05_2017_transformer/vaswani_chatgpt.py
# With cache (fast - 10-15x speedup!)
python milestones/05_2017_transformer/vaswani_chatgpt.py --use-cache
```
Watch the tokens/sec metric jump from ~40 to ~500! 🚀
---
**Congratulations! You've completed Module 17: KV Caching (Memoization)!**
You now understand the optimization that makes ChatGPT, Claude, and all production LLMs possible. This is THE technique that transformed language models from research toys into products used by millions of people every day.
**From Theory to Practice**: You've gone from O(n²) naive generation to O(n) optimized generation. This is real ML engineering!
"""