mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-03-12 00:03:35 -05:00
feat(memoization): Add profiling motivation section
- Shows O(n²) latency growth in transformer generation - Demonstrates problem before teaching solution - Prepares module for reorganization to Module 15
This commit is contained in:
@@ -67,6 +67,79 @@ from typing import Tuple, Optional, Dict, List
|
||||
# Import TinyTorch components from previous modules
|
||||
from tinytorch.core.tensor import Tensor
|
||||
|
||||
# %% [markdown]
|
||||
"""
|
||||
## 🔬 Motivation: Why Memoization Matters for Transformers
|
||||
|
||||
Before we learn KV caching, let's profile transformer generation to understand
|
||||
the problem we're solving. We'll see O(n²) growth in latency as we generate text.
|
||||
"""
|
||||
|
||||
# %%
|
||||
# Profile transformer generation to discover the bottleneck
|
||||
from tinytorch.profiling.profiler import Profiler
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
profiler = Profiler()
|
||||
|
||||
def naive_attention_step(seq_len, hidden_dim=64):
|
||||
"""
|
||||
Simulates one step of attention computation.
|
||||
Without caching, this processes ALL previous tokens every time.
|
||||
"""
|
||||
# Q, K, V for entire sequence
|
||||
q = Tensor(np.random.randn(1, seq_len, hidden_dim))
|
||||
k = Tensor(np.random.randn(1, seq_len, hidden_dim))
|
||||
v = Tensor(np.random.randn(1, seq_len, hidden_dim))
|
||||
|
||||
# Attention: Q @ K.T then @ V
|
||||
# This is O(seq_len²) in complexity
|
||||
scores = q @ k.T # (1, seq_len, seq_len)
|
||||
output = scores @ v
|
||||
|
||||
return output
|
||||
|
||||
# Profile at increasing sequence lengths
|
||||
print("🔬 Profiling Transformer Generation (Without Caching):\n")
|
||||
print(" Seq Len | Latency (ms) | Growth")
|
||||
print(" ---------|----------------|----------")
|
||||
|
||||
sequence_lengths = [10, 20, 40, 80, 160]
|
||||
latencies = []
|
||||
|
||||
for seq_len in sequence_lengths:
|
||||
# Measure latency for this sequence length
|
||||
latency = profiler.measure_latency(
|
||||
lambda: naive_attention_step(seq_len),
|
||||
None,
|
||||
warmup=5,
|
||||
iterations=20
|
||||
)
|
||||
latencies.append(latency)
|
||||
|
||||
# Calculate growth rate
|
||||
if len(latencies) > 1:
|
||||
growth = latencies[-1] / latencies[-2]
|
||||
print(f" {seq_len:3d} | {latency:6.2f} | {growth:.2f}×")
|
||||
else:
|
||||
print(f" {seq_len:3d} | {latency:6.2f} | baseline")
|
||||
|
||||
print("\n💡 Key Observations:")
|
||||
print(" • Latency grows QUADRATICALLY with sequence length")
|
||||
print(" • Each new token forces recomputation of ALL previous K,V pairs")
|
||||
print(" • For 160 tokens: ~4× time vs 80 tokens (2² growth)")
|
||||
|
||||
print("\n🎯 The Problem:")
|
||||
print(" K and V values for previous tokens NEVER change,")
|
||||
print(" yet we recompute them every single step!")
|
||||
|
||||
print("\n✨ The Solution:")
|
||||
print(" CACHE the K,V values! (That's memoization)")
|
||||
print(" • First compute: Calculate and store K,V")
|
||||
print(" • Later steps: Reuse stored K,V")
|
||||
print(" • Complexity: O(n²) → O(n)")
|
||||
print(" • Speedup: 10-15× for typical generation\n")
|
||||
|
||||
# %% [markdown]
|
||||
"""
|
||||
## 🎯 Part 1: Understanding the Autoregressive Generation Problem
|
||||
|
||||
Reference in New Issue
Block a user