mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-04-29 19:57:30 -05:00
Add jupytext to requirements and export Module 14
Requirements.txt updates: - Added jupytext>=1.16.0 (required for tito export) - Added nbformat>=5.10.0 (jupytext dependency) - New section: Development Tools (Required for tito export) Module 14 export: - Successfully exported kvcaching_dev.py to tinytorch/generation/kv_cache.py - Generated kvcaching_dev.ipynb (21 cells: 9 code, 12 markdown) - KVCache class, enable_kv_cache(), disable_kv_cache() now in package Auto-generated updates: - Added DO NOT EDIT warnings to 8 exported files - Updated _modidx.py with Module 14 exports - Protected core files from manual editing Export now works with: tito export 14_kvcaching Students can import: from tinytorch.generation.kv_cache import enable_kv_cache
This commit is contained in:
314
tinytorch/generation/kv_cache.py
generated
314
tinytorch/generation/kv_cache.py
generated
@@ -1,16 +1,31 @@
|
||||
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../modules/source/14_kvcaching/kvcaching_dev.py (unless otherwise specified).
|
||||
# ╔═══════════════════════════════════════════════════════════════════════════════╗
|
||||
# ║ 🚨 CRITICAL WARNING 🚨 ║
|
||||
# ║ AUTOGENERATED! DO NOT EDIT! ║
|
||||
# ║ ║
|
||||
# ║ This file is AUTOMATICALLY GENERATED from source modules. ║
|
||||
# ║ ANY CHANGES MADE HERE WILL BE LOST when modules are re-exported! ║
|
||||
# ║ ║
|
||||
# ║ ✅ TO EDIT: modules/source/XX_kv_cache/kv_cache_dev.py ║
|
||||
# ║ ✅ TO EXPORT: Run 'tito module complete <module_name>' ║
|
||||
# ║ ║
|
||||
# ║ 🛡️ STUDENT PROTECTION: This file contains optimized implementations. ║
|
||||
# ║ Editing it directly may break module functionality and training. ║
|
||||
# ║ ║
|
||||
# ║ 🎓 LEARNING TIP: Work in modules/source/ - that's where real development ║
|
||||
# ║ happens! The tinytorch/ directory is just the compiled output. ║
|
||||
# ╚═══════════════════════════════════════════════════════════════════════════════╝
|
||||
# %% auto 0
|
||||
__all__ = ['KVCache', 'enable_kv_cache', 'disable_kv_cache']
|
||||
|
||||
__all__ = ['KVCache', 'enable_kv_cache']
|
||||
|
||||
# Cell
|
||||
# %% ../../modules/source/14_kvcaching/kvcaching_dev.ipynb 1
|
||||
import numpy as np
|
||||
import time
|
||||
from typing import Tuple, Optional, Dict, List
|
||||
|
||||
# Import TinyTorch components from previous modules
|
||||
from tinytorch.core.tensor import Tensor
|
||||
from ..core.tensor import Tensor
|
||||
|
||||
# Cell
|
||||
# %% ../../modules/source/14_kvcaching/kvcaching_dev.ipynb 5
|
||||
class KVCache:
|
||||
"""
|
||||
Efficient key-value cache for autoregressive generation.
|
||||
@@ -48,113 +63,192 @@ class KVCache:
|
||||
- Memory: O(num_layers × batch × heads × max_seq × head_dim)
|
||||
"""
|
||||
|
||||
def __init__(self, batch_size: int, max_seq_len: int, num_layers: int,
|
||||
def __init__(self, batch_size: int, max_seq_len: int, num_layers: int,
|
||||
num_heads: int, head_dim: int):
|
||||
"""
|
||||
Initialize KV cache for efficient generation.
|
||||
|
||||
|
||||
TODO: Set up pre-allocated cache storage for all transformer layers
|
||||
|
||||
APPROACH:
|
||||
1. Store configuration parameters (batch_size, max_seq_len, etc.)
|
||||
2. Initialize sequence position counter to 0
|
||||
3. Create empty list for cache storage
|
||||
4. For each layer, pre-allocate zero-filled key and value caches
|
||||
5. Store each layer's (key_cache, value_cache) tuple in the list
|
||||
|
||||
Args:
|
||||
batch_size: Number of sequences to generate simultaneously
|
||||
max_seq_len: Maximum sequence length to support
|
||||
num_layers: Number of transformer layers
|
||||
num_heads: Number of attention heads per layer
|
||||
head_dim: Dimension of each attention head
|
||||
|
||||
EXAMPLE:
|
||||
>>> cache = KVCache(batch_size=2, max_seq_len=128, num_layers=4,
|
||||
... num_heads=8, head_dim=64)
|
||||
>>> cache.seq_pos # 0 (no tokens cached yet)
|
||||
>>> len(cache.caches) # 4 (one per layer)
|
||||
>>> cache.caches[0][0].shape # (2, 8, 128, 64) - key cache for layer 0
|
||||
|
||||
HINTS:
|
||||
- Cache shape: (batch_size, num_heads, max_seq_len, head_dim)
|
||||
- Use Tensor(np.zeros(...)) to create cache tensors
|
||||
- Store caches as list of tuples: [(key_0, val_0), (key_1, val_1), ...]
|
||||
- Pre-allocation avoids dynamic resizing overhead during generation
|
||||
"""
|
||||
### BEGIN SOLUTION
|
||||
self.batch_size = batch_size
|
||||
self.max_seq_len = max_seq_len
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
|
||||
# Current sequence position (how many tokens are cached)
|
||||
self.seq_pos = 0
|
||||
|
||||
|
||||
# Cache storage: list of (key_cache, value_cache) tuples per layer
|
||||
self.caches = []
|
||||
|
||||
|
||||
for layer_idx in range(num_layers):
|
||||
# Pre-allocate cache tensors with maximum size
|
||||
# Shape: (batch_size, num_heads, max_seq_len, head_dim)
|
||||
key_cache = Tensor(np.zeros((batch_size, num_heads, max_seq_len, head_dim)))
|
||||
value_cache = Tensor(np.zeros((batch_size, num_heads, max_seq_len, head_dim)))
|
||||
|
||||
|
||||
self.caches.append((key_cache, value_cache))
|
||||
### END SOLUTION
|
||||
|
||||
def update(self, layer_idx: int, key: Tensor, value: Tensor) -> None:
|
||||
"""
|
||||
Update cache with new key-value pairs for given layer.
|
||||
|
||||
This is the core caching operation - efficiently append new K,V
|
||||
|
||||
TODO: Efficiently append new K,V to cache without data copying
|
||||
|
||||
APPROACH:
|
||||
1. Validate layer_idx is in range [0, num_layers-1]
|
||||
2. Validate seq_pos hasn't exceeded max_seq_len
|
||||
3. Retrieve the (key_cache, value_cache) tuple for this layer
|
||||
4. Write new key to position seq_pos in key_cache using indexed assignment
|
||||
5. Write new value to position seq_pos in value_cache using indexed assignment
|
||||
6. Note: seq_pos is advanced externally via advance() after all layers
|
||||
|
||||
This is the core caching operation - efficiently append new K,V
|
||||
to the cache without recomputation. This operation is O(1) because
|
||||
it's just an indexed assignment.
|
||||
|
||||
IMPORTANT: KV caching is designed for INFERENCE (generation) only,
|
||||
|
||||
IMPORTANT: KV caching is designed for INFERENCE (generation) only,
|
||||
not training. During generation, gradients are not computed. If you
|
||||
need gradients, don't use caching (use standard forward pass instead).
|
||||
|
||||
|
||||
Args:
|
||||
layer_idx: Which transformer layer (0 to num_layers-1)
|
||||
key: New key tensor, shape (batch_size, num_heads, 1, head_dim)
|
||||
value: New value tensor, shape (batch_size, num_heads, 1, head_dim)
|
||||
|
||||
|
||||
EXAMPLE:
|
||||
>>> cache = KVCache(batch_size=1, max_seq_len=10, num_layers=2,
|
||||
... num_heads=4, head_dim=64)
|
||||
>>> new_k = Tensor(np.random.randn(1, 4, 1, 64))
|
||||
>>> new_v = Tensor(np.random.randn(1, 4, 1, 64))
|
||||
>>> cache.update(layer_idx=0, key=new_k, value=new_v)
|
||||
>>> cache.seq_pos # Still 0 (update doesn't advance position)
|
||||
>>> cache.advance()
|
||||
>>> cache.seq_pos # Now 1
|
||||
|
||||
HINTS:
|
||||
- Use slicing: cache[:, :, seq_pos:seq_pos+1, :] to write to position
|
||||
- Use .data for direct NumPy access (no gradient tracking needed)
|
||||
- Raise ValueError with helpful messages for invalid inputs
|
||||
- This is an in-place operation (modifies cache, returns None)
|
||||
|
||||
Raises:
|
||||
ValueError: If layer_idx is out of range or sequence is full
|
||||
"""
|
||||
### BEGIN SOLUTION
|
||||
if layer_idx >= self.num_layers:
|
||||
raise ValueError(f"Layer index {layer_idx} >= num_layers {self.num_layers}")
|
||||
|
||||
|
||||
if self.seq_pos >= self.max_seq_len:
|
||||
raise ValueError(f"Sequence position {self.seq_pos} >= max_seq_len {self.max_seq_len}")
|
||||
|
||||
|
||||
# Get cache for this layer
|
||||
key_cache, value_cache = self.caches[layer_idx]
|
||||
|
||||
|
||||
# Update cache at current position (efficient O(1) write)
|
||||
# Note: We use .data here because caching is inference-only (no gradients needed)
|
||||
# This avoids gradient tracking overhead during generation
|
||||
key_cache.data[:, :, self.seq_pos:self.seq_pos+1, :] = key.data
|
||||
value_cache.data[:, :, self.seq_pos:self.seq_pos+1, :] = value.data
|
||||
|
||||
|
||||
# Note: seq_pos is advanced externally via advance() after all layers process
|
||||
### END SOLUTION
|
||||
|
||||
def get(self, layer_idx: int) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Retrieve cached key-value pairs for attention computation.
|
||||
|
||||
|
||||
TODO: Return only the valid cached portion for this layer
|
||||
|
||||
APPROACH:
|
||||
1. Validate layer_idx is in range
|
||||
2. Retrieve the (key_cache, value_cache) tuple for this layer
|
||||
3. Calculate valid_len = seq_pos (number of tokens currently cached)
|
||||
4. Slice key_cache to get [:, :, :valid_len, :] (only filled portion)
|
||||
5. Slice value_cache to get [:, :, :valid_len, :] (only filled portion)
|
||||
6. Wrap sliced data in new Tensor objects and return
|
||||
|
||||
Returns only the valid portion of the cache (up to current seq_pos).
|
||||
This is O(1) because we're just slicing NumPy arrays (view, not copy).
|
||||
|
||||
|
||||
IMPORTANT: Returns Tensors without gradient tracking since caching
|
||||
is inference-only. The returned tensors can be used in attention
|
||||
computation but won't propagate gradients backward.
|
||||
|
||||
|
||||
Args:
|
||||
layer_idx: Which transformer layer to get cache for
|
||||
|
||||
|
||||
Returns:
|
||||
(cached_keys, cached_values): Tensors shaped for attention
|
||||
Keys: (batch_size, num_heads, seq_pos, head_dim)
|
||||
Values: (batch_size, num_heads, seq_pos, head_dim)
|
||||
|
||||
|
||||
EXAMPLE:
|
||||
>>> cache = KVCache(batch_size=1, max_seq_len=100, num_layers=2,
|
||||
... num_heads=4, head_dim=64)
|
||||
>>> # After processing 3 tokens
|
||||
>>> cache.seq_pos = 3
|
||||
>>> cached_k, cached_v = cache.get(layer_idx=0)
|
||||
>>> cached_k.shape # (1, 4, 3, 64) - only first 3 positions
|
||||
>>> cached_v.shape # (1, 4, 3, 64)
|
||||
|
||||
HINTS:
|
||||
- valid_len = self.seq_pos (how many tokens have been cached so far)
|
||||
- Use slicing: cache.data[:, :, :valid_len, :] to get valid portion
|
||||
- Wrap result in Tensor() for consistency with TinyTorch API
|
||||
- If seq_pos=0, returns empty cache (shape with 0 in sequence dimension)
|
||||
|
||||
Raises:
|
||||
ValueError: If layer_idx is out of range
|
||||
"""
|
||||
### BEGIN SOLUTION
|
||||
if layer_idx >= self.num_layers:
|
||||
raise ValueError(f"Layer index {layer_idx} >= num_layers {self.num_layers}")
|
||||
|
||||
|
||||
# Get cache for this layer
|
||||
key_cache, value_cache = self.caches[layer_idx]
|
||||
|
||||
|
||||
# Return only the valid portion (up to current sequence position)
|
||||
# seq_pos tracks where to write next, so we have seq_pos valid tokens
|
||||
valid_len = self.seq_pos
|
||||
|
||||
|
||||
# Note: Creating new Tensors from .data (no gradient tracking)
|
||||
# This is correct for inference-only caching
|
||||
cached_keys = Tensor(key_cache.data[:, :, :valid_len, :])
|
||||
cached_values = Tensor(value_cache.data[:, :, :valid_len, :])
|
||||
|
||||
|
||||
return cached_keys, cached_values
|
||||
### END SOLUTION
|
||||
|
||||
def advance(self) -> None:
|
||||
"""
|
||||
@@ -204,7 +298,7 @@ class KVCache:
|
||||
'total_elements': total_elements
|
||||
}
|
||||
|
||||
# Cell
|
||||
# %% ../../modules/source/14_kvcaching/kvcaching_dev.ipynb 9
|
||||
def enable_kv_cache(batch_size: int, max_seq_len: int, num_layers: int,
|
||||
num_heads: int, head_dim: int) -> KVCache:
|
||||
"""
|
||||
@@ -257,3 +351,159 @@ def enable_kv_cache(batch_size: int, max_seq_len: int, num_layers: int,
|
||||
|
||||
return cache
|
||||
|
||||
# %% ../../modules/source/14_kvcaching/kvcaching_dev.ipynb 14
|
||||
def enable_kv_cache(model):
|
||||
"""
|
||||
Enable KV caching for a transformer model WITHOUT modifying Module 12/13 code.
|
||||
|
||||
TODO: Create cache and non-invasively patch attention layers
|
||||
|
||||
APPROACH:
|
||||
1. Validate model has required attributes (embed_dim, num_layers, num_heads, max_seq_len, blocks)
|
||||
2. Calculate head_dim from embed_dim and num_heads
|
||||
3. Create KVCache instance sized for this model's architecture
|
||||
4. Store cache on model as model._kv_cache and set model._cache_enabled flag
|
||||
5. For each transformer block, wrap its attention forward method with caching logic
|
||||
6. Print confirmation message with cache statistics
|
||||
7. Return the cache object
|
||||
|
||||
This function demonstrates **non-invasive optimization** - adding capabilities
|
||||
to existing systems without breaking them. Similar to how Module 05 (Autograd)
|
||||
uses enable_autograd() to add gradient tracking to Tensors.
|
||||
|
||||
Args:
|
||||
model: A GPT-style transformer model with:
|
||||
- model.embed_dim (int)
|
||||
- model.num_layers (int)
|
||||
- model.num_heads (int)
|
||||
- model.max_seq_len (int)
|
||||
- model.blocks (list of TransformerBlock objects)
|
||||
|
||||
Returns:
|
||||
cache: KVCache object for this model
|
||||
|
||||
EXAMPLE:
|
||||
>>> from tinytorch.models.transformer import GPT
|
||||
>>> model = GPT(vocab_size=100, embed_dim=128, num_layers=4, num_heads=4)
|
||||
>>> cache = enable_kv_cache(model)
|
||||
>>> hasattr(model, '_kv_cache') # True
|
||||
>>> model._cache_enabled # True
|
||||
>>> cache.num_layers # 4 (matches model)
|
||||
|
||||
HINTS:
|
||||
- Use hasattr() to validate model attributes exist
|
||||
- head_dim = model.embed_dim // model.num_heads
|
||||
- Store cache on model with model._kv_cache = cache
|
||||
- Set flag with model._cache_enabled = True
|
||||
- Save original forward with block._original_attention_forward
|
||||
- Use a factory function to create patched forwards (closure captures layer_idx)
|
||||
|
||||
Pedagogical Note:
|
||||
This teaches students that optimizations can be LAYERED on top of
|
||||
working systems. Module 14 doesn't break Modules 12-13; it enhances them!
|
||||
"""
|
||||
### BEGIN SOLUTION
|
||||
import types
|
||||
|
||||
# Validate model has required attributes
|
||||
required_attrs = ['embed_dim', 'num_layers', 'num_heads', 'max_seq_len', 'blocks']
|
||||
for attr in required_attrs:
|
||||
if not hasattr(model, attr):
|
||||
raise AttributeError(
|
||||
f"Model missing '{attr}' - enable_kv_cache() requires a GPT-style model "
|
||||
f"with {', '.join(required_attrs)}"
|
||||
)
|
||||
|
||||
# Calculate head dimension
|
||||
head_dim = model.embed_dim // model.num_heads
|
||||
if model.embed_dim % model.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"embed_dim ({model.embed_dim}) must be divisible by num_heads ({model.num_heads})"
|
||||
)
|
||||
|
||||
# Create cache for this model
|
||||
cache = KVCache(
|
||||
batch_size=1, # Default to single sequence; can be reset for batch inference
|
||||
max_seq_len=model.max_seq_len,
|
||||
num_layers=model.num_layers,
|
||||
num_heads=model.num_heads,
|
||||
head_dim=head_dim
|
||||
)
|
||||
|
||||
# Store cache on model for easy access
|
||||
model._kv_cache = cache
|
||||
model._cache_enabled = True
|
||||
|
||||
# Patch each transformer block's attention
|
||||
for layer_idx, block in enumerate(model.blocks):
|
||||
# Store original attention forward method
|
||||
if not hasattr(block, '_original_attention_forward'):
|
||||
block._original_attention_forward = block.attention.forward
|
||||
|
||||
# Create cached version
|
||||
def make_cached_forward(layer_idx, original_forward):
|
||||
"""Factory to create cached forward with correct layer_idx closure"""
|
||||
def cached_forward(x):
|
||||
"""
|
||||
Cached attention forward pass.
|
||||
|
||||
EDUCATIONAL NOTE: In a production implementation, this would:
|
||||
1. Check if we're generating (single new token) vs training (full sequence)
|
||||
2. For generation: only compute K,V for new token, retrieve history from cache
|
||||
3. For training: use original uncached path
|
||||
|
||||
For TinyTorch simplicity, we demonstrate the concept without full implementation.
|
||||
The cache is created and tracked, showing students the architecture pattern.
|
||||
"""
|
||||
# In training: use original path (no caching during backprop!)
|
||||
# In generation: this is where we'd use cache
|
||||
# For now, pass through to original to maintain correctness
|
||||
return original_forward(x)
|
||||
|
||||
return cached_forward
|
||||
|
||||
# Patch this block's attention
|
||||
block.attention.forward = make_cached_forward(layer_idx, block._original_attention_forward)
|
||||
|
||||
print(f"⚡ KV Cache enabled for model!")
|
||||
print(f" Architecture: {model.num_layers} layers × {model.num_heads} heads × {head_dim}D")
|
||||
print(f" Memory: {cache.get_memory_usage()['total_mb']:.2f} MB")
|
||||
print(f" Cache stored in: model._kv_cache")
|
||||
print()
|
||||
print(f"💡 To disable: call disable_kv_cache(model)")
|
||||
print()
|
||||
|
||||
return cache
|
||||
### END SOLUTION
|
||||
|
||||
|
||||
#| export
|
||||
def disable_kv_cache(model):
|
||||
"""
|
||||
Disable KV caching and restore original attention behavior.
|
||||
|
||||
Args:
|
||||
model: Model with caching enabled
|
||||
|
||||
Example:
|
||||
```python
|
||||
cache = enable_kv_cache(model)
|
||||
# ... do cached generation ...
|
||||
disable_kv_cache(model) # Back to normal
|
||||
```
|
||||
"""
|
||||
if not hasattr(model, '_cache_enabled') or not model._cache_enabled:
|
||||
print("⚠️ KV cache not enabled on this model")
|
||||
return
|
||||
|
||||
# Restore original attention forwards
|
||||
for block in model.blocks:
|
||||
if hasattr(block, '_original_attention_forward'):
|
||||
block.attention.forward = block._original_attention_forward
|
||||
|
||||
# Clean up
|
||||
model._cache_enabled = False
|
||||
if hasattr(model, '_kv_cache'):
|
||||
delattr(model, '_kv_cache')
|
||||
|
||||
print("✓ KV cache disabled, original attention restored")
|
||||
|
||||
Reference in New Issue
Block a user