Add comprehensive documentation for KV cache path selection

Enhanced Module 14 with extensive educational documentation explaining:

Three-Path Selection Strategy:
- PATH 1: Training (seq_len > 1) - Uses original attention, preserves gradients
- PATH 2: First Token (cache empty) - Uses original attention, initializes cache
- PATH 3: Cached Generation (cache populated) - THE SPEEDUP PATH, O(n) computation

Why .data Instead of Tensor Operations:
- Explicit intent: Clear separation of training vs inference code
- Performance: Avoids autograd overhead during generation
- Industry standard: Production LLMs (vLLM, llama.cpp) use same pattern

O(n²) to O(n) Transformation Explained:
- WITHOUT cache: O(N³) total across all steps (1² + 2² + ... + N²)
- WITH cache: O(N²) total across all steps (1 + 2 + ... + N)
- Result: 5-7x speedup on short sequences, 10-15x on longer ones

Inline comments added at every decision point for student comprehension.
Module 14 now complete with working implementation and comprehensive pedagogy.
This commit is contained in:
Vijay Janapa Reddi
2025-11-06 12:30:39 -05:00
parent 3b21687f0f
commit 80734693e8
3 changed files with 379 additions and 115 deletions

View File

@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "markdown",
"id": "bd52e3da",
"id": "1078513e",
"metadata": {
"cell_marker": "\"\"\""
},
@@ -52,7 +52,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "26b79392",
"id": "266270f3",
"metadata": {},
"outputs": [],
"source": [
@@ -69,7 +69,7 @@
},
{
"cell_type": "markdown",
"id": "7cd54d44",
"id": "06ca957c",
"metadata": {
"cell_marker": "\"\"\""
},
@@ -126,7 +126,7 @@
},
{
"cell_type": "markdown",
"id": "eb00159f",
"id": "dc896d3f",
"metadata": {
"cell_marker": "\"\"\""
},
@@ -199,7 +199,7 @@
},
{
"cell_type": "markdown",
"id": "4dd6bb57",
"id": "c3feca5a",
"metadata": {
"cell_marker": "\"\"\"",
"lines_to_next_cell": 1
@@ -278,7 +278,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "4b00b030",
"id": "6d054a8c",
"metadata": {
"lines_to_next_cell": 1,
"nbgrader": {
@@ -293,7 +293,7 @@
"class KVCache:\n",
" \"\"\"\n",
" Efficient key-value cache for autoregressive generation.\n",
" \n",
"\n",
" Stores K,V matrices for each transformer layer to avoid recomputation\n",
" during sequential token generation. This is THE critical optimization\n",
" that makes production language model serving economically viable.\n",
@@ -320,13 +320,13 @@
" ...\n",
" Layer N: [Key_cache, Value_cache]\n",
" ```\n",
" \n",
"\n",
" Performance:\n",
" - Update: O(1) - just index assignment\n",
" - Get: O(1) - just slicing (no data copy)\n",
" - Memory: O(num_layers × batch × heads × max_seq × head_dim)\n",
" \"\"\"\n",
" \n",
"\n",
" def __init__(self, batch_size: int, max_seq_len: int, num_layers: int,\n",
" num_heads: int, head_dim: int):\n",
" \"\"\"\n",
@@ -382,7 +382,7 @@
"\n",
" self.caches.append((key_cache, value_cache))\n",
" ### END SOLUTION\n",
" \n",
"\n",
" def update(self, layer_idx: int, key: Tensor, value: Tensor) -> None:\n",
" \"\"\"\n",
" Update cache with new key-value pairs for given layer.\n",
@@ -447,7 +447,7 @@
"\n",
" # Note: seq_pos is advanced externally via advance() after all layers process\n",
" ### END SOLUTION\n",
" \n",
"\n",
" def get(self, layer_idx: int) -> Tuple[Tensor, Tensor]:\n",
" \"\"\"\n",
" Retrieve cached key-value pairs for attention computation.\n",
@@ -513,48 +513,48 @@
"\n",
" return cached_keys, cached_values\n",
" ### END SOLUTION\n",
" \n",
"\n",
" def advance(self) -> None:\n",
" \"\"\"\n",
" Advance sequence position after processing current token.\n",
" \n",
"\n",
" Call this after all layers have processed the current token and\n",
" updated their caches. This moves the write pointer forward.\n",
" \"\"\"\n",
" self.seq_pos += 1\n",
" \n",
"\n",
" def reset(self) -> None:\n",
" \"\"\"\n",
" Reset cache for new generation sequence.\n",
" \n",
"\n",
" Call this when starting a new generation (new prompt).\n",
" Resets the sequence position counter and optionally zeros cache data.\n",
" \"\"\"\n",
" self.seq_pos = 0\n",
" \n",
"\n",
" # Zero out caches for clean state (helps with debugging)\n",
" for layer_idx in range(self.num_layers):\n",
" key_cache, value_cache = self.caches[layer_idx]\n",
" key_cache.data.fill(0.0)\n",
" value_cache.data.fill(0.0)\n",
" \n",
"\n",
" def get_memory_usage(self) -> Dict[str, float]:\n",
" \"\"\"\n",
" Calculate memory usage of the cache system.\n",
" \n",
"\n",
" Returns:\n",
" Dictionary with memory statistics in MB\n",
" \"\"\"\n",
" # Calculate size of one cache tensor\n",
" cache_size = self.batch_size * self.num_heads * self.max_seq_len * self.head_dim\n",
" bytes_per_float = 4 # float32\n",
" \n",
"\n",
" # Each layer has key_cache + value_cache\n",
" total_cache_tensors = self.num_layers * 2\n",
" total_elements = cache_size * total_cache_tensors\n",
" total_bytes = total_elements * bytes_per_float\n",
" total_mb = total_bytes / (1024 * 1024)\n",
" \n",
"\n",
" return {\n",
" 'total_mb': total_mb,\n",
" 'per_layer_mb': total_mb / self.num_layers,\n",
@@ -565,7 +565,7 @@
},
{
"cell_type": "markdown",
"id": "d07be1e1",
"id": "94cee9a8",
"metadata": {
"cell_marker": "\"\"\"",
"lines_to_next_cell": 1
@@ -581,7 +581,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "7a1814c1",
"id": "62409497",
"metadata": {
"nbgrader": {
"grade": true,
@@ -669,7 +669,7 @@
},
{
"cell_type": "markdown",
"id": "2b19f7ea",
"id": "39ea5911",
"metadata": {
"cell_marker": "\"\"\"",
"lines_to_next_cell": 1
@@ -716,7 +716,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "cccb9a0d",
"id": "7f453db6",
"metadata": {
"lines_to_next_cell": 1
},
@@ -731,14 +731,14 @@
" This function creates a properly sized cache for the model architecture.\n",
" Call this before starting generation, then pass the cache to your\n",
" generation loop.\n",
" \n",
"\n",
" Args:\n",
" batch_size: Number of sequences to generate simultaneously\n",
" max_seq_len: Maximum sequence length to support\n",
" num_layers: Number of transformer layers in model\n",
" num_heads: Number of attention heads per layer\n",
" head_dim: Dimension per attention head (usually embed_dim // num_heads)\n",
" \n",
"\n",
" Returns:\n",
" KVCache instance ready for use\n",
" \n",
@@ -778,7 +778,7 @@
},
{
"cell_type": "markdown",
"id": "517247d2",
"id": "80402a25",
"metadata": {
"cell_marker": "\"\"\"",
"lines_to_next_cell": 1
@@ -794,7 +794,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "e6fba64c",
"id": "fc77d324",
"metadata": {
"nbgrader": {
"grade": true,
@@ -857,7 +857,7 @@
},
{
"cell_type": "markdown",
"id": "4fa0c25c",
"id": "df7728e0",
"metadata": {
"cell_marker": "\"\"\""
},
@@ -924,7 +924,7 @@
},
{
"cell_type": "markdown",
"id": "6c76f95c",
"id": "1df5b0fc",
"metadata": {
"cell_marker": "\"\"\"",
"lines_to_next_cell": 1
@@ -960,7 +960,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "ba55f283",
"id": "7a8281fd",
"metadata": {
"nbgrader": {
"grade": false,
@@ -1066,28 +1066,107 @@
" \"\"\"\n",
" Cached attention forward pass with REAL speedup!\n",
" \n",
" Strategy:\n",
" - Training (seq_len > 1): Use original path (full gradients)\n",
" - Generation (seq_len = 1): Use cache for 10-15x speedup\n",
" PATH SELECTION STRATEGY (Key to Understanding KV Caching):\n",
" ──────────────────────────────────────────────────────────\n",
" \n",
" Cache operations use .data (inference-only, no grad tracking).\n",
" Training path unchanged (full gradient flow preserved).\n",
" We have THREE possible paths through attention:\n",
" \n",
" 1⃣ TRAINING PATH (seq_len > 1):\n",
" - Input: Full sequence of tokens (e.g., 64 tokens)\n",
" - Action: Use ORIGINAL attention (no caching)\n",
" - Why: Need full gradient flow for backpropagation\n",
" - Complexity: O(n²) but that's fine for training\n",
" - Example: x.shape = (batch=1, seq=64, embed=128)\n",
" \n",
" 2⃣ FIRST TOKEN PATH (seq_len == 1 AND cache empty):\n",
" - Input: Single token (the first one in generation)\n",
" - Action: Use ORIGINAL attention (initialize cache)\n",
" - Why: Cache is empty, nothing to retrieve yet\n",
" - Complexity: O(1) - only one token\n",
" - Example: x.shape = (batch=1, seq=1, embed=128), cache.seq_pos=0\n",
" \n",
" 3⃣ CACHED GENERATION PATH (seq_len == 1 AND cache populated):\n",
" - Input: Single NEW token (during generation)\n",
" - Action: Compute K,V for new token ONLY, retrieve history from cache\n",
" - Why: This is where the speedup happens! O(n²) → O(n)\n",
" - Complexity: O(n) - only compute for new token, reuse cache\n",
" - Example: x.shape = (batch=1, seq=1, embed=128), cache.seq_pos=5\n",
" \n",
" \n",
" WHY .data INSTEAD OF TENSOR OPERATIONS?\n",
" ────────────────────────────────────────\n",
" \n",
" In the cached path, we use numpy via .data for three reasons:\n",
" \n",
" 1. **Explicit Intent**: Makes it crystal clear this is inference-only\n",
" - Training: Uses Tensor operations → gradients tracked\n",
" - Inference: Uses .data → no gradient overhead\n",
" \n",
" 2. **Performance**: Avoids any autograd bookkeeping\n",
" - Even if small, every bit counts in generation\n",
" - Production LLMs (vLLM, llama.cpp) use similar patterns\n",
" \n",
" 3. **Educational Clarity**: Shows students the distinction\n",
" - \"When do I need gradients?\" (training)\n",
" - \"When can I skip them?\" (inference)\n",
" \n",
" We COULD use Tensor operations with requires_grad=False, but .data\n",
" is more explicit and is the industry-standard pattern.\n",
" \n",
" \n",
" THE O(n²) → O(n) TRANSFORMATION:\n",
" ─────────────────────────────────\n",
" \n",
" WITHOUT Cache (Standard Attention):\n",
" Step 1: Process token 1 → Compute attention for 1 token (1² = 1 op)\n",
" Step 2: Process tokens 1-2 → Compute attention for 2 tokens (2² = 4 ops)\n",
" Step 3: Process tokens 1-3 → Compute attention for 3 tokens (3² = 9 ops)\n",
" ...\n",
" Step N: Process tokens 1-N → Compute attention for N tokens (N² ops)\n",
" \n",
" Total: 1 + 4 + 9 + ... + N² = O(N³) across all steps!\n",
" \n",
" WITH Cache (Our Implementation):\n",
" Step 1: Process token 1 → Compute K,V for token 1, cache it (1 op)\n",
" Step 2: Process token 2 → Compute K,V for token 2, retrieve 1 (2 ops)\n",
" Step 3: Process token 3 → Compute K,V for token 3, retrieve 1-2 (3 ops)\n",
" ...\n",
" Step N: Process token N → Compute K,V for token N, retrieve 1-(N-1) (N ops)\n",
" \n",
" Total: 1 + 2 + 3 + ... + N = O(N²) across all steps!\n",
" \n",
" That's why we see 5-7x speedup on short sequences, and 10-15x on longer ones!\n",
" \"\"\"\n",
" from tinytorch.core.tensor import Tensor\n",
" import numpy as np\n",
" \n",
" seq_len = x.shape[1]\n",
" \n",
" # TRAINING PATH: Full sequence, use original attention (preserves gradients)\n",
" # ═══════════════════════════════════════════════════════════════\n",
" # PATH SELECTION: Choose between training, first token, or cached\n",
" # ═══════════════════════════════════════════════════════════════\n",
" \n",
" # PATH 1: TRAINING (seq_len > 1)\n",
" # ───────────────────────────────────\n",
" # Input is a full sequence (e.g., 64 tokens during training)\n",
" # We MUST use original attention to preserve gradient flow\n",
" # No caching during training - we need backprop through everything\n",
" if seq_len > 1:\n",
" return original_forward(x, mask)\n",
" return original_forward(x, mask) # O(n²) but preserves gradients\n",
" \n",
" # GENERATION PATH: Single token, use KV cache for speedup\n",
" # This is inference-only, so we use .data for performance\n",
" \n",
" # Check if cache is empty (first token) - if so, use original path\n",
" # PATH 2: FIRST TOKEN (seq_len == 1, cache empty)\n",
" # ────────────────────────────────────────────────\n",
" # This is the very first token in generation (cache.seq_pos == 0)\n",
" # Cache is empty, so there's nothing to retrieve yet\n",
" # Use original attention to process this token, which will populate cache\n",
" if cache_obj.seq_pos == 0:\n",
" return original_forward(x, mask)\n",
" return original_forward(x, mask) # O(1) - just one token\n",
" \n",
" # PATH 3: CACHED GENERATION (seq_len == 1, cache populated)\n",
" # ──────────────────────────────────────────────────────────\n",
" # This is a NEW token during generation (cache has history)\n",
" # We can now use the cache for massive speedup!\n",
" # Compute K,V for ONLY this new token, retrieve cached history\n",
" \n",
" # Get attention layer (assumes block.attention has the attention object)\n",
" attention = block.attention\n",
@@ -1120,13 +1199,22 @@
" K_all, V_all = cache_obj.get(layer_idx)\n",
" \n",
" # Step 5: Compute attention using new Q with ALL cached K, V\n",
" # ─────────────────────────────────────────────────────────\n",
" # Scaled dot-product attention: softmax(Q @ K^T / sqrt(d_k)) @ V\n",
" # Use numpy operations directly for batched matmul\n",
" #\n",
" # NOTE: We use .data (numpy arrays) here instead of Tensor operations\n",
" # Why? This is INFERENCE-ONLY code (no gradients needed):\n",
" # - Explicit: Makes it clear this is inference, not training\n",
" # - Fast: Avoids autograd overhead (even if small)\n",
" # - Standard: Production LLMs (vLLM, llama.cpp) do the same\n",
" #\n",
" # If this were training, we'd use Tensor operations for gradient flow.\n",
" # But in generation (inference), .data is the right choice.\n",
" \n",
" # Q @ K^T: (batch, num_heads, 1, head_dim) @ (batch, num_heads, head_dim, seq_len)\n",
" # → (batch, num_heads, 1, seq_len)\n",
" K_transposed = np.transpose(K_all.data, (0, 1, 3, 2))\n",
" scores = np.matmul(Q_heads.data, K_transposed)\n",
" K_transposed = np.transpose(K_all.data, (0, 1, 3, 2)) # .data = numpy array\n",
" scores = np.matmul(Q_heads.data, K_transposed) # Pure numpy matmul\n",
" \n",
" # Scale by sqrt(head_dim)\n",
" scores = scores / np.sqrt(head_dim)\n",
@@ -1211,7 +1299,7 @@
},
{
"cell_type": "markdown",
"id": "fb549b54",
"id": "969b4e1c",
"metadata": {
"cell_marker": "\"\"\"",
"lines_to_next_cell": 1
@@ -1227,7 +1315,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "9f792a3d",
"id": "2c198422",
"metadata": {
"lines_to_next_cell": 2,
"nbgrader": {
@@ -1297,7 +1385,7 @@
},
{
"cell_type": "markdown",
"id": "ce64525c",
"id": "5c56c36a",
"metadata": {
"cell_marker": "\"\"\"",
"lines_to_next_cell": 1
@@ -1311,7 +1399,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "d693cda0",
"id": "fbc1c29f",
"metadata": {
"lines_to_next_cell": 1,
"nbgrader": {
@@ -1396,7 +1484,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "fbaff03f",
"id": "e1b4fcb9",
"metadata": {
"lines_to_next_cell": 2
},
@@ -1408,7 +1496,7 @@
},
{
"cell_type": "markdown",
"id": "c22b893c",
"id": "ff6d655d",
"metadata": {
"cell_marker": "\"\"\""
},

View File

@@ -264,7 +264,7 @@ This design enables **O(1) updates** - just write to the next position!
class KVCache:
"""
Efficient key-value cache for autoregressive generation.
Stores K,V matrices for each transformer layer to avoid recomputation
during sequential token generation. This is THE critical optimization
that makes production language model serving economically viable.
@@ -291,13 +291,13 @@ class KVCache:
...
Layer N: [Key_cache, Value_cache]
```
Performance:
- Update: O(1) - just index assignment
- Get: O(1) - just slicing (no data copy)
- Memory: O(num_layers × batch × heads × max_seq × head_dim)
"""
def __init__(self, batch_size: int, max_seq_len: int, num_layers: int,
num_heads: int, head_dim: int):
"""
@@ -353,7 +353,7 @@ class KVCache:
self.caches.append((key_cache, value_cache))
### END SOLUTION
def update(self, layer_idx: int, key: Tensor, value: Tensor) -> None:
"""
Update cache with new key-value pairs for given layer.
@@ -418,7 +418,7 @@ class KVCache:
# Note: seq_pos is advanced externally via advance() after all layers process
### END SOLUTION
def get(self, layer_idx: int) -> Tuple[Tensor, Tensor]:
"""
Retrieve cached key-value pairs for attention computation.
@@ -484,48 +484,48 @@ class KVCache:
return cached_keys, cached_values
### END SOLUTION
def advance(self) -> None:
"""
Advance sequence position after processing current token.
Call this after all layers have processed the current token and
updated their caches. This moves the write pointer forward.
"""
self.seq_pos += 1
def reset(self) -> None:
"""
Reset cache for new generation sequence.
Call this when starting a new generation (new prompt).
Resets the sequence position counter and optionally zeros cache data.
"""
self.seq_pos = 0
# Zero out caches for clean state (helps with debugging)
for layer_idx in range(self.num_layers):
key_cache, value_cache = self.caches[layer_idx]
key_cache.data.fill(0.0)
value_cache.data.fill(0.0)
def get_memory_usage(self) -> Dict[str, float]:
"""
Calculate memory usage of the cache system.
Returns:
Dictionary with memory statistics in MB
"""
# Calculate size of one cache tensor
cache_size = self.batch_size * self.num_heads * self.max_seq_len * self.head_dim
bytes_per_float = 4 # float32
# Each layer has key_cache + value_cache
total_cache_tensors = self.num_layers * 2
total_elements = cache_size * total_cache_tensors
total_bytes = total_elements * bytes_per_float
total_mb = total_bytes / (1024 * 1024)
return {
'total_mb': total_mb,
'per_layer_mb': total_mb / self.num_layers,
@@ -667,14 +667,14 @@ def enable_kv_cache(batch_size: int, max_seq_len: int, num_layers: int,
This function creates a properly sized cache for the model architecture.
Call this before starting generation, then pass the cache to your
generation loop.
Args:
batch_size: Number of sequences to generate simultaneously
max_seq_len: Maximum sequence length to support
num_layers: Number of transformer layers in model
num_heads: Number of attention heads per layer
head_dim: Dimension per attention head (usually embed_dim // num_heads)
Returns:
KVCache instance ready for use
@@ -958,28 +958,107 @@ def enable_kv_cache(model):
"""
Cached attention forward pass with REAL speedup!
Strategy:
- Training (seq_len > 1): Use original path (full gradients)
- Generation (seq_len = 1): Use cache for 10-15x speedup
PATH SELECTION STRATEGY (Key to Understanding KV Caching):
──────────────────────────────────────────────────────────
Cache operations use .data (inference-only, no grad tracking).
Training path unchanged (full gradient flow preserved).
We have THREE possible paths through attention:
1⃣ TRAINING PATH (seq_len > 1):
- Input: Full sequence of tokens (e.g., 64 tokens)
- Action: Use ORIGINAL attention (no caching)
- Why: Need full gradient flow for backpropagation
- Complexity: O(n²) but that's fine for training
- Example: x.shape = (batch=1, seq=64, embed=128)
2⃣ FIRST TOKEN PATH (seq_len == 1 AND cache empty):
- Input: Single token (the first one in generation)
- Action: Use ORIGINAL attention (initialize cache)
- Why: Cache is empty, nothing to retrieve yet
- Complexity: O(1) - only one token
- Example: x.shape = (batch=1, seq=1, embed=128), cache.seq_pos=0
3⃣ CACHED GENERATION PATH (seq_len == 1 AND cache populated):
- Input: Single NEW token (during generation)
- Action: Compute K,V for new token ONLY, retrieve history from cache
- Why: This is where the speedup happens! O(n²) → O(n)
- Complexity: O(n) - only compute for new token, reuse cache
- Example: x.shape = (batch=1, seq=1, embed=128), cache.seq_pos=5
WHY .data INSTEAD OF TENSOR OPERATIONS?
────────────────────────────────────────
In the cached path, we use numpy via .data for three reasons:
1. **Explicit Intent**: Makes it crystal clear this is inference-only
- Training: Uses Tensor operations → gradients tracked
- Inference: Uses .data → no gradient overhead
2. **Performance**: Avoids any autograd bookkeeping
- Even if small, every bit counts in generation
- Production LLMs (vLLM, llama.cpp) use similar patterns
3. **Educational Clarity**: Shows students the distinction
- "When do I need gradients?" (training)
- "When can I skip them?" (inference)
We COULD use Tensor operations with requires_grad=False, but .data
is more explicit and is the industry-standard pattern.
THE O(n²) → O(n) TRANSFORMATION:
─────────────────────────────────
WITHOUT Cache (Standard Attention):
Step 1: Process token 1 → Compute attention for 1 token (1² = 1 op)
Step 2: Process tokens 1-2 → Compute attention for 2 tokens (2² = 4 ops)
Step 3: Process tokens 1-3 → Compute attention for 3 tokens (3² = 9 ops)
...
Step N: Process tokens 1-N → Compute attention for N tokens (N² ops)
Total: 1 + 4 + 9 + ... + N² = O(N³) across all steps!
WITH Cache (Our Implementation):
Step 1: Process token 1 → Compute K,V for token 1, cache it (1 op)
Step 2: Process token 2 → Compute K,V for token 2, retrieve 1 (2 ops)
Step 3: Process token 3 → Compute K,V for token 3, retrieve 1-2 (3 ops)
...
Step N: Process token N → Compute K,V for token N, retrieve 1-(N-1) (N ops)
Total: 1 + 2 + 3 + ... + N = O(N²) across all steps!
That's why we see 5-7x speedup on short sequences, and 10-15x on longer ones!
"""
from tinytorch.core.tensor import Tensor
import numpy as np
seq_len = x.shape[1]
# TRAINING PATH: Full sequence, use original attention (preserves gradients)
# ═══════════════════════════════════════════════════════════════
# PATH SELECTION: Choose between training, first token, or cached
# ═══════════════════════════════════════════════════════════════
# PATH 1: TRAINING (seq_len > 1)
# ───────────────────────────────────
# Input is a full sequence (e.g., 64 tokens during training)
# We MUST use original attention to preserve gradient flow
# No caching during training - we need backprop through everything
if seq_len > 1:
return original_forward(x, mask)
return original_forward(x, mask) # O(n²) but preserves gradients
# GENERATION PATH: Single token, use KV cache for speedup
# This is inference-only, so we use .data for performance
# Check if cache is empty (first token) - if so, use original path
# PATH 2: FIRST TOKEN (seq_len == 1, cache empty)
# ────────────────────────────────────────────────
# This is the very first token in generation (cache.seq_pos == 0)
# Cache is empty, so there's nothing to retrieve yet
# Use original attention to process this token, which will populate cache
if cache_obj.seq_pos == 0:
return original_forward(x, mask)
return original_forward(x, mask) # O(1) - just one token
# PATH 3: CACHED GENERATION (seq_len == 1, cache populated)
# ──────────────────────────────────────────────────────────
# This is a NEW token during generation (cache has history)
# We can now use the cache for massive speedup!
# Compute K,V for ONLY this new token, retrieve cached history
# Get attention layer (assumes block.attention has the attention object)
attention = block.attention
@@ -1012,13 +1091,22 @@ def enable_kv_cache(model):
K_all, V_all = cache_obj.get(layer_idx)
# Step 5: Compute attention using new Q with ALL cached K, V
# ─────────────────────────────────────────────────────────
# Scaled dot-product attention: softmax(Q @ K^T / sqrt(d_k)) @ V
# Use numpy operations directly for batched matmul
#
# NOTE: We use .data (numpy arrays) here instead of Tensor operations
# Why? This is INFERENCE-ONLY code (no gradients needed):
# - Explicit: Makes it clear this is inference, not training
# - Fast: Avoids autograd overhead (even if small)
# - Standard: Production LLMs (vLLM, llama.cpp) do the same
#
# If this were training, we'd use Tensor operations for gradient flow.
# But in generation (inference), .data is the right choice.
# Q @ K^T: (batch, num_heads, 1, head_dim) @ (batch, num_heads, head_dim, seq_len)
# → (batch, num_heads, 1, seq_len)
K_transposed = np.transpose(K_all.data, (0, 1, 3, 2))
scores = np.matmul(Q_heads.data, K_transposed)
K_transposed = np.transpose(K_all.data, (0, 1, 3, 2)) # .data = numpy array
scores = np.matmul(Q_heads.data, K_transposed) # Pure numpy matmul
# Scale by sqrt(head_dim)
scores = scores / np.sqrt(head_dim)
@@ -1157,7 +1245,7 @@ def test_unit_noninvasive_integration():
# Test 4: Can re-enable
print(" Test 4: Re-enable caching")
cache2 = enable_kv_cache(model)
_ = enable_kv_cache(model)
assert model._cache_enabled == True, "Cache should be re-enabled"
print("✅ Non-invasive cache integration works correctly!")

View File

@@ -29,7 +29,7 @@ from ..core.tensor import Tensor
class KVCache:
"""
Efficient key-value cache for autoregressive generation.
Stores K,V matrices for each transformer layer to avoid recomputation
during sequential token generation. This is THE critical optimization
that makes production language model serving economically viable.
@@ -56,13 +56,13 @@ class KVCache:
...
Layer N: [Key_cache, Value_cache]
```
Performance:
- Update: O(1) - just index assignment
- Get: O(1) - just slicing (no data copy)
- Memory: O(num_layers × batch × heads × max_seq × head_dim)
"""
def __init__(self, batch_size: int, max_seq_len: int, num_layers: int,
num_heads: int, head_dim: int):
"""
@@ -118,7 +118,7 @@ class KVCache:
self.caches.append((key_cache, value_cache))
### END SOLUTION
def update(self, layer_idx: int, key: Tensor, value: Tensor) -> None:
"""
Update cache with new key-value pairs for given layer.
@@ -183,7 +183,7 @@ class KVCache:
# Note: seq_pos is advanced externally via advance() after all layers process
### END SOLUTION
def get(self, layer_idx: int) -> Tuple[Tensor, Tensor]:
"""
Retrieve cached key-value pairs for attention computation.
@@ -249,48 +249,48 @@ class KVCache:
return cached_keys, cached_values
### END SOLUTION
def advance(self) -> None:
"""
Advance sequence position after processing current token.
Call this after all layers have processed the current token and
updated their caches. This moves the write pointer forward.
"""
self.seq_pos += 1
def reset(self) -> None:
"""
Reset cache for new generation sequence.
Call this when starting a new generation (new prompt).
Resets the sequence position counter and optionally zeros cache data.
"""
self.seq_pos = 0
# Zero out caches for clean state (helps with debugging)
for layer_idx in range(self.num_layers):
key_cache, value_cache = self.caches[layer_idx]
key_cache.data.fill(0.0)
value_cache.data.fill(0.0)
def get_memory_usage(self) -> Dict[str, float]:
"""
Calculate memory usage of the cache system.
Returns:
Dictionary with memory statistics in MB
"""
# Calculate size of one cache tensor
cache_size = self.batch_size * self.num_heads * self.max_seq_len * self.head_dim
bytes_per_float = 4 # float32
# Each layer has key_cache + value_cache
total_cache_tensors = self.num_layers * 2
total_elements = cache_size * total_cache_tensors
total_bytes = total_elements * bytes_per_float
total_mb = total_bytes / (1024 * 1024)
return {
'total_mb': total_mb,
'per_layer_mb': total_mb / self.num_layers,
@@ -307,14 +307,14 @@ def enable_kv_cache(batch_size: int, max_seq_len: int, num_layers: int,
This function creates a properly sized cache for the model architecture.
Call this before starting generation, then pass the cache to your
generation loop.
Args:
batch_size: Number of sequences to generate simultaneously
max_seq_len: Maximum sequence length to support
num_layers: Number of transformer layers in model
num_heads: Number of attention heads per layer
head_dim: Dimension per attention head (usually embed_dim // num_heads)
Returns:
KVCache instance ready for use
@@ -447,28 +447,107 @@ def enable_kv_cache(model):
"""
Cached attention forward pass with REAL speedup!
Strategy:
- Training (seq_len > 1): Use original path (full gradients)
- Generation (seq_len = 1): Use cache for 10-15x speedup
PATH SELECTION STRATEGY (Key to Understanding KV Caching):
──────────────────────────────────────────────────────────
Cache operations use .data (inference-only, no grad tracking).
Training path unchanged (full gradient flow preserved).
We have THREE possible paths through attention:
1⃣ TRAINING PATH (seq_len > 1):
- Input: Full sequence of tokens (e.g., 64 tokens)
- Action: Use ORIGINAL attention (no caching)
- Why: Need full gradient flow for backpropagation
- Complexity: O(n²) but that's fine for training
- Example: x.shape = (batch=1, seq=64, embed=128)
2⃣ FIRST TOKEN PATH (seq_len == 1 AND cache empty):
- Input: Single token (the first one in generation)
- Action: Use ORIGINAL attention (initialize cache)
- Why: Cache is empty, nothing to retrieve yet
- Complexity: O(1) - only one token
- Example: x.shape = (batch=1, seq=1, embed=128), cache.seq_pos=0
3⃣ CACHED GENERATION PATH (seq_len == 1 AND cache populated):
- Input: Single NEW token (during generation)
- Action: Compute K,V for new token ONLY, retrieve history from cache
- Why: This is where the speedup happens! O(n²) → O(n)
- Complexity: O(n) - only compute for new token, reuse cache
- Example: x.shape = (batch=1, seq=1, embed=128), cache.seq_pos=5
WHY .data INSTEAD OF TENSOR OPERATIONS?
────────────────────────────────────────
In the cached path, we use numpy via .data for three reasons:
1. **Explicit Intent**: Makes it crystal clear this is inference-only
- Training: Uses Tensor operations → gradients tracked
- Inference: Uses .data → no gradient overhead
2. **Performance**: Avoids any autograd bookkeeping
- Even if small, every bit counts in generation
- Production LLMs (vLLM, llama.cpp) use similar patterns
3. **Educational Clarity**: Shows students the distinction
- "When do I need gradients?" (training)
- "When can I skip them?" (inference)
We COULD use Tensor operations with requires_grad=False, but .data
is more explicit and is the industry-standard pattern.
THE O(n²) → O(n) TRANSFORMATION:
─────────────────────────────────
WITHOUT Cache (Standard Attention):
Step 1: Process token 1 → Compute attention for 1 token (1² = 1 op)
Step 2: Process tokens 1-2 → Compute attention for 2 tokens (2² = 4 ops)
Step 3: Process tokens 1-3 → Compute attention for 3 tokens (3² = 9 ops)
...
Step N: Process tokens 1-N → Compute attention for N tokens (N² ops)
Total: 1 + 4 + 9 + ... + N² = O(N³) across all steps!
WITH Cache (Our Implementation):
Step 1: Process token 1 → Compute K,V for token 1, cache it (1 op)
Step 2: Process token 2 → Compute K,V for token 2, retrieve 1 (2 ops)
Step 3: Process token 3 → Compute K,V for token 3, retrieve 1-2 (3 ops)
...
Step N: Process token N → Compute K,V for token N, retrieve 1-(N-1) (N ops)
Total: 1 + 2 + 3 + ... + N = O(N²) across all steps!
That's why we see 5-7x speedup on short sequences, and 10-15x on longer ones!
"""
from tinytorch.core.tensor import Tensor
import numpy as np
seq_len = x.shape[1]
# TRAINING PATH: Full sequence, use original attention (preserves gradients)
# ═══════════════════════════════════════════════════════════════
# PATH SELECTION: Choose between training, first token, or cached
# ═══════════════════════════════════════════════════════════════
# PATH 1: TRAINING (seq_len > 1)
# ───────────────────────────────────
# Input is a full sequence (e.g., 64 tokens during training)
# We MUST use original attention to preserve gradient flow
# No caching during training - we need backprop through everything
if seq_len > 1:
return original_forward(x, mask)
return original_forward(x, mask) # O(n²) but preserves gradients
# GENERATION PATH: Single token, use KV cache for speedup
# This is inference-only, so we use .data for performance
# Check if cache is empty (first token) - if so, use original path
# PATH 2: FIRST TOKEN (seq_len == 1, cache empty)
# ────────────────────────────────────────────────
# This is the very first token in generation (cache.seq_pos == 0)
# Cache is empty, so there's nothing to retrieve yet
# Use original attention to process this token, which will populate cache
if cache_obj.seq_pos == 0:
return original_forward(x, mask)
return original_forward(x, mask) # O(1) - just one token
# PATH 3: CACHED GENERATION (seq_len == 1, cache populated)
# ──────────────────────────────────────────────────────────
# This is a NEW token during generation (cache has history)
# We can now use the cache for massive speedup!
# Compute K,V for ONLY this new token, retrieve cached history
# Get attention layer (assumes block.attention has the attention object)
attention = block.attention
@@ -501,13 +580,22 @@ def enable_kv_cache(model):
K_all, V_all = cache_obj.get(layer_idx)
# Step 5: Compute attention using new Q with ALL cached K, V
# ─────────────────────────────────────────────────────────
# Scaled dot-product attention: softmax(Q @ K^T / sqrt(d_k)) @ V
# Use numpy operations directly for batched matmul
#
# NOTE: We use .data (numpy arrays) here instead of Tensor operations
# Why? This is INFERENCE-ONLY code (no gradients needed):
# - Explicit: Makes it clear this is inference, not training
# - Fast: Avoids autograd overhead (even if small)
# - Standard: Production LLMs (vLLM, llama.cpp) do the same
#
# If this were training, we'd use Tensor operations for gradient flow.
# But in generation (inference), .data is the right choice.
# Q @ K^T: (batch, num_heads, 1, head_dim) @ (batch, num_heads, head_dim, seq_len)
# → (batch, num_heads, 1, seq_len)
K_transposed = np.transpose(K_all.data, (0, 1, 3, 2))
scores = np.matmul(Q_heads.data, K_transposed)
K_transposed = np.transpose(K_all.data, (0, 1, 3, 2)) # .data = numpy array
scores = np.matmul(Q_heads.data, K_transposed) # Pure numpy matmul
# Scale by sqrt(head_dim)
scores = scores / np.sqrt(head_dim)