Add transformer quickdemo with live learning progression dashboard

New milestone 05 demo that shows students the model learning to "talk":
- Live dashboard with epoch-by-epoch response progression
- Systems stats panel (tokens/sec, batch time, memory)
- 3 test prompts with full history displayed
- Smaller model (110K params) for ~2 minute training time

🤖 Generated with [Claude Code](https://claude.com/claude-code)
This commit is contained in:
Vijay Janapa Reddi
2025-11-22 15:55:12 -05:00
parent 5c3695a797
commit 59ebf0d385
2 changed files with 549 additions and 29 deletions

View File

@@ -0,0 +1,481 @@
#!/usr/bin/env python3
"""
TinyTalks Quick Demo - Watch Your Transformer Learn to Talk!
=============================================================
A fast, visual demonstration of transformer training.
See the model go from gibberish to coherent answers in ~2 minutes!
Features:
- Smaller model (~50K params) for fast training
- Live dashboard showing training progress
- Rotating prompts to show diverse capabilities
- Learning progression display (gibberish -> coherent)
"""
import sys
import os
import time
import numpy as np
from pathlib import Path
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
# Rich for live dashboard
from rich.console import Console
from rich.layout import Layout
from rich.panel import Panel
from rich.table import Table
from rich.live import Live
from rich.text import Text
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn
from rich import box
# TinyTorch imports
from tinytorch.core.tensor import Tensor
from tinytorch.core.optimizers import Adam
from tinytorch.core.losses import CrossEntropyLoss
from tinytorch.models.transformer import GPT
from tinytorch.text.tokenization import CharTokenizer
console = Console()
# =============================================================================
# Configuration - Optimized for ~2 minute training
# =============================================================================
CONFIG = {
# Model (smaller for speed)
"n_layer": 2,
"n_head": 2,
"n_embd": 64,
"max_seq_len": 32, # Shorter sequences for speed
# Training (optimized for ~2 min on pure Python)
"epochs": 8,
"batches_per_epoch": 30,
"batch_size": 8,
"learning_rate": 0.003, # Balanced LR for stable convergence
# Display
"update_interval": 5, # Update dashboard every N batches
}
# Test prompts to show model learning (3 prompts for better progression display)
TEST_PROMPTS = [
"Q: What is 2+2?\nA:",
"Q: What color is the sky?\nA:",
"Q: Say hello\nA:",
]
# =============================================================================
# Dataset
# =============================================================================
class TinyTalksDataset:
"""Simple character-level dataset from TinyTalks."""
def __init__(self, data_path: Path, seq_len: int):
self.seq_len = seq_len
# Load text
with open(data_path, 'r') as f:
self.text = f.read()
# Create tokenizer and build vocabulary
self.tokenizer = CharTokenizer()
self.tokenizer.build_vocab([self.text])
# Tokenize entire text
self.tokens = self.tokenizer.encode(self.text)
def __len__(self):
return len(self.tokens) - self.seq_len
def get_batch(self, batch_size: int):
"""Get random batch of sequences."""
indices = np.random.randint(0, len(self) - 1, size=batch_size)
inputs = []
targets = []
for idx in indices:
seq = self.tokens[idx:idx + self.seq_len + 1]
inputs.append(seq[:-1])
targets.append(seq[1:])
return (
Tensor(np.array(inputs)),
Tensor(np.array(targets))
)
# =============================================================================
# Text Generation
# =============================================================================
def generate_response(model, tokenizer, prompt: str, max_tokens: int = 30) -> str:
"""Generate text from prompt."""
# Encode prompt
tokens = tokenizer.encode(prompt)
for _ in range(max_tokens):
# Prepare input
context = tokens[-CONFIG["max_seq_len"]:]
x = Tensor(np.array([context]))
# Forward pass
logits = model.forward(x)
# Get next token probabilities
last_logits = logits.data[0, -1, :]
# Temperature sampling
temperature = 0.8
last_logits = last_logits / temperature
exp_logits = np.exp(last_logits - np.max(last_logits))
probs = exp_logits / np.sum(exp_logits)
# Sample
next_token = np.random.choice(len(probs), p=probs)
tokens.append(next_token)
# Stop at newline (end of answer)
if tokenizer.decode([next_token]) == '\n':
break
# Decode and extract answer
full_text = tokenizer.decode(tokens)
# Get just the answer part
if "A:" in full_text:
answer = full_text.split("A:")[-1].strip()
# Clean up - take first line
answer = answer.split('\n')[0].strip()
return answer if answer else "(empty)"
return full_text[len(prompt):].strip() or "(empty)"
# =============================================================================
# Dashboard Layout
# =============================================================================
def make_layout() -> Layout:
"""Create the dashboard layout."""
layout = Layout()
layout.split_column(
Layout(name="header", size=3),
Layout(name="main", ratio=1),
Layout(name="footer", size=3),
)
layout["main"].split_row(
Layout(name="left", ratio=1),
Layout(name="outputs", ratio=2),
)
layout["left"].split_column(
Layout(name="progress", ratio=2),
Layout(name="stats", ratio=1),
)
return layout
def make_header() -> Panel:
"""Create header panel."""
return Panel(
Text("TinyTalks Quick Demo - Watch Your Transformer Learn!",
style="bold cyan", justify="center"),
box=box.ROUNDED,
style="cyan",
)
def make_progress_panel(epoch: int, total_epochs: int, batch: int,
total_batches: int, loss: float, elapsed: float) -> Panel:
"""Create training progress panel."""
# Calculate overall progress
total_steps = total_epochs * total_batches
current_step = (epoch - 1) * total_batches + batch
progress_pct = (current_step / total_steps) * 100
# Progress bar
bar_width = 20
filled = int(bar_width * progress_pct / 100)
bar = "" * filled + "" * (bar_width - filled)
# Estimate time remaining
if current_step > 0:
time_per_step = elapsed / current_step
remaining_steps = total_steps - current_step
eta = remaining_steps * time_per_step
eta_str = f"{eta:.0f}s"
else:
eta_str = "..."
content = Text()
content.append(f"Epoch: {epoch}/{total_epochs}\n", style="bold")
content.append(f"Batch: {batch}/{total_batches}\n")
content.append(f"Loss: {loss:.3f}\n\n", style="yellow")
content.append(f"{bar} {progress_pct:.0f}%\n\n", style="green")
content.append(f"Elapsed: {elapsed:.0f}s\n")
content.append(f"ETA: {eta_str}")
return Panel(
content,
title="[bold]Training Progress[/bold]",
border_style="green",
box=box.ROUNDED,
)
def make_outputs_panel(responses: dict, epoch: int) -> Panel:
"""Create model outputs panel showing all epoch responses as a log."""
content = Text()
# Show all 3 prompts with full epoch history
for i, prompt in enumerate(TEST_PROMPTS):
q = prompt.split('\n')[0]
content.append(f"{q}\n", style="cyan bold")
# Show all epochs completed so far
for ep in range(1, epoch + 1):
key = f"epoch_{ep}_{i}"
response = responses.get(key, "...")
# Most recent epoch is highlighted
style = "white" if ep == epoch else "dim"
content.append(f" Ep{ep}: ", style="yellow")
# Truncate long responses to fit
display_response = response[:25] + "..." if len(response) > 25 else response
content.append(f"{display_response}\n", style=style)
content.append("\n")
return Panel(
content,
title=f"[bold]Learning Progression (Epoch {epoch})[/bold]",
border_style="blue",
box=box.ROUNDED,
)
def make_stats_panel(stats: dict) -> Panel:
"""Create systems stats panel."""
content = Text()
content.append("Performance Metrics\n", style="bold")
content.append(f" Tokens/sec: {stats.get('tokens_per_sec', 0):.1f}\n")
content.append(f" Batch time: {stats.get('batch_time_ms', 0):.0f}ms\n")
content.append(f" Memory: {stats.get('memory_mb', 0):.1f}MB\n\n")
content.append("Model Stats\n", style="bold")
content.append(f" Parameters: {stats.get('params', 0):,}\n")
content.append(f" Vocab size: {stats.get('vocab_size', 0)}\n")
return Panel(
content,
title="[bold]Systems[/bold]",
border_style="magenta",
box=box.ROUNDED,
)
def make_footer(message: str = "") -> Panel:
"""Create footer panel."""
if not message:
message = "Training in progress... Watch the model learn to answer questions!"
return Panel(
Text(message, style="dim", justify="center"),
box=box.ROUNDED,
style="dim",
)
# =============================================================================
# Main Training Loop
# =============================================================================
def main():
"""Main training function with live dashboard."""
# Welcome
console.print()
console.print(Panel.fit(
"[bold cyan]TinyTalks Quick Demo[/bold cyan]\n\n"
"Watch a transformer learn to answer questions in real-time!\n"
"The model starts with random weights (gibberish output)\n"
"and learns to produce coherent answers.\n\n"
"[dim]Training time: ~2 minutes[/dim]",
title="Welcome",
border_style="cyan",
))
console.print()
# Load dataset
project_root = Path(__file__).parent.parent.parent
data_path = project_root / "datasets" / "tinytalks" / "splits" / "train.txt"
if not data_path.exists():
console.print(f"[red]Error: Dataset not found at {data_path}[/red]")
console.print("[yellow]Please ensure TinyTalks dataset is available.[/yellow]")
return
console.print(f"[dim]Loading dataset from {data_path}...[/dim]")
dataset = TinyTalksDataset(data_path, CONFIG["max_seq_len"])
console.print(f"[green]✓[/green] Loaded {len(dataset.text):,} characters, vocab size: {dataset.tokenizer.vocab_size}")
# Create model
console.print("[dim]Creating model...[/dim]")
model = GPT(
vocab_size=dataset.tokenizer.vocab_size,
embed_dim=CONFIG["n_embd"],
num_heads=CONFIG["n_head"],
num_layers=CONFIG["n_layer"],
max_seq_len=CONFIG["max_seq_len"],
)
# Count parameters
param_count = sum(p.data.size for p in model.parameters())
console.print(f"[green]✓[/green] Model created: {param_count:,} parameters")
console.print(f"[dim] {CONFIG['n_layer']} layers, {CONFIG['n_head']} heads, {CONFIG['n_embd']} embed dim[/dim]")
# Setup training
optimizer = Adam(model.parameters(), lr=CONFIG["learning_rate"])
criterion = CrossEntropyLoss()
console.print()
console.print("[bold green]Starting training with live dashboard...[/bold green]")
console.print("[dim]Press Ctrl+C to stop early[/dim]")
console.print()
time.sleep(1)
# Storage for responses and stats
responses = {}
stats = {
"params": param_count,
"vocab_size": dataset.tokenizer.vocab_size,
"tokens_per_sec": 0,
"batch_time_ms": 0,
"memory_mb": param_count * 4 / (1024 * 1024), # Rough estimate
}
# Create layout
layout = make_layout()
# Training loop with live display
start_time = time.time()
current_loss = 0.0
total_tokens = 0
try:
with Live(layout, console=console, refresh_per_second=4) as live:
for epoch in range(1, CONFIG["epochs"] + 1):
epoch_loss = 0.0
for batch_idx in range(1, CONFIG["batches_per_epoch"] + 1):
batch_start = time.time()
# Get batch
inputs, targets = dataset.get_batch(CONFIG["batch_size"])
# Forward pass
logits = model.forward(inputs)
# Reshape for loss
batch_size, seq_len, vocab_size = logits.shape
logits_flat = logits.reshape(batch_size * seq_len, vocab_size)
targets_flat = targets.reshape(-1)
# Compute loss
loss = criterion(logits_flat, targets_flat)
# Backward pass
loss.backward()
# Update
optimizer.step()
optimizer.zero_grad()
# Track loss and stats
batch_loss = float(loss.data)
epoch_loss += batch_loss
current_loss = epoch_loss / batch_idx
# Update systems stats
batch_time = time.time() - batch_start
tokens_in_batch = batch_size * seq_len
total_tokens += tokens_in_batch
elapsed = time.time() - start_time
stats["batch_time_ms"] = batch_time * 1000
stats["tokens_per_sec"] = total_tokens / elapsed if elapsed > 0 else 0
# Update dashboard
layout["header"].update(make_header())
layout["progress"].update(make_progress_panel(
epoch, CONFIG["epochs"],
batch_idx, CONFIG["batches_per_epoch"],
current_loss, elapsed
))
layout["stats"].update(make_stats_panel(stats))
layout["outputs"].update(make_outputs_panel(responses, epoch))
layout["footer"].update(make_footer())
# End of epoch - generate sample responses
for i, prompt in enumerate(TEST_PROMPTS):
response = generate_response(model, dataset.tokenizer, prompt)
responses[f"epoch_{epoch}_{i}"] = response
# Update display with new responses
layout["outputs"].update(make_outputs_panel(responses, epoch))
# Show epoch completion message
layout["footer"].update(make_footer(
f"Epoch {epoch} complete! Loss: {current_loss:.3f}"
))
# Training complete
total_time = time.time() - start_time
console.print()
console.print(Panel.fit(
f"[bold green]Training Complete![/bold green]\n\n"
f"Total time: {total_time:.1f} seconds\n"
f"Final loss: {current_loss:.3f}\n"
f"Epochs: {CONFIG['epochs']}\n\n"
"[cyan]Watch how your transformer learned to talk![/cyan]",
title="Success",
border_style="green",
))
# Show learning progression for all prompts
console.print()
console.print("[bold]Full Learning Progression:[/bold]")
console.print()
for i, prompt in enumerate(TEST_PROMPTS):
q = prompt.split('\n')[0]
table = Table(box=box.ROUNDED, title=q)
table.add_column("Epoch", style="cyan")
table.add_column("Response", style="white")
for epoch in range(1, CONFIG["epochs"] + 1):
key = f"epoch_{epoch}_{i}"
resp = responses.get(key, "...")
table.add_row(str(epoch), resp)
console.print(table)
console.print()
except KeyboardInterrupt:
console.print("\n[yellow]Training stopped by user[/yellow]")
if __name__ == "__main__":
main()

View File

@@ -59,9 +59,9 @@ MILESTONE_SCRIPTS = {
"name": "MLP Revival (1986)",
"year": 1986,
"title": "Backpropagation Breakthrough",
"script": "milestones/03_1986_mlp/02_rumelhart_mnist.py",
"script": "milestones/03_1986_mlp/01_rumelhart_tinydigits.py",
"required_modules": [1, 2, 3, 4, 5, 6, 7],
"description": "Train deep networks on MNIST",
"description": "Train deep networks on TinyDigits",
"historical_context": "Rumelhart, Hinton & Williams (Nature, 1986)",
"emoji": "🎓"
},
@@ -81,7 +81,7 @@ MILESTONE_SCRIPTS = {
"name": "Transformer Era (2017)",
"year": 2017,
"title": "Attention is All You Need",
"script": "milestones/05_2017_transformer/01_vaswani_generation.py",
"script": "milestones/05_2017_transformer/03_quickdemo.py",
"required_modules": list(range(1, 14)),
"description": "Build transformer with self-attention",
"historical_context": "Vaswani et al. revolutionized NLP",
@@ -946,43 +946,55 @@ class MilestoneCommand(BaseCommand):
))
return 1
# Check prerequisites (unless skipped)
completed_modules = []
# Check prerequisites and validate exports/tests (unless skipped)
if not args.skip_checks:
console.print(f"\n[bold cyan]🔍 Checking prerequisites for Milestone {milestone_id}...[/bold cyan]\n")
# Load module completion status
progress_file = Path(".tito") / "progress.json"
if progress_file.exists():
try:
with open(progress_file, 'r') as f:
progress_data = json.load(f)
completed_modules = progress_data.get("completed_modules", [])
except (json.JSONDecodeError, IOError):
pass
# Check each required module
missing_modules = []
for module_num in milestone["required_modules"]:
if module_num in completed_modules:
console.print(f" [green]✓[/green] Module {module_num:02d} - complete")
else:
console.print(f" [red]✗[/red] Module {module_num:02d} - NOT complete")
missing_modules.append(module_num)
if missing_modules:
# Use unified progress tracker
from ..core.progress_tracker import ProgressTracker
from ..core.milestone_validator import MilestoneValidator
from .export import ExportCommand
from .test import TestCommand
tracker = ProgressTracker(self.config.project_root)
validator = MilestoneValidator(self.config.project_root, console)
export_cmd = ExportCommand(self.config)
test_cmd = TestCommand(self.config)
# Check if modules are completed
all_complete, missing_modules = validator.check_prerequisites(milestone_id, tracker)
if not all_complete:
console.print(Panel(
f"[bold yellow]❌ Missing Required Modules[/bold yellow]\n\n"
f"[yellow]Milestone {milestone_id} requires modules: {', '.join(f'{m:02d}' for m in milestone['required_modules'])}[/yellow]\n"
f"[red]Missing: {', '.join(f'{m:02d}' for m in missing_modules)}[/red]\n\n"
f"[cyan]Complete the missing modules first:[/cyan]\n" +
"\n".join(f"[dim] tito module start {m:02d}[/dim]" for m in missing_modules[:3]),
"\n".join(f"[dim] tito module complete {m:02d}[/dim]" for m in missing_modules[:3]),
title="Prerequisites Not Met",
border_style="yellow"
))
return 1
console.print(f"\n[green]✅ All prerequisites met![/green]\n")
console.print(f"[green]✅ All required modules completed![/green]\n")
# Validate that all modules are exported and tested
console.print(f"[bold cyan]🔧 Validating exports and tests...[/bold cyan]\n")
success, failed = validator.validate_and_export_modules(milestone_id, export_cmd, test_cmd)
if not success:
console.print(Panel(
f"[bold red]❌ Validation Failed[/bold red]\n\n"
f"[yellow]Some required modules failed export or tests:[/yellow]\n"
f"[red]{', '.join(failed)}[/red]\n\n"
f"[cyan]Fix the issues and try again:[/cyan]\n"
f"[dim] tito module complete XX[/dim] - Re-export and test modules",
title="Validation Failed",
border_style="red"
))
return 1
console.print(f"\n[green]✅ All modules exported and tested! Ready to run milestone.[/green]\n")
# Test imports work
console.print("[bold cyan]🧪 Testing YOUR implementations...[/bold cyan]\n")
@@ -1034,7 +1046,11 @@ class MilestoneCommand(BaseCommand):
padding=(1, 2)
))
input("\n[yellow]Press Enter to begin...[/yellow] ")
try:
input("\n[yellow]Press Enter to begin...[/yellow] ")
except EOFError:
# Non-interactive mode, proceed automatically
pass
# Run the milestone script
console.print(f"\n[bold green]🚀 Starting Milestone {milestone_id}...[/bold green]\n")
@@ -1052,6 +1068,14 @@ class MilestoneCommand(BaseCommand):
if result.returncode == 0:
# Success! Mark milestone as complete
self._mark_milestone_complete(milestone_id)
# Also update unified progress tracker
try:
from ..core.progress_tracker import ProgressTracker
tracker = ProgressTracker(self.config.project_root)
tracker.mark_milestone_completed(milestone_id)
except Exception:
pass # Non-critical
console.print(Panel(
f"[bold green]🏆 MILESTONE ACHIEVED![/bold green]\n\n"
@@ -1074,6 +1098,21 @@ class MilestoneCommand(BaseCommand):
console.print(f"\n[bold yellow]🎯 What's Next:[/bold yellow]")
console.print(f"[dim]Milestone {next_id}: {next_milestone['name']} ({next_milestone['year']})[/dim]")
# Get completed modules for checking next milestone
progress_file = Path(".tito") / "progress.json"
completed_modules = []
if progress_file.exists():
try:
with open(progress_file, 'r') as f:
progress_data = json.load(f)
for mod in progress_data.get("completed_modules", []):
try:
completed_modules.append(int(mod.split("_")[0]))
except (ValueError, IndexError):
pass
except (json.JSONDecodeError, IOError):
pass
# Check if unlocked
missing = [m for m in next_milestone["required_modules"] if m not in completed_modules]
if missing: