mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-04-28 20:13:46 -05:00
Add comprehensive documentation for KV cache path selection
Enhanced Module 14 with extensive educational documentation explaining: Three-Path Selection Strategy: - PATH 1: Training (seq_len > 1) - Uses original attention, preserves gradients - PATH 2: First Token (cache empty) - Uses original attention, initializes cache - PATH 3: Cached Generation (cache populated) - THE SPEEDUP PATH, O(n) computation Why .data Instead of Tensor Operations: - Explicit intent: Clear separation of training vs inference code - Performance: Avoids autograd overhead during generation - Industry standard: Production LLMs (vLLM, llama.cpp) use same pattern O(n²) to O(n) Transformation Explained: - WITHOUT cache: O(N³) total across all steps (1² + 2² + ... + N²) - WITH cache: O(N²) total across all steps (1 + 2 + ... + N) - Result: 5-7x speedup on short sequences, 10-15x on longer ones Inline comments added at every decision point for student comprehension. Module 14 now complete with working implementation and comprehensive pedagogy.
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bd52e3da",
|
||||
"id": "1078513e",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -52,7 +52,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "26b79392",
|
||||
"id": "266270f3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -69,7 +69,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7cd54d44",
|
||||
"id": "06ca957c",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -126,7 +126,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "eb00159f",
|
||||
"id": "dc896d3f",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -199,7 +199,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4dd6bb57",
|
||||
"id": "c3feca5a",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -278,7 +278,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4b00b030",
|
||||
"id": "6d054a8c",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -293,7 +293,7 @@
|
||||
"class KVCache:\n",
|
||||
" \"\"\"\n",
|
||||
" Efficient key-value cache for autoregressive generation.\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" Stores K,V matrices for each transformer layer to avoid recomputation\n",
|
||||
" during sequential token generation. This is THE critical optimization\n",
|
||||
" that makes production language model serving economically viable.\n",
|
||||
@@ -320,13 +320,13 @@
|
||||
" ...\n",
|
||||
" Layer N: [Key_cache, Value_cache]\n",
|
||||
" ```\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" Performance:\n",
|
||||
" - Update: O(1) - just index assignment\n",
|
||||
" - Get: O(1) - just slicing (no data copy)\n",
|
||||
" - Memory: O(num_layers × batch × heads × max_seq × head_dim)\n",
|
||||
" \"\"\"\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" def __init__(self, batch_size: int, max_seq_len: int, num_layers: int,\n",
|
||||
" num_heads: int, head_dim: int):\n",
|
||||
" \"\"\"\n",
|
||||
@@ -382,7 +382,7 @@
|
||||
"\n",
|
||||
" self.caches.append((key_cache, value_cache))\n",
|
||||
" ### END SOLUTION\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" def update(self, layer_idx: int, key: Tensor, value: Tensor) -> None:\n",
|
||||
" \"\"\"\n",
|
||||
" Update cache with new key-value pairs for given layer.\n",
|
||||
@@ -447,7 +447,7 @@
|
||||
"\n",
|
||||
" # Note: seq_pos is advanced externally via advance() after all layers process\n",
|
||||
" ### END SOLUTION\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" def get(self, layer_idx: int) -> Tuple[Tensor, Tensor]:\n",
|
||||
" \"\"\"\n",
|
||||
" Retrieve cached key-value pairs for attention computation.\n",
|
||||
@@ -513,48 +513,48 @@
|
||||
"\n",
|
||||
" return cached_keys, cached_values\n",
|
||||
" ### END SOLUTION\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" def advance(self) -> None:\n",
|
||||
" \"\"\"\n",
|
||||
" Advance sequence position after processing current token.\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" Call this after all layers have processed the current token and\n",
|
||||
" updated their caches. This moves the write pointer forward.\n",
|
||||
" \"\"\"\n",
|
||||
" self.seq_pos += 1\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" def reset(self) -> None:\n",
|
||||
" \"\"\"\n",
|
||||
" Reset cache for new generation sequence.\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" Call this when starting a new generation (new prompt).\n",
|
||||
" Resets the sequence position counter and optionally zeros cache data.\n",
|
||||
" \"\"\"\n",
|
||||
" self.seq_pos = 0\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" # Zero out caches for clean state (helps with debugging)\n",
|
||||
" for layer_idx in range(self.num_layers):\n",
|
||||
" key_cache, value_cache = self.caches[layer_idx]\n",
|
||||
" key_cache.data.fill(0.0)\n",
|
||||
" value_cache.data.fill(0.0)\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" def get_memory_usage(self) -> Dict[str, float]:\n",
|
||||
" \"\"\"\n",
|
||||
" Calculate memory usage of the cache system.\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" Returns:\n",
|
||||
" Dictionary with memory statistics in MB\n",
|
||||
" \"\"\"\n",
|
||||
" # Calculate size of one cache tensor\n",
|
||||
" cache_size = self.batch_size * self.num_heads * self.max_seq_len * self.head_dim\n",
|
||||
" bytes_per_float = 4 # float32\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" # Each layer has key_cache + value_cache\n",
|
||||
" total_cache_tensors = self.num_layers * 2\n",
|
||||
" total_elements = cache_size * total_cache_tensors\n",
|
||||
" total_bytes = total_elements * bytes_per_float\n",
|
||||
" total_mb = total_bytes / (1024 * 1024)\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" return {\n",
|
||||
" 'total_mb': total_mb,\n",
|
||||
" 'per_layer_mb': total_mb / self.num_layers,\n",
|
||||
@@ -565,7 +565,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d07be1e1",
|
||||
"id": "94cee9a8",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -581,7 +581,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7a1814c1",
|
||||
"id": "62409497",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": true,
|
||||
@@ -669,7 +669,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2b19f7ea",
|
||||
"id": "39ea5911",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -716,7 +716,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cccb9a0d",
|
||||
"id": "7f453db6",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1
|
||||
},
|
||||
@@ -731,14 +731,14 @@
|
||||
" This function creates a properly sized cache for the model architecture.\n",
|
||||
" Call this before starting generation, then pass the cache to your\n",
|
||||
" generation loop.\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" batch_size: Number of sequences to generate simultaneously\n",
|
||||
" max_seq_len: Maximum sequence length to support\n",
|
||||
" num_layers: Number of transformer layers in model\n",
|
||||
" num_heads: Number of attention heads per layer\n",
|
||||
" head_dim: Dimension per attention head (usually embed_dim // num_heads)\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" Returns:\n",
|
||||
" KVCache instance ready for use\n",
|
||||
" \n",
|
||||
@@ -778,7 +778,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "517247d2",
|
||||
"id": "80402a25",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -794,7 +794,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e6fba64c",
|
||||
"id": "fc77d324",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": true,
|
||||
@@ -857,7 +857,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4fa0c25c",
|
||||
"id": "df7728e0",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -924,7 +924,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6c76f95c",
|
||||
"id": "1df5b0fc",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -960,7 +960,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ba55f283",
|
||||
"id": "7a8281fd",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": false,
|
||||
@@ -1066,28 +1066,107 @@
|
||||
" \"\"\"\n",
|
||||
" Cached attention forward pass with REAL speedup!\n",
|
||||
" \n",
|
||||
" Strategy:\n",
|
||||
" - Training (seq_len > 1): Use original path (full gradients)\n",
|
||||
" - Generation (seq_len = 1): Use cache for 10-15x speedup\n",
|
||||
" PATH SELECTION STRATEGY (Key to Understanding KV Caching):\n",
|
||||
" ──────────────────────────────────────────────────────────\n",
|
||||
" \n",
|
||||
" Cache operations use .data (inference-only, no grad tracking).\n",
|
||||
" Training path unchanged (full gradient flow preserved).\n",
|
||||
" We have THREE possible paths through attention:\n",
|
||||
" \n",
|
||||
" 1️⃣ TRAINING PATH (seq_len > 1):\n",
|
||||
" - Input: Full sequence of tokens (e.g., 64 tokens)\n",
|
||||
" - Action: Use ORIGINAL attention (no caching)\n",
|
||||
" - Why: Need full gradient flow for backpropagation\n",
|
||||
" - Complexity: O(n²) but that's fine for training\n",
|
||||
" - Example: x.shape = (batch=1, seq=64, embed=128)\n",
|
||||
" \n",
|
||||
" 2️⃣ FIRST TOKEN PATH (seq_len == 1 AND cache empty):\n",
|
||||
" - Input: Single token (the first one in generation)\n",
|
||||
" - Action: Use ORIGINAL attention (initialize cache)\n",
|
||||
" - Why: Cache is empty, nothing to retrieve yet\n",
|
||||
" - Complexity: O(1) - only one token\n",
|
||||
" - Example: x.shape = (batch=1, seq=1, embed=128), cache.seq_pos=0\n",
|
||||
" \n",
|
||||
" 3️⃣ CACHED GENERATION PATH (seq_len == 1 AND cache populated):\n",
|
||||
" - Input: Single NEW token (during generation)\n",
|
||||
" - Action: Compute K,V for new token ONLY, retrieve history from cache\n",
|
||||
" - Why: This is where the speedup happens! O(n²) → O(n)\n",
|
||||
" - Complexity: O(n) - only compute for new token, reuse cache\n",
|
||||
" - Example: x.shape = (batch=1, seq=1, embed=128), cache.seq_pos=5\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" WHY .data INSTEAD OF TENSOR OPERATIONS?\n",
|
||||
" ────────────────────────────────────────\n",
|
||||
" \n",
|
||||
" In the cached path, we use numpy via .data for three reasons:\n",
|
||||
" \n",
|
||||
" 1. **Explicit Intent**: Makes it crystal clear this is inference-only\n",
|
||||
" - Training: Uses Tensor operations → gradients tracked\n",
|
||||
" - Inference: Uses .data → no gradient overhead\n",
|
||||
" \n",
|
||||
" 2. **Performance**: Avoids any autograd bookkeeping\n",
|
||||
" - Even if small, every bit counts in generation\n",
|
||||
" - Production LLMs (vLLM, llama.cpp) use similar patterns\n",
|
||||
" \n",
|
||||
" 3. **Educational Clarity**: Shows students the distinction\n",
|
||||
" - \"When do I need gradients?\" (training)\n",
|
||||
" - \"When can I skip them?\" (inference)\n",
|
||||
" \n",
|
||||
" We COULD use Tensor operations with requires_grad=False, but .data\n",
|
||||
" is more explicit and is the industry-standard pattern.\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" THE O(n²) → O(n) TRANSFORMATION:\n",
|
||||
" ─────────────────────────────────\n",
|
||||
" \n",
|
||||
" WITHOUT Cache (Standard Attention):\n",
|
||||
" Step 1: Process token 1 → Compute attention for 1 token (1² = 1 op)\n",
|
||||
" Step 2: Process tokens 1-2 → Compute attention for 2 tokens (2² = 4 ops)\n",
|
||||
" Step 3: Process tokens 1-3 → Compute attention for 3 tokens (3² = 9 ops)\n",
|
||||
" ...\n",
|
||||
" Step N: Process tokens 1-N → Compute attention for N tokens (N² ops)\n",
|
||||
" \n",
|
||||
" Total: 1 + 4 + 9 + ... + N² = O(N³) across all steps!\n",
|
||||
" \n",
|
||||
" WITH Cache (Our Implementation):\n",
|
||||
" Step 1: Process token 1 → Compute K,V for token 1, cache it (1 op)\n",
|
||||
" Step 2: Process token 2 → Compute K,V for token 2, retrieve 1 (2 ops)\n",
|
||||
" Step 3: Process token 3 → Compute K,V for token 3, retrieve 1-2 (3 ops)\n",
|
||||
" ...\n",
|
||||
" Step N: Process token N → Compute K,V for token N, retrieve 1-(N-1) (N ops)\n",
|
||||
" \n",
|
||||
" Total: 1 + 2 + 3 + ... + N = O(N²) across all steps!\n",
|
||||
" \n",
|
||||
" That's why we see 5-7x speedup on short sequences, and 10-15x on longer ones!\n",
|
||||
" \"\"\"\n",
|
||||
" from tinytorch.core.tensor import Tensor\n",
|
||||
" import numpy as np\n",
|
||||
" \n",
|
||||
" seq_len = x.shape[1]\n",
|
||||
" \n",
|
||||
" # TRAINING PATH: Full sequence, use original attention (preserves gradients)\n",
|
||||
" # ═══════════════════════════════════════════════════════════════\n",
|
||||
" # PATH SELECTION: Choose between training, first token, or cached\n",
|
||||
" # ═══════════════════════════════════════════════════════════════\n",
|
||||
" \n",
|
||||
" # PATH 1: TRAINING (seq_len > 1)\n",
|
||||
" # ───────────────────────────────────\n",
|
||||
" # Input is a full sequence (e.g., 64 tokens during training)\n",
|
||||
" # We MUST use original attention to preserve gradient flow\n",
|
||||
" # No caching during training - we need backprop through everything\n",
|
||||
" if seq_len > 1:\n",
|
||||
" return original_forward(x, mask)\n",
|
||||
" return original_forward(x, mask) # O(n²) but preserves gradients\n",
|
||||
" \n",
|
||||
" # GENERATION PATH: Single token, use KV cache for speedup\n",
|
||||
" # This is inference-only, so we use .data for performance\n",
|
||||
" \n",
|
||||
" # Check if cache is empty (first token) - if so, use original path\n",
|
||||
" # PATH 2: FIRST TOKEN (seq_len == 1, cache empty)\n",
|
||||
" # ────────────────────────────────────────────────\n",
|
||||
" # This is the very first token in generation (cache.seq_pos == 0)\n",
|
||||
" # Cache is empty, so there's nothing to retrieve yet\n",
|
||||
" # Use original attention to process this token, which will populate cache\n",
|
||||
" if cache_obj.seq_pos == 0:\n",
|
||||
" return original_forward(x, mask)\n",
|
||||
" return original_forward(x, mask) # O(1) - just one token\n",
|
||||
" \n",
|
||||
" # PATH 3: CACHED GENERATION (seq_len == 1, cache populated)\n",
|
||||
" # ──────────────────────────────────────────────────────────\n",
|
||||
" # This is a NEW token during generation (cache has history)\n",
|
||||
" # We can now use the cache for massive speedup!\n",
|
||||
" # Compute K,V for ONLY this new token, retrieve cached history\n",
|
||||
" \n",
|
||||
" # Get attention layer (assumes block.attention has the attention object)\n",
|
||||
" attention = block.attention\n",
|
||||
@@ -1120,13 +1199,22 @@
|
||||
" K_all, V_all = cache_obj.get(layer_idx)\n",
|
||||
" \n",
|
||||
" # Step 5: Compute attention using new Q with ALL cached K, V\n",
|
||||
" # ─────────────────────────────────────────────────────────\n",
|
||||
" # Scaled dot-product attention: softmax(Q @ K^T / sqrt(d_k)) @ V\n",
|
||||
" # Use numpy operations directly for batched matmul\n",
|
||||
" #\n",
|
||||
" # NOTE: We use .data (numpy arrays) here instead of Tensor operations\n",
|
||||
" # Why? This is INFERENCE-ONLY code (no gradients needed):\n",
|
||||
" # - Explicit: Makes it clear this is inference, not training\n",
|
||||
" # - Fast: Avoids autograd overhead (even if small)\n",
|
||||
" # - Standard: Production LLMs (vLLM, llama.cpp) do the same\n",
|
||||
" #\n",
|
||||
" # If this were training, we'd use Tensor operations for gradient flow.\n",
|
||||
" # But in generation (inference), .data is the right choice.\n",
|
||||
" \n",
|
||||
" # Q @ K^T: (batch, num_heads, 1, head_dim) @ (batch, num_heads, head_dim, seq_len)\n",
|
||||
" # → (batch, num_heads, 1, seq_len)\n",
|
||||
" K_transposed = np.transpose(K_all.data, (0, 1, 3, 2))\n",
|
||||
" scores = np.matmul(Q_heads.data, K_transposed)\n",
|
||||
" K_transposed = np.transpose(K_all.data, (0, 1, 3, 2)) # .data = numpy array\n",
|
||||
" scores = np.matmul(Q_heads.data, K_transposed) # Pure numpy matmul\n",
|
||||
" \n",
|
||||
" # Scale by sqrt(head_dim)\n",
|
||||
" scores = scores / np.sqrt(head_dim)\n",
|
||||
@@ -1211,7 +1299,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fb549b54",
|
||||
"id": "969b4e1c",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -1227,7 +1315,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9f792a3d",
|
||||
"id": "2c198422",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 2,
|
||||
"nbgrader": {
|
||||
@@ -1297,7 +1385,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ce64525c",
|
||||
"id": "5c56c36a",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -1311,7 +1399,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d693cda0",
|
||||
"id": "fbc1c29f",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -1396,7 +1484,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fbaff03f",
|
||||
"id": "e1b4fcb9",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 2
|
||||
},
|
||||
@@ -1408,7 +1496,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c22b893c",
|
||||
"id": "ff6d655d",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
|
||||
@@ -264,7 +264,7 @@ This design enables **O(1) updates** - just write to the next position!
|
||||
class KVCache:
|
||||
"""
|
||||
Efficient key-value cache for autoregressive generation.
|
||||
|
||||
|
||||
Stores K,V matrices for each transformer layer to avoid recomputation
|
||||
during sequential token generation. This is THE critical optimization
|
||||
that makes production language model serving economically viable.
|
||||
@@ -291,13 +291,13 @@ class KVCache:
|
||||
...
|
||||
Layer N: [Key_cache, Value_cache]
|
||||
```
|
||||
|
||||
|
||||
Performance:
|
||||
- Update: O(1) - just index assignment
|
||||
- Get: O(1) - just slicing (no data copy)
|
||||
- Memory: O(num_layers × batch × heads × max_seq × head_dim)
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, batch_size: int, max_seq_len: int, num_layers: int,
|
||||
num_heads: int, head_dim: int):
|
||||
"""
|
||||
@@ -353,7 +353,7 @@ class KVCache:
|
||||
|
||||
self.caches.append((key_cache, value_cache))
|
||||
### END SOLUTION
|
||||
|
||||
|
||||
def update(self, layer_idx: int, key: Tensor, value: Tensor) -> None:
|
||||
"""
|
||||
Update cache with new key-value pairs for given layer.
|
||||
@@ -418,7 +418,7 @@ class KVCache:
|
||||
|
||||
# Note: seq_pos is advanced externally via advance() after all layers process
|
||||
### END SOLUTION
|
||||
|
||||
|
||||
def get(self, layer_idx: int) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Retrieve cached key-value pairs for attention computation.
|
||||
@@ -484,48 +484,48 @@ class KVCache:
|
||||
|
||||
return cached_keys, cached_values
|
||||
### END SOLUTION
|
||||
|
||||
|
||||
def advance(self) -> None:
|
||||
"""
|
||||
Advance sequence position after processing current token.
|
||||
|
||||
|
||||
Call this after all layers have processed the current token and
|
||||
updated their caches. This moves the write pointer forward.
|
||||
"""
|
||||
self.seq_pos += 1
|
||||
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
Reset cache for new generation sequence.
|
||||
|
||||
|
||||
Call this when starting a new generation (new prompt).
|
||||
Resets the sequence position counter and optionally zeros cache data.
|
||||
"""
|
||||
self.seq_pos = 0
|
||||
|
||||
|
||||
# Zero out caches for clean state (helps with debugging)
|
||||
for layer_idx in range(self.num_layers):
|
||||
key_cache, value_cache = self.caches[layer_idx]
|
||||
key_cache.data.fill(0.0)
|
||||
value_cache.data.fill(0.0)
|
||||
|
||||
|
||||
def get_memory_usage(self) -> Dict[str, float]:
|
||||
"""
|
||||
Calculate memory usage of the cache system.
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with memory statistics in MB
|
||||
"""
|
||||
# Calculate size of one cache tensor
|
||||
cache_size = self.batch_size * self.num_heads * self.max_seq_len * self.head_dim
|
||||
bytes_per_float = 4 # float32
|
||||
|
||||
|
||||
# Each layer has key_cache + value_cache
|
||||
total_cache_tensors = self.num_layers * 2
|
||||
total_elements = cache_size * total_cache_tensors
|
||||
total_bytes = total_elements * bytes_per_float
|
||||
total_mb = total_bytes / (1024 * 1024)
|
||||
|
||||
|
||||
return {
|
||||
'total_mb': total_mb,
|
||||
'per_layer_mb': total_mb / self.num_layers,
|
||||
@@ -667,14 +667,14 @@ def enable_kv_cache(batch_size: int, max_seq_len: int, num_layers: int,
|
||||
This function creates a properly sized cache for the model architecture.
|
||||
Call this before starting generation, then pass the cache to your
|
||||
generation loop.
|
||||
|
||||
|
||||
Args:
|
||||
batch_size: Number of sequences to generate simultaneously
|
||||
max_seq_len: Maximum sequence length to support
|
||||
num_layers: Number of transformer layers in model
|
||||
num_heads: Number of attention heads per layer
|
||||
head_dim: Dimension per attention head (usually embed_dim // num_heads)
|
||||
|
||||
|
||||
Returns:
|
||||
KVCache instance ready for use
|
||||
|
||||
@@ -958,28 +958,107 @@ def enable_kv_cache(model):
|
||||
"""
|
||||
Cached attention forward pass with REAL speedup!
|
||||
|
||||
Strategy:
|
||||
- Training (seq_len > 1): Use original path (full gradients)
|
||||
- Generation (seq_len = 1): Use cache for 10-15x speedup
|
||||
PATH SELECTION STRATEGY (Key to Understanding KV Caching):
|
||||
──────────────────────────────────────────────────────────
|
||||
|
||||
Cache operations use .data (inference-only, no grad tracking).
|
||||
Training path unchanged (full gradient flow preserved).
|
||||
We have THREE possible paths through attention:
|
||||
|
||||
1️⃣ TRAINING PATH (seq_len > 1):
|
||||
- Input: Full sequence of tokens (e.g., 64 tokens)
|
||||
- Action: Use ORIGINAL attention (no caching)
|
||||
- Why: Need full gradient flow for backpropagation
|
||||
- Complexity: O(n²) but that's fine for training
|
||||
- Example: x.shape = (batch=1, seq=64, embed=128)
|
||||
|
||||
2️⃣ FIRST TOKEN PATH (seq_len == 1 AND cache empty):
|
||||
- Input: Single token (the first one in generation)
|
||||
- Action: Use ORIGINAL attention (initialize cache)
|
||||
- Why: Cache is empty, nothing to retrieve yet
|
||||
- Complexity: O(1) - only one token
|
||||
- Example: x.shape = (batch=1, seq=1, embed=128), cache.seq_pos=0
|
||||
|
||||
3️⃣ CACHED GENERATION PATH (seq_len == 1 AND cache populated):
|
||||
- Input: Single NEW token (during generation)
|
||||
- Action: Compute K,V for new token ONLY, retrieve history from cache
|
||||
- Why: This is where the speedup happens! O(n²) → O(n)
|
||||
- Complexity: O(n) - only compute for new token, reuse cache
|
||||
- Example: x.shape = (batch=1, seq=1, embed=128), cache.seq_pos=5
|
||||
|
||||
|
||||
WHY .data INSTEAD OF TENSOR OPERATIONS?
|
||||
────────────────────────────────────────
|
||||
|
||||
In the cached path, we use numpy via .data for three reasons:
|
||||
|
||||
1. **Explicit Intent**: Makes it crystal clear this is inference-only
|
||||
- Training: Uses Tensor operations → gradients tracked
|
||||
- Inference: Uses .data → no gradient overhead
|
||||
|
||||
2. **Performance**: Avoids any autograd bookkeeping
|
||||
- Even if small, every bit counts in generation
|
||||
- Production LLMs (vLLM, llama.cpp) use similar patterns
|
||||
|
||||
3. **Educational Clarity**: Shows students the distinction
|
||||
- "When do I need gradients?" (training)
|
||||
- "When can I skip them?" (inference)
|
||||
|
||||
We COULD use Tensor operations with requires_grad=False, but .data
|
||||
is more explicit and is the industry-standard pattern.
|
||||
|
||||
|
||||
THE O(n²) → O(n) TRANSFORMATION:
|
||||
─────────────────────────────────
|
||||
|
||||
WITHOUT Cache (Standard Attention):
|
||||
Step 1: Process token 1 → Compute attention for 1 token (1² = 1 op)
|
||||
Step 2: Process tokens 1-2 → Compute attention for 2 tokens (2² = 4 ops)
|
||||
Step 3: Process tokens 1-3 → Compute attention for 3 tokens (3² = 9 ops)
|
||||
...
|
||||
Step N: Process tokens 1-N → Compute attention for N tokens (N² ops)
|
||||
|
||||
Total: 1 + 4 + 9 + ... + N² = O(N³) across all steps!
|
||||
|
||||
WITH Cache (Our Implementation):
|
||||
Step 1: Process token 1 → Compute K,V for token 1, cache it (1 op)
|
||||
Step 2: Process token 2 → Compute K,V for token 2, retrieve 1 (2 ops)
|
||||
Step 3: Process token 3 → Compute K,V for token 3, retrieve 1-2 (3 ops)
|
||||
...
|
||||
Step N: Process token N → Compute K,V for token N, retrieve 1-(N-1) (N ops)
|
||||
|
||||
Total: 1 + 2 + 3 + ... + N = O(N²) across all steps!
|
||||
|
||||
That's why we see 5-7x speedup on short sequences, and 10-15x on longer ones!
|
||||
"""
|
||||
from tinytorch.core.tensor import Tensor
|
||||
import numpy as np
|
||||
|
||||
seq_len = x.shape[1]
|
||||
|
||||
# TRAINING PATH: Full sequence, use original attention (preserves gradients)
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
# PATH SELECTION: Choose between training, first token, or cached
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
|
||||
# PATH 1: TRAINING (seq_len > 1)
|
||||
# ───────────────────────────────────
|
||||
# Input is a full sequence (e.g., 64 tokens during training)
|
||||
# We MUST use original attention to preserve gradient flow
|
||||
# No caching during training - we need backprop through everything
|
||||
if seq_len > 1:
|
||||
return original_forward(x, mask)
|
||||
return original_forward(x, mask) # O(n²) but preserves gradients
|
||||
|
||||
# GENERATION PATH: Single token, use KV cache for speedup
|
||||
# This is inference-only, so we use .data for performance
|
||||
|
||||
# Check if cache is empty (first token) - if so, use original path
|
||||
# PATH 2: FIRST TOKEN (seq_len == 1, cache empty)
|
||||
# ────────────────────────────────────────────────
|
||||
# This is the very first token in generation (cache.seq_pos == 0)
|
||||
# Cache is empty, so there's nothing to retrieve yet
|
||||
# Use original attention to process this token, which will populate cache
|
||||
if cache_obj.seq_pos == 0:
|
||||
return original_forward(x, mask)
|
||||
return original_forward(x, mask) # O(1) - just one token
|
||||
|
||||
# PATH 3: CACHED GENERATION (seq_len == 1, cache populated)
|
||||
# ──────────────────────────────────────────────────────────
|
||||
# This is a NEW token during generation (cache has history)
|
||||
# We can now use the cache for massive speedup!
|
||||
# Compute K,V for ONLY this new token, retrieve cached history
|
||||
|
||||
# Get attention layer (assumes block.attention has the attention object)
|
||||
attention = block.attention
|
||||
@@ -1012,13 +1091,22 @@ def enable_kv_cache(model):
|
||||
K_all, V_all = cache_obj.get(layer_idx)
|
||||
|
||||
# Step 5: Compute attention using new Q with ALL cached K, V
|
||||
# ─────────────────────────────────────────────────────────
|
||||
# Scaled dot-product attention: softmax(Q @ K^T / sqrt(d_k)) @ V
|
||||
# Use numpy operations directly for batched matmul
|
||||
#
|
||||
# NOTE: We use .data (numpy arrays) here instead of Tensor operations
|
||||
# Why? This is INFERENCE-ONLY code (no gradients needed):
|
||||
# - Explicit: Makes it clear this is inference, not training
|
||||
# - Fast: Avoids autograd overhead (even if small)
|
||||
# - Standard: Production LLMs (vLLM, llama.cpp) do the same
|
||||
#
|
||||
# If this were training, we'd use Tensor operations for gradient flow.
|
||||
# But in generation (inference), .data is the right choice.
|
||||
|
||||
# Q @ K^T: (batch, num_heads, 1, head_dim) @ (batch, num_heads, head_dim, seq_len)
|
||||
# → (batch, num_heads, 1, seq_len)
|
||||
K_transposed = np.transpose(K_all.data, (0, 1, 3, 2))
|
||||
scores = np.matmul(Q_heads.data, K_transposed)
|
||||
K_transposed = np.transpose(K_all.data, (0, 1, 3, 2)) # .data = numpy array
|
||||
scores = np.matmul(Q_heads.data, K_transposed) # Pure numpy matmul
|
||||
|
||||
# Scale by sqrt(head_dim)
|
||||
scores = scores / np.sqrt(head_dim)
|
||||
@@ -1157,7 +1245,7 @@ def test_unit_noninvasive_integration():
|
||||
|
||||
# Test 4: Can re-enable
|
||||
print(" Test 4: Re-enable caching")
|
||||
cache2 = enable_kv_cache(model)
|
||||
_ = enable_kv_cache(model)
|
||||
assert model._cache_enabled == True, "Cache should be re-enabled"
|
||||
|
||||
print("✅ Non-invasive cache integration works correctly!")
|
||||
|
||||
150
tinytorch/generation/kv_cache.py
generated
150
tinytorch/generation/kv_cache.py
generated
@@ -29,7 +29,7 @@ from ..core.tensor import Tensor
|
||||
class KVCache:
|
||||
"""
|
||||
Efficient key-value cache for autoregressive generation.
|
||||
|
||||
|
||||
Stores K,V matrices for each transformer layer to avoid recomputation
|
||||
during sequential token generation. This is THE critical optimization
|
||||
that makes production language model serving economically viable.
|
||||
@@ -56,13 +56,13 @@ class KVCache:
|
||||
...
|
||||
Layer N: [Key_cache, Value_cache]
|
||||
```
|
||||
|
||||
|
||||
Performance:
|
||||
- Update: O(1) - just index assignment
|
||||
- Get: O(1) - just slicing (no data copy)
|
||||
- Memory: O(num_layers × batch × heads × max_seq × head_dim)
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, batch_size: int, max_seq_len: int, num_layers: int,
|
||||
num_heads: int, head_dim: int):
|
||||
"""
|
||||
@@ -118,7 +118,7 @@ class KVCache:
|
||||
|
||||
self.caches.append((key_cache, value_cache))
|
||||
### END SOLUTION
|
||||
|
||||
|
||||
def update(self, layer_idx: int, key: Tensor, value: Tensor) -> None:
|
||||
"""
|
||||
Update cache with new key-value pairs for given layer.
|
||||
@@ -183,7 +183,7 @@ class KVCache:
|
||||
|
||||
# Note: seq_pos is advanced externally via advance() after all layers process
|
||||
### END SOLUTION
|
||||
|
||||
|
||||
def get(self, layer_idx: int) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Retrieve cached key-value pairs for attention computation.
|
||||
@@ -249,48 +249,48 @@ class KVCache:
|
||||
|
||||
return cached_keys, cached_values
|
||||
### END SOLUTION
|
||||
|
||||
|
||||
def advance(self) -> None:
|
||||
"""
|
||||
Advance sequence position after processing current token.
|
||||
|
||||
|
||||
Call this after all layers have processed the current token and
|
||||
updated their caches. This moves the write pointer forward.
|
||||
"""
|
||||
self.seq_pos += 1
|
||||
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
Reset cache for new generation sequence.
|
||||
|
||||
|
||||
Call this when starting a new generation (new prompt).
|
||||
Resets the sequence position counter and optionally zeros cache data.
|
||||
"""
|
||||
self.seq_pos = 0
|
||||
|
||||
|
||||
# Zero out caches for clean state (helps with debugging)
|
||||
for layer_idx in range(self.num_layers):
|
||||
key_cache, value_cache = self.caches[layer_idx]
|
||||
key_cache.data.fill(0.0)
|
||||
value_cache.data.fill(0.0)
|
||||
|
||||
|
||||
def get_memory_usage(self) -> Dict[str, float]:
|
||||
"""
|
||||
Calculate memory usage of the cache system.
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary with memory statistics in MB
|
||||
"""
|
||||
# Calculate size of one cache tensor
|
||||
cache_size = self.batch_size * self.num_heads * self.max_seq_len * self.head_dim
|
||||
bytes_per_float = 4 # float32
|
||||
|
||||
|
||||
# Each layer has key_cache + value_cache
|
||||
total_cache_tensors = self.num_layers * 2
|
||||
total_elements = cache_size * total_cache_tensors
|
||||
total_bytes = total_elements * bytes_per_float
|
||||
total_mb = total_bytes / (1024 * 1024)
|
||||
|
||||
|
||||
return {
|
||||
'total_mb': total_mb,
|
||||
'per_layer_mb': total_mb / self.num_layers,
|
||||
@@ -307,14 +307,14 @@ def enable_kv_cache(batch_size: int, max_seq_len: int, num_layers: int,
|
||||
This function creates a properly sized cache for the model architecture.
|
||||
Call this before starting generation, then pass the cache to your
|
||||
generation loop.
|
||||
|
||||
|
||||
Args:
|
||||
batch_size: Number of sequences to generate simultaneously
|
||||
max_seq_len: Maximum sequence length to support
|
||||
num_layers: Number of transformer layers in model
|
||||
num_heads: Number of attention heads per layer
|
||||
head_dim: Dimension per attention head (usually embed_dim // num_heads)
|
||||
|
||||
|
||||
Returns:
|
||||
KVCache instance ready for use
|
||||
|
||||
@@ -447,28 +447,107 @@ def enable_kv_cache(model):
|
||||
"""
|
||||
Cached attention forward pass with REAL speedup!
|
||||
|
||||
Strategy:
|
||||
- Training (seq_len > 1): Use original path (full gradients)
|
||||
- Generation (seq_len = 1): Use cache for 10-15x speedup
|
||||
PATH SELECTION STRATEGY (Key to Understanding KV Caching):
|
||||
──────────────────────────────────────────────────────────
|
||||
|
||||
Cache operations use .data (inference-only, no grad tracking).
|
||||
Training path unchanged (full gradient flow preserved).
|
||||
We have THREE possible paths through attention:
|
||||
|
||||
1️⃣ TRAINING PATH (seq_len > 1):
|
||||
- Input: Full sequence of tokens (e.g., 64 tokens)
|
||||
- Action: Use ORIGINAL attention (no caching)
|
||||
- Why: Need full gradient flow for backpropagation
|
||||
- Complexity: O(n²) but that's fine for training
|
||||
- Example: x.shape = (batch=1, seq=64, embed=128)
|
||||
|
||||
2️⃣ FIRST TOKEN PATH (seq_len == 1 AND cache empty):
|
||||
- Input: Single token (the first one in generation)
|
||||
- Action: Use ORIGINAL attention (initialize cache)
|
||||
- Why: Cache is empty, nothing to retrieve yet
|
||||
- Complexity: O(1) - only one token
|
||||
- Example: x.shape = (batch=1, seq=1, embed=128), cache.seq_pos=0
|
||||
|
||||
3️⃣ CACHED GENERATION PATH (seq_len == 1 AND cache populated):
|
||||
- Input: Single NEW token (during generation)
|
||||
- Action: Compute K,V for new token ONLY, retrieve history from cache
|
||||
- Why: This is where the speedup happens! O(n²) → O(n)
|
||||
- Complexity: O(n) - only compute for new token, reuse cache
|
||||
- Example: x.shape = (batch=1, seq=1, embed=128), cache.seq_pos=5
|
||||
|
||||
|
||||
WHY .data INSTEAD OF TENSOR OPERATIONS?
|
||||
────────────────────────────────────────
|
||||
|
||||
In the cached path, we use numpy via .data for three reasons:
|
||||
|
||||
1. **Explicit Intent**: Makes it crystal clear this is inference-only
|
||||
- Training: Uses Tensor operations → gradients tracked
|
||||
- Inference: Uses .data → no gradient overhead
|
||||
|
||||
2. **Performance**: Avoids any autograd bookkeeping
|
||||
- Even if small, every bit counts in generation
|
||||
- Production LLMs (vLLM, llama.cpp) use similar patterns
|
||||
|
||||
3. **Educational Clarity**: Shows students the distinction
|
||||
- "When do I need gradients?" (training)
|
||||
- "When can I skip them?" (inference)
|
||||
|
||||
We COULD use Tensor operations with requires_grad=False, but .data
|
||||
is more explicit and is the industry-standard pattern.
|
||||
|
||||
|
||||
THE O(n²) → O(n) TRANSFORMATION:
|
||||
─────────────────────────────────
|
||||
|
||||
WITHOUT Cache (Standard Attention):
|
||||
Step 1: Process token 1 → Compute attention for 1 token (1² = 1 op)
|
||||
Step 2: Process tokens 1-2 → Compute attention for 2 tokens (2² = 4 ops)
|
||||
Step 3: Process tokens 1-3 → Compute attention for 3 tokens (3² = 9 ops)
|
||||
...
|
||||
Step N: Process tokens 1-N → Compute attention for N tokens (N² ops)
|
||||
|
||||
Total: 1 + 4 + 9 + ... + N² = O(N³) across all steps!
|
||||
|
||||
WITH Cache (Our Implementation):
|
||||
Step 1: Process token 1 → Compute K,V for token 1, cache it (1 op)
|
||||
Step 2: Process token 2 → Compute K,V for token 2, retrieve 1 (2 ops)
|
||||
Step 3: Process token 3 → Compute K,V for token 3, retrieve 1-2 (3 ops)
|
||||
...
|
||||
Step N: Process token N → Compute K,V for token N, retrieve 1-(N-1) (N ops)
|
||||
|
||||
Total: 1 + 2 + 3 + ... + N = O(N²) across all steps!
|
||||
|
||||
That's why we see 5-7x speedup on short sequences, and 10-15x on longer ones!
|
||||
"""
|
||||
from tinytorch.core.tensor import Tensor
|
||||
import numpy as np
|
||||
|
||||
seq_len = x.shape[1]
|
||||
|
||||
# TRAINING PATH: Full sequence, use original attention (preserves gradients)
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
# PATH SELECTION: Choose between training, first token, or cached
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
|
||||
# PATH 1: TRAINING (seq_len > 1)
|
||||
# ───────────────────────────────────
|
||||
# Input is a full sequence (e.g., 64 tokens during training)
|
||||
# We MUST use original attention to preserve gradient flow
|
||||
# No caching during training - we need backprop through everything
|
||||
if seq_len > 1:
|
||||
return original_forward(x, mask)
|
||||
return original_forward(x, mask) # O(n²) but preserves gradients
|
||||
|
||||
# GENERATION PATH: Single token, use KV cache for speedup
|
||||
# This is inference-only, so we use .data for performance
|
||||
|
||||
# Check if cache is empty (first token) - if so, use original path
|
||||
# PATH 2: FIRST TOKEN (seq_len == 1, cache empty)
|
||||
# ────────────────────────────────────────────────
|
||||
# This is the very first token in generation (cache.seq_pos == 0)
|
||||
# Cache is empty, so there's nothing to retrieve yet
|
||||
# Use original attention to process this token, which will populate cache
|
||||
if cache_obj.seq_pos == 0:
|
||||
return original_forward(x, mask)
|
||||
return original_forward(x, mask) # O(1) - just one token
|
||||
|
||||
# PATH 3: CACHED GENERATION (seq_len == 1, cache populated)
|
||||
# ──────────────────────────────────────────────────────────
|
||||
# This is a NEW token during generation (cache has history)
|
||||
# We can now use the cache for massive speedup!
|
||||
# Compute K,V for ONLY this new token, retrieve cached history
|
||||
|
||||
# Get attention layer (assumes block.attention has the attention object)
|
||||
attention = block.attention
|
||||
@@ -501,13 +580,22 @@ def enable_kv_cache(model):
|
||||
K_all, V_all = cache_obj.get(layer_idx)
|
||||
|
||||
# Step 5: Compute attention using new Q with ALL cached K, V
|
||||
# ─────────────────────────────────────────────────────────
|
||||
# Scaled dot-product attention: softmax(Q @ K^T / sqrt(d_k)) @ V
|
||||
# Use numpy operations directly for batched matmul
|
||||
#
|
||||
# NOTE: We use .data (numpy arrays) here instead of Tensor operations
|
||||
# Why? This is INFERENCE-ONLY code (no gradients needed):
|
||||
# - Explicit: Makes it clear this is inference, not training
|
||||
# - Fast: Avoids autograd overhead (even if small)
|
||||
# - Standard: Production LLMs (vLLM, llama.cpp) do the same
|
||||
#
|
||||
# If this were training, we'd use Tensor operations for gradient flow.
|
||||
# But in generation (inference), .data is the right choice.
|
||||
|
||||
# Q @ K^T: (batch, num_heads, 1, head_dim) @ (batch, num_heads, head_dim, seq_len)
|
||||
# → (batch, num_heads, 1, seq_len)
|
||||
K_transposed = np.transpose(K_all.data, (0, 1, 3, 2))
|
||||
scores = np.matmul(Q_heads.data, K_transposed)
|
||||
K_transposed = np.transpose(K_all.data, (0, 1, 3, 2)) # .data = numpy array
|
||||
scores = np.matmul(Q_heads.data, K_transposed) # Pure numpy matmul
|
||||
|
||||
# Scale by sqrt(head_dim)
|
||||
scores = scores / np.sqrt(head_dim)
|
||||
|
||||
Reference in New Issue
Block a user