mirror of
https://github.com/harvard-edge/cs249r_book.git
synced 2026-04-29 00:59:07 -05:00
fix(attention): correct O(n²) complexity explanation and memory table bug
- Clarify that attention time complexity is O(n²×d), not O(n²), since each of the n² query-key pairs requires a d-dimensional dot product - Fix Total Memory column in analyze_attention_memory_overhead() which was duplicating the Optimizer column instead of summing all components - Update KEY INSIGHT multiplier from 4x to 7x to match corrected total Fixes harvard-edge/cs249r_book#1150
This commit is contained in:
@@ -223,12 +223,12 @@ Output: (batch_size, seq_len, d_model) ← Weighted combination of values
|
||||
|
||||
### Why O(n²) Complexity?
|
||||
|
||||
For sequence length n, we compute:
|
||||
1. **QK^T**: n queries × n keys = n² similarity scores
|
||||
2. **Softmax**: n² weights to normalize
|
||||
3. **Weights×V**: n² weights × n values = n² operations for aggregation
|
||||
For sequence length n and embedding dimension d, we compute:
|
||||
1. **QK^T**: n queries × n keys, each a d-dimensional dot product = O(n² × d) operations
|
||||
2. **Softmax**: n² weights to normalize = O(n²) operations
|
||||
3. **Weights×V**: n² weights applied to d-dimensional values = O(n² × d) operations
|
||||
|
||||
This quadratic scaling is attention's blessing (global connectivity) and curse (memory/compute limits).
|
||||
The total **time complexity** is **O(n² × d)** per attention head. The **memory complexity** is **O(n²)** for storing the attention weight matrix. This quadratic scaling in sequence length is attention's blessing (global connectivity) and curse (memory/compute limits).
|
||||
|
||||
### The Attention Matrix Visualization
|
||||
|
||||
@@ -919,9 +919,12 @@ def analyze_attention_memory_overhead():
|
||||
# Optimizer state (Adam: +2× for momentum and velocity)
|
||||
optimizer_memory_mb = backward_memory_mb + 2 * attention_matrix_mb
|
||||
|
||||
print(f"{seq_len:7d} | {attention_matrix_mb:6.2f}MB | {backward_memory_mb:10.2f}MB | {optimizer_memory_mb:10.2f}MB | {optimizer_memory_mb:11.2f}MB")
|
||||
# Total = forward + gradients + optimizer state
|
||||
total_memory_mb = attention_matrix_mb + backward_memory_mb + optimizer_memory_mb
|
||||
|
||||
print(f"\n💡 KEY INSIGHT: Training requires 4x memory of inference")
|
||||
print(f"{seq_len:7d} | {attention_matrix_mb:6.2f}MB | {backward_memory_mb:10.2f}MB | {optimizer_memory_mb:10.2f}MB | {total_memory_mb:11.2f}MB")
|
||||
|
||||
print(f"\n💡 KEY INSIGHT: Training requires ~7x memory of inference (1x forward + 2x gradients + 4x optimizer state)")
|
||||
print(f"🚀 For GPT-3 (96 layers, 2048 context): ~6GB just for attention gradients!")
|
||||
|
||||
# Run the analysis
|
||||
|
||||
Reference in New Issue
Block a user