mirror of
https://github.com/harvard-edge/cs249r_book.git
synced 2026-04-29 00:59:07 -05:00
refactor(milestone): extract challenges into separate functions
Break the monolithic main() into clean, documented functions: - CONFIG dict for shared hyperparameters - build_model() for creating fresh model/optimizer/loss - challenge_1_reversal() - anti-diagonal attention patterns - challenge_2_copying() - diagonal attention patterns - challenge_3_mixed() - prefix-conditioned behavior (fresh model) - print_final_results() - summary table and messages This makes the code much easier for students to understand and clearly shows why challenge 3 needs a fresh model.
This commit is contained in:
@@ -468,107 +468,161 @@ def run_challenge(name, model, train_data, test_data, optimizer, loss_fn, epochs
|
||||
return passed, final_acc
|
||||
|
||||
|
||||
def main():
|
||||
"""Main training loop with three challenges."""
|
||||
# =============================================================================
|
||||
# CONFIGURATION
|
||||
# =============================================================================
|
||||
|
||||
# Banner
|
||||
console.print()
|
||||
console.print(Panel.fit(
|
||||
"[bold cyan]MILESTONE 05: ATTENTION IS ALL YOU NEED[/bold cyan]\n\n"
|
||||
"[yellow]Prove your attention mechanism works by passing THREE challenges.[/yellow]\n\n"
|
||||
"Challenge 1: Sequence Reversal (PYTHON -> NOHTYP)\n"
|
||||
"Challenge 2: Sequence Copying (TENSOR -> TENSOR)\n"
|
||||
"Challenge 3: Mixed Tasks ([R]ABC -> CBA, [C]ABC -> ABC)",
|
||||
border_style="cyan",
|
||||
title="The Transformer Challenge"
|
||||
))
|
||||
console.print()
|
||||
# Model hyperparameters (shared across all challenges)
|
||||
CONFIG = {
|
||||
'vocab_size': 29, # 0=pad, 1-26=A-Z, 27=[R], 28=[C]
|
||||
'seq_len': 6, # Sequence length for tasks
|
||||
'embed_dim': 64, # Embedding dimensions
|
||||
'num_heads': 4, # Attention heads
|
||||
'num_layers': 2, # Transformer blocks
|
||||
'lr': 0.001, # Learning rate
|
||||
}
|
||||
|
||||
# Configuration
|
||||
vocab_size = 29 # 0=pad, 1-26=A-Z, 27=[R], 28=[C]
|
||||
seq_len = 6
|
||||
embed_dim = 64
|
||||
num_heads = 4
|
||||
num_layers = 2
|
||||
lr = 0.001
|
||||
|
||||
console.print(Panel(
|
||||
f"[bold]Model Configuration[/bold]\n"
|
||||
f" Vocabulary: {vocab_size} tokens (A-Z + special)\n"
|
||||
f" Sequence: {seq_len} letters\n"
|
||||
f" Embedding: {embed_dim} dimensions\n"
|
||||
f" Attention: {num_heads} heads\n"
|
||||
f" Layers: {num_layers} transformer blocks\n"
|
||||
f" Learning: {lr}",
|
||||
title="Configuration",
|
||||
border_style="blue"
|
||||
))
|
||||
def build_model(config=CONFIG):
|
||||
"""
|
||||
Build a fresh transformer model with optimizer.
|
||||
|
||||
# Build model
|
||||
console.print("\n[bold]Building Transformer...[/bold]")
|
||||
Returns:
|
||||
model: AttentionTransformer instance
|
||||
optimizer: Adam optimizer
|
||||
loss_fn: CrossEntropyLoss
|
||||
"""
|
||||
model = AttentionTransformer(
|
||||
vocab_size=vocab_size,
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
seq_len=seq_len + 1, # +1 for task prefix in challenge 3
|
||||
num_layers=num_layers
|
||||
vocab_size=config['vocab_size'],
|
||||
embed_dim=config['embed_dim'],
|
||||
num_heads=config['num_heads'],
|
||||
seq_len=config['seq_len'] + 1, # +1 for task prefix in challenge 3
|
||||
num_layers=config['num_layers']
|
||||
)
|
||||
console.print(f" Total parameters: {model.total_params:,}")
|
||||
|
||||
for param in model.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
optimizer = Adam(model.parameters(), lr=lr)
|
||||
optimizer = Adam(model.parameters(), lr=config['lr'])
|
||||
loss_fn = CrossEntropyLoss()
|
||||
|
||||
results = {}
|
||||
return model, optimizer, loss_fn
|
||||
|
||||
# Challenge 1: Sequence Reversal
|
||||
train_rev = generate_reversal_data(600, seq_len)
|
||||
test_rev = generate_reversal_data(200, seq_len)
|
||||
passed1, acc1 = run_challenge(
|
||||
|
||||
# =============================================================================
|
||||
# CHALLENGE 1: SEQUENCE REVERSAL
|
||||
# =============================================================================
|
||||
|
||||
def challenge_1_reversal(model, optimizer, loss_fn, config=CONFIG):
|
||||
"""
|
||||
Challenge 1: Learn to reverse sequences (PYTHON -> NOHTYP).
|
||||
|
||||
Tests if attention can learn anti-diagonal patterns where
|
||||
output position i attends to input position (n-1-i).
|
||||
|
||||
Returns:
|
||||
passed: bool - whether target accuracy was achieved
|
||||
accuracy: float - final test accuracy
|
||||
"""
|
||||
seq_len = config['seq_len']
|
||||
|
||||
train_data = generate_reversal_data(600, seq_len)
|
||||
test_data = generate_reversal_data(200, seq_len)
|
||||
|
||||
return run_challenge(
|
||||
"CHALLENGE 1: SEQUENCE REVERSAL",
|
||||
model, train_rev, test_rev, optimizer, loss_fn,
|
||||
model, train_data, test_data, optimizer, loss_fn,
|
||||
epochs=50, target_acc=95
|
||||
)
|
||||
results['reversal'] = (passed1, acc1)
|
||||
|
||||
# Challenge 2: Sequence Copying (same model, different task)
|
||||
train_copy = generate_copy_data(600, seq_len)
|
||||
test_copy = generate_copy_data(200, seq_len)
|
||||
passed2, acc2 = run_challenge(
|
||||
|
||||
# =============================================================================
|
||||
# CHALLENGE 2: SEQUENCE COPYING
|
||||
# =============================================================================
|
||||
|
||||
def challenge_2_copying(model, optimizer, loss_fn, config=CONFIG):
|
||||
"""
|
||||
Challenge 2: Learn to copy sequences (TENSOR -> TENSOR).
|
||||
|
||||
Tests if attention can learn diagonal patterns where
|
||||
output position i attends to input position i.
|
||||
|
||||
Note: Uses the SAME model from Challenge 1, demonstrating
|
||||
that transformers can adapt to new tasks.
|
||||
|
||||
Returns:
|
||||
passed: bool - whether target accuracy was achieved
|
||||
accuracy: float - final test accuracy
|
||||
"""
|
||||
seq_len = config['seq_len']
|
||||
|
||||
train_data = generate_copy_data(600, seq_len)
|
||||
test_data = generate_copy_data(200, seq_len)
|
||||
|
||||
return run_challenge(
|
||||
"CHALLENGE 2: SEQUENCE COPYING",
|
||||
model, train_copy, test_copy, optimizer, loss_fn,
|
||||
model, train_data, test_data, optimizer, loss_fn,
|
||||
epochs=50, target_acc=95
|
||||
)
|
||||
results['copying'] = (passed2, acc2)
|
||||
|
||||
# Challenge 3: Mixed Tasks (the real test)
|
||||
# Reset model and optimizer so it learns both tasks from scratch
|
||||
# with prefix conditioning (otherwise it's stuck in "copy mode")
|
||||
console.print("\n[dim]Reinitializing model for mixed task learning...[/dim]")
|
||||
model = AttentionTransformer(
|
||||
vocab_size=vocab_size,
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
seq_len=seq_len + 1,
|
||||
num_layers=num_layers
|
||||
)
|
||||
for param in model.parameters():
|
||||
param.requires_grad = True
|
||||
optimizer = Adam(model.parameters(), lr=lr)
|
||||
|
||||
train_mixed = generate_mixed_data(800, seq_len)
|
||||
test_mixed = generate_mixed_data(300, seq_len)
|
||||
passed3, acc3 = run_challenge(
|
||||
# =============================================================================
|
||||
# CHALLENGE 3: MIXED TASK INFERENCE
|
||||
# =============================================================================
|
||||
|
||||
def challenge_3_mixed(config=CONFIG):
|
||||
"""
|
||||
Challenge 3: Learn BOTH tasks with prefix conditioning.
|
||||
|
||||
This is the real test! The model must learn:
|
||||
[R]ABC -> CBA (reverse when prefix is [R])
|
||||
[C]ABC -> ABC (copy when prefix is [C])
|
||||
|
||||
IMPORTANT: We build a FRESH model here because:
|
||||
- After challenges 1 & 2, the model is "stuck" in copy mode
|
||||
- To learn conditional behavior, it needs to see both tasks together
|
||||
- This tests the model's ability to use the prefix token to route behavior
|
||||
|
||||
Returns:
|
||||
passed: bool - whether target accuracy was achieved
|
||||
accuracy: float - final test accuracy
|
||||
"""
|
||||
seq_len = config['seq_len']
|
||||
|
||||
console.print("\n[dim]Building fresh model for mixed task learning...[/dim]")
|
||||
model, optimizer, loss_fn = build_model(config)
|
||||
console.print(f"[dim] Total parameters: {model.total_params:,}[/dim]")
|
||||
|
||||
train_data = generate_mixed_data(800, seq_len)
|
||||
test_data = generate_mixed_data(300, seq_len)
|
||||
|
||||
return run_challenge(
|
||||
"CHALLENGE 3: MIXED TASK INFERENCE",
|
||||
model, train_mixed, test_mixed, optimizer, loss_fn,
|
||||
model, train_data, test_data, optimizer, loss_fn,
|
||||
epochs=60, target_acc=90
|
||||
)
|
||||
results['mixed'] = (passed3, acc3)
|
||||
|
||||
# Final Summary
|
||||
console.print("\n" + "="*60)
|
||||
|
||||
# =============================================================================
|
||||
# RESULTS DISPLAY
|
||||
# =============================================================================
|
||||
|
||||
def print_final_results(results):
|
||||
"""
|
||||
Print the final results table and success/failure message.
|
||||
|
||||
Args:
|
||||
results: dict with keys 'reversal', 'copying', 'mixed'
|
||||
each containing (passed: bool, accuracy: float)
|
||||
|
||||
Returns:
|
||||
0 if all passed, 1 otherwise
|
||||
"""
|
||||
passed1, acc1 = results['reversal']
|
||||
passed2, acc2 = results['copying']
|
||||
passed3, acc3 = results['mixed']
|
||||
|
||||
console.print("\n" + "=" * 60)
|
||||
console.print(Panel.fit("[bold]FINAL RESULTS[/bold]", border_style="cyan"))
|
||||
|
||||
table = Table(box=box.ROUNDED)
|
||||
@@ -636,5 +690,80 @@ def main():
|
||||
return 1
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MAIN ENTRY POINT
|
||||
# =============================================================================
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main entry point: Run all three transformer challenges.
|
||||
|
||||
Challenge Structure:
|
||||
────────────────────
|
||||
1. REVERSAL - Can attention learn anti-diagonal patterns?
|
||||
2. COPYING - Can attention learn diagonal patterns? (same model)
|
||||
3. MIXED - Can attention learn conditional behavior? (fresh model)
|
||||
|
||||
The first two challenges use the same model to show adaptability.
|
||||
The third uses a fresh model to properly test prefix conditioning.
|
||||
"""
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# BANNER
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
console.print()
|
||||
console.print(Panel.fit(
|
||||
"[bold cyan]MILESTONE 05: ATTENTION IS ALL YOU NEED[/bold cyan]\n\n"
|
||||
"[yellow]Prove your attention mechanism works by passing THREE challenges.[/yellow]\n\n"
|
||||
"Challenge 1: Sequence Reversal (PYTHON -> NOHTYP)\n"
|
||||
"Challenge 2: Sequence Copying (TENSOR -> TENSOR)\n"
|
||||
"Challenge 3: Mixed Tasks ([R]ABC -> CBA, [C]ABC -> ABC)",
|
||||
border_style="cyan",
|
||||
title="The Transformer Challenge"
|
||||
))
|
||||
console.print()
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# CONFIGURATION
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
console.print(Panel(
|
||||
f"[bold]Model Configuration[/bold]\n"
|
||||
f" Vocabulary: {CONFIG['vocab_size']} tokens (A-Z + special)\n"
|
||||
f" Sequence: {CONFIG['seq_len']} letters\n"
|
||||
f" Embedding: {CONFIG['embed_dim']} dimensions\n"
|
||||
f" Attention: {CONFIG['num_heads']} heads\n"
|
||||
f" Layers: {CONFIG['num_layers']} transformer blocks\n"
|
||||
f" Learning: {CONFIG['lr']}",
|
||||
title="Configuration",
|
||||
border_style="blue"
|
||||
))
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# BUILD MODEL (shared for challenges 1 & 2)
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
console.print("\n[bold]Building Transformer...[/bold]")
|
||||
model, optimizer, loss_fn = build_model()
|
||||
console.print(f" Total parameters: {model.total_params:,}")
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# RUN CHALLENGES
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
results = {}
|
||||
|
||||
# Challenge 1: Sequence Reversal
|
||||
results['reversal'] = challenge_1_reversal(model, optimizer, loss_fn)
|
||||
|
||||
# Challenge 2: Sequence Copying (same model, different task)
|
||||
results['copying'] = challenge_2_copying(model, optimizer, loss_fn)
|
||||
|
||||
# Challenge 3: Mixed Tasks (fresh model - see docstring for why)
|
||||
results['mixed'] = challenge_3_mixed()
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# FINAL RESULTS
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
return print_final_results(results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
||||
Reference in New Issue
Block a user