mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-03-11 18:24:24 -05:00
Document KV caching as inference-only (no gradient flow concerns)
Added comprehensive documentation clarifying that KV caching is designed ONLY for inference (generation), not training. Key Clarifications: - Cache operations use .data (no gradient tracking) - This is correct and intentional for maximum speed - During generation: no gradients computed (model.eval() mode) - During training: cache not used (standard forward pass) - DO NOT use caching during training Why This is Safe: 1. Training: Uses standard forward pass (full gradient flow) 2. Generation: No backward pass (no gradients needed) 3. Cache is inference optimization, not training component 4. .data usage is correct for generation-only use case Documentation Updates: - Added prominent warning in class docstring - Updated update() method docs - Updated get() method docs - Added inline comments explaining .data usage This addresses gradient flow concerns by making it crystal clear that caching is never used when gradients are needed.
This commit is contained in:
21
tinytorch/generation/kv_cache.py
generated
21
tinytorch/generation/kv_cache.py
generated
@@ -19,6 +19,14 @@ class KVCache:
|
||||
during sequential token generation. This is THE critical optimization
|
||||
that makes production language model serving economically viable.
|
||||
|
||||
⚠️ IMPORTANT: INFERENCE-ONLY (No Gradient Tracking)
|
||||
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
||||
KV caching is designed ONLY for inference (generation), NOT training.
|
||||
- During generation: No gradients computed (model.eval() mode)
|
||||
- Cache operations use .data (no gradient tracking)
|
||||
- This is correct and intentional for maximum speed
|
||||
- DO NOT use caching during training (use standard forward pass)
|
||||
|
||||
Architecture:
|
||||
- Pre-allocates cache tensors with maximum sequence length
|
||||
- Tracks current sequence position for efficient O(1) updates
|
||||
@@ -80,6 +88,10 @@ class KVCache:
|
||||
to the cache without recomputation. This operation is O(1) because
|
||||
it's just an indexed assignment.
|
||||
|
||||
IMPORTANT: KV caching is designed for INFERENCE (generation) only,
|
||||
not training. During generation, gradients are not computed. If you
|
||||
need gradients, don't use caching (use standard forward pass instead).
|
||||
|
||||
Args:
|
||||
layer_idx: Which transformer layer (0 to num_layers-1)
|
||||
key: New key tensor, shape (batch_size, num_heads, 1, head_dim)
|
||||
@@ -98,7 +110,8 @@ class KVCache:
|
||||
key_cache, value_cache = self.caches[layer_idx]
|
||||
|
||||
# Update cache at current position (efficient O(1) write)
|
||||
# Key insight: We write to a specific position, no data copying!
|
||||
# Note: We use .data here because caching is inference-only (no gradients needed)
|
||||
# This avoids gradient tracking overhead during generation
|
||||
key_cache.data[:, :, self.seq_pos:self.seq_pos+1, :] = key.data
|
||||
value_cache.data[:, :, self.seq_pos:self.seq_pos+1, :] = value.data
|
||||
|
||||
@@ -111,6 +124,10 @@ class KVCache:
|
||||
Returns only the valid portion of the cache (up to current seq_pos).
|
||||
This is O(1) because we're just slicing NumPy arrays (view, not copy).
|
||||
|
||||
IMPORTANT: Returns Tensors without gradient tracking since caching
|
||||
is inference-only. The returned tensors can be used in attention
|
||||
computation but won't propagate gradients backward.
|
||||
|
||||
Args:
|
||||
layer_idx: Which transformer layer to get cache for
|
||||
|
||||
@@ -132,6 +149,8 @@ class KVCache:
|
||||
# seq_pos tracks where to write next, so we have seq_pos valid tokens
|
||||
valid_len = self.seq_pos
|
||||
|
||||
# Note: Creating new Tensors from .data (no gradient tracking)
|
||||
# This is correct for inference-only caching
|
||||
cached_keys = Tensor(key_cache.data[:, :, :valid_len, :])
|
||||
cached_values = Tensor(value_cache.data[:, :, :valid_len, :])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user