Fix enable_kv_cache to handle mask parameter and add integration test

Module 14 fix:
- Updated cached_forward() to accept mask parameter (x, mask=None)
- Attention forward calls with 2 args: forward(x, mask)
- Now properly passes through both arguments to original forward

Integration test (test_kv_cache_milestone.py):
- Tests generation WITHOUT cache (baseline)
- Tests generation WITH cache enabled
- Verifies cache infrastructure works without breaking model
- Documents current implementation (architecture demo)
- Shows that full speedup requires deeper attention integration

Test results:
 Without cache: 139.3 tok/s
 With cache: 142.5 tok/s (similar - expected with pass-through)
 Cache infrastructure successfully integrated
 Model continues to work with caching enabled

Educational value:
Students learn the PATTERN of non-invasive optimization through
composition and monkey-patching, which is more important than
absolute speedup numbers for this module.
This commit is contained in:
Vijay Janapa Reddi
2025-11-05 19:13:41 -05:00
parent 28320ebb81
commit 6c8b448086
4 changed files with 175 additions and 36 deletions

View File

@@ -443,23 +443,23 @@ def enable_kv_cache(model):
# Create cached version
def make_cached_forward(layer_idx, original_forward):
"""Factory to create cached forward with correct layer_idx closure"""
def cached_forward(x):
def cached_forward(x, mask=None):
"""
Cached attention forward pass.
EDUCATIONAL NOTE: In a production implementation, this would:
1. Check if we're generating (single new token) vs training (full sequence)
2. For generation: only compute K,V for new token, retrieve history from cache
3. For training: use original uncached path
For TinyTorch simplicity, we demonstrate the concept without full implementation.
The cache is created and tracked, showing students the architecture pattern.
"""
# In training: use original path (no caching during backprop!)
# In generation: this is where we'd use cache
# For now, pass through to original to maintain correctness
return original_forward(x)
return original_forward(x, mask)
return cached_forward
# Patch this block's attention