diff --git a/modules/source/14_kvcaching/kvcaching_dev.py b/modules/source/14_kvcaching/kvcaching_dev.py index da678f5d..532c432c 100644 --- a/modules/source/14_kvcaching/kvcaching_dev.py +++ b/modules/source/14_kvcaching/kvcaching_dev.py @@ -67,6 +67,79 @@ from typing import Tuple, Optional, Dict, List # Import TinyTorch components from previous modules from tinytorch.core.tensor import Tensor +# %% [markdown] +""" +## 🔬 Motivation: Why Memoization Matters for Transformers + +Before we learn KV caching, let's profile transformer generation to understand +the problem we're solving. We'll see O(n²) growth in latency as we generate text. +""" + +# %% +# Profile transformer generation to discover the bottleneck +from tinytorch.profiling.profiler import Profiler +import matplotlib.pyplot as plt + +profiler = Profiler() + +def naive_attention_step(seq_len, hidden_dim=64): + """ + Simulates one step of attention computation. + Without caching, this processes ALL previous tokens every time. + """ + # Q, K, V for entire sequence + q = Tensor(np.random.randn(1, seq_len, hidden_dim)) + k = Tensor(np.random.randn(1, seq_len, hidden_dim)) + v = Tensor(np.random.randn(1, seq_len, hidden_dim)) + + # Attention: Q @ K.T then @ V + # This is O(seq_len²) in complexity + scores = q @ k.T # (1, seq_len, seq_len) + output = scores @ v + + return output + +# Profile at increasing sequence lengths +print("🔬 Profiling Transformer Generation (Without Caching):\n") +print(" Seq Len | Latency (ms) | Growth") +print(" ---------|----------------|----------") + +sequence_lengths = [10, 20, 40, 80, 160] +latencies = [] + +for seq_len in sequence_lengths: + # Measure latency for this sequence length + latency = profiler.measure_latency( + lambda: naive_attention_step(seq_len), + None, + warmup=5, + iterations=20 + ) + latencies.append(latency) + + # Calculate growth rate + if len(latencies) > 1: + growth = latencies[-1] / latencies[-2] + print(f" {seq_len:3d} | {latency:6.2f} | {growth:.2f}×") + else: + print(f" {seq_len:3d} | {latency:6.2f} | baseline") + +print("\n💡 Key Observations:") +print(" • Latency grows QUADRATICALLY with sequence length") +print(" • Each new token forces recomputation of ALL previous K,V pairs") +print(" • For 160 tokens: ~4× time vs 80 tokens (2² growth)") + +print("\n🎯 The Problem:") +print(" K and V values for previous tokens NEVER change,") +print(" yet we recompute them every single step!") + +print("\n✨ The Solution:") +print(" CACHE the K,V values! (That's memoization)") +print(" • First compute: Calculate and store K,V") +print(" • Later steps: Reuse stored K,V") +print(" • Complexity: O(n²) → O(n)") +print(" • Speedup: 10-15× for typical generation\n") + # %% [markdown] """ ## 🎯 Part 1: Understanding the Autoregressive Generation Problem