mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-06-01 09:12:51 -05:00
refactor: Simplify attention module to follow TinyTorch patterns
CHANGED: Simplified attention module to focus on core concepts - Remove multi-head attention, positional encoding, layer norm, transformer block - Keep only: scaled_dot_product_attention, SelfAttention, masking utilities - Reduce complexity from ⭐⭐⭐⭐ to ⭐⭐⭐ (matches CNN level) - Cut from 885 lines to ~440 lines (aligned with other modules) - Update dependencies: only requires tensor (not layers/activations/networks) - Change pedagogical framework: 'Build → Use → Understand' (not Reflect) - Focus on single concept per module (following established TinyTorch pattern) RESULT: Clean, focused attention module teaching core mechanism - Students master fundamental attention before advanced concepts - Consistent with TinyTorch's one-concept-per-module approach - Foundation for future multi-head attention and transformer modules - All tests passing (100% success rate)
This commit is contained in:
@@ -1,31 +1,30 @@
|
||||
# 🔥 Module: Attention
|
||||
|
||||
## 📊 Module Info
|
||||
- **Difficulty**: ⭐⭐⭐⭐ Advanced
|
||||
- **Time Estimate**: 6-8 hours
|
||||
- **Prerequisites**: Tensor, Activations, Layers, Networks modules
|
||||
- **Next Steps**: CNN, Autograd, Training modules
|
||||
- **Difficulty**: ⭐⭐⭐ Advanced
|
||||
- **Time Estimate**: 4-5 hours
|
||||
- **Prerequisites**: Tensor module
|
||||
- **Next Steps**: Training, Transformers modules
|
||||
|
||||
Build attention mechanisms from scratch and understand the core technology powering modern AI systems like ChatGPT, BERT, and GPT-4. This module teaches you that attention is a powerful pattern-matching mechanism that allows models to dynamically focus on relevant parts of input sequences.
|
||||
Build the core attention mechanism that powers modern AI! This module implements the fundamental scaled dot-product attention that's used in ChatGPT, BERT, GPT-4, and virtually all state-of-the-art AI systems.
|
||||
|
||||
## 🎯 Learning Objectives
|
||||
|
||||
By the end of this module, you will be able to:
|
||||
|
||||
- **Master attention mechanisms**: Understand how Query, Key, Value projections enable dynamic focus
|
||||
- **Implement self-attention**: Build the core component that powers transformer architectures
|
||||
- **Create multi-head attention**: Combine multiple attention patterns for richer representations
|
||||
- **Add positional encoding**: Give transformers the ability to understand sequence order
|
||||
- **Build transformer blocks**: Compose attention with feed-forward networks and residual connections
|
||||
- **Compare attention patterns**: Understand when to use self-attention vs cross-attention
|
||||
- **Master the attention formula**: Understand and implement `Attention(Q,K,V) = softmax(QK^T/√d_k)V`
|
||||
- **Build self-attention**: Create the core component that enables global context understanding
|
||||
- **Control information flow**: Implement masking for causal, padding, and bidirectional attention
|
||||
- **Visualize attention patterns**: See what the model "pays attention to"
|
||||
- **Understand modern AI**: Grasp the mechanism that revolutionized natural language processing
|
||||
|
||||
## 🧠 Build → Use → Reflect
|
||||
## 🧠 Build → Use → Understand
|
||||
|
||||
This module follows TinyTorch's **Build → Use → Reflect** framework:
|
||||
This module follows TinyTorch's **Build → Use → Understand** framework:
|
||||
|
||||
1. **Build**: Implement attention mechanisms and transformer components from mathematical foundations
|
||||
2. **Use**: Apply attention to sequence tasks and visualize what the model "pays attention to"
|
||||
3. **Reflect**: Compare attention's global perspective with CNN's local receptive fields
|
||||
1. **Build**: Implement the core attention mechanism and masking utilities from mathematical foundations
|
||||
2. **Use**: Apply attention to sequence tasks and visualize attention patterns
|
||||
3. **Understand**: How attention enables dynamic, global context modeling that powers modern AI
|
||||
|
||||
## 📚 What You'll Build
|
||||
|
||||
@@ -35,6 +34,8 @@ def scaled_dot_product_attention(Q, K, V, mask=None):
|
||||
"""
|
||||
The fundamental attention operation:
|
||||
Attention(Q,K,V) = softmax(QK^T/√d_k)V
|
||||
|
||||
This exact function powers ChatGPT, BERT, and all transformers.
|
||||
"""
|
||||
d_k = Q.shape[-1]
|
||||
scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k)
|
||||
@@ -44,94 +45,96 @@ def scaled_dot_product_attention(Q, K, V, mask=None):
|
||||
return attention_weights @ V, attention_weights
|
||||
```
|
||||
|
||||
### Multi-Head Attention
|
||||
### Self-Attention Wrapper
|
||||
```python
|
||||
class MultiHeadAttention:
|
||||
class SelfAttention:
|
||||
"""
|
||||
Multiple attention heads capture different types of relationships:
|
||||
- Head 1: Syntactic relationships
|
||||
- Head 2: Semantic relationships
|
||||
- Head 3: Long-range dependencies
|
||||
Convenient wrapper for self-attention where Q=K=V.
|
||||
The most common use case in transformer models.
|
||||
"""
|
||||
def __init__(self, d_model, num_heads):
|
||||
self.num_heads = num_heads
|
||||
self.d_k = d_model // num_heads
|
||||
self.W_q = Dense(d_model, d_model)
|
||||
self.W_k = Dense(d_model, d_model)
|
||||
self.W_v = Dense(d_model, d_model)
|
||||
self.W_o = Dense(d_model, d_model)
|
||||
def __init__(self, d_model):
|
||||
self.d_model = d_model
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
# Self-attention: Q = K = V = x
|
||||
return scaled_dot_product_attention(x, x, x, mask)
|
||||
```
|
||||
|
||||
### Transformer Block
|
||||
### Attention Masking
|
||||
```python
|
||||
class TransformerBlock:
|
||||
"""
|
||||
Complete transformer layer combining:
|
||||
1. Multi-head self-attention
|
||||
2. Residual connections
|
||||
3. Layer normalization
|
||||
4. Feed-forward network
|
||||
"""
|
||||
def __init__(self, d_model, num_heads, d_ff):
|
||||
self.attention = MultiHeadAttention(d_model, num_heads)
|
||||
self.feed_forward = Sequential([
|
||||
Dense(d_model, d_ff),
|
||||
ReLU(),
|
||||
Dense(d_ff, d_model)
|
||||
])
|
||||
# Causal masking (GPT-style: can't see future tokens)
|
||||
causal_mask = create_causal_mask(seq_len)
|
||||
|
||||
# Padding masking (ignore padding tokens)
|
||||
padding_mask = create_padding_mask(lengths, max_length)
|
||||
|
||||
# Bidirectional masking (BERT-style: can see all tokens)
|
||||
bidirectional_mask = create_bidirectional_mask(seq_len)
|
||||
```
|
||||
|
||||
## 🔬 Key Concepts
|
||||
|
||||
### Why Attention Matters
|
||||
- **Global context**: Unlike CNNs, attention can connect any two positions directly
|
||||
- **Dynamic weights**: Attention weights adapt based on input content, not fixed patterns
|
||||
### Why Attention Revolutionized AI
|
||||
- **Global connectivity**: Unlike CNNs, attention connects any two positions directly
|
||||
- **Dynamic weights**: Attention adapts to input content, not fixed like convolution kernels
|
||||
- **Parallel processing**: Unlike RNNs, all positions computed simultaneously
|
||||
- **Interpretability**: You can visualize what the model pays attention to
|
||||
- **Scalability**: Attention scales to very long sequences (with modifications)
|
||||
|
||||
### The Attention Formula Explained
|
||||
```
|
||||
Attention(Q,K,V) = softmax(QK^T/√d_k)V
|
||||
|
||||
Where:
|
||||
- Q (Query): "What am I looking for?"
|
||||
- K (Key): "What information is available?"
|
||||
- V (Value): "What is the actual content?"
|
||||
- √d_k scaling: Prevents extreme softmax values
|
||||
```
|
||||
|
||||
### Attention vs Convolution
|
||||
| Aspect | Convolution | Attention |
|
||||
|--------|-------------|-----------|
|
||||
| **Receptive field** | Local, grows with depth | Global from layer 1 |
|
||||
| **Computation** | O(n) with kernel size | O(n²) with sequence length |
|
||||
| **Inductive bias** | Spatial locality | Sequence relationships |
|
||||
| **Best for** | Images, spatial data | Text, sequences |
|
||||
| **Weights** | Fixed learned kernels | Dynamic input-dependent |
|
||||
| **Best for** | Spatial data (images) | Sequential data (text) |
|
||||
|
||||
### Real-World Applications
|
||||
- **Language Models**: GPT, BERT, ChatGPT
|
||||
- **Machine Translation**: Google Translate
|
||||
- **Vision Transformers**: Image classification without convolution
|
||||
- **Multimodal AI**: CLIP, DALL-E combining text and images
|
||||
- **Language Models**: GPT, BERT, ChatGPT use self-attention to understand context
|
||||
- **Machine Translation**: Google Translate uses attention to align source and target words
|
||||
- **Image Understanding**: Vision Transformers apply attention to image patches
|
||||
- **Multimodal AI**: CLIP, DALL-E use attention to connect text and images
|
||||
|
||||
## 🚀 From Attention to Modern AI
|
||||
|
||||
This module bridges classical ML and modern AI:
|
||||
This module teaches the **core building block** of modern AI:
|
||||
|
||||
**Classical (pre-2017)**: RNNs + CNNs + LSTMs
|
||||
**Modern (post-2017)**: Transformers + Attention + Self-Supervision
|
||||
**What you're building**: The fundamental attention mechanism
|
||||
**What it enables**: Multi-head attention, positional encoding, transformer blocks
|
||||
**What it powers**: ChatGPT, BERT, GPT-4, and contemporary AI systems
|
||||
|
||||
Understanding attention mechanisms gives you the foundation to understand:
|
||||
- How ChatGPT generates text
|
||||
- How BERT understands language
|
||||
Understanding this module gives you the foundation to understand:
|
||||
- How ChatGPT generates coherent text
|
||||
- How BERT understands language bidirectionally
|
||||
- How Vision Transformers work without convolution
|
||||
- How DALL-E combines text and images
|
||||
- How modern AI achieves human-like language understanding
|
||||
|
||||
## 📈 Module Progression
|
||||
|
||||
```
|
||||
Tensors → Activations → Layers → Networks → **ATTENTION** → CNN → Training
|
||||
↑ ↑
|
||||
Foundation Modern AI Core
|
||||
Tensors → **ATTENTION** → Layers → Networks → CNNs → Training
|
||||
↑ ↑
|
||||
Foundation Modern AI Core
|
||||
```
|
||||
|
||||
After completing this module, you'll understand the mechanism that powers the AI revolution, making you ready to work with state-of-the-art models and architectures.
|
||||
After completing this module, you'll understand the mechanism that sparked the AI revolution, making you ready to work with state-of-the-art models and architectures.
|
||||
|
||||
## 🎯 Success Criteria
|
||||
|
||||
You'll know you've mastered this module when you can:
|
||||
- [ ] Explain why attention enables better long-range dependencies than RNNs
|
||||
- [ ] Implement multi-head attention from scratch
|
||||
- [ ] Visualize attention patterns and interpret what the model focuses on
|
||||
- [ ] Compare computational complexity of attention vs convolution
|
||||
- [ ] Build a complete transformer block with residual connections
|
||||
- [ ] Understand why transformers have revolutionized NLP and computer vision
|
||||
- [ ] Implement scaled dot-product attention from scratch
|
||||
- [ ] Explain why the √d_k scaling prevents gradient problems
|
||||
- [ ] Create different types of attention masks for various use cases
|
||||
- [ ] Visualize and interpret attention weights
|
||||
- [ ] Understand why attention enabled the transformer revolution
|
||||
- [ ] Connect this foundation to modern AI systems like ChatGPT
|
||||
@@ -17,23 +17,22 @@ Welcome to the Attention module! This is where you'll implement the revolutionar
|
||||
## Learning Goals
|
||||
- Understand attention as dynamic pattern matching with Query, Key, Value projections
|
||||
- Implement scaled dot-product attention from mathematical foundations
|
||||
- Build multi-head attention to capture diverse relationship patterns
|
||||
- Create positional encoding to give transformers sequence awareness
|
||||
- Compose transformer blocks that combine attention with feed-forward networks
|
||||
- Compare attention's global connectivity with CNN's local receptive fields
|
||||
- Master the attention formula that powers all transformer models
|
||||
- Create masking utilities for different attention patterns
|
||||
- Build the foundation for understanding modern AI architectures
|
||||
|
||||
## Build → Use → Reflect
|
||||
1. **Build**: Implement attention mechanisms from scratch using mathematical principles
|
||||
## Build → Use → Understand
|
||||
1. **Build**: Implement the core attention mechanism from scratch using mathematical principles
|
||||
2. **Use**: Apply attention to sequence tasks and visualize attention patterns
|
||||
3. **Reflect**: Understand how attention revolutionized AI by enabling global context modeling
|
||||
3. **Understand**: How attention revolutionized AI by enabling global context modeling
|
||||
|
||||
## What You'll Learn
|
||||
By the end of this module, you'll understand:
|
||||
- How attention enables dynamic focus on relevant input parts
|
||||
- Why multi-head attention captures diverse relationship types
|
||||
- How positional encoding gives transformers sequence understanding
|
||||
- The transformer architecture that powers modern AI systems
|
||||
- Computational trade-offs between attention and convolution
|
||||
- The mathematical foundation behind all transformer models
|
||||
- Why attention is more powerful than fixed convolution kernels
|
||||
- How masking enables different attention patterns (causal, padding)
|
||||
- The building block that powers ChatGPT, BERT, and modern AI
|
||||
"""
|
||||
|
||||
# %% nbgrader={"grade": false, "grade_id": "attention-imports", "locked": false, "schema_version": 3, "solution": false, "task": false}
|
||||
@@ -47,22 +46,13 @@ import os
|
||||
from typing import List, Union, Optional, Tuple
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Import all the building blocks we need - try package first, then local modules
|
||||
# Import our building blocks - try package first, then local modules
|
||||
try:
|
||||
from tinytorch.core.tensor import Tensor
|
||||
from tinytorch.core.layers import Dense
|
||||
from tinytorch.core.activations import ReLU, Softmax
|
||||
from tinytorch.core.networks import Sequential
|
||||
except ImportError:
|
||||
# For development, import from local modules
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '02_tensor'))
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '03_activations'))
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '04_layers'))
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '05_networks'))
|
||||
from tensor_dev import Tensor
|
||||
from activations_dev import ReLU, Softmax
|
||||
from layers_dev import Dense
|
||||
from networks_dev import Sequential
|
||||
|
||||
# %% nbgrader={"grade": false, "grade_id": "attention-setup", "locked": false, "schema_version": 3, "solution": false, "task": false}
|
||||
#| hide
|
||||
@@ -98,20 +88,18 @@ print("Ready to build attention mechanisms that power modern AI!")
|
||||
# Final package structure:
|
||||
from tinytorch.core.attention import (
|
||||
scaled_dot_product_attention, # Core attention function
|
||||
MultiHeadAttention, # Multi-head attention layer
|
||||
PositionalEncoding, # Position information
|
||||
TransformerBlock, # Complete transformer layer
|
||||
SelfAttention # Self-attention wrapper
|
||||
SelfAttention, # Self-attention wrapper
|
||||
create_causal_mask, # Masking utilities
|
||||
create_padding_mask
|
||||
)
|
||||
from tinytorch.core.layers import Dense # Building blocks
|
||||
from tinytorch.core.tensor import Tensor # Foundation
|
||||
```
|
||||
|
||||
**Why this matters:**
|
||||
- **Learning:** Focused module for deep understanding of attention
|
||||
- **Production:** Proper organization like PyTorch's `torch.nn.MultiheadAttention`
|
||||
- **Learning:** Focused module for deep understanding of core attention
|
||||
- **Production:** Proper organization like PyTorch's attention functions
|
||||
- **Consistency:** All attention mechanisms live together in `core.attention`
|
||||
- **Integration:** Works seamlessly with tensors, layers, and networks
|
||||
- **Foundation:** Building block for future transformer modules
|
||||
"""
|
||||
|
||||
# %% [markdown]
|
||||
@@ -155,6 +143,14 @@ Attention(Q,K,V) = softmax(QK^T/√d_k)V
|
||||
- **Interpretability**: Attention weights show what the model focuses on
|
||||
- **Scalability**: Works for sequences of varying lengths
|
||||
|
||||
### Attention vs Convolution
|
||||
| Aspect | Convolution | Attention |
|
||||
|--------|-------------|-----------|
|
||||
| **Receptive field** | Local, grows with depth | Global from layer 1 |
|
||||
| **Computation** | O(n) with kernel size | O(n²) with sequence length |
|
||||
| **Weights** | Fixed learned kernels | Dynamic input-dependent |
|
||||
| **Best for** | Spatial data (images) | Sequential data (text) |
|
||||
|
||||
Let's implement this step by step!
|
||||
"""
|
||||
|
||||
@@ -283,539 +279,54 @@ print("📈 Progress: Scaled Dot-Product Attention ✓")
|
||||
|
||||
# %% [markdown]
|
||||
"""
|
||||
## Step 3: Multi-Head Attention - Capturing Diverse Relationships
|
||||
## Step 3: Self-Attention - The Most Common Case
|
||||
|
||||
### Why Multiple Heads?
|
||||
A single attention head captures one type of relationship. Multiple heads allow the model to attend to different types of patterns simultaneously:
|
||||
### What is Self-Attention?
|
||||
**Self-Attention** is the most common use of attention where Q, K, and V all come from the same input sequence. This is what enables models like GPT to understand relationships between words in a sentence.
|
||||
|
||||
- **Head 1**: Syntactic relationships (subject-verb)
|
||||
- **Head 2**: Semantic relationships (word meanings)
|
||||
- **Head 3**: Long-range dependencies
|
||||
- **Head 4**: Local context patterns
|
||||
### Why Self-Attention Matters
|
||||
- **Context understanding**: Each word can attend to every other word
|
||||
- **Long-range dependencies**: Connect distant related concepts
|
||||
- **Parallel processing**: Unlike RNNs, all positions computed simultaneously
|
||||
- **Foundation of GPT**: How language models understand context
|
||||
|
||||
### The Multi-Head Architecture
|
||||
```
|
||||
MultiHead(Q,K,V) = Concat(head₁, head₂, ..., headₕ)W^O
|
||||
|
||||
where headᵢ = Attention(QWᵢᵠ, KWᵢᴷ, VWᵢⱽ)
|
||||
```
|
||||
|
||||
### Implementation Strategy
|
||||
1. **Project**: Apply learned projections to create Q, K, V for each head
|
||||
2. **Split**: Divide into multiple heads with smaller dimensions
|
||||
3. **Attend**: Apply attention for each head independently
|
||||
4. **Combine**: Concatenate heads and apply output projection
|
||||
|
||||
Let's build this powerful attention mechanism!
|
||||
Let's create a convenient wrapper for self-attention!
|
||||
"""
|
||||
|
||||
# %% nbgrader={"grade": false, "grade_id": "multi-head-attention", "locked": false, "schema_version": 3, "solution": true, "task": false}
|
||||
# %% nbgrader={"grade": false, "grade_id": "self-attention", "locked": false, "schema_version": 3, "solution": true, "task": false}
|
||||
#| export
|
||||
class MultiHeadAttention:
|
||||
class SelfAttention:
|
||||
"""
|
||||
Multi-Head Attention - Enables models to attend to different representation
|
||||
subspaces simultaneously. This is the core component of transformer models.
|
||||
Self-Attention wrapper - Convenience class for self-attention where Q=K=V.
|
||||
|
||||
In transformers, each head learns to focus on different types of relationships:
|
||||
- Syntactic patterns
|
||||
- Semantic relationships
|
||||
- Long-range dependencies
|
||||
- Local context
|
||||
This is the most common use case in transformer models where each position
|
||||
attends to all positions in the same sequence.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, num_heads: int):
|
||||
def __init__(self, d_model: int):
|
||||
"""
|
||||
Initialize Multi-Head Attention.
|
||||
|
||||
Args:
|
||||
d_model: Model dimension (must be divisible by num_heads)
|
||||
num_heads: Number of attention heads
|
||||
"""
|
||||
assert d_model % num_heads == 0, f"d_model ({d_model}) must be divisible by num_heads ({num_heads})"
|
||||
|
||||
self.d_model = d_model
|
||||
self.num_heads = num_heads
|
||||
self.d_k = d_model // num_heads # Dimension per head
|
||||
|
||||
# For testing purposes, use simple weight matrices instead of Dense layers
|
||||
# In production, these would be Dense layers
|
||||
np.random.seed(42) # For reproducible testing
|
||||
self.W_q = np.random.randn(d_model, d_model) * 0.1 # Query projection
|
||||
self.W_k = np.random.randn(d_model, d_model) * 0.1 # Key projection
|
||||
self.W_v = np.random.randn(d_model, d_model) * 0.1 # Value projection
|
||||
self.W_o = np.random.randn(d_model, d_model) * 0.1 # Output projection
|
||||
|
||||
print(f"🔧 MultiHeadAttention: {num_heads} heads, {self.d_k} dims per head")
|
||||
|
||||
def forward(self, query: np.ndarray, key: np.ndarray, value: np.ndarray,
|
||||
mask: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Forward pass of multi-head attention.
|
||||
|
||||
Args:
|
||||
query: Query tensor (..., seq_len_q, d_model)
|
||||
key: Key tensor (..., seq_len_k, d_model)
|
||||
value: Value tensor (..., seq_len_v, d_model)
|
||||
mask: Optional attention mask
|
||||
|
||||
Returns:
|
||||
output: Attention output (..., seq_len_q, d_model)
|
||||
attention_weights: Average attention weights across heads
|
||||
"""
|
||||
batch_dims = query.shape[:-2]
|
||||
seq_len_q = query.shape[-2]
|
||||
seq_len_k = key.shape[-2]
|
||||
|
||||
# Step 1: Apply linear projections to get Q, K, V for all heads
|
||||
Q = np.matmul(query, self.W_q) # (..., seq_len_q, d_model)
|
||||
K = np.matmul(key, self.W_k) # (..., seq_len_k, d_model)
|
||||
V = np.matmul(value, self.W_v) # (..., seq_len_v, d_model)
|
||||
|
||||
# Step 2: Reshape and transpose for multi-head processing
|
||||
# Split d_model into num_heads * d_k
|
||||
Q = self._reshape_for_heads(Q) # (..., num_heads, seq_len_q, d_k)
|
||||
K = self._reshape_for_heads(K) # (..., num_heads, seq_len_k, d_k)
|
||||
V = self._reshape_for_heads(V) # (..., num_heads, seq_len_v, d_k)
|
||||
|
||||
# Step 3: Apply attention for each head
|
||||
attention_output, attention_weights = self._apply_attention_heads(Q, K, V, mask)
|
||||
|
||||
# Step 4: Concatenate heads and apply output projection
|
||||
# Reshape back to (..., seq_len_q, d_model)
|
||||
attention_output = self._concatenate_heads(attention_output)
|
||||
|
||||
# Final linear projection
|
||||
output = np.matmul(attention_output, self.W_o)
|
||||
|
||||
return output, attention_weights
|
||||
|
||||
def _reshape_for_heads(self, x: np.ndarray) -> np.ndarray:
|
||||
"""Reshape tensor for multi-head processing."""
|
||||
# Input: (..., seq_len, d_model)
|
||||
# Output: (..., num_heads, seq_len, d_k)
|
||||
batch_dims = x.shape[:-2]
|
||||
seq_len = x.shape[-2]
|
||||
|
||||
# Reshape to (..., seq_len, num_heads, d_k)
|
||||
x = x.reshape(*batch_dims, seq_len, self.num_heads, self.d_k)
|
||||
|
||||
# Transpose to (..., num_heads, seq_len, d_k)
|
||||
x = x.transpose(*range(len(batch_dims)), -2, -3, -1)
|
||||
|
||||
return x
|
||||
|
||||
def _apply_attention_heads(self, Q: np.ndarray, K: np.ndarray, V: np.ndarray,
|
||||
mask: Optional[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Apply scaled dot-product attention for each head."""
|
||||
# Input shapes: (..., num_heads, seq_len, d_k)
|
||||
|
||||
# Apply attention to each head
|
||||
attention_outputs = []
|
||||
attention_weights_list = []
|
||||
|
||||
for head in range(self.num_heads):
|
||||
# Extract tensors for this head
|
||||
Q_head = Q[..., head, :, :] # (..., seq_len_q, d_k)
|
||||
K_head = K[..., head, :, :] # (..., seq_len_k, d_k)
|
||||
V_head = V[..., head, :, :] # (..., seq_len_v, d_k)
|
||||
|
||||
# Apply attention for this head
|
||||
head_output, head_weights = scaled_dot_product_attention(Q_head, K_head, V_head, mask)
|
||||
|
||||
attention_outputs.append(head_output)
|
||||
attention_weights_list.append(head_weights)
|
||||
|
||||
# Stack outputs: (..., num_heads, seq_len_q, d_k)
|
||||
attention_output = np.stack(attention_outputs, axis=-3)
|
||||
|
||||
# Average attention weights across heads for visualization
|
||||
attention_weights = np.stack(attention_weights_list, axis=-3)
|
||||
attention_weights = np.mean(attention_weights, axis=-3)
|
||||
|
||||
return attention_output, attention_weights
|
||||
|
||||
def _concatenate_heads(self, x: np.ndarray) -> np.ndarray:
|
||||
"""Concatenate attention heads back to original dimension."""
|
||||
# Input: (..., num_heads, seq_len, d_k)
|
||||
# Output: (..., seq_len, d_model)
|
||||
batch_dims = x.shape[:-3]
|
||||
seq_len = x.shape[-2]
|
||||
|
||||
# Transpose to (..., seq_len, num_heads, d_k)
|
||||
x = x.transpose(*range(len(batch_dims)), -2, -3, -1)
|
||||
|
||||
# Reshape to (..., seq_len, d_model)
|
||||
x = x.reshape(*batch_dims, seq_len, self.d_model)
|
||||
|
||||
return x
|
||||
|
||||
def __call__(self, query: np.ndarray, key: np.ndarray, value: np.ndarray,
|
||||
mask: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Make the class callable."""
|
||||
return self.forward(query, key, value, mask)
|
||||
|
||||
# %% [markdown]
|
||||
"""
|
||||
### 🧪 Unit Test: Multi-Head Attention
|
||||
|
||||
**This is a unit test** - it tests multi-head attention composition in isolation.
|
||||
|
||||
Let's verify that our multi-head attention correctly splits, processes, and combines attention heads.
|
||||
"""
|
||||
|
||||
# %% nbgrader={"grade": false, "grade_id": "test-multi-head", "locked": false, "schema_version": 3, "solution": false, "task": false}
|
||||
print("🔬 Unit Test: Multi-Head Attention...")
|
||||
|
||||
# Test parameters
|
||||
d_model = 64
|
||||
num_heads = 8
|
||||
seq_len = 10
|
||||
np.random.seed(42)
|
||||
|
||||
# Create test data
|
||||
query = np.random.randn(seq_len, d_model) * 0.1
|
||||
key = np.random.randn(seq_len, d_model) * 0.1
|
||||
value = np.random.randn(seq_len, d_model) * 0.1
|
||||
|
||||
print(f"📊 Test setup: d_model={d_model}, num_heads={num_heads}, seq_len={seq_len}")
|
||||
|
||||
# Create multi-head attention
|
||||
mha = MultiHeadAttention(d_model, num_heads)
|
||||
|
||||
# Test forward pass
|
||||
output, weights = mha(query, key, value)
|
||||
|
||||
print(f"📊 Output shapes: output{output.shape}, weights{weights.shape}")
|
||||
|
||||
# Verify properties
|
||||
print(f"✅ Output shape correct: {output.shape == (seq_len, d_model)}")
|
||||
print(f"✅ Attention weights shape correct: {weights.shape == (seq_len, seq_len)}")
|
||||
print(f"✅ Attention weights sum to 1: {np.allclose(np.sum(weights, axis=-1), 1.0)}")
|
||||
print(f"✅ d_k per head correct: {mha.d_k == d_model // num_heads}")
|
||||
|
||||
# Test self-attention (Q = K = V)
|
||||
self_output, self_weights = mha(query, query, query)
|
||||
print(f"✅ Self-attention works: {self_output.shape == (seq_len, d_model)}")
|
||||
|
||||
print("📈 Progress: Multi-Head Attention ✓")
|
||||
|
||||
# %% [markdown]
|
||||
"""
|
||||
## Step 4: Positional Encoding - Teaching Transformers About Order
|
||||
|
||||
### The Position Problem
|
||||
Unlike RNNs or CNNs, attention is **position-agnostic**. The operation is symmetric - swapping two input positions gives the same result. This is both a strength (parallelizable) and weakness (no understanding of order).
|
||||
|
||||
### Why Position Matters
|
||||
For sequences, order is crucial:
|
||||
- "The cat sat on the mat" ≠ "The mat sat on the cat"
|
||||
- "I didn't say she stole my money" has different meanings based on emphasis
|
||||
- Code execution order matters: `x = 1; y = x + 1` ≠ `y = x + 1; x = 1`
|
||||
|
||||
### Sinusoidal Positional Encoding
|
||||
The Transformer paper uses sinusoidal functions to encode position:
|
||||
```
|
||||
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
|
||||
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
|
||||
```
|
||||
|
||||
### Why Sinusoidal?
|
||||
- **Deterministic**: Same position always gets same encoding
|
||||
- **Extrapolation**: Can handle sequences longer than training
|
||||
- **Smooth**: Similar positions get similar encodings
|
||||
- **Learnable patterns**: Model can learn to use positional relationships
|
||||
|
||||
Let's implement this crucial component!
|
||||
"""
|
||||
|
||||
# %% nbgrader={"grade": false, "grade_id": "positional-encoding", "locked": false, "schema_version": 3, "solution": true, "task": false}
|
||||
#| export
|
||||
class PositionalEncoding:
|
||||
"""
|
||||
Positional Encoding using sinusoidal functions.
|
||||
|
||||
Adds position information to transformer inputs so the model
|
||||
can understand sequence order. Uses the same approach as the
|
||||
original Transformer paper.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, max_length: int = 5000):
|
||||
"""
|
||||
Initialize positional encoding.
|
||||
Initialize Self-Attention.
|
||||
|
||||
Args:
|
||||
d_model: Model dimension
|
||||
max_length: Maximum sequence length to precompute
|
||||
"""
|
||||
self.d_model = d_model
|
||||
self.max_length = max_length
|
||||
|
||||
# Precompute positional encodings
|
||||
self.pe = self._create_positional_encoding()
|
||||
print(f"🔧 PositionalEncoding: d_model={d_model}, max_length={max_length}")
|
||||
|
||||
def _create_positional_encoding(self) -> np.ndarray:
|
||||
"""
|
||||
Create sinusoidal positional encoding matrix.
|
||||
|
||||
Returns:
|
||||
pe: Positional encoding matrix (max_length, d_model)
|
||||
"""
|
||||
pe = np.zeros((self.max_length, self.d_model))
|
||||
|
||||
# Create position indices
|
||||
position = np.arange(self.max_length).reshape(-1, 1) # (max_length, 1)
|
||||
|
||||
# Create dimension indices for the sinusoidal pattern
|
||||
div_term = np.exp(np.arange(0, self.d_model, 2) *
|
||||
-(math.log(10000.0) / self.d_model)) # (d_model//2,)
|
||||
|
||||
# Apply sinusoidal functions
|
||||
pe[:, 0::2] = np.sin(position * div_term) # Even indices: sin
|
||||
pe[:, 1::2] = np.cos(position * div_term) # Odd indices: cos
|
||||
|
||||
return pe
|
||||
|
||||
def forward(self, x: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Add positional encoding to input embeddings.
|
||||
|
||||
Args:
|
||||
x: Input tensor (..., seq_len, d_model)
|
||||
|
||||
Returns:
|
||||
output: Input + positional encoding (..., seq_len, d_model)
|
||||
"""
|
||||
seq_len = x.shape[-2]
|
||||
|
||||
if seq_len > self.max_length:
|
||||
raise ValueError(f"Sequence length {seq_len} exceeds max_length {self.max_length}")
|
||||
|
||||
# Get positional encoding for this sequence length
|
||||
pos_encoding = self.pe[:seq_len, :] # (seq_len, d_model)
|
||||
|
||||
# Add to input (broadcasting handles batch dimensions)
|
||||
output = x + pos_encoding
|
||||
|
||||
return output
|
||||
|
||||
def __call__(self, x: np.ndarray) -> np.ndarray:
|
||||
"""Make the class callable."""
|
||||
return self.forward(x)
|
||||
|
||||
# %% [markdown]
|
||||
"""
|
||||
### 🧪 Unit Test: Positional Encoding
|
||||
|
||||
**This is a unit test** - it tests positional encoding addition in isolation.
|
||||
|
||||
Let's verify that positional encoding adds meaningful position information.
|
||||
"""
|
||||
|
||||
# %% nbgrader={"grade": false, "grade_id": "test-positional-encoding", "locked": false, "schema_version": 3, "solution": false, "task": false}
|
||||
print("🔬 Unit Test: Positional Encoding...")
|
||||
|
||||
# Test parameters
|
||||
d_model = 32
|
||||
max_length = 100
|
||||
seq_len = 8
|
||||
np.random.seed(42)
|
||||
|
||||
# Create test data (like word embeddings)
|
||||
embeddings = np.random.randn(seq_len, d_model) * 0.1
|
||||
|
||||
print(f"📊 Test setup: d_model={d_model}, seq_len={seq_len}")
|
||||
|
||||
# Create positional encoding
|
||||
pos_enc = PositionalEncoding(d_model, max_length)
|
||||
|
||||
# Test forward pass
|
||||
output = pos_enc(embeddings)
|
||||
|
||||
print(f"📊 Shapes: input{embeddings.shape}, output{output.shape}")
|
||||
|
||||
# Verify properties
|
||||
print(f"✅ Output shape preserved: {output.shape == embeddings.shape}")
|
||||
print(f"✅ Positional encoding has correct shape: {pos_enc.pe.shape == (max_length, d_model)}")
|
||||
|
||||
# Test that different positions get different encodings
|
||||
pos_0 = pos_enc.pe[0, :]
|
||||
pos_1 = pos_enc.pe[1, :]
|
||||
pos_10 = pos_enc.pe[10, :]
|
||||
|
||||
print(f"✅ Different positions have different encodings: {not np.allclose(pos_0, pos_1)}")
|
||||
print(f"✅ Position encoding bounded: {np.all(np.abs(pos_enc.pe) <= 1.1)}")
|
||||
|
||||
print("📈 Progress: Positional Encoding ✓")
|
||||
|
||||
# %% [markdown]
|
||||
"""
|
||||
## Step 5: Layer Normalization - Stabilizing Training
|
||||
|
||||
### Why Normalization Matters
|
||||
Deep networks suffer from **internal covariate shift** - as parameters change during training, the distribution of layer inputs changes, making training unstable.
|
||||
|
||||
### Layer Normalization vs Batch Normalization
|
||||
- **Batch Norm**: Normalizes across the batch dimension
|
||||
- **Layer Norm**: Normalizes across the feature dimension
|
||||
- **Why Layer Norm for Transformers**: Works better with variable sequence lengths and smaller batches
|
||||
|
||||
### The Layer Norm Operation
|
||||
```
|
||||
LayerNorm(x) = γ * (x - μ) / σ + β
|
||||
|
||||
where:
|
||||
μ = mean(x, axis=-1) # Mean across features
|
||||
σ = std(x, axis=-1) # Standard deviation across features
|
||||
γ, β = learnable parameters
|
||||
```
|
||||
|
||||
Let's implement this essential component!
|
||||
"""
|
||||
|
||||
# %% nbgrader={"grade": false, "grade_id": "layer-normalization", "locked": false, "schema_version": 3, "solution": true, "task": false}
|
||||
#| export
|
||||
class LayerNorm:
|
||||
"""
|
||||
Layer Normalization - Normalizes inputs across the feature dimension.
|
||||
|
||||
Essential for stable transformer training. Unlike batch normalization,
|
||||
layer norm works consistently across different batch sizes and
|
||||
sequence lengths.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, eps: float = 1e-6):
|
||||
"""
|
||||
Initialize Layer Normalization.
|
||||
|
||||
Args:
|
||||
d_model: Model dimension
|
||||
eps: Small constant for numerical stability
|
||||
"""
|
||||
self.d_model = d_model
|
||||
self.eps = eps
|
||||
|
||||
# Learnable parameters
|
||||
self.gamma = np.ones(d_model) # Scale parameter
|
||||
self.beta = np.zeros(d_model) # Shift parameter
|
||||
|
||||
print(f"🔧 LayerNorm: d_model={d_model}")
|
||||
|
||||
def forward(self, x: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Apply layer normalization.
|
||||
|
||||
Args:
|
||||
x: Input tensor (..., d_model)
|
||||
|
||||
Returns:
|
||||
output: Normalized tensor (..., d_model)
|
||||
"""
|
||||
# Compute mean and variance across the last dimension (features)
|
||||
mean = np.mean(x, axis=-1, keepdims=True) # (..., 1)
|
||||
variance = np.var(x, axis=-1, keepdims=True) # (..., 1)
|
||||
|
||||
# Normalize
|
||||
x_normalized = (x - mean) / np.sqrt(variance + self.eps)
|
||||
|
||||
# Apply learnable transformation
|
||||
output = self.gamma * x_normalized + self.beta
|
||||
|
||||
return output
|
||||
|
||||
def __call__(self, x: np.ndarray) -> np.ndarray:
|
||||
"""Make the class callable."""
|
||||
return self.forward(x)
|
||||
# %% [markdown]
|
||||
"""
|
||||
## Step 6: Complete Transformer Block
|
||||
|
||||
### The Transformer Architecture
|
||||
A transformer block combines all the components we've built:
|
||||
|
||||
1. **Multi-Head Self-Attention** - Global context modeling
|
||||
2. **Residual Connection** - Gradient flow and training stability
|
||||
3. **Layer Normalization** - Input distribution stabilization
|
||||
4. **Feed-Forward Network** - Non-linear transformation
|
||||
5. **Another Residual + LayerNorm** - More stability
|
||||
|
||||
### Pre-Norm vs Post-Norm
|
||||
We'll use **Pre-Norm** (LayerNorm before attention/FFN) as it's more stable:
|
||||
```
|
||||
x = x + MultiHeadAttention(LayerNorm(x))
|
||||
x = x + FeedForward(LayerNorm(x))
|
||||
```
|
||||
|
||||
Let's build the complete transformer block!
|
||||
"""
|
||||
|
||||
# %% nbgrader={"grade": false, "grade_id": "transformer-block", "locked": false, "schema_version": 3, "solution": true, "task": false}
|
||||
#| export
|
||||
class TransformerBlock:
|
||||
"""
|
||||
Complete Transformer Block - The fundamental building block of transformer models.
|
||||
|
||||
Combines multi-head attention, feed-forward networks, residual connections,
|
||||
and layer normalization. This is the exact architecture used in GPT, BERT,
|
||||
and other transformer models.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
|
||||
"""
|
||||
Initialize Transformer Block.
|
||||
|
||||
Args:
|
||||
d_model: Model dimension
|
||||
num_heads: Number of attention heads
|
||||
d_ff: Feed-forward network dimension (usually 4 * d_model)
|
||||
dropout: Dropout rate (not implemented in this version)
|
||||
"""
|
||||
self.d_model = d_model
|
||||
self.num_heads = num_heads
|
||||
self.d_ff = d_ff
|
||||
|
||||
# Multi-head self-attention
|
||||
self.self_attention = MultiHeadAttention(d_model, num_heads)
|
||||
|
||||
# Feed-forward network (simplified for testing)
|
||||
np.random.seed(42)
|
||||
self.ff_w1 = np.random.randn(d_model, d_ff) * 0.1
|
||||
self.ff_b1 = np.zeros(d_ff)
|
||||
self.ff_w2 = np.random.randn(d_ff, d_model) * 0.1
|
||||
self.ff_b2 = np.zeros(d_model)
|
||||
|
||||
# Layer normalization layers
|
||||
self.ln1 = LayerNorm(d_model) # Before attention
|
||||
self.ln2 = LayerNorm(d_model) # Before feed-forward
|
||||
|
||||
print(f"🔧 TransformerBlock: d_model={d_model}, heads={num_heads}, d_ff={d_ff}")
|
||||
print(f"🔧 SelfAttention: d_model={d_model}")
|
||||
|
||||
def forward(self, x: np.ndarray, mask: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Forward pass of transformer block.
|
||||
Forward pass of self-attention.
|
||||
|
||||
Args:
|
||||
x: Input tensor (..., seq_len, d_model)
|
||||
mask: Optional attention mask
|
||||
|
||||
Returns:
|
||||
output: Transformed tensor (..., seq_len, d_model)
|
||||
attention_weights: Attention weights from self-attention
|
||||
output: Self-attention output (..., seq_len, d_model)
|
||||
attention_weights: Attention weights
|
||||
"""
|
||||
# Self-attention with residual connection and layer norm (Pre-Norm)
|
||||
ln1_output = self.ln1(x)
|
||||
attn_output, attention_weights = self.self_attention(ln1_output, ln1_output, ln1_output, mask)
|
||||
x = x + attn_output # Residual connection
|
||||
|
||||
# Feed-forward with residual connection and layer norm (Pre-Norm)
|
||||
ln2_output = self.ln2(x)
|
||||
# Simple feed-forward: Linear -> ReLU -> Linear
|
||||
ff_hidden = np.matmul(ln2_output, self.ff_w1) + self.ff_b1
|
||||
ff_hidden = np.maximum(0, ff_hidden) # ReLU activation
|
||||
ff_output = np.matmul(ff_hidden, self.ff_w2) + self.ff_b2
|
||||
x = x + ff_output # Residual connection
|
||||
|
||||
return x, attention_weights
|
||||
# Self-attention: Q = K = V = x
|
||||
return scaled_dot_product_attention(x, x, x, mask)
|
||||
|
||||
def __call__(self, x: np.ndarray, mask: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Make the class callable."""
|
||||
@@ -823,63 +334,298 @@ class TransformerBlock:
|
||||
|
||||
# %% [markdown]
|
||||
"""
|
||||
### 🧪 Unit Test: Complete Transformer Block
|
||||
### 🧪 Unit Test: Self-Attention
|
||||
|
||||
**This is a unit test** - it tests the complete transformer block integration.
|
||||
**This is a unit test** - it tests self-attention wrapper functionality.
|
||||
|
||||
Let's verify that our transformer block properly combines all components.
|
||||
Let's verify our self-attention wrapper works correctly.
|
||||
"""
|
||||
|
||||
# %% nbgrader={"grade": false, "grade_id": "test-transformer-block", "locked": false, "schema_version": 3, "solution": false, "task": false}
|
||||
print("🔬 Unit Test: Complete Transformer Block...")
|
||||
# %% nbgrader={"grade": false, "grade_id": "test-self-attention", "locked": false, "schema_version": 3, "solution": false, "task": false}
|
||||
print("🔬 Unit Test: Self-Attention...")
|
||||
|
||||
# Test parameters
|
||||
d_model = 64
|
||||
num_heads = 8
|
||||
d_ff = 256
|
||||
seq_len = 12
|
||||
d_model = 32
|
||||
seq_len = 8
|
||||
np.random.seed(42)
|
||||
|
||||
# Create test data (like embeddings + positional encoding)
|
||||
# Create test data (like word embeddings)
|
||||
x = np.random.randn(seq_len, d_model) * 0.1
|
||||
|
||||
print(f"📊 Test setup: d_model={d_model}, heads={num_heads}, d_ff={d_ff}, seq_len={seq_len}")
|
||||
print(f"📊 Test setup: d_model={d_model}, seq_len={seq_len}")
|
||||
|
||||
# Create transformer block
|
||||
transformer = TransformerBlock(d_model, num_heads, d_ff)
|
||||
# Create self-attention
|
||||
self_attn = SelfAttention(d_model)
|
||||
|
||||
# Test forward pass
|
||||
output, attention_weights = transformer(x)
|
||||
output, weights = self_attn(x)
|
||||
|
||||
print(f"📊 Output shapes: output{output.shape}, attention{attention_weights.shape}")
|
||||
print(f"📊 Output shapes: output{output.shape}, weights{weights.shape}")
|
||||
|
||||
# Verify properties
|
||||
print(f"✅ Output shape preserved: {output.shape == x.shape}")
|
||||
print(f"✅ Attention weights correct shape: {attention_weights.shape == (seq_len, seq_len)}")
|
||||
print(f"✅ Attention weights sum to 1: {np.allclose(np.sum(attention_weights, axis=-1), 1.0)}")
|
||||
print(f"✅ Attention weights correct shape: {weights.shape == (seq_len, seq_len)}")
|
||||
print(f"✅ Attention weights sum to 1: {np.allclose(np.sum(weights, axis=-1), 1.0)}")
|
||||
print(f"✅ Self-attention is symmetric operation: {weights.shape[0] == weights.shape[1]}")
|
||||
|
||||
# Test with causal mask (for autoregressive models like GPT)
|
||||
causal_mask = np.tril(np.ones((seq_len, seq_len))) # Lower triangular mask
|
||||
output_masked, attention_masked = transformer(x, causal_mask)
|
||||
print("📈 Progress: Self-Attention ✓")
|
||||
|
||||
print(f"✅ Masked transformer works: {output_masked.shape == x.shape}")
|
||||
# %% [markdown]
|
||||
"""
|
||||
## Step 4: Attention Masking - Controlling Information Flow
|
||||
|
||||
# Verify causal masking worked
|
||||
upper_triangle = np.triu(attention_masked, k=1) # Upper triangle should be ~0
|
||||
print(f"✅ Causal masking applied: {np.all(upper_triangle < 1e-6)}")
|
||||
### Why Masking Matters
|
||||
Masking allows us to control which positions can attend to which other positions:
|
||||
|
||||
print("📈 Progress: Complete Transformer Block ✓")
|
||||
1. **Causal Masking**: For autoregressive models (like GPT) - can't see future tokens
|
||||
2. **Padding Masking**: Ignore padding tokens in variable-length sequences
|
||||
3. **Custom Masking**: Application-specific attention patterns
|
||||
|
||||
### Types of Masks
|
||||
- **Causal (Lower Triangular)**: Position i can only attend to positions ≤ i
|
||||
- **Padding**: Mask out padding tokens so they don't affect attention
|
||||
- **Bidirectional**: All positions can attend to all positions (like BERT)
|
||||
|
||||
Let's implement these essential masking utilities!
|
||||
"""
|
||||
|
||||
# %% nbgrader={"grade": false, "grade_id": "attention-masking", "locked": false, "schema_version": 3, "solution": true, "task": false}
|
||||
#| export
|
||||
def create_causal_mask(seq_len: int) -> np.ndarray:
|
||||
"""
|
||||
Create a causal (lower triangular) mask for autoregressive models.
|
||||
|
||||
Used in models like GPT where each position can only attend to
|
||||
previous positions, not future ones.
|
||||
|
||||
Args:
|
||||
seq_len: Sequence length
|
||||
|
||||
Returns:
|
||||
mask: Causal mask (seq_len, seq_len) with 1s for allowed positions, 0s for blocked
|
||||
"""
|
||||
return np.tril(np.ones((seq_len, seq_len)))
|
||||
|
||||
#| export
|
||||
def create_padding_mask(lengths: List[int], max_length: int) -> np.ndarray:
|
||||
"""
|
||||
Create padding mask for variable-length sequences.
|
||||
|
||||
Args:
|
||||
lengths: List of actual sequence lengths
|
||||
max_length: Maximum sequence length (padded length)
|
||||
|
||||
Returns:
|
||||
mask: Padding mask (batch_size, max_length, max_length)
|
||||
"""
|
||||
batch_size = len(lengths)
|
||||
mask = np.zeros((batch_size, max_length, max_length))
|
||||
|
||||
for i, length in enumerate(lengths):
|
||||
mask[i, :length, :length] = 1
|
||||
|
||||
return mask
|
||||
|
||||
#| export
|
||||
def create_bidirectional_mask(seq_len: int) -> np.ndarray:
|
||||
"""
|
||||
Create a bidirectional mask where all positions can attend to all positions.
|
||||
|
||||
Used in models like BERT for bidirectional context understanding.
|
||||
|
||||
Args:
|
||||
seq_len: Sequence length
|
||||
|
||||
Returns:
|
||||
mask: All-ones mask (seq_len, seq_len)
|
||||
"""
|
||||
return np.ones((seq_len, seq_len))
|
||||
|
||||
# %% [markdown]
|
||||
"""
|
||||
### 🧪 Unit Test: Attention Masking
|
||||
|
||||
**This is a unit test** - it tests all masking utilities work correctly.
|
||||
|
||||
Let's verify our masking functions create the correct patterns.
|
||||
"""
|
||||
|
||||
# %% nbgrader={"grade": false, "grade_id": "test-masking", "locked": false, "schema_version": 3, "solution": false, "task": false}
|
||||
print("🔬 Unit Test: Attention Masking...")
|
||||
|
||||
# Test causal mask
|
||||
seq_len = 5
|
||||
causal_mask = create_causal_mask(seq_len)
|
||||
|
||||
print(f"📊 Causal mask for seq_len={seq_len}:")
|
||||
print(causal_mask)
|
||||
|
||||
# Verify causal mask properties
|
||||
print(f"✅ Causal mask is lower triangular: {np.allclose(causal_mask, np.tril(causal_mask))}")
|
||||
print(f"✅ Causal mask has correct shape: {causal_mask.shape == (seq_len, seq_len)}")
|
||||
print(f"✅ Causal mask upper triangle is zeros: {np.all(np.triu(causal_mask, k=1) == 0)}")
|
||||
|
||||
# Test padding mask
|
||||
lengths = [5, 3, 4]
|
||||
max_length = 5
|
||||
padding_mask = create_padding_mask(lengths, max_length)
|
||||
|
||||
print(f"📊 Padding mask for lengths {lengths}, max_length={max_length}:")
|
||||
print("Mask for sequence 0 (length 5):")
|
||||
print(padding_mask[0])
|
||||
print("Mask for sequence 1 (length 3):")
|
||||
print(padding_mask[1])
|
||||
|
||||
# Verify padding mask properties
|
||||
print(f"✅ Padding mask has correct shape: {padding_mask.shape == (3, max_length, max_length)}")
|
||||
print(f"✅ Full-length sequence is all ones: {np.all(padding_mask[0] == 1)}")
|
||||
print(f"✅ Short sequence has zeros in padding area: {np.all(padding_mask[1, 3:, :] == 0)}")
|
||||
|
||||
# Test bidirectional mask
|
||||
bidirectional_mask = create_bidirectional_mask(seq_len)
|
||||
print(f"✅ Bidirectional mask is all ones: {np.all(bidirectional_mask == 1)}")
|
||||
print(f"✅ Bidirectional mask has correct shape: {bidirectional_mask.shape == (seq_len, seq_len)}")
|
||||
|
||||
print("📈 Progress: Attention Masking ✓")
|
||||
|
||||
# %% [markdown]
|
||||
"""
|
||||
## Step 5: Attention Visualization and Analysis
|
||||
|
||||
### Understanding What Attention Learns
|
||||
Let's create a simple example to see what attention patterns emerge and understand the behavior.
|
||||
"""
|
||||
|
||||
# %% nbgrader={"grade": false, "grade_id": "attention-analysis", "locked": false, "schema_version": 3, "solution": false, "task": false}
|
||||
print("🎯 Attention behavior analysis:")
|
||||
|
||||
# Create a simple sequence with clear patterns
|
||||
simple_seq = np.array([
|
||||
[1, 0, 0, 0], # Position 0: [1, 0, 0, 0]
|
||||
[0, 1, 0, 0], # Position 1: [0, 1, 0, 0]
|
||||
[0, 0, 1, 0], # Position 2: [0, 0, 1, 0]
|
||||
[1, 0, 0, 0], # Position 3: [1, 0, 0, 0] (same as position 0)
|
||||
])
|
||||
|
||||
print(f"🎯 Simple test sequence shape: {simple_seq.shape}")
|
||||
|
||||
# Apply attention
|
||||
output, weights = scaled_dot_product_attention(simple_seq, simple_seq, simple_seq)
|
||||
|
||||
print(f"🎯 Attention pattern analysis:")
|
||||
print(f"Position 0 attends most to position: {np.argmax(weights[0])}")
|
||||
print(f"Position 3 attends most to position: {np.argmax(weights[3])}")
|
||||
print(f"✅ Positions with same content should attend to each other!")
|
||||
|
||||
# Test with causal masking
|
||||
causal_mask = create_causal_mask(4)
|
||||
output_causal, weights_causal = scaled_dot_product_attention(simple_seq, simple_seq, simple_seq, causal_mask)
|
||||
|
||||
print(f"🎯 With causal masking:")
|
||||
print(f"Position 3 can only attend to positions 0-3: {np.sum(weights_causal[3, :]) > 0.99}")
|
||||
|
||||
if _should_show_plots():
|
||||
plt.figure(figsize=(12, 4))
|
||||
|
||||
plt.subplot(1, 3, 1)
|
||||
plt.imshow(weights, cmap='Blues')
|
||||
plt.title('Full Attention Weights\n(Darker = Higher Attention)')
|
||||
plt.xlabel('Key Position')
|
||||
plt.ylabel('Query Position')
|
||||
plt.colorbar()
|
||||
|
||||
# Add text annotations
|
||||
for i in range(4):
|
||||
for j in range(4):
|
||||
plt.text(j, i, f'{weights[i,j]:.2f}',
|
||||
ha='center', va='center',
|
||||
color='white' if weights[i,j] > 0.5 else 'black')
|
||||
|
||||
plt.subplot(1, 3, 2)
|
||||
plt.imshow(weights_causal, cmap='Blues')
|
||||
plt.title('Causal Attention Weights\n(Upper triangle masked)')
|
||||
plt.xlabel('Key Position')
|
||||
plt.ylabel('Query Position')
|
||||
plt.colorbar()
|
||||
|
||||
plt.subplot(1, 3, 3)
|
||||
plt.plot(weights[0], 'o-', label='Position 0 attention')
|
||||
plt.plot(weights[3], 's-', label='Position 3 attention')
|
||||
plt.xlabel('Attending to Position')
|
||||
plt.ylabel('Attention Weight')
|
||||
plt.title('Attention Distribution')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
print("🎯 Attention learns to focus on similar content!")
|
||||
|
||||
# %% [markdown]
|
||||
"""
|
||||
### 🧪 Unit Test: Complete Attention System Integration
|
||||
|
||||
**This is a unit test** - it tests the complete attention system working together.
|
||||
|
||||
Let's verify all components work together seamlessly.
|
||||
"""
|
||||
|
||||
# %% nbgrader={"grade": false, "grade_id": "test-integration", "locked": false, "schema_version": 3, "solution": false, "task": false}
|
||||
print("🔬 Unit Test: Complete Attention System Integration...")
|
||||
|
||||
# Test parameters
|
||||
d_model = 64
|
||||
seq_len = 16
|
||||
batch_size = 2
|
||||
np.random.seed(42)
|
||||
|
||||
print(f"📊 Integration test: d_model={d_model}, seq_len={seq_len}, batch_size={batch_size}")
|
||||
|
||||
# Step 1: Create input embeddings (simulating word embeddings)
|
||||
embeddings = np.random.randn(batch_size, seq_len, d_model) * 0.1
|
||||
print(f"📊 Input embeddings: {embeddings.shape}")
|
||||
|
||||
# Step 2: Test basic attention
|
||||
output, attention_weights = scaled_dot_product_attention(embeddings, embeddings, embeddings)
|
||||
print(f"✅ Basic attention works: {output.shape}")
|
||||
|
||||
# Step 3: Test self-attention wrapper
|
||||
self_attn = SelfAttention(d_model)
|
||||
self_output, self_weights = self_attn(embeddings[0]) # Single batch item
|
||||
print(f"✅ Self-attention output: {self_output.shape}")
|
||||
|
||||
# Step 4: Test with causal mask (like GPT)
|
||||
causal_mask = create_causal_mask(seq_len)
|
||||
causal_output, causal_weights = scaled_dot_product_attention(
|
||||
embeddings[0], embeddings[0], embeddings[0], causal_mask
|
||||
)
|
||||
print(f"✅ Causal attention works: {causal_output.shape}")
|
||||
|
||||
# Step 5: Test with padding mask (variable lengths)
|
||||
lengths = [seq_len, seq_len-3] # Different sequence lengths
|
||||
padding_mask = create_padding_mask(lengths, seq_len)
|
||||
padded_output, padded_weights = scaled_dot_product_attention(
|
||||
embeddings[0], embeddings[0], embeddings[0], padding_mask[0]
|
||||
)
|
||||
print(f"✅ Padding mask works: {padded_output.shape}")
|
||||
|
||||
# Step 6: Verify all outputs have correct properties
|
||||
print(f"✅ All attention weights sum to 1: {np.allclose(np.sum(attention_weights, axis=-1), 1.0)}")
|
||||
print(f"✅ All outputs preserve input shape: {output.shape == embeddings.shape}")
|
||||
print(f"✅ Causal masking works: {np.all(np.triu(causal_weights, k=1) < 1e-6)}")
|
||||
|
||||
print("📈 Progress: Complete Attention System ✓")
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("🔥 ATTENTION MODULE COMPLETE!")
|
||||
print("="*50)
|
||||
print("✅ Scaled dot-product attention")
|
||||
print("✅ Multi-head attention")
|
||||
print("✅ Positional encoding")
|
||||
print("✅ Layer normalization")
|
||||
print("✅ Complete transformer block")
|
||||
print("✅ Self-attention wrapper")
|
||||
print("✅ Masking utilities")
|
||||
print("✅ Integration tests")
|
||||
print("✅ Self-attention wrapper")
|
||||
print("✅ Causal masking")
|
||||
print("✅ Padding masking")
|
||||
print("✅ Bidirectional masking")
|
||||
print("✅ Attention visualization")
|
||||
print("✅ Complete integration tests")
|
||||
print("\nYou now understand the core mechanism powering modern AI! 🚀")
|
||||
print("Next: Apply these attention mechanisms to real datasets and tasks.")
|
||||
print("Next: Learn how to build complete transformer models using this foundation.")
|
||||
|
||||
@@ -3,12 +3,12 @@
|
||||
|
||||
name: "attention"
|
||||
title: "Attention"
|
||||
description: "Attention mechanisms and transformer architectures"
|
||||
description: "Core attention mechanism and masking utilities"
|
||||
|
||||
# Dependencies - Used by CLI for module ordering and prerequisites
|
||||
dependencies:
|
||||
prerequisites: ["setup", "tensor", "activations", "layers", "networks"]
|
||||
enables: ["training", "cnn", "optimization", "transformers"]
|
||||
prerequisites: ["setup", "tensor"]
|
||||
enables: ["training", "transformers", "nlp"]
|
||||
|
||||
# Package Export - What gets built into tinytorch package
|
||||
exports_to: "tinytorch.core.attention"
|
||||
@@ -20,13 +20,13 @@ files:
|
||||
tests: "inline"
|
||||
|
||||
# Educational Metadata
|
||||
difficulty: "⭐⭐⭐⭐"
|
||||
time_estimate: "6-8 hours"
|
||||
difficulty: "⭐⭐⭐"
|
||||
time_estimate: "4-5 hours"
|
||||
|
||||
# Components - What's implemented in this module
|
||||
components:
|
||||
- "scaled_dot_product_attention"
|
||||
- "MultiHeadAttention"
|
||||
- "PositionalEncoding"
|
||||
- "TransformerBlock"
|
||||
- "SelfAttention"
|
||||
- "SelfAttention"
|
||||
- "create_causal_mask"
|
||||
- "create_padding_mask"
|
||||
- "create_bidirectional_mask"
|
||||
Reference in New Issue
Block a user