diff --git a/modules/source/14_kvcaching/kvcaching_dev.ipynb b/modules/source/14_kvcaching/kvcaching_dev.ipynb index 32a34198..b86c77f5 100644 --- a/modules/source/14_kvcaching/kvcaching_dev.ipynb +++ b/modules/source/14_kvcaching/kvcaching_dev.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "bd52e3da", + "id": "1078513e", "metadata": { "cell_marker": "\"\"\"" }, @@ -52,7 +52,7 @@ { "cell_type": "code", "execution_count": null, - "id": "26b79392", + "id": "266270f3", "metadata": {}, "outputs": [], "source": [ @@ -69,7 +69,7 @@ }, { "cell_type": "markdown", - "id": "7cd54d44", + "id": "06ca957c", "metadata": { "cell_marker": "\"\"\"" }, @@ -126,7 +126,7 @@ }, { "cell_type": "markdown", - "id": "eb00159f", + "id": "dc896d3f", "metadata": { "cell_marker": "\"\"\"" }, @@ -199,7 +199,7 @@ }, { "cell_type": "markdown", - "id": "4dd6bb57", + "id": "c3feca5a", "metadata": { "cell_marker": "\"\"\"", "lines_to_next_cell": 1 @@ -278,7 +278,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4b00b030", + "id": "6d054a8c", "metadata": { "lines_to_next_cell": 1, "nbgrader": { @@ -293,7 +293,7 @@ "class KVCache:\n", " \"\"\"\n", " Efficient key-value cache for autoregressive generation.\n", - " \n", + "\n", " Stores K,V matrices for each transformer layer to avoid recomputation\n", " during sequential token generation. This is THE critical optimization\n", " that makes production language model serving economically viable.\n", @@ -320,13 +320,13 @@ " ...\n", " Layer N: [Key_cache, Value_cache]\n", " ```\n", - " \n", + "\n", " Performance:\n", " - Update: O(1) - just index assignment\n", " - Get: O(1) - just slicing (no data copy)\n", " - Memory: O(num_layers × batch × heads × max_seq × head_dim)\n", " \"\"\"\n", - " \n", + "\n", " def __init__(self, batch_size: int, max_seq_len: int, num_layers: int,\n", " num_heads: int, head_dim: int):\n", " \"\"\"\n", @@ -382,7 +382,7 @@ "\n", " self.caches.append((key_cache, value_cache))\n", " ### END SOLUTION\n", - " \n", + "\n", " def update(self, layer_idx: int, key: Tensor, value: Tensor) -> None:\n", " \"\"\"\n", " Update cache with new key-value pairs for given layer.\n", @@ -447,7 +447,7 @@ "\n", " # Note: seq_pos is advanced externally via advance() after all layers process\n", " ### END SOLUTION\n", - " \n", + "\n", " def get(self, layer_idx: int) -> Tuple[Tensor, Tensor]:\n", " \"\"\"\n", " Retrieve cached key-value pairs for attention computation.\n", @@ -513,48 +513,48 @@ "\n", " return cached_keys, cached_values\n", " ### END SOLUTION\n", - " \n", + "\n", " def advance(self) -> None:\n", " \"\"\"\n", " Advance sequence position after processing current token.\n", - " \n", + "\n", " Call this after all layers have processed the current token and\n", " updated their caches. This moves the write pointer forward.\n", " \"\"\"\n", " self.seq_pos += 1\n", - " \n", + "\n", " def reset(self) -> None:\n", " \"\"\"\n", " Reset cache for new generation sequence.\n", - " \n", + "\n", " Call this when starting a new generation (new prompt).\n", " Resets the sequence position counter and optionally zeros cache data.\n", " \"\"\"\n", " self.seq_pos = 0\n", - " \n", + "\n", " # Zero out caches for clean state (helps with debugging)\n", " for layer_idx in range(self.num_layers):\n", " key_cache, value_cache = self.caches[layer_idx]\n", " key_cache.data.fill(0.0)\n", " value_cache.data.fill(0.0)\n", - " \n", + "\n", " def get_memory_usage(self) -> Dict[str, float]:\n", " \"\"\"\n", " Calculate memory usage of the cache system.\n", - " \n", + "\n", " Returns:\n", " Dictionary with memory statistics in MB\n", " \"\"\"\n", " # Calculate size of one cache tensor\n", " cache_size = self.batch_size * self.num_heads * self.max_seq_len * self.head_dim\n", " bytes_per_float = 4 # float32\n", - " \n", + "\n", " # Each layer has key_cache + value_cache\n", " total_cache_tensors = self.num_layers * 2\n", " total_elements = cache_size * total_cache_tensors\n", " total_bytes = total_elements * bytes_per_float\n", " total_mb = total_bytes / (1024 * 1024)\n", - " \n", + "\n", " return {\n", " 'total_mb': total_mb,\n", " 'per_layer_mb': total_mb / self.num_layers,\n", @@ -565,7 +565,7 @@ }, { "cell_type": "markdown", - "id": "d07be1e1", + "id": "94cee9a8", "metadata": { "cell_marker": "\"\"\"", "lines_to_next_cell": 1 @@ -581,7 +581,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7a1814c1", + "id": "62409497", "metadata": { "nbgrader": { "grade": true, @@ -669,7 +669,7 @@ }, { "cell_type": "markdown", - "id": "2b19f7ea", + "id": "39ea5911", "metadata": { "cell_marker": "\"\"\"", "lines_to_next_cell": 1 @@ -716,7 +716,7 @@ { "cell_type": "code", "execution_count": null, - "id": "cccb9a0d", + "id": "7f453db6", "metadata": { "lines_to_next_cell": 1 }, @@ -731,14 +731,14 @@ " This function creates a properly sized cache for the model architecture.\n", " Call this before starting generation, then pass the cache to your\n", " generation loop.\n", - " \n", + "\n", " Args:\n", " batch_size: Number of sequences to generate simultaneously\n", " max_seq_len: Maximum sequence length to support\n", " num_layers: Number of transformer layers in model\n", " num_heads: Number of attention heads per layer\n", " head_dim: Dimension per attention head (usually embed_dim // num_heads)\n", - " \n", + "\n", " Returns:\n", " KVCache instance ready for use\n", " \n", @@ -778,7 +778,7 @@ }, { "cell_type": "markdown", - "id": "517247d2", + "id": "80402a25", "metadata": { "cell_marker": "\"\"\"", "lines_to_next_cell": 1 @@ -794,7 +794,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e6fba64c", + "id": "fc77d324", "metadata": { "nbgrader": { "grade": true, @@ -857,7 +857,7 @@ }, { "cell_type": "markdown", - "id": "4fa0c25c", + "id": "df7728e0", "metadata": { "cell_marker": "\"\"\"" }, @@ -924,7 +924,7 @@ }, { "cell_type": "markdown", - "id": "6c76f95c", + "id": "1df5b0fc", "metadata": { "cell_marker": "\"\"\"", "lines_to_next_cell": 1 @@ -960,7 +960,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ba55f283", + "id": "7a8281fd", "metadata": { "nbgrader": { "grade": false, @@ -1066,28 +1066,107 @@ " \"\"\"\n", " Cached attention forward pass with REAL speedup!\n", " \n", - " Strategy:\n", - " - Training (seq_len > 1): Use original path (full gradients)\n", - " - Generation (seq_len = 1): Use cache for 10-15x speedup\n", + " PATH SELECTION STRATEGY (Key to Understanding KV Caching):\n", + " ──────────────────────────────────────────────────────────\n", " \n", - " Cache operations use .data (inference-only, no grad tracking).\n", - " Training path unchanged (full gradient flow preserved).\n", + " We have THREE possible paths through attention:\n", + " \n", + " 1️⃣ TRAINING PATH (seq_len > 1):\n", + " - Input: Full sequence of tokens (e.g., 64 tokens)\n", + " - Action: Use ORIGINAL attention (no caching)\n", + " - Why: Need full gradient flow for backpropagation\n", + " - Complexity: O(n²) but that's fine for training\n", + " - Example: x.shape = (batch=1, seq=64, embed=128)\n", + " \n", + " 2️⃣ FIRST TOKEN PATH (seq_len == 1 AND cache empty):\n", + " - Input: Single token (the first one in generation)\n", + " - Action: Use ORIGINAL attention (initialize cache)\n", + " - Why: Cache is empty, nothing to retrieve yet\n", + " - Complexity: O(1) - only one token\n", + " - Example: x.shape = (batch=1, seq=1, embed=128), cache.seq_pos=0\n", + " \n", + " 3️⃣ CACHED GENERATION PATH (seq_len == 1 AND cache populated):\n", + " - Input: Single NEW token (during generation)\n", + " - Action: Compute K,V for new token ONLY, retrieve history from cache\n", + " - Why: This is where the speedup happens! O(n²) → O(n)\n", + " - Complexity: O(n) - only compute for new token, reuse cache\n", + " - Example: x.shape = (batch=1, seq=1, embed=128), cache.seq_pos=5\n", + " \n", + " \n", + " WHY .data INSTEAD OF TENSOR OPERATIONS?\n", + " ────────────────────────────────────────\n", + " \n", + " In the cached path, we use numpy via .data for three reasons:\n", + " \n", + " 1. **Explicit Intent**: Makes it crystal clear this is inference-only\n", + " - Training: Uses Tensor operations → gradients tracked\n", + " - Inference: Uses .data → no gradient overhead\n", + " \n", + " 2. **Performance**: Avoids any autograd bookkeeping\n", + " - Even if small, every bit counts in generation\n", + " - Production LLMs (vLLM, llama.cpp) use similar patterns\n", + " \n", + " 3. **Educational Clarity**: Shows students the distinction\n", + " - \"When do I need gradients?\" (training)\n", + " - \"When can I skip them?\" (inference)\n", + " \n", + " We COULD use Tensor operations with requires_grad=False, but .data\n", + " is more explicit and is the industry-standard pattern.\n", + " \n", + " \n", + " THE O(n²) → O(n) TRANSFORMATION:\n", + " ─────────────────────────────────\n", + " \n", + " WITHOUT Cache (Standard Attention):\n", + " Step 1: Process token 1 → Compute attention for 1 token (1² = 1 op)\n", + " Step 2: Process tokens 1-2 → Compute attention for 2 tokens (2² = 4 ops)\n", + " Step 3: Process tokens 1-3 → Compute attention for 3 tokens (3² = 9 ops)\n", + " ...\n", + " Step N: Process tokens 1-N → Compute attention for N tokens (N² ops)\n", + " \n", + " Total: 1 + 4 + 9 + ... + N² = O(N³) across all steps!\n", + " \n", + " WITH Cache (Our Implementation):\n", + " Step 1: Process token 1 → Compute K,V for token 1, cache it (1 op)\n", + " Step 2: Process token 2 → Compute K,V for token 2, retrieve 1 (2 ops)\n", + " Step 3: Process token 3 → Compute K,V for token 3, retrieve 1-2 (3 ops)\n", + " ...\n", + " Step N: Process token N → Compute K,V for token N, retrieve 1-(N-1) (N ops)\n", + " \n", + " Total: 1 + 2 + 3 + ... + N = O(N²) across all steps!\n", + " \n", + " That's why we see 5-7x speedup on short sequences, and 10-15x on longer ones!\n", " \"\"\"\n", " from tinytorch.core.tensor import Tensor\n", " import numpy as np\n", " \n", " seq_len = x.shape[1]\n", " \n", - " # TRAINING PATH: Full sequence, use original attention (preserves gradients)\n", + " # ═══════════════════════════════════════════════════════════════\n", + " # PATH SELECTION: Choose between training, first token, or cached\n", + " # ═══════════════════════════════════════════════════════════════\n", + " \n", + " # PATH 1: TRAINING (seq_len > 1)\n", + " # ───────────────────────────────────\n", + " # Input is a full sequence (e.g., 64 tokens during training)\n", + " # We MUST use original attention to preserve gradient flow\n", + " # No caching during training - we need backprop through everything\n", " if seq_len > 1:\n", - " return original_forward(x, mask)\n", + " return original_forward(x, mask) # O(n²) but preserves gradients\n", " \n", - " # GENERATION PATH: Single token, use KV cache for speedup\n", - " # This is inference-only, so we use .data for performance\n", - " \n", - " # Check if cache is empty (first token) - if so, use original path\n", + " # PATH 2: FIRST TOKEN (seq_len == 1, cache empty)\n", + " # ────────────────────────────────────────────────\n", + " # This is the very first token in generation (cache.seq_pos == 0)\n", + " # Cache is empty, so there's nothing to retrieve yet\n", + " # Use original attention to process this token, which will populate cache\n", " if cache_obj.seq_pos == 0:\n", - " return original_forward(x, mask)\n", + " return original_forward(x, mask) # O(1) - just one token\n", + " \n", + " # PATH 3: CACHED GENERATION (seq_len == 1, cache populated)\n", + " # ──────────────────────────────────────────────────────────\n", + " # This is a NEW token during generation (cache has history)\n", + " # We can now use the cache for massive speedup!\n", + " # Compute K,V for ONLY this new token, retrieve cached history\n", " \n", " # Get attention layer (assumes block.attention has the attention object)\n", " attention = block.attention\n", @@ -1120,13 +1199,22 @@ " K_all, V_all = cache_obj.get(layer_idx)\n", " \n", " # Step 5: Compute attention using new Q with ALL cached K, V\n", + " # ─────────────────────────────────────────────────────────\n", " # Scaled dot-product attention: softmax(Q @ K^T / sqrt(d_k)) @ V\n", - " # Use numpy operations directly for batched matmul\n", + " #\n", + " # NOTE: We use .data (numpy arrays) here instead of Tensor operations\n", + " # Why? This is INFERENCE-ONLY code (no gradients needed):\n", + " # - Explicit: Makes it clear this is inference, not training\n", + " # - Fast: Avoids autograd overhead (even if small)\n", + " # - Standard: Production LLMs (vLLM, llama.cpp) do the same\n", + " #\n", + " # If this were training, we'd use Tensor operations for gradient flow.\n", + " # But in generation (inference), .data is the right choice.\n", " \n", " # Q @ K^T: (batch, num_heads, 1, head_dim) @ (batch, num_heads, head_dim, seq_len)\n", " # → (batch, num_heads, 1, seq_len)\n", - " K_transposed = np.transpose(K_all.data, (0, 1, 3, 2))\n", - " scores = np.matmul(Q_heads.data, K_transposed)\n", + " K_transposed = np.transpose(K_all.data, (0, 1, 3, 2)) # .data = numpy array\n", + " scores = np.matmul(Q_heads.data, K_transposed) # Pure numpy matmul\n", " \n", " # Scale by sqrt(head_dim)\n", " scores = scores / np.sqrt(head_dim)\n", @@ -1211,7 +1299,7 @@ }, { "cell_type": "markdown", - "id": "fb549b54", + "id": "969b4e1c", "metadata": { "cell_marker": "\"\"\"", "lines_to_next_cell": 1 @@ -1227,7 +1315,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9f792a3d", + "id": "2c198422", "metadata": { "lines_to_next_cell": 2, "nbgrader": { @@ -1297,7 +1385,7 @@ }, { "cell_type": "markdown", - "id": "ce64525c", + "id": "5c56c36a", "metadata": { "cell_marker": "\"\"\"", "lines_to_next_cell": 1 @@ -1311,7 +1399,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d693cda0", + "id": "fbc1c29f", "metadata": { "lines_to_next_cell": 1, "nbgrader": { @@ -1396,7 +1484,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fbaff03f", + "id": "e1b4fcb9", "metadata": { "lines_to_next_cell": 2 }, @@ -1408,7 +1496,7 @@ }, { "cell_type": "markdown", - "id": "c22b893c", + "id": "ff6d655d", "metadata": { "cell_marker": "\"\"\"" }, diff --git a/modules/source/14_kvcaching/kvcaching_dev.py b/modules/source/14_kvcaching/kvcaching_dev.py index daf16304..da678f5d 100644 --- a/modules/source/14_kvcaching/kvcaching_dev.py +++ b/modules/source/14_kvcaching/kvcaching_dev.py @@ -264,7 +264,7 @@ This design enables **O(1) updates** - just write to the next position! class KVCache: """ Efficient key-value cache for autoregressive generation. - + Stores K,V matrices for each transformer layer to avoid recomputation during sequential token generation. This is THE critical optimization that makes production language model serving economically viable. @@ -291,13 +291,13 @@ class KVCache: ... Layer N: [Key_cache, Value_cache] ``` - + Performance: - Update: O(1) - just index assignment - Get: O(1) - just slicing (no data copy) - Memory: O(num_layers × batch × heads × max_seq × head_dim) """ - + def __init__(self, batch_size: int, max_seq_len: int, num_layers: int, num_heads: int, head_dim: int): """ @@ -353,7 +353,7 @@ class KVCache: self.caches.append((key_cache, value_cache)) ### END SOLUTION - + def update(self, layer_idx: int, key: Tensor, value: Tensor) -> None: """ Update cache with new key-value pairs for given layer. @@ -418,7 +418,7 @@ class KVCache: # Note: seq_pos is advanced externally via advance() after all layers process ### END SOLUTION - + def get(self, layer_idx: int) -> Tuple[Tensor, Tensor]: """ Retrieve cached key-value pairs for attention computation. @@ -484,48 +484,48 @@ class KVCache: return cached_keys, cached_values ### END SOLUTION - + def advance(self) -> None: """ Advance sequence position after processing current token. - + Call this after all layers have processed the current token and updated their caches. This moves the write pointer forward. """ self.seq_pos += 1 - + def reset(self) -> None: """ Reset cache for new generation sequence. - + Call this when starting a new generation (new prompt). Resets the sequence position counter and optionally zeros cache data. """ self.seq_pos = 0 - + # Zero out caches for clean state (helps with debugging) for layer_idx in range(self.num_layers): key_cache, value_cache = self.caches[layer_idx] key_cache.data.fill(0.0) value_cache.data.fill(0.0) - + def get_memory_usage(self) -> Dict[str, float]: """ Calculate memory usage of the cache system. - + Returns: Dictionary with memory statistics in MB """ # Calculate size of one cache tensor cache_size = self.batch_size * self.num_heads * self.max_seq_len * self.head_dim bytes_per_float = 4 # float32 - + # Each layer has key_cache + value_cache total_cache_tensors = self.num_layers * 2 total_elements = cache_size * total_cache_tensors total_bytes = total_elements * bytes_per_float total_mb = total_bytes / (1024 * 1024) - + return { 'total_mb': total_mb, 'per_layer_mb': total_mb / self.num_layers, @@ -667,14 +667,14 @@ def enable_kv_cache(batch_size: int, max_seq_len: int, num_layers: int, This function creates a properly sized cache for the model architecture. Call this before starting generation, then pass the cache to your generation loop. - + Args: batch_size: Number of sequences to generate simultaneously max_seq_len: Maximum sequence length to support num_layers: Number of transformer layers in model num_heads: Number of attention heads per layer head_dim: Dimension per attention head (usually embed_dim // num_heads) - + Returns: KVCache instance ready for use @@ -958,28 +958,107 @@ def enable_kv_cache(model): """ Cached attention forward pass with REAL speedup! - Strategy: - - Training (seq_len > 1): Use original path (full gradients) - - Generation (seq_len = 1): Use cache for 10-15x speedup + PATH SELECTION STRATEGY (Key to Understanding KV Caching): + ────────────────────────────────────────────────────────── - Cache operations use .data (inference-only, no grad tracking). - Training path unchanged (full gradient flow preserved). + We have THREE possible paths through attention: + + 1️⃣ TRAINING PATH (seq_len > 1): + - Input: Full sequence of tokens (e.g., 64 tokens) + - Action: Use ORIGINAL attention (no caching) + - Why: Need full gradient flow for backpropagation + - Complexity: O(n²) but that's fine for training + - Example: x.shape = (batch=1, seq=64, embed=128) + + 2️⃣ FIRST TOKEN PATH (seq_len == 1 AND cache empty): + - Input: Single token (the first one in generation) + - Action: Use ORIGINAL attention (initialize cache) + - Why: Cache is empty, nothing to retrieve yet + - Complexity: O(1) - only one token + - Example: x.shape = (batch=1, seq=1, embed=128), cache.seq_pos=0 + + 3️⃣ CACHED GENERATION PATH (seq_len == 1 AND cache populated): + - Input: Single NEW token (during generation) + - Action: Compute K,V for new token ONLY, retrieve history from cache + - Why: This is where the speedup happens! O(n²) → O(n) + - Complexity: O(n) - only compute for new token, reuse cache + - Example: x.shape = (batch=1, seq=1, embed=128), cache.seq_pos=5 + + + WHY .data INSTEAD OF TENSOR OPERATIONS? + ──────────────────────────────────────── + + In the cached path, we use numpy via .data for three reasons: + + 1. **Explicit Intent**: Makes it crystal clear this is inference-only + - Training: Uses Tensor operations → gradients tracked + - Inference: Uses .data → no gradient overhead + + 2. **Performance**: Avoids any autograd bookkeeping + - Even if small, every bit counts in generation + - Production LLMs (vLLM, llama.cpp) use similar patterns + + 3. **Educational Clarity**: Shows students the distinction + - "When do I need gradients?" (training) + - "When can I skip them?" (inference) + + We COULD use Tensor operations with requires_grad=False, but .data + is more explicit and is the industry-standard pattern. + + + THE O(n²) → O(n) TRANSFORMATION: + ───────────────────────────────── + + WITHOUT Cache (Standard Attention): + Step 1: Process token 1 → Compute attention for 1 token (1² = 1 op) + Step 2: Process tokens 1-2 → Compute attention for 2 tokens (2² = 4 ops) + Step 3: Process tokens 1-3 → Compute attention for 3 tokens (3² = 9 ops) + ... + Step N: Process tokens 1-N → Compute attention for N tokens (N² ops) + + Total: 1 + 4 + 9 + ... + N² = O(N³) across all steps! + + WITH Cache (Our Implementation): + Step 1: Process token 1 → Compute K,V for token 1, cache it (1 op) + Step 2: Process token 2 → Compute K,V for token 2, retrieve 1 (2 ops) + Step 3: Process token 3 → Compute K,V for token 3, retrieve 1-2 (3 ops) + ... + Step N: Process token N → Compute K,V for token N, retrieve 1-(N-1) (N ops) + + Total: 1 + 2 + 3 + ... + N = O(N²) across all steps! + + That's why we see 5-7x speedup on short sequences, and 10-15x on longer ones! """ from tinytorch.core.tensor import Tensor import numpy as np seq_len = x.shape[1] - # TRAINING PATH: Full sequence, use original attention (preserves gradients) + # ═══════════════════════════════════════════════════════════════ + # PATH SELECTION: Choose between training, first token, or cached + # ═══════════════════════════════════════════════════════════════ + + # PATH 1: TRAINING (seq_len > 1) + # ─────────────────────────────────── + # Input is a full sequence (e.g., 64 tokens during training) + # We MUST use original attention to preserve gradient flow + # No caching during training - we need backprop through everything if seq_len > 1: - return original_forward(x, mask) + return original_forward(x, mask) # O(n²) but preserves gradients - # GENERATION PATH: Single token, use KV cache for speedup - # This is inference-only, so we use .data for performance - - # Check if cache is empty (first token) - if so, use original path + # PATH 2: FIRST TOKEN (seq_len == 1, cache empty) + # ──────────────────────────────────────────────── + # This is the very first token in generation (cache.seq_pos == 0) + # Cache is empty, so there's nothing to retrieve yet + # Use original attention to process this token, which will populate cache if cache_obj.seq_pos == 0: - return original_forward(x, mask) + return original_forward(x, mask) # O(1) - just one token + + # PATH 3: CACHED GENERATION (seq_len == 1, cache populated) + # ────────────────────────────────────────────────────────── + # This is a NEW token during generation (cache has history) + # We can now use the cache for massive speedup! + # Compute K,V for ONLY this new token, retrieve cached history # Get attention layer (assumes block.attention has the attention object) attention = block.attention @@ -1012,13 +1091,22 @@ def enable_kv_cache(model): K_all, V_all = cache_obj.get(layer_idx) # Step 5: Compute attention using new Q with ALL cached K, V + # ───────────────────────────────────────────────────────── # Scaled dot-product attention: softmax(Q @ K^T / sqrt(d_k)) @ V - # Use numpy operations directly for batched matmul + # + # NOTE: We use .data (numpy arrays) here instead of Tensor operations + # Why? This is INFERENCE-ONLY code (no gradients needed): + # - Explicit: Makes it clear this is inference, not training + # - Fast: Avoids autograd overhead (even if small) + # - Standard: Production LLMs (vLLM, llama.cpp) do the same + # + # If this were training, we'd use Tensor operations for gradient flow. + # But in generation (inference), .data is the right choice. # Q @ K^T: (batch, num_heads, 1, head_dim) @ (batch, num_heads, head_dim, seq_len) # → (batch, num_heads, 1, seq_len) - K_transposed = np.transpose(K_all.data, (0, 1, 3, 2)) - scores = np.matmul(Q_heads.data, K_transposed) + K_transposed = np.transpose(K_all.data, (0, 1, 3, 2)) # .data = numpy array + scores = np.matmul(Q_heads.data, K_transposed) # Pure numpy matmul # Scale by sqrt(head_dim) scores = scores / np.sqrt(head_dim) @@ -1157,7 +1245,7 @@ def test_unit_noninvasive_integration(): # Test 4: Can re-enable print(" Test 4: Re-enable caching") - cache2 = enable_kv_cache(model) + _ = enable_kv_cache(model) assert model._cache_enabled == True, "Cache should be re-enabled" print("✅ Non-invasive cache integration works correctly!") diff --git a/tinytorch/generation/kv_cache.py b/tinytorch/generation/kv_cache.py index 49dd8407..1cbc93cf 100644 --- a/tinytorch/generation/kv_cache.py +++ b/tinytorch/generation/kv_cache.py @@ -29,7 +29,7 @@ from ..core.tensor import Tensor class KVCache: """ Efficient key-value cache for autoregressive generation. - + Stores K,V matrices for each transformer layer to avoid recomputation during sequential token generation. This is THE critical optimization that makes production language model serving economically viable. @@ -56,13 +56,13 @@ class KVCache: ... Layer N: [Key_cache, Value_cache] ``` - + Performance: - Update: O(1) - just index assignment - Get: O(1) - just slicing (no data copy) - Memory: O(num_layers × batch × heads × max_seq × head_dim) """ - + def __init__(self, batch_size: int, max_seq_len: int, num_layers: int, num_heads: int, head_dim: int): """ @@ -118,7 +118,7 @@ class KVCache: self.caches.append((key_cache, value_cache)) ### END SOLUTION - + def update(self, layer_idx: int, key: Tensor, value: Tensor) -> None: """ Update cache with new key-value pairs for given layer. @@ -183,7 +183,7 @@ class KVCache: # Note: seq_pos is advanced externally via advance() after all layers process ### END SOLUTION - + def get(self, layer_idx: int) -> Tuple[Tensor, Tensor]: """ Retrieve cached key-value pairs for attention computation. @@ -249,48 +249,48 @@ class KVCache: return cached_keys, cached_values ### END SOLUTION - + def advance(self) -> None: """ Advance sequence position after processing current token. - + Call this after all layers have processed the current token and updated their caches. This moves the write pointer forward. """ self.seq_pos += 1 - + def reset(self) -> None: """ Reset cache for new generation sequence. - + Call this when starting a new generation (new prompt). Resets the sequence position counter and optionally zeros cache data. """ self.seq_pos = 0 - + # Zero out caches for clean state (helps with debugging) for layer_idx in range(self.num_layers): key_cache, value_cache = self.caches[layer_idx] key_cache.data.fill(0.0) value_cache.data.fill(0.0) - + def get_memory_usage(self) -> Dict[str, float]: """ Calculate memory usage of the cache system. - + Returns: Dictionary with memory statistics in MB """ # Calculate size of one cache tensor cache_size = self.batch_size * self.num_heads * self.max_seq_len * self.head_dim bytes_per_float = 4 # float32 - + # Each layer has key_cache + value_cache total_cache_tensors = self.num_layers * 2 total_elements = cache_size * total_cache_tensors total_bytes = total_elements * bytes_per_float total_mb = total_bytes / (1024 * 1024) - + return { 'total_mb': total_mb, 'per_layer_mb': total_mb / self.num_layers, @@ -307,14 +307,14 @@ def enable_kv_cache(batch_size: int, max_seq_len: int, num_layers: int, This function creates a properly sized cache for the model architecture. Call this before starting generation, then pass the cache to your generation loop. - + Args: batch_size: Number of sequences to generate simultaneously max_seq_len: Maximum sequence length to support num_layers: Number of transformer layers in model num_heads: Number of attention heads per layer head_dim: Dimension per attention head (usually embed_dim // num_heads) - + Returns: KVCache instance ready for use @@ -447,28 +447,107 @@ def enable_kv_cache(model): """ Cached attention forward pass with REAL speedup! - Strategy: - - Training (seq_len > 1): Use original path (full gradients) - - Generation (seq_len = 1): Use cache for 10-15x speedup + PATH SELECTION STRATEGY (Key to Understanding KV Caching): + ────────────────────────────────────────────────────────── - Cache operations use .data (inference-only, no grad tracking). - Training path unchanged (full gradient flow preserved). + We have THREE possible paths through attention: + + 1️⃣ TRAINING PATH (seq_len > 1): + - Input: Full sequence of tokens (e.g., 64 tokens) + - Action: Use ORIGINAL attention (no caching) + - Why: Need full gradient flow for backpropagation + - Complexity: O(n²) but that's fine for training + - Example: x.shape = (batch=1, seq=64, embed=128) + + 2️⃣ FIRST TOKEN PATH (seq_len == 1 AND cache empty): + - Input: Single token (the first one in generation) + - Action: Use ORIGINAL attention (initialize cache) + - Why: Cache is empty, nothing to retrieve yet + - Complexity: O(1) - only one token + - Example: x.shape = (batch=1, seq=1, embed=128), cache.seq_pos=0 + + 3️⃣ CACHED GENERATION PATH (seq_len == 1 AND cache populated): + - Input: Single NEW token (during generation) + - Action: Compute K,V for new token ONLY, retrieve history from cache + - Why: This is where the speedup happens! O(n²) → O(n) + - Complexity: O(n) - only compute for new token, reuse cache + - Example: x.shape = (batch=1, seq=1, embed=128), cache.seq_pos=5 + + + WHY .data INSTEAD OF TENSOR OPERATIONS? + ──────────────────────────────────────── + + In the cached path, we use numpy via .data for three reasons: + + 1. **Explicit Intent**: Makes it crystal clear this is inference-only + - Training: Uses Tensor operations → gradients tracked + - Inference: Uses .data → no gradient overhead + + 2. **Performance**: Avoids any autograd bookkeeping + - Even if small, every bit counts in generation + - Production LLMs (vLLM, llama.cpp) use similar patterns + + 3. **Educational Clarity**: Shows students the distinction + - "When do I need gradients?" (training) + - "When can I skip them?" (inference) + + We COULD use Tensor operations with requires_grad=False, but .data + is more explicit and is the industry-standard pattern. + + + THE O(n²) → O(n) TRANSFORMATION: + ───────────────────────────────── + + WITHOUT Cache (Standard Attention): + Step 1: Process token 1 → Compute attention for 1 token (1² = 1 op) + Step 2: Process tokens 1-2 → Compute attention for 2 tokens (2² = 4 ops) + Step 3: Process tokens 1-3 → Compute attention for 3 tokens (3² = 9 ops) + ... + Step N: Process tokens 1-N → Compute attention for N tokens (N² ops) + + Total: 1 + 4 + 9 + ... + N² = O(N³) across all steps! + + WITH Cache (Our Implementation): + Step 1: Process token 1 → Compute K,V for token 1, cache it (1 op) + Step 2: Process token 2 → Compute K,V for token 2, retrieve 1 (2 ops) + Step 3: Process token 3 → Compute K,V for token 3, retrieve 1-2 (3 ops) + ... + Step N: Process token N → Compute K,V for token N, retrieve 1-(N-1) (N ops) + + Total: 1 + 2 + 3 + ... + N = O(N²) across all steps! + + That's why we see 5-7x speedup on short sequences, and 10-15x on longer ones! """ from tinytorch.core.tensor import Tensor import numpy as np seq_len = x.shape[1] - # TRAINING PATH: Full sequence, use original attention (preserves gradients) + # ═══════════════════════════════════════════════════════════════ + # PATH SELECTION: Choose between training, first token, or cached + # ═══════════════════════════════════════════════════════════════ + + # PATH 1: TRAINING (seq_len > 1) + # ─────────────────────────────────── + # Input is a full sequence (e.g., 64 tokens during training) + # We MUST use original attention to preserve gradient flow + # No caching during training - we need backprop through everything if seq_len > 1: - return original_forward(x, mask) + return original_forward(x, mask) # O(n²) but preserves gradients - # GENERATION PATH: Single token, use KV cache for speedup - # This is inference-only, so we use .data for performance - - # Check if cache is empty (first token) - if so, use original path + # PATH 2: FIRST TOKEN (seq_len == 1, cache empty) + # ──────────────────────────────────────────────── + # This is the very first token in generation (cache.seq_pos == 0) + # Cache is empty, so there's nothing to retrieve yet + # Use original attention to process this token, which will populate cache if cache_obj.seq_pos == 0: - return original_forward(x, mask) + return original_forward(x, mask) # O(1) - just one token + + # PATH 3: CACHED GENERATION (seq_len == 1, cache populated) + # ────────────────────────────────────────────────────────── + # This is a NEW token during generation (cache has history) + # We can now use the cache for massive speedup! + # Compute K,V for ONLY this new token, retrieve cached history # Get attention layer (assumes block.attention has the attention object) attention = block.attention @@ -501,13 +580,22 @@ def enable_kv_cache(model): K_all, V_all = cache_obj.get(layer_idx) # Step 5: Compute attention using new Q with ALL cached K, V + # ───────────────────────────────────────────────────────── # Scaled dot-product attention: softmax(Q @ K^T / sqrt(d_k)) @ V - # Use numpy operations directly for batched matmul + # + # NOTE: We use .data (numpy arrays) here instead of Tensor operations + # Why? This is INFERENCE-ONLY code (no gradients needed): + # - Explicit: Makes it clear this is inference, not training + # - Fast: Avoids autograd overhead (even if small) + # - Standard: Production LLMs (vLLM, llama.cpp) do the same + # + # If this were training, we'd use Tensor operations for gradient flow. + # But in generation (inference), .data is the right choice. # Q @ K^T: (batch, num_heads, 1, head_dim) @ (batch, num_heads, head_dim, seq_len) # → (batch, num_heads, 1, seq_len) - K_transposed = np.transpose(K_all.data, (0, 1, 3, 2)) - scores = np.matmul(Q_heads.data, K_transposed) + K_transposed = np.transpose(K_all.data, (0, 1, 3, 2)) # .data = numpy array + scores = np.matmul(Q_heads.data, K_transposed) # Pure numpy matmul # Scale by sqrt(head_dim) scores = scores / np.sqrt(head_dim)