From 6d0afe4949c72e1307b69fc472e45f47ccb492dd Mon Sep 17 00:00:00 2001 From: Vijay Janapa Reddi Date: Wed, 5 Nov 2025 14:05:47 -0500 Subject: [PATCH] Document KV caching as inference-only (no gradient flow concerns) Added comprehensive documentation clarifying that KV caching is designed ONLY for inference (generation), not training. Key Clarifications: - Cache operations use .data (no gradient tracking) - This is correct and intentional for maximum speed - During generation: no gradients computed (model.eval() mode) - During training: cache not used (standard forward pass) - DO NOT use caching during training Why This is Safe: 1. Training: Uses standard forward pass (full gradient flow) 2. Generation: No backward pass (no gradients needed) 3. Cache is inference optimization, not training component 4. .data usage is correct for generation-only use case Documentation Updates: - Added prominent warning in class docstring - Updated update() method docs - Updated get() method docs - Added inline comments explaining .data usage This addresses gradient flow concerns by making it crystal clear that caching is never used when gradients are needed. --- modules/source/14_kvcaching/kvcaching_dev.py | 21 +++++++++++++++++++- tinytorch/generation/kv_cache.py | 21 +++++++++++++++++++- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/modules/source/14_kvcaching/kvcaching_dev.py b/modules/source/14_kvcaching/kvcaching_dev.py index 1f21fda8..9ea715f1 100644 --- a/modules/source/14_kvcaching/kvcaching_dev.py +++ b/modules/source/14_kvcaching/kvcaching_dev.py @@ -269,6 +269,14 @@ class KVCache: during sequential token generation. This is THE critical optimization that makes production language model serving economically viable. + ⚠️ IMPORTANT: INFERENCE-ONLY (No Gradient Tracking) + ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + KV caching is designed ONLY for inference (generation), NOT training. + - During generation: No gradients computed (model.eval() mode) + - Cache operations use .data (no gradient tracking) + - This is correct and intentional for maximum speed + - DO NOT use caching during training (use standard forward pass) + Architecture: - Pre-allocates cache tensors with maximum sequence length - Tracks current sequence position for efficient O(1) updates @@ -330,6 +338,10 @@ class KVCache: to the cache without recomputation. This operation is O(1) because it's just an indexed assignment. + IMPORTANT: KV caching is designed for INFERENCE (generation) only, + not training. During generation, gradients are not computed. If you + need gradients, don't use caching (use standard forward pass instead). + Args: layer_idx: Which transformer layer (0 to num_layers-1) key: New key tensor, shape (batch_size, num_heads, 1, head_dim) @@ -348,7 +360,8 @@ class KVCache: key_cache, value_cache = self.caches[layer_idx] # Update cache at current position (efficient O(1) write) - # Key insight: We write to a specific position, no data copying! + # Note: We use .data here because caching is inference-only (no gradients needed) + # This avoids gradient tracking overhead during generation key_cache.data[:, :, self.seq_pos:self.seq_pos+1, :] = key.data value_cache.data[:, :, self.seq_pos:self.seq_pos+1, :] = value.data @@ -361,6 +374,10 @@ class KVCache: Returns only the valid portion of the cache (up to current seq_pos). This is O(1) because we're just slicing NumPy arrays (view, not copy). + IMPORTANT: Returns Tensors without gradient tracking since caching + is inference-only. The returned tensors can be used in attention + computation but won't propagate gradients backward. + Args: layer_idx: Which transformer layer to get cache for @@ -382,6 +399,8 @@ class KVCache: # seq_pos tracks where to write next, so we have seq_pos valid tokens valid_len = self.seq_pos + # Note: Creating new Tensors from .data (no gradient tracking) + # This is correct for inference-only caching cached_keys = Tensor(key_cache.data[:, :, :valid_len, :]) cached_values = Tensor(value_cache.data[:, :, :valid_len, :]) diff --git a/tinytorch/generation/kv_cache.py b/tinytorch/generation/kv_cache.py index 44d64d1d..0ca362b8 100644 --- a/tinytorch/generation/kv_cache.py +++ b/tinytorch/generation/kv_cache.py @@ -19,6 +19,14 @@ class KVCache: during sequential token generation. This is THE critical optimization that makes production language model serving economically viable. + ⚠️ IMPORTANT: INFERENCE-ONLY (No Gradient Tracking) + ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + KV caching is designed ONLY for inference (generation), NOT training. + - During generation: No gradients computed (model.eval() mode) + - Cache operations use .data (no gradient tracking) + - This is correct and intentional for maximum speed + - DO NOT use caching during training (use standard forward pass) + Architecture: - Pre-allocates cache tensors with maximum sequence length - Tracks current sequence position for efficient O(1) updates @@ -80,6 +88,10 @@ class KVCache: to the cache without recomputation. This operation is O(1) because it's just an indexed assignment. + IMPORTANT: KV caching is designed for INFERENCE (generation) only, + not training. During generation, gradients are not computed. If you + need gradients, don't use caching (use standard forward pass instead). + Args: layer_idx: Which transformer layer (0 to num_layers-1) key: New key tensor, shape (batch_size, num_heads, 1, head_dim) @@ -98,7 +110,8 @@ class KVCache: key_cache, value_cache = self.caches[layer_idx] # Update cache at current position (efficient O(1) write) - # Key insight: We write to a specific position, no data copying! + # Note: We use .data here because caching is inference-only (no gradients needed) + # This avoids gradient tracking overhead during generation key_cache.data[:, :, self.seq_pos:self.seq_pos+1, :] = key.data value_cache.data[:, :, self.seq_pos:self.seq_pos+1, :] = value.data @@ -111,6 +124,10 @@ class KVCache: Returns only the valid portion of the cache (up to current seq_pos). This is O(1) because we're just slicing NumPy arrays (view, not copy). + IMPORTANT: Returns Tensors without gradient tracking since caching + is inference-only. The returned tensors can be used in attention + computation but won't propagate gradients backward. + Args: layer_idx: Which transformer layer to get cache for @@ -132,6 +149,8 @@ class KVCache: # seq_pos tracks where to write next, so we have seq_pos valid tokens valid_len = self.seq_pos + # Note: Creating new Tensors from .data (no gradient tracking) + # This is correct for inference-only caching cached_keys = Tensor(key_cache.data[:, :, :valid_len, :]) cached_values = Tensor(value_cache.data[:, :, :valid_len, :])