Files
TinyTorch/demos/demo_attention.py
Vijay Janapa Reddi 84291fcf5e Add educational descriptions and interpretation guides to all demos
- Added 'Understanding This Demo' panels explaining what students will see
- Added inline interpretation guides with 💡 markers throughout demos
- Enhanced explanations of outputs, tables, and visualizations
- Added context about why concepts matter in ML/AI
- Improved pedagogical clarity for all 8 demo files:
  - demo_tensor_math.py: Matrix operations context
  - demo_activations.py: Nonlinearity importance
  - demo_single_neuron.py: Learning process clarity
  - demo_xor_network.py: Multi-layer necessity
  - demo_vision.py: CNN feature hierarchy
  - demo_attention.py: Attention mechanics
  - demo_training.py: Pipeline understanding
  - demo_language.py: Language generation insights

These additions help students not just see the demos run, but understand
what the outputs mean and why these concepts are fundamental to ML.
2025-09-18 19:54:34 -04:00

459 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
TinyTorch Demo 07: Attention Mechanisms - The AI Revolution
Shows how attention transforms sequence processing and enables modern AI!
"""
import sys
import numpy as np
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from rich.syntax import Syntax
from rich.text import Text
from rich.columns import Columns
def demo_attention():
"""Demo attention mechanisms for sequence understanding and modern AI"""
console = Console()
try:
# Import TinyTorch modules
import tinytorch.core.tensor as tt
import tinytorch.core.activations as act
import tinytorch.core.layers as layers
import tinytorch.core.dense as dense
import tinytorch.core.attention as attention
# Main header
console.print(Panel.fit(
"🎯 TinyTorch Attention Mechanisms Demo\nThe breakthrough that enabled ChatGPT and modern AI!",
style="bold cyan",
border_style="bright_blue"
))
console.print()
# What this demo shows
console.print(Panel(
"[bold yellow]What This Demo Shows:[/bold yellow]\n\n"
"Attention mechanisms solved the fundamental problem of sequence processing - how to let\n"
"any part of a sequence directly access information from any other part. You'll discover:\n\n"
"• Why RNNs failed on long sequences - the information bottleneck problem\n"
"• How attention enables direct connections between all sequence positions\n"
"• The elegant math behind attention: Query, Key, Value operations\n"
"• Why multi-head attention gives different types of understanding\n"
"• How Transformers stack attention layers to build deep understanding\n\n"
"[bold cyan]Key Insight:[/bold cyan] Attention is about letting the model decide what to focus on,\n"
"instead of forcing it through fixed computation patterns. This flexibility is why it works!",
title="📚 Understanding This Demo",
style="blue"
))
console.print()
# Demo 1: The Attention Problem
console.print(Panel(
"From fixed-size bottlenecks to dynamic focus...",
title="🧠 Demo 1: Why Attention Revolutionized AI",
style="green"
))
# Simulate a sequence processing problem
sequence = ["The", "cat", "sat", "on", "the", "mat"]
console.print(f"[bold cyan]Input sequence:[/bold cyan] {' '.join(sequence)}")
console.print()
# Create comparison table
comparison_table = Table(show_header=True, header_style="bold magenta")
comparison_table.add_column("Traditional RNN", style="red")
comparison_table.add_column("Attention Mechanism", style="green")
rnn_steps = [
"[The] → h1",
"[cat] + h1 → h2",
"[sat] + h2 → h3",
"[on] + h3 → h4",
"[the] + h4 → h5",
"[mat] + h5 → h6 (final)"
]
attention_steps = [
"Process ALL positions simultaneously:",
"[The, cat, sat, on, the, mat]",
"",
"For each output:",
"Look at ALL inputs with learned weights",
"Direct access to any information!"
]
for rnn, attn in zip(rnn_steps, attention_steps):
comparison_table.add_row(rnn, attn)
console.print(comparison_table)
console.print()
console.print("[dim]💡 [bold]Key Difference:[/bold] RNNs process sequentially, attention processes in parallel:[/dim]")
console.print("[dim] • RNN: Must go through h3 to connect 'cat' and 'mat' (loses information)[/dim]")
console.print("[dim] • Attention: 'cat' and 'mat' can directly interact (preserves all information)[/dim]")
console.print()
# Problems and solutions
problems_panel = Panel(
"❌ Problem: h6 must encode ALL previous information!\n❌ Result: Information loss, especially for long sequences",
title="Traditional RNN Issues",
style="red"
)
solutions_panel = Panel(
"✅ Solution: Direct access to any previous information!\n✅ Result: No information bottleneck!",
title="Attention Solution",
style="green"
)
console.print(Columns([problems_panel, solutions_panel]))
console.print()
# Demo 2: Basic Attention Mechanism
print("🔍 Demo 2: Basic Attention Computation")
print("Computing attention weights step by step...")
print()
# Create simple sequence embeddings (3 words, 4 dimensions each)
sequence_length = 3
embed_dim = 4
# Word embeddings for "cat sat mat"
embeddings = tt.Tensor([
[1.0, 0.5, 0.2, 0.8], # "cat"
[0.3, 1.0, 0.7, 0.1], # "sat"
[0.6, 0.2, 1.0, 0.4] # "mat"
])
print("Word embeddings (3 words × 4 dimensions):")
for i, word in enumerate(["cat", "sat", "mat"]):
emb = embeddings.data[i]
print(f" {word}: [{emb[0]:.1f}, {emb[1]:.1f}, {emb[2]:.1f}, {emb[3]:.1f}]")
print()
# Simple attention: query attends to all keys
query = embeddings.data[1] # "sat" is attending
keys = embeddings.data # to all words
print(f"Query (word 'sat'): {query}")
print()
# Compute attention scores (dot product)
scores = np.dot(keys, query)
print("Attention scores (how much 'sat' attends to each word):")
for i, (word, score) in enumerate(zip(["cat", "sat", "mat"], scores)):
print(f" 'sat''{word}': {score:.3f}")
print()
console.print("[dim]💡 [bold]Understanding Scores:[/bold] Higher scores = stronger relationships:[/dim]")
console.print("[dim] • Dot product measures similarity between embeddings[/dim]")
console.print("[dim] • Similar vectors have high dot products[/dim]")
console.print("[dim] • These raw scores will be normalized with softmax[/dim]")
console.print()
# Softmax to get attention weights
exp_scores = np.exp(scores)
attention_weights = exp_scores / np.sum(exp_scores)
print("Attention weights (after softmax):")
for i, (word, weight) in enumerate(zip(["cat", "sat", "mat"], attention_weights)):
print(f" 'sat''{word}': {weight:.3f} ({weight*100:.1f}%)")
print(f"Total: {np.sum(attention_weights):.3f}")
print()
console.print("[dim]💡 [bold]Weights Interpretation:[/bold] Softmax creates a probability distribution:[/dim]")
console.print("[dim] • All weights sum to 1.0 (100%)[/dim]")
console.print("[dim] • Higher weights = more attention/importance[/dim]")
console.print("[dim] • The model learns what to pay attention to![/dim]")
console.print()
# Compute attended output
attended_output = np.sum(keys * attention_weights.reshape(-1, 1), axis=0)
print(f"Attended output for 'sat': {attended_output}")
print("(Weighted combination of all word embeddings)")
print()
# Demo 3: Multi-Head Attention
print("🧩 Demo 3: Multi-Head Attention - Multiple Perspectives")
print("Like having multiple experts focus on different aspects...")
print()
# Create multi-head attention layer
num_heads = 2
head_dim = embed_dim // num_heads
print(f"Multi-head setup: {num_heads} heads, {head_dim} dimensions each")
print()
# Simulate different attention heads
print("Head 1 (Syntax Expert) - Focuses on grammatical relationships:")
syntax_scores = np.array([0.2, 0.7, 0.1]) # Focuses on current word
syntax_weights = np.exp(syntax_scores) / np.sum(np.exp(syntax_scores))
for word, weight in zip(["cat", "sat", "mat"], syntax_weights):
print(f" '{word}': {weight:.3f}")
print()
print("Head 2 (Semantic Expert) - Focuses on meaning relationships:")
semantic_scores = np.array([0.4, 0.2, 0.4]) # Focuses on related objects
semantic_weights = np.exp(semantic_scores) / np.sum(np.exp(semantic_scores))
for word, weight in zip(["cat", "sat", "mat"], semantic_weights):
print(f" '{word}': {weight:.3f}")
print()
print("💡 Key insight: Different heads learn different types of relationships!")
print()
console.print("[dim]💡 [bold]Multi-Head Benefits:[/bold] Like having multiple experts:[/dim]")
console.print("[dim] • One head might focus on grammar (subject-verb)[/dim]")
console.print("[dim] • Another on semantics (cat-mat are both objects)[/dim]")
console.print("[dim] • Another on position (nearby words)[/dim]")
console.print("[dim] • Combined: Rich, multi-faceted understanding![/dim]")
console.print()
# Demo 4: Self-Attention in Practice
print("🎭 Demo 4: Self-Attention - Words Talking to Each Other")
print("Every word attends to every other word...")
print()
# Create attention layer
attn_layer = attention.SelfAttention(d_model=4)
print("Self-attention matrix (who attends to whom):")
print(" cat sat mat")
# Simulate attention weights for visualization
attention_matrix = np.array([
[0.4, 0.3, 0.3], # cat attends to...
[0.2, 0.6, 0.2], # sat attends to...
[0.3, 0.2, 0.5] # mat attends to...
])
for i, word in enumerate(["cat", "sat", "mat"]):
weights = attention_matrix[i]
print(f" {word}: {weights[0]:.1f} {weights[1]:.1f} {weights[2]:.1f}")
print()
print("Interpretation:")
print("'cat' focuses on itself (0.4) and context words")
print("'sat' focuses mainly on itself (0.6) - the action")
print("'mat' balances between all words")
print()
console.print("[dim]💡 [bold]Self-Attention Patterns:[/bold] Different words have different focus patterns:[/dim]")
console.print("[dim] • Content words (nouns/verbs) often have high self-attention[/dim]")
console.print("[dim] • Function words distribute attention more broadly[/dim]")
console.print("[dim] • These patterns emerge automatically during training![/dim]")
console.print()
# Demo 5: Scaled Dot-Product Attention
console.print(Panel(
"The mathematical foundation of modern AI",
title="⚖️ Demo 5: Scaled Dot-Product Attention - The Core Formula",
style="blue"
))
# Display the attention formula with syntax highlighting
formula_code = """
# The Attention Formula that Changed Everything
Attention(Q, K, V) = softmax(Q @ K^T / √d_k) @ V
Where:
Q = Queries (what we're looking for)
K = Keys (what's available to match against)
V = Values (what we actually retrieve)
d_k = key dimension (for scaling)
"""
console.print(Syntax(formula_code, "python", theme="monokai", line_numbers=False))
console.print()
# Create Q, K, V matrices
d_k = 4 # key dimension
scale_factor = 1.0 / np.sqrt(d_k)
Q = embeddings # Queries
K = embeddings # Keys
V = embeddings # Values
print(f"Q (Queries): {Q.data.shape}")
print(f"K (Keys): {K.data.shape}")
print(f"V (Values): {V.data.shape}")
print(f"Scale factor: 1/√{d_k} = {scale_factor:.3f}")
print()
# Compute attention
QK = np.dot(Q.data, K.data.T) # Query-Key similarity
scaled_QK = QK * scale_factor # Scale to prevent large values
attn_weights = np.exp(scaled_QK) / np.sum(np.exp(scaled_QK), axis=1, keepdims=True)
output = np.dot(attn_weights, V.data)
print("Attention weights matrix:")
for i in range(3):
print(f" [{attn_weights[i,0]:.3f}, {attn_weights[i,1]:.3f}, {attn_weights[i,2]:.3f}]")
print()
print("Output (attended representations):")
for i, word in enumerate(["cat", "sat", "mat"]):
out = output[i]
print(f" {word}: [{out[0]:.3f}, {out[1]:.3f}, {out[2]:.3f}, {out[3]:.3f}]")
print()
console.print("[dim]💡 [bold]The Magic Formula:[/bold] Why this simple equation changed AI:[/dim]")
console.print("[dim] • Q⋅Kᵀ: Measures relevance between positions[/dim]")
console.print("[dim] • √dₖ scaling: Prevents gradient problems in deep networks[/dim]")
console.print("[dim] • Softmax: Creates sharp, interpretable attention patterns[/dim]")
console.print("[dim] • ×V: Retrieves weighted information from relevant positions[/dim]")
console.print()
# Demo 6: Transformer Architecture Preview
console.print(Panel(
"How attention enables modern language models...",
title="🏗️ Demo 6: Transformer Architecture - The Full Picture",
style="magenta"
))
# Transformer architecture diagram
transformer_arch = """
┌─────────────────────┐
│ Input Embeddings │
└─────────────────────┘
┌─────────────────────┐
│ Multi-Head Self- │
│ Attention │
└─────────────────────┘
↓ + (residual)
┌─────────────────────┐
│ Layer Normalization │
└─────────────────────┘
┌─────────────────────┐
│ Feed-Forward │
│ Network │
└─────────────────────┘
↓ + (residual)
┌─────────────────────┐
│ Layer Normalization │
└─────────────────────┘
┌─────────────────────┐
│ Output │
└─────────────────────┘
"""
console.print(Panel(transformer_arch, title="Transformer Block", style="cyan"))
# Why it works table
why_table = Table(show_header=True, header_style="bold magenta")
why_table.add_column("Component", style="cyan")
why_table.add_column("Purpose", style="yellow")
why_table.add_row("Self-attention", "Captures long-range dependencies")
why_table.add_row("Multi-head", "Multiple types of relationships")
why_table.add_row("Residual connections", "Stable training")
why_table.add_row("Layer normalization", "Normalized activations")
why_table.add_row("Feed-forward", "Non-linear transformations")
console.print(why_table)
console.print()
console.print("[dim]💡 [bold]Architecture Power:[/bold] Each component has a critical role:[/dim]")
console.print("[dim] • Residual connections: Allow 100+ layer deep networks[/dim]")
console.print("[dim] • Layer norm: Stabilizes training of very deep models[/dim]")
console.print("[dim] • Feed-forward: Adds computation power beyond attention[/dim]")
console.print()
# Demo 7: Real-World Applications
print("🌍 Demo 7: Real-World Impact")
print("Where attention mechanisms changed everything...")
print()
applications = [
("Language Translation", "Attention shows which source words align with target words"),
("ChatGPT/GPT-4", "Self-attention enables understanding of entire conversation context"),
("Image Captioning", "Visual attention focuses on relevant image regions"),
("Document Analysis", "Attention connects information across long documents"),
("Code Generation", "Attention relates variable names and function calls"),
("Scientific Discovery", "Attention finds patterns in massive datasets")
]
print("Revolutionary applications:")
for app, description in applications:
print(f"{app}: {description}")
print()
# Demo 8: Scaling Analysis
print("📈 Demo 8: Why Attention Scales")
print("Understanding computational complexity...")
print()
print("Attention complexity analysis:")
print(" Sequence length: n")
print(" Embedding dimension: d")
print(" ")
print(" Self-attention: O(n² × d)")
print(" Feed-forward: O(n × d²)")
print(" ")
print(" For long sequences: attention dominates")
print(" For wide embeddings: feed-forward dominates")
print()
print("Example scaling:")
for n in [100, 1000, 10000]:
attn_ops = n * n * 512
ff_ops = n * 512 * 2048
print(f" n={n}: Attention={attn_ops:,} ops, Feed-forward={ff_ops:,} ops")
print()
console.print("[dim]💡 [bold]Scaling Challenge:[/bold] Why context windows are limited:[/dim]")
console.print("[dim] • Attention is O(n²) - quadratic in sequence length[/dim]")
console.print("[dim] • This is why GPT models have token limits (4k, 8k, 32k, etc.)[/dim]")
console.print("[dim] • Active research: Efficient attention for longer sequences[/dim]")
console.print()
# Success summary
console.print(Panel.fit(
"🎯 Achievements:\n"
"• Understood the attention revolution and why it matters\n"
"• Computed attention weights and attended outputs\n"
"• Explored multi-head attention for different perspectives\n"
"• Analyzed self-attention matrices\n"
"• Implemented scaled dot-product attention formula\n"
"• Previewed complete Transformer architecture\n"
"• Connected to real-world AI applications\n"
"• Analyzed computational scaling properties\n\n"
"🔥 Next: End-to-end training pipelines!",
title="🏆 TinyTorch Attention Demo Complete!",
style="bold green",
border_style="bright_green"
))
return True
except ImportError as e:
console.print(Panel(
f"Could not import TinyTorch modules: {e}\n\n💡 Make sure to run: tito export 07_attention",
title="❌ Import Error",
style="bold red"
))
return False
except Exception as e:
console.print(Panel(
f"Demo failed: {e}",
title="❌ Error",
style="bold red"
))
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
success = demo_attention()
sys.exit(0 if success else 1)