Document KV caching as inference-only (no gradient flow concerns)

Added comprehensive documentation clarifying that KV caching is designed
ONLY for inference (generation), not training.

Key Clarifications:
- Cache operations use .data (no gradient tracking)
- This is correct and intentional for maximum speed
- During generation: no gradients computed (model.eval() mode)
- During training: cache not used (standard forward pass)
- DO NOT use caching during training

Why This is Safe:
1. Training: Uses standard forward pass (full gradient flow)
2. Generation: No backward pass (no gradients needed)
3. Cache is inference optimization, not training component
4. .data usage is correct for generation-only use case

Documentation Updates:
- Added prominent warning in class docstring
- Updated update() method docs
- Updated get() method docs
- Added inline comments explaining .data usage

This addresses gradient flow concerns by making it crystal clear that
caching is never used when gradients are needed.
This commit is contained in:
Vijay Janapa Reddi
2025-11-05 14:05:47 -05:00
parent b3f63d7ccf
commit 6d0afe4949
2 changed files with 40 additions and 2 deletions

View File

@@ -19,6 +19,14 @@ class KVCache:
during sequential token generation. This is THE critical optimization
that makes production language model serving economically viable.
⚠️ IMPORTANT: INFERENCE-ONLY (No Gradient Tracking)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
KV caching is designed ONLY for inference (generation), NOT training.
- During generation: No gradients computed (model.eval() mode)
- Cache operations use .data (no gradient tracking)
- This is correct and intentional for maximum speed
- DO NOT use caching during training (use standard forward pass)
Architecture:
- Pre-allocates cache tensors with maximum sequence length
- Tracks current sequence position for efficient O(1) updates
@@ -80,6 +88,10 @@ class KVCache:
to the cache without recomputation. This operation is O(1) because
it's just an indexed assignment.
IMPORTANT: KV caching is designed for INFERENCE (generation) only,
not training. During generation, gradients are not computed. If you
need gradients, don't use caching (use standard forward pass instead).
Args:
layer_idx: Which transformer layer (0 to num_layers-1)
key: New key tensor, shape (batch_size, num_heads, 1, head_dim)
@@ -98,7 +110,8 @@ class KVCache:
key_cache, value_cache = self.caches[layer_idx]
# Update cache at current position (efficient O(1) write)
# Key insight: We write to a specific position, no data copying!
# Note: We use .data here because caching is inference-only (no gradients needed)
# This avoids gradient tracking overhead during generation
key_cache.data[:, :, self.seq_pos:self.seq_pos+1, :] = key.data
value_cache.data[:, :, self.seq_pos:self.seq_pos+1, :] = value.data
@@ -111,6 +124,10 @@ class KVCache:
Returns only the valid portion of the cache (up to current seq_pos).
This is O(1) because we're just slicing NumPy arrays (view, not copy).
IMPORTANT: Returns Tensors without gradient tracking since caching
is inference-only. The returned tensors can be used in attention
computation but won't propagate gradients backward.
Args:
layer_idx: Which transformer layer to get cache for
@@ -132,6 +149,8 @@ class KVCache:
# seq_pos tracks where to write next, so we have seq_pos valid tokens
valid_len = self.seq_pos
# Note: Creating new Tensors from .data (no gradient tracking)
# This is correct for inference-only caching
cached_keys = Tensor(key_cache.data[:, :, :valid_len, :])
cached_values = Tensor(value_cache.data[:, :, :valid_len, :])