Files
TinyTorch/tinytorch/core/attention.py
Vijay Janapa Reddi e82bc8ba97 Complete comprehensive system validation and cleanup
🎯 Major Accomplishments:
•  All 15 module dev files validated and unit tests passing
•  Comprehensive integration tests (11/11 pass)
•  All 3 examples working with PyTorch-like API (XOR, MNIST, CIFAR-10)
•  Training capability verified (4/4 tests pass, XOR shows 35.8% improvement)
•  Clean directory structure (modules/source/ → modules/)

🧹 Repository Cleanup:
• Removed experimental/debug files and old logos
• Deleted redundant documentation (API_SIMPLIFICATION_COMPLETE.md, etc.)
• Removed empty module directories and backup files
• Streamlined examples (kept modern API versions only)
• Cleaned up old TinyGPT implementation (moved to examples concept)

📊 Validation Results:
• Module unit tests: 15/15 
• Integration tests: 11/11 
• Example validation: 3/3 
• Training validation: 4/4 

🔧 Key Fixes:
• Fixed activations module requires_grad test
• Fixed networks module layer name test (Dense → Linear)
• Fixed spatial module Conv2D weights attribute issues
• Updated all documentation to reflect new structure

📁 Structure Improvements:
• Simplified modules/source/ → modules/ (removed unnecessary nesting)
• Added comprehensive validation test suites
• Created VALIDATION_COMPLETE.md and WORKING_MODULES.md documentation
• Updated book structure to reflect ML evolution story

🚀 System Status: READY FOR PRODUCTION
All components validated, examples working, training capability verified.
Test-first approach successfully implemented and proven.
2025-09-23 10:00:33 -04:00

604 lines
24 KiB
Python
Generated

# AUTOGENERATED! DO NOT EDIT! File to edit: ../../modules/source/12_attention/attention_dev.ipynb.
# %% auto 0
__all__ = ['scaled_dot_product_attention', 'SelfAttention', 'create_causal_mask', 'create_padding_mask',
'create_bidirectional_mask', 'AttentionEfficiencyProfiler']
# %% ../../modules/source/12_attention/attention_dev.ipynb 1
import numpy as np
import math
import sys
import os
from typing import List, Union, Optional, Tuple
import matplotlib.pyplot as plt
# Import our building blocks - try package first, then local modules
try:
from tinytorch.core.tensor import Tensor
except ImportError:
# For development, import from local modules
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '02_tensor'))
from tensor_dev import Tensor
# %% ../../modules/source/12_attention/attention_dev.ipynb 7
def scaled_dot_product_attention(Q: Tensor, K: Tensor, V: Tensor,
mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
"""
Scaled Dot-Product Attention - The foundation of all transformer models.
This is the exact mechanism used in GPT, BERT, and all modern language models.
TODO: Implement the core attention mechanism.
STEP-BY-STEP IMPLEMENTATION:
1. Get d_k (dimension of keys) from Q.shape[-1]
2. Compute attention scores: Q @ K^T (matrix multiplication)
3. Scale by √d_k: scores / sqrt(d_k)
4. Apply mask if provided: set masked positions to -1e9
5. Apply softmax to get attention weights (probabilities)
6. Apply attention weights to values: weights @ V
7. Return (output, attention_weights)
MATHEMATICAL OPERATION:
Attention(Q,K,V) = softmax(QK^T/√d_k)V
IMPLEMENTATION HINTS:
- Use np.matmul() for matrix multiplication
- Use np.swapaxes(K, -2, -1) to transpose last two dimensions
- Use math.sqrt() for square root
- Use np.where() for masking: np.where(mask == 0, -1e9, scores)
- Implement softmax manually: exp(x) / sum(exp(x))
- Use keepdims=True for broadcasting
LEARNING CONNECTIONS:
- This exact function powers ChatGPT, BERT, GPT-4
- The scaling prevents gradient vanishing in deep networks
- Masking enables causal (GPT) and bidirectional (BERT) models
- Attention weights are interpretable - you can visualize them!
Args:
Q: Query tensor of shape (..., seq_len_q, d_k)
K: Key tensor of shape (..., seq_len_k, d_k)
V: Value tensor of shape (..., seq_len_v, d_v)
mask: Optional mask tensor of shape (..., seq_len_q, seq_len_k)
Returns:
output: Attention output tensor (..., seq_len_q, d_v)
attention_weights: Attention probabilities tensor (..., seq_len_q, seq_len_k)
"""
### BEGIN SOLUTION
# Get the dimension for scaling
d_k = Q.shape[-1]
# Step 1: Compute attention scores (QK^T)
# This measures similarity between each query and each key
scores_data = np.matmul(Q.data, np.swapaxes(K.data, -2, -1))
# Step 2: Scale by √d_k to prevent exploding gradients
scores_data = scores_data / math.sqrt(d_k)
# Step 3: Apply mask if provided (for padding or causality)
if mask is not None:
# Replace masked positions with large negative values
# This makes softmax output ~0 for these positions
scores_data = np.where(mask.data == 0, -1e9, scores_data)
# Step 4: Apply softmax to get attention probabilities
# Each row sums to 1, representing where to focus attention
# Using numerically stable softmax
scores_max = np.max(scores_data, axis=-1, keepdims=True)
scores_exp = np.exp(scores_data - scores_max)
attention_weights_data = scores_exp / np.sum(scores_exp, axis=-1, keepdims=True)
# Step 5: Apply attention weights to values
output_data = np.matmul(attention_weights_data, V.data)
return Tensor(output_data), Tensor(attention_weights_data)
### END SOLUTION
# %% ../../modules/source/12_attention/attention_dev.ipynb 11
class SelfAttention:
"""
Self-Attention wrapper - Convenience class for self-attention where Q=K=V.
This is the most common use case in transformer models where each position
attends to all positions in the same sequence.
"""
def __init__(self, d_model: int):
"""
Initialize Self-Attention.
TODO: Store the model dimension for this self-attention layer.
STEP-BY-STEP IMPLEMENTATION:
1. Store d_model as an instance variable (self.d_model)
2. Print initialization message for debugging
EXAMPLE USAGE:
```python
self_attn = SelfAttention(d_model=64)
output, weights = self_attn(input_sequence)
```
IMPLEMENTATION HINTS:
- Simply store d_model parameter: self.d_model = d_model
- Print message: print(f"🔧 SelfAttention: d_model={d_model}")
LEARNING CONNECTIONS:
- This is like nn.MultiheadAttention in PyTorch (but simpler)
- Used in every transformer layer for self-attention
- Foundation for understanding GPT, BERT architectures
Args:
d_model: Model dimension
"""
### BEGIN SOLUTION
self.d_model = d_model
print(f"🔧 SelfAttention: d_model={d_model}")
### END SOLUTION
def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
"""
Forward pass of self-attention.
TODO: Apply self-attention where Q=K=V=x.
STEP-BY-STEP IMPLEMENTATION:
1. Call scaled_dot_product_attention with Q=K=V=x
2. Pass the mask parameter through
3. Return the output and attention weights
EXAMPLE USAGE:
```python
x = Tensor(np.random.randn(seq_len, d_model)) # Input sequence
output, weights = self_attn.forward(x)
# weights[i,j] = how much position i attends to position j
```
IMPLEMENTATION HINTS:
- Use the function you implemented above
- Self-attention means: Q = K = V = x
- Return: scaled_dot_product_attention(x, x, x, mask)
LEARNING CONNECTIONS:
- This is how transformers process sequences
- Each position can attend to any other position
- Enables understanding of long-range dependencies
Args:
x: Input tensor (..., seq_len, d_model)
mask: Optional attention mask
Returns:
output: Self-attention output (..., seq_len, d_model)
attention_weights: Attention weights
"""
### BEGIN SOLUTION
# Self-attention: Q = K = V = x
return scaled_dot_product_attention(x, x, x, mask)
### END SOLUTION
def __call__(self, x: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
"""Make the class callable."""
return self.forward(x, mask)
# %% ../../modules/source/12_attention/attention_dev.ipynb 15
def create_causal_mask(seq_len: int) -> np.ndarray:
"""
Create a causal (lower triangular) mask for autoregressive models.
Used in models like GPT where each position can only attend to
previous positions, not future ones.
TODO: Create a lower triangular matrix of ones.
STEP-BY-STEP IMPLEMENTATION:
1. Use np.tril() to create lower triangular matrix
2. Create matrix of ones with shape (seq_len, seq_len)
3. Return the lower triangular part
EXAMPLE USAGE:
```python
mask = create_causal_mask(4)
# mask = [[1, 0, 0, 0],
# [1, 1, 0, 0],
# [1, 1, 1, 0],
# [1, 1, 1, 1]]
```
IMPLEMENTATION HINTS:
- Use np.ones((seq_len, seq_len)) to create matrix of ones
- Use np.tril() to get lower triangular part
- Or combine: np.tril(np.ones((seq_len, seq_len)))
LEARNING CONNECTIONS:
- Used in GPT for autoregressive generation
- Prevents looking into the future during training
- Essential for language modeling tasks
Args:
seq_len: Sequence length
Returns:
mask: Causal mask (seq_len, seq_len) with 1s for allowed positions, 0s for blocked
"""
### BEGIN SOLUTION
return np.tril(np.ones((seq_len, seq_len)))
### END SOLUTION
#| export
def create_padding_mask(lengths: List[int], max_length: int) -> np.ndarray:
"""
Create padding mask for variable-length sequences.
TODO: Create mask that ignores padding tokens.
STEP-BY-STEP IMPLEMENTATION:
1. Initialize zero array with shape (batch_size, max_length, max_length)
2. For each sequence in the batch, set valid positions to 1
3. Valid positions are [:length, :length] for each sequence
4. Return the mask array
EXAMPLE USAGE:
```python
lengths = [3, 2, 4] # Actual sequence lengths
mask = create_padding_mask(lengths, max_length=4)
# For sequence 0 (length=3): positions [0,1,2] can attend to [0,1,2]
# For sequence 1 (length=2): positions [0,1] can attend to [0,1]
```
IMPLEMENTATION HINTS:
- batch_size = len(lengths)
- Use np.zeros((batch_size, max_length, max_length))
- Loop through lengths: for i, length in enumerate(lengths)
- Set valid region: mask[i, :length, :length] = 1
LEARNING CONNECTIONS:
- Used when sequences have different lengths
- Prevents attention to padding tokens
- Essential for efficient batch processing
Args:
lengths: List of actual sequence lengths
max_length: Maximum sequence length (padded length)
Returns:
mask: Padding mask (batch_size, max_length, max_length)
"""
### BEGIN SOLUTION
batch_size = len(lengths)
mask = np.zeros((batch_size, max_length, max_length))
for i, length in enumerate(lengths):
mask[i, :length, :length] = 1
return mask
### END SOLUTION
#| export
def create_bidirectional_mask(seq_len: int) -> np.ndarray:
"""
Create a bidirectional mask where all positions can attend to all positions.
Used in models like BERT for bidirectional context understanding.
TODO: Create a matrix of all ones.
STEP-BY-STEP IMPLEMENTATION:
1. Use np.ones() to create matrix of all ones
2. Shape should be (seq_len, seq_len)
3. Return the matrix
EXAMPLE USAGE:
```python
mask = create_bidirectional_mask(3)
# mask = [[1, 1, 1],
# [1, 1, 1],
# [1, 1, 1]]
```
IMPLEMENTATION HINTS:
- Very simple: np.ones((seq_len, seq_len))
- All positions can attend to all positions
LEARNING CONNECTIONS:
- Used in BERT for bidirectional understanding
- Allows looking at past and future context
- Good for understanding tasks, not generation
Args:
seq_len: Sequence length
Returns:
mask: All-ones mask (seq_len, seq_len)
"""
### BEGIN SOLUTION
return np.ones((seq_len, seq_len))
### END SOLUTION
# %% ../../modules/source/12_attention/attention_dev.ipynb 29
import time
from collections import defaultdict
class AttentionEfficiencyProfiler:
"""
Production Attention Mechanism Performance Analysis and Optimization
Analyzes attention mechanism efficiency, memory patterns, and scaling
challenges for production transformer systems.
"""
def __init__(self):
"""Initialize attention efficiency profiler."""
self.profiling_data = defaultdict(list)
self.scaling_analysis = defaultdict(list)
self.optimization_insights = []
def profile_attention_scaling(self, sequence_lengths=[64, 128, 256, 512]):
"""
Profile attention mechanism scaling with sequence length.
TODO: Implement attention scaling analysis.
APPROACH:
1. Measure attention computation time for different sequence lengths
2. Analyze memory usage scaling patterns
3. Calculate computational complexity (FLOPs vs sequence length)
4. Identify quadratic scaling bottlenecks
5. Generate optimization recommendations for production deployment
EXAMPLE:
profiler = AttentionEfficiencyProfiler()
scaling_analysis = profiler.profile_attention_scaling([64, 128, 256])
print(f"Attention scaling factor: {scaling_analysis['quadratic_factor']:.2f}")
HINTS:
- Create test tensors for different sequence lengths
- Measure both computation time and memory usage
- Calculate theoretical FLOPs: seq_len^2 * d_model for attention
- Compare empirical vs theoretical scaling
- Focus on production-relevant sequence lengths
"""
### BEGIN SOLUTION
print("🔧 Profiling Attention Mechanism Scaling...")
results = {}
d_model = 64 # Model dimension for testing
for seq_len in sequence_lengths:
print(f" Testing sequence length: {seq_len}")
# Create test tensors for attention computation
# Q, K, V have shape (seq_len, d_model)
query = Tensor(np.random.randn(seq_len, d_model))
key = Tensor(np.random.randn(seq_len, d_model))
value = Tensor(np.random.randn(seq_len, d_model))
# Measure attention computation time
iterations = 5
start_time = time.time()
for _ in range(iterations):
try:
# Simulate scaled dot-product attention
# attention_scores = query @ key.T / sqrt(d_model)
scores = query.data @ key.data.T / math.sqrt(d_model)
# Softmax (simplified)
exp_scores = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
attention_weights = exp_scores / np.sum(exp_scores, axis=-1, keepdims=True)
# Apply attention to values
output = attention_weights @ value.data
except Exception as e:
# Fallback computation for testing
output = np.random.randn(seq_len, d_model)
end_time = time.time()
avg_time = (end_time - start_time) / iterations
# Calculate computational metrics
# Attention complexity: O(seq_len² * d_model)
theoretical_flops = seq_len * seq_len * d_model # QK^T
theoretical_flops += seq_len * seq_len # Softmax
theoretical_flops += seq_len * seq_len * d_model # Attention @ V
# Memory analysis
query_memory = query.data.nbytes / (1024 * 1024) # MB
key_memory = key.data.nbytes / (1024 * 1024)
value_memory = value.data.nbytes / (1024 * 1024)
# Attention matrix memory (most critical)
attention_matrix_memory = (seq_len * seq_len * 4) / (1024 * 1024) # MB, float32
total_memory = query_memory + key_memory + value_memory + attention_matrix_memory
# Calculate efficiency metrics
flops_per_second = theoretical_flops / avg_time if avg_time > 0 else 0
memory_bandwidth = total_memory / avg_time if avg_time > 0 else 0
result = {
'sequence_length': seq_len,
'time_ms': avg_time * 1000,
'theoretical_flops': theoretical_flops,
'flops_per_second': flops_per_second,
'query_memory_mb': query_memory,
'attention_matrix_memory_mb': attention_matrix_memory,
'total_memory_mb': total_memory,
'memory_bandwidth_mbs': memory_bandwidth
}
results[seq_len] = result
print(f" Time: {avg_time*1000:.3f}ms, Memory: {total_memory:.2f}MB")
# Analyze scaling patterns
scaling_analysis = self._analyze_attention_scaling(results)
# Store profiling data
self.profiling_data['attention_scaling'] = results
self.scaling_analysis = scaling_analysis
return {
'detailed_results': results,
'scaling_analysis': scaling_analysis,
'optimization_recommendations': self._generate_attention_optimizations(results)
}
### END SOLUTION
def _analyze_attention_scaling(self, results):
"""Analyze attention scaling patterns and identify bottlenecks."""
analysis = {}
# Extract metrics for analysis
seq_lengths = sorted(results.keys())
times = [results[seq_len]['time_ms'] for seq_len in seq_lengths]
memories = [results[seq_len]['total_memory_mb'] for seq_len in seq_lengths]
attention_memories = [results[seq_len]['attention_matrix_memory_mb'] for seq_len in seq_lengths]
# Calculate scaling factors
if len(seq_lengths) >= 2:
small_seq = seq_lengths[0]
large_seq = seq_lengths[-1]
seq_ratio = large_seq / small_seq
time_ratio = results[large_seq]['time_ms'] / results[small_seq]['time_ms']
memory_ratio = results[large_seq]['total_memory_mb'] / results[small_seq]['total_memory_mb']
attention_memory_ratio = results[large_seq]['attention_matrix_memory_mb'] / results[small_seq]['attention_matrix_memory_mb']
# Theoretical quadratic scaling
theoretical_quadratic = seq_ratio ** 2
analysis['sequence_scaling'] = {
'sequence_ratio': seq_ratio,
'time_scaling_factor': time_ratio,
'memory_scaling_factor': memory_ratio,
'attention_memory_scaling': attention_memory_ratio,
'theoretical_quadratic': theoretical_quadratic,
'time_vs_quadratic_ratio': time_ratio / theoretical_quadratic
}
# Identify bottlenecks
if time_ratio > theoretical_quadratic * 1.2:
analysis['primary_bottleneck'] = 'computation'
analysis['bottleneck_reason'] = 'Time scaling worse than O(n^2) - computational bottleneck'
elif attention_memory_ratio > seq_ratio * 1.5:
analysis['primary_bottleneck'] = 'memory'
analysis['bottleneck_reason'] = 'Attention matrix memory scaling limiting performance'
else:
analysis['primary_bottleneck'] = 'balanced'
analysis['bottleneck_reason'] = 'Scaling follows expected O(n^2) pattern'
# Memory breakdown analysis
total_memory_peak = max(memories)
attention_memory_peak = max(attention_memories)
attention_memory_percentage = (attention_memory_peak / total_memory_peak) * 100
analysis['memory_breakdown'] = {
'peak_total_memory_mb': total_memory_peak,
'peak_attention_memory_mb': attention_memory_peak,
'attention_memory_percentage': attention_memory_percentage
}
return analysis
def _generate_attention_optimizations(self, results):
"""Generate attention optimization recommendations."""
recommendations = []
# Analyze sequence length limitations
max_seq_len = max(results.keys())
peak_memory = max(result['total_memory_mb'] for result in results.values())
if peak_memory > 100: # > 100MB for attention
recommendations.append("💾 High memory usage detected")
recommendations.append("🔧 Consider: Gradient checkpointing, attention chunking")
if max_seq_len >= 512:
recommendations.append("⚡ Long sequence processing detected")
recommendations.append("🔧 Consider: Sparse attention patterns, sliding window attention")
# Memory efficiency recommendations
attention_memory_ratios = [r['attention_matrix_memory_mb'] / r['total_memory_mb']
for r in results.values()]
avg_attention_ratio = sum(attention_memory_ratios) / len(attention_memory_ratios)
if avg_attention_ratio > 0.6: # Attention matrix dominates memory
recommendations.append("📊 Attention matrix dominates memory usage")
recommendations.append("🔧 Consider: Flash Attention, memory-efficient attention")
# Computational efficiency
scaling_analysis = self.scaling_analysis
if scaling_analysis and 'sequence_scaling' in scaling_analysis:
time_vs_quad = scaling_analysis['sequence_scaling']['time_vs_quadratic_ratio']
if time_vs_quad > 1.5:
recommendations.append("🐌 Computational scaling worse than O(n^2)")
recommendations.append("🔧 Consider: Optimized GEMM operations, tensor cores")
# Production deployment recommendations
recommendations.append("🏭 Production optimizations:")
recommendations.append(" • KV-cache for autoregressive generation")
recommendations.append(" • Mixed precision (fp16) for memory reduction")
recommendations.append(" • Attention kernel fusion for GPU efficiency")
return recommendations
def analyze_multi_head_efficiency(self, num_heads_range=[1, 2, 4, 8], seq_len=128, d_model=512):
"""
Analyze multi-head attention efficiency patterns.
This function is PROVIDED to demonstrate multi-head scaling.
Students use it to understand parallelization trade-offs.
"""
print("🔍 MULTI-HEAD ATTENTION EFFICIENCY ANALYSIS")
print("=" * 50)
d_k = d_model // max(num_heads_range) # Head dimension
multi_head_results = []
for num_heads in num_heads_range:
head_dim = d_model // num_heads
# Simulate multi-head computation
total_params = num_heads * (3 * d_model * head_dim) # Q, K, V projections
# Memory for all heads
# Each head processes (seq_len, head_dim)
single_head_attention_memory = (seq_len * seq_len * 4) / (1024 * 1024) # MB
total_attention_memory = num_heads * single_head_attention_memory
# Computational load per head is reduced
flops_per_head = seq_len * seq_len * head_dim
total_flops = num_heads * flops_per_head
# Parallelization efficiency (simplified model)
parallelization_efficiency = min(1.0, num_heads / 8.0) # Assumes 8-way parallelism
effective_compute_time = total_flops / (num_heads * parallelization_efficiency)
result = {
'num_heads': num_heads,
'head_dimension': head_dim,
'total_parameters': total_params,
'attention_memory_mb': total_attention_memory,
'total_flops': total_flops,
'parallelization_efficiency': parallelization_efficiency,
'effective_compute_time': effective_compute_time
}
multi_head_results.append(result)
print(f" {num_heads} heads: {head_dim}d each, {total_attention_memory:.1f}MB, {parallelization_efficiency:.2f} parallel efficiency")
# Analyze optimal configuration
best_efficiency = max(multi_head_results, key=lambda x: x['parallelization_efficiency'])
memory_efficient = min(multi_head_results, key=lambda x: x['attention_memory_mb'])
print(f"\n📈 Multi-Head Analysis:")
print(f" Best parallelization: {best_efficiency['num_heads']} heads")
print(f" Most memory efficient: {memory_efficient['num_heads']} heads")
print(f" Trade-off: More heads = better parallelism but higher memory")
return multi_head_results