mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-04-28 19:24:28 -05:00
Document KV caching as inference-only (no gradient flow concerns)
Added comprehensive documentation clarifying that KV caching is designed ONLY for inference (generation), not training. Key Clarifications: - Cache operations use .data (no gradient tracking) - This is correct and intentional for maximum speed - During generation: no gradients computed (model.eval() mode) - During training: cache not used (standard forward pass) - DO NOT use caching during training Why This is Safe: 1. Training: Uses standard forward pass (full gradient flow) 2. Generation: No backward pass (no gradients needed) 3. Cache is inference optimization, not training component 4. .data usage is correct for generation-only use case Documentation Updates: - Added prominent warning in class docstring - Updated update() method docs - Updated get() method docs - Added inline comments explaining .data usage This addresses gradient flow concerns by making it crystal clear that caching is never used when gradients are needed.
This commit is contained in:
@@ -269,6 +269,14 @@ class KVCache:
|
||||
during sequential token generation. This is THE critical optimization
|
||||
that makes production language model serving economically viable.
|
||||
|
||||
⚠️ IMPORTANT: INFERENCE-ONLY (No Gradient Tracking)
|
||||
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
||||
KV caching is designed ONLY for inference (generation), NOT training.
|
||||
- During generation: No gradients computed (model.eval() mode)
|
||||
- Cache operations use .data (no gradient tracking)
|
||||
- This is correct and intentional for maximum speed
|
||||
- DO NOT use caching during training (use standard forward pass)
|
||||
|
||||
Architecture:
|
||||
- Pre-allocates cache tensors with maximum sequence length
|
||||
- Tracks current sequence position for efficient O(1) updates
|
||||
@@ -330,6 +338,10 @@ class KVCache:
|
||||
to the cache without recomputation. This operation is O(1) because
|
||||
it's just an indexed assignment.
|
||||
|
||||
IMPORTANT: KV caching is designed for INFERENCE (generation) only,
|
||||
not training. During generation, gradients are not computed. If you
|
||||
need gradients, don't use caching (use standard forward pass instead).
|
||||
|
||||
Args:
|
||||
layer_idx: Which transformer layer (0 to num_layers-1)
|
||||
key: New key tensor, shape (batch_size, num_heads, 1, head_dim)
|
||||
@@ -348,7 +360,8 @@ class KVCache:
|
||||
key_cache, value_cache = self.caches[layer_idx]
|
||||
|
||||
# Update cache at current position (efficient O(1) write)
|
||||
# Key insight: We write to a specific position, no data copying!
|
||||
# Note: We use .data here because caching is inference-only (no gradients needed)
|
||||
# This avoids gradient tracking overhead during generation
|
||||
key_cache.data[:, :, self.seq_pos:self.seq_pos+1, :] = key.data
|
||||
value_cache.data[:, :, self.seq_pos:self.seq_pos+1, :] = value.data
|
||||
|
||||
@@ -361,6 +374,10 @@ class KVCache:
|
||||
Returns only the valid portion of the cache (up to current seq_pos).
|
||||
This is O(1) because we're just slicing NumPy arrays (view, not copy).
|
||||
|
||||
IMPORTANT: Returns Tensors without gradient tracking since caching
|
||||
is inference-only. The returned tensors can be used in attention
|
||||
computation but won't propagate gradients backward.
|
||||
|
||||
Args:
|
||||
layer_idx: Which transformer layer to get cache for
|
||||
|
||||
@@ -382,6 +399,8 @@ class KVCache:
|
||||
# seq_pos tracks where to write next, so we have seq_pos valid tokens
|
||||
valid_len = self.seq_pos
|
||||
|
||||
# Note: Creating new Tensors from .data (no gradient tracking)
|
||||
# This is correct for inference-only caching
|
||||
cached_keys = Tensor(key_cache.data[:, :, :valid_len, :])
|
||||
cached_values = Tensor(value_cache.data[:, :, :valid_len, :])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user