fix(attention): correct O(n²) complexity explanation and memory table bug

- Clarify that attention time complexity is O(n²×d), not O(n²), since each
  of the n² query-key pairs requires a d-dimensional dot product
- Fix Total Memory column in analyze_attention_memory_overhead() which was
  duplicating the Optimizer column instead of summing all components
- Update KEY INSIGHT multiplier from 4x to 7x to match corrected total

Fixes harvard-edge/cs249r_book#1150
This commit is contained in:
Vijay Janapa Reddi
2026-02-04 08:37:32 -05:00
parent 1a80e57fa0
commit 20a4ba2379

View File

@@ -223,12 +223,12 @@ Output: (batch_size, seq_len, d_model) ← Weighted combination of values
### Why O(n²) Complexity?
For sequence length n, we compute:
1. **QK^T**: n queries × n keys = n² similarity scores
2. **Softmax**: n² weights to normalize
3. **Weights×V**: n² weights × n values = operations for aggregation
For sequence length n and embedding dimension d, we compute:
1. **QK^T**: n queries × n keys, each a d-dimensional dot product = O(n² × d) operations
2. **Softmax**: n² weights to normalize = O(n²) operations
3. **Weights×V**: n² weights applied to d-dimensional values = O(n² × d) operations
This quadratic scaling is attention's blessing (global connectivity) and curse (memory/compute limits).
The total **time complexity** is **O(n² × d)** per attention head. The **memory complexity** is **O(n²)** for storing the attention weight matrix. This quadratic scaling in sequence length is attention's blessing (global connectivity) and curse (memory/compute limits).
### The Attention Matrix Visualization
@@ -919,9 +919,12 @@ def analyze_attention_memory_overhead():
# Optimizer state (Adam: +2× for momentum and velocity)
optimizer_memory_mb = backward_memory_mb + 2 * attention_matrix_mb
print(f"{seq_len:7d} | {attention_matrix_mb:6.2f}MB | {backward_memory_mb:10.2f}MB | {optimizer_memory_mb:10.2f}MB | {optimizer_memory_mb:11.2f}MB")
# Total = forward + gradients + optimizer state
total_memory_mb = attention_matrix_mb + backward_memory_mb + optimizer_memory_mb
print(f"\n💡 KEY INSIGHT: Training requires 4x memory of inference")
print(f"{seq_len:7d} | {attention_matrix_mb:6.2f}MB | {backward_memory_mb:10.2f}MB | {optimizer_memory_mb:10.2f}MB | {total_memory_mb:11.2f}MB")
print(f"\n💡 KEY INSIGHT: Training requires ~7x memory of inference (1x forward + 2x gradients + 4x optimizer state)")
print(f"🚀 For GPT-3 (96 layers, 2048 context): ~6GB just for attention gradients!")
# Run the analysis