feat(memoization): Add profiling motivation section

- Shows O(n²) latency growth in transformer generation
- Demonstrates problem before teaching solution
- Prepares module for reorganization to Module 15
This commit is contained in:
Vijay Janapa Reddi
2025-11-09 09:16:08 -05:00
parent b52b762545
commit 976f0ed278

View File

@@ -67,6 +67,79 @@ from typing import Tuple, Optional, Dict, List
# Import TinyTorch components from previous modules
from tinytorch.core.tensor import Tensor
# %% [markdown]
"""
## 🔬 Motivation: Why Memoization Matters for Transformers
Before we learn KV caching, let's profile transformer generation to understand
the problem we're solving. We'll see O(n²) growth in latency as we generate text.
"""
# %%
# Profile transformer generation to discover the bottleneck
from tinytorch.profiling.profiler import Profiler
import matplotlib.pyplot as plt
profiler = Profiler()
def naive_attention_step(seq_len, hidden_dim=64):
"""
Simulates one step of attention computation.
Without caching, this processes ALL previous tokens every time.
"""
# Q, K, V for entire sequence
q = Tensor(np.random.randn(1, seq_len, hidden_dim))
k = Tensor(np.random.randn(1, seq_len, hidden_dim))
v = Tensor(np.random.randn(1, seq_len, hidden_dim))
# Attention: Q @ K.T then @ V
# This is O(seq_len²) in complexity
scores = q @ k.T # (1, seq_len, seq_len)
output = scores @ v
return output
# Profile at increasing sequence lengths
print("🔬 Profiling Transformer Generation (Without Caching):\n")
print(" Seq Len | Latency (ms) | Growth")
print(" ---------|----------------|----------")
sequence_lengths = [10, 20, 40, 80, 160]
latencies = []
for seq_len in sequence_lengths:
# Measure latency for this sequence length
latency = profiler.measure_latency(
lambda: naive_attention_step(seq_len),
None,
warmup=5,
iterations=20
)
latencies.append(latency)
# Calculate growth rate
if len(latencies) > 1:
growth = latencies[-1] / latencies[-2]
print(f" {seq_len:3d} | {latency:6.2f} | {growth:.2f}×")
else:
print(f" {seq_len:3d} | {latency:6.2f} | baseline")
print("\n💡 Key Observations:")
print(" • Latency grows QUADRATICALLY with sequence length")
print(" • Each new token forces recomputation of ALL previous K,V pairs")
print(" • For 160 tokens: ~4× time vs 80 tokens (2² growth)")
print("\n🎯 The Problem:")
print(" K and V values for previous tokens NEVER change,")
print(" yet we recompute them every single step!")
print("\n✨ The Solution:")
print(" CACHE the K,V values! (That's memoization)")
print(" • First compute: Calculate and store K,V")
print(" • Later steps: Reuse stored K,V")
print(" • Complexity: O(n²) → O(n)")
print(" • Speedup: 10-15× for typical generation\n")
# %% [markdown]
"""
## 🎯 Part 1: Understanding the Autoregressive Generation Problem