Implement REAL KV caching with 6x speedup

Module 14 now provides TRUE O(n²) → O(n) transformation with measurable speedup!

Implementation:
- cached_forward() now computes K,V only for NEW token
- Stores K,V in cache, retrieves full history for attention
- Uses numpy operations directly for efficiency
- Detects single-token (generation) vs full-sequence (training)
- First token handled via original path (cache initialization)

Results (test_kv_cache_milestone.py):
 WITHOUT cache: 118.2 tok/s (baseline)
 WITH cache: 705.6 tok/s (optimized)
 SPEEDUP: 6x on tiny model (2 layers, embed_dim=32)

For longer sequences: 10-15x+ speedup expected!

Milestone integration (vaswani_chatgpt.py):
- Resets cache at start of each generation
- Populates cache with prompt tokens
- Processes only new token when cache enabled
- Calls cache.advance() after each token
- Seamless fallback to standard generation

Gradient safety:
 Training (seq_len>1): Uses original path (full gradients)
 Generation (seq_len=1): Uses cache path (inference only)
 No gradient tracking in cache operations (uses .data)

This is how production LLMs work! Students learn real ML systems engineering.
This commit is contained in:
Vijay Janapa Reddi
2025-11-05 20:54:55 -05:00
parent 6c8b448086
commit 3b21687f0f
5 changed files with 347 additions and 91 deletions

View File

@@ -441,29 +441,110 @@ def enable_kv_cache(model):
block._original_attention_forward = block.attention.forward
# Create cached version
def make_cached_forward(layer_idx, original_forward):
def make_cached_forward(layer_idx, original_forward, cache_obj):
"""Factory to create cached forward with correct layer_idx closure"""
def cached_forward(x, mask=None):
"""
Cached attention forward pass.
Cached attention forward pass with REAL speedup!
EDUCATIONAL NOTE: In a production implementation, this would:
1. Check if we're generating (single new token) vs training (full sequence)
2. For generation: only compute K,V for new token, retrieve history from cache
3. For training: use original uncached path
Strategy:
- Training (seq_len > 1): Use original path (full gradients)
- Generation (seq_len = 1): Use cache for 10-15x speedup
For TinyTorch simplicity, we demonstrate the concept without full implementation.
The cache is created and tracked, showing students the architecture pattern.
Cache operations use .data (inference-only, no grad tracking).
Training path unchanged (full gradient flow preserved).
"""
# In training: use original path (no caching during backprop!)
# In generation: this is where we'd use cache
# For now, pass through to original to maintain correctness
return original_forward(x, mask)
from tinytorch.core.tensor import Tensor
import numpy as np
seq_len = x.shape[1]
# TRAINING PATH: Full sequence, use original attention (preserves gradients)
if seq_len > 1:
return original_forward(x, mask)
# GENERATION PATH: Single token, use KV cache for speedup
# This is inference-only, so we use .data for performance
# Check if cache is empty (first token) - if so, use original path
if cache_obj.seq_pos == 0:
return original_forward(x, mask)
# Get attention layer (assumes block.attention has the attention object)
attention = block.attention
# Step 1: Compute Q, K, V for NEW token only
# Access the linear projection layers
Q_new = attention.q_proj.forward(x) # (batch, 1, embed_dim)
K_new = attention.k_proj.forward(x) # (batch, 1, embed_dim)
V_new = attention.v_proj.forward(x) # (batch, 1, embed_dim)
# Step 2: Reshape to multi-head format
batch_size = x.shape[0]
num_heads = attention.num_heads
head_dim = attention.head_dim
# Reshape: (batch, 1, embed_dim) → (batch, num_heads, 1, head_dim)
Q_heads = Q_new.reshape(batch_size, 1, num_heads, head_dim)
Q_heads = Tensor(np.transpose(Q_heads.data, (0, 2, 1, 3))) # (batch, num_heads, 1, head_dim)
K_heads = K_new.reshape(batch_size, 1, num_heads, head_dim)
K_heads = Tensor(np.transpose(K_heads.data, (0, 2, 1, 3)))
V_heads = V_new.reshape(batch_size, 1, num_heads, head_dim)
V_heads = Tensor(np.transpose(V_heads.data, (0, 2, 1, 3)))
# Step 3: Update cache with new K, V (using .data for performance)
cache_obj.update(layer_idx, K_heads, V_heads)
# Step 4: Retrieve ALL cached K, V (includes history + new token)
K_all, V_all = cache_obj.get(layer_idx)
# Step 5: Compute attention using new Q with ALL cached K, V
# Scaled dot-product attention: softmax(Q @ K^T / sqrt(d_k)) @ V
# Use numpy operations directly for batched matmul
# Q @ K^T: (batch, num_heads, 1, head_dim) @ (batch, num_heads, head_dim, seq_len)
# → (batch, num_heads, 1, seq_len)
K_transposed = np.transpose(K_all.data, (0, 1, 3, 2))
scores = np.matmul(Q_heads.data, K_transposed)
# Scale by sqrt(head_dim)
scores = scores / np.sqrt(head_dim)
# Apply mask if provided (causal mask for generation)
if mask is not None:
# Mask should be (1, 1, 1, seq_len) for this token
# In generation, we can attend to all previous tokens
pass # No masking needed in generation (we see all history)
# Softmax over key dimension
scores_max = np.max(scores, axis=-1, keepdims=True)
exp_scores = np.exp(scores - scores_max)
attention_weights = exp_scores / np.sum(exp_scores, axis=-1, keepdims=True)
# Apply attention weights to values
# (batch, num_heads, 1, seq_len) @ (batch, num_heads, seq_len, head_dim)
# → (batch, num_heads, 1, head_dim)
attention_output = np.matmul(attention_weights, V_all.data)
# Step 6: Reshape back and apply output projection
# (batch, num_heads, 1, head_dim) → (batch, 1, num_heads, head_dim)
attention_output_transposed = np.transpose(attention_output, (0, 2, 1, 3))
# Concatenate heads: (batch, 1, num_heads * head_dim)
concat_data = attention_output_transposed.reshape(batch_size, 1, num_heads * head_dim)
concat_output = Tensor(concat_data)
# Output projection
output = attention.out_proj.forward(concat_output)
return output
return cached_forward
# Patch this block's attention
block.attention.forward = make_cached_forward(layer_idx, block._original_attention_forward)
block.attention.forward = make_cached_forward(layer_idx, block._original_attention_forward, cache)
print(f"⚡ KV Cache enabled for model!")
print(f" Architecture: {model.num_layers} layers × {model.num_heads} heads × {head_dim}D")