From 59ebf0d385fd381d940b6423b743b9ce378f3d48 Mon Sep 17 00:00:00 2001 From: Vijay Janapa Reddi Date: Sat, 22 Nov 2025 15:55:12 -0500 Subject: [PATCH] Add transformer quickdemo with live learning progression dashboard MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New milestone 05 demo that shows students the model learning to "talk": - Live dashboard with epoch-by-epoch response progression - Systems stats panel (tokens/sec, batch time, memory) - 3 test prompts with full history displayed - Smaller model (110K params) for ~2 minute training time ๐Ÿค– Generated with [Claude Code](https://claude.com/claude-code) --- .../05_2017_transformer/03_quickdemo.py | 481 ++++++++++++++++++ tito/commands/milestone.py | 97 ++-- 2 files changed, 549 insertions(+), 29 deletions(-) create mode 100644 milestones/05_2017_transformer/03_quickdemo.py diff --git a/milestones/05_2017_transformer/03_quickdemo.py b/milestones/05_2017_transformer/03_quickdemo.py new file mode 100644 index 00000000..716a2dc6 --- /dev/null +++ b/milestones/05_2017_transformer/03_quickdemo.py @@ -0,0 +1,481 @@ +#!/usr/bin/env python3 +""" +TinyTalks Quick Demo - Watch Your Transformer Learn to Talk! +============================================================= + +A fast, visual demonstration of transformer training. +See the model go from gibberish to coherent answers in ~2 minutes! + +Features: +- Smaller model (~50K params) for fast training +- Live dashboard showing training progress +- Rotating prompts to show diverse capabilities +- Learning progression display (gibberish -> coherent) +""" + +import sys +import os +import time +import numpy as np +from pathlib import Path + +# Add project root to path +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +# Rich for live dashboard +from rich.console import Console +from rich.layout import Layout +from rich.panel import Panel +from rich.table import Table +from rich.live import Live +from rich.text import Text +from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn +from rich import box + +# TinyTorch imports +from tinytorch.core.tensor import Tensor +from tinytorch.core.optimizers import Adam +from tinytorch.core.losses import CrossEntropyLoss +from tinytorch.models.transformer import GPT +from tinytorch.text.tokenization import CharTokenizer + +console = Console() + +# ============================================================================= +# Configuration - Optimized for ~2 minute training +# ============================================================================= + +CONFIG = { + # Model (smaller for speed) + "n_layer": 2, + "n_head": 2, + "n_embd": 64, + "max_seq_len": 32, # Shorter sequences for speed + + # Training (optimized for ~2 min on pure Python) + "epochs": 8, + "batches_per_epoch": 30, + "batch_size": 8, + "learning_rate": 0.003, # Balanced LR for stable convergence + + # Display + "update_interval": 5, # Update dashboard every N batches +} + +# Test prompts to show model learning (3 prompts for better progression display) +TEST_PROMPTS = [ + "Q: What is 2+2?\nA:", + "Q: What color is the sky?\nA:", + "Q: Say hello\nA:", +] + + +# ============================================================================= +# Dataset +# ============================================================================= + +class TinyTalksDataset: + """Simple character-level dataset from TinyTalks.""" + + def __init__(self, data_path: Path, seq_len: int): + self.seq_len = seq_len + + # Load text + with open(data_path, 'r') as f: + self.text = f.read() + + # Create tokenizer and build vocabulary + self.tokenizer = CharTokenizer() + self.tokenizer.build_vocab([self.text]) + + # Tokenize entire text + self.tokens = self.tokenizer.encode(self.text) + + def __len__(self): + return len(self.tokens) - self.seq_len + + def get_batch(self, batch_size: int): + """Get random batch of sequences.""" + indices = np.random.randint(0, len(self) - 1, size=batch_size) + + inputs = [] + targets = [] + + for idx in indices: + seq = self.tokens[idx:idx + self.seq_len + 1] + inputs.append(seq[:-1]) + targets.append(seq[1:]) + + return ( + Tensor(np.array(inputs)), + Tensor(np.array(targets)) + ) + + +# ============================================================================= +# Text Generation +# ============================================================================= + +def generate_response(model, tokenizer, prompt: str, max_tokens: int = 30) -> str: + """Generate text from prompt.""" + # Encode prompt + tokens = tokenizer.encode(prompt) + + for _ in range(max_tokens): + # Prepare input + context = tokens[-CONFIG["max_seq_len"]:] + x = Tensor(np.array([context])) + + # Forward pass + logits = model.forward(x) + + # Get next token probabilities + last_logits = logits.data[0, -1, :] + + # Temperature sampling + temperature = 0.8 + last_logits = last_logits / temperature + exp_logits = np.exp(last_logits - np.max(last_logits)) + probs = exp_logits / np.sum(exp_logits) + + # Sample + next_token = np.random.choice(len(probs), p=probs) + tokens.append(next_token) + + # Stop at newline (end of answer) + if tokenizer.decode([next_token]) == '\n': + break + + # Decode and extract answer + full_text = tokenizer.decode(tokens) + + # Get just the answer part + if "A:" in full_text: + answer = full_text.split("A:")[-1].strip() + # Clean up - take first line + answer = answer.split('\n')[0].strip() + return answer if answer else "(empty)" + + return full_text[len(prompt):].strip() or "(empty)" + + +# ============================================================================= +# Dashboard Layout +# ============================================================================= + +def make_layout() -> Layout: + """Create the dashboard layout.""" + layout = Layout() + + layout.split_column( + Layout(name="header", size=3), + Layout(name="main", ratio=1), + Layout(name="footer", size=3), + ) + + layout["main"].split_row( + Layout(name="left", ratio=1), + Layout(name="outputs", ratio=2), + ) + + layout["left"].split_column( + Layout(name="progress", ratio=2), + Layout(name="stats", ratio=1), + ) + + return layout + + +def make_header() -> Panel: + """Create header panel.""" + return Panel( + Text("TinyTalks Quick Demo - Watch Your Transformer Learn!", + style="bold cyan", justify="center"), + box=box.ROUNDED, + style="cyan", + ) + + +def make_progress_panel(epoch: int, total_epochs: int, batch: int, + total_batches: int, loss: float, elapsed: float) -> Panel: + """Create training progress panel.""" + # Calculate overall progress + total_steps = total_epochs * total_batches + current_step = (epoch - 1) * total_batches + batch + progress_pct = (current_step / total_steps) * 100 + + # Progress bar + bar_width = 20 + filled = int(bar_width * progress_pct / 100) + bar = "โ–ˆ" * filled + "โ–‘" * (bar_width - filled) + + # Estimate time remaining + if current_step > 0: + time_per_step = elapsed / current_step + remaining_steps = total_steps - current_step + eta = remaining_steps * time_per_step + eta_str = f"{eta:.0f}s" + else: + eta_str = "..." + + content = Text() + content.append(f"Epoch: {epoch}/{total_epochs}\n", style="bold") + content.append(f"Batch: {batch}/{total_batches}\n") + content.append(f"Loss: {loss:.3f}\n\n", style="yellow") + content.append(f"{bar} {progress_pct:.0f}%\n\n", style="green") + content.append(f"Elapsed: {elapsed:.0f}s\n") + content.append(f"ETA: {eta_str}") + + return Panel( + content, + title="[bold]Training Progress[/bold]", + border_style="green", + box=box.ROUNDED, + ) + + +def make_outputs_panel(responses: dict, epoch: int) -> Panel: + """Create model outputs panel showing all epoch responses as a log.""" + content = Text() + + # Show all 3 prompts with full epoch history + for i, prompt in enumerate(TEST_PROMPTS): + q = prompt.split('\n')[0] + content.append(f"{q}\n", style="cyan bold") + + # Show all epochs completed so far + for ep in range(1, epoch + 1): + key = f"epoch_{ep}_{i}" + response = responses.get(key, "...") + # Most recent epoch is highlighted + style = "white" if ep == epoch else "dim" + content.append(f" Ep{ep}: ", style="yellow") + # Truncate long responses to fit + display_response = response[:25] + "..." if len(response) > 25 else response + content.append(f"{display_response}\n", style=style) + + content.append("\n") + + return Panel( + content, + title=f"[bold]Learning Progression (Epoch {epoch})[/bold]", + border_style="blue", + box=box.ROUNDED, + ) + + +def make_stats_panel(stats: dict) -> Panel: + """Create systems stats panel.""" + content = Text() + + content.append("Performance Metrics\n", style="bold") + content.append(f" Tokens/sec: {stats.get('tokens_per_sec', 0):.1f}\n") + content.append(f" Batch time: {stats.get('batch_time_ms', 0):.0f}ms\n") + content.append(f" Memory: {stats.get('memory_mb', 0):.1f}MB\n\n") + + content.append("Model Stats\n", style="bold") + content.append(f" Parameters: {stats.get('params', 0):,}\n") + content.append(f" Vocab size: {stats.get('vocab_size', 0)}\n") + + return Panel( + content, + title="[bold]Systems[/bold]", + border_style="magenta", + box=box.ROUNDED, + ) + + +def make_footer(message: str = "") -> Panel: + """Create footer panel.""" + if not message: + message = "Training in progress... Watch the model learn to answer questions!" + + return Panel( + Text(message, style="dim", justify="center"), + box=box.ROUNDED, + style="dim", + ) + + +# ============================================================================= +# Main Training Loop +# ============================================================================= + +def main(): + """Main training function with live dashboard.""" + + # Welcome + console.print() + console.print(Panel.fit( + "[bold cyan]TinyTalks Quick Demo[/bold cyan]\n\n" + "Watch a transformer learn to answer questions in real-time!\n" + "The model starts with random weights (gibberish output)\n" + "and learns to produce coherent answers.\n\n" + "[dim]Training time: ~2 minutes[/dim]", + title="Welcome", + border_style="cyan", + )) + console.print() + + # Load dataset + project_root = Path(__file__).parent.parent.parent + data_path = project_root / "datasets" / "tinytalks" / "splits" / "train.txt" + + if not data_path.exists(): + console.print(f"[red]Error: Dataset not found at {data_path}[/red]") + console.print("[yellow]Please ensure TinyTalks dataset is available.[/yellow]") + return + + console.print(f"[dim]Loading dataset from {data_path}...[/dim]") + dataset = TinyTalksDataset(data_path, CONFIG["max_seq_len"]) + console.print(f"[green]โœ“[/green] Loaded {len(dataset.text):,} characters, vocab size: {dataset.tokenizer.vocab_size}") + + # Create model + console.print("[dim]Creating model...[/dim]") + model = GPT( + vocab_size=dataset.tokenizer.vocab_size, + embed_dim=CONFIG["n_embd"], + num_heads=CONFIG["n_head"], + num_layers=CONFIG["n_layer"], + max_seq_len=CONFIG["max_seq_len"], + ) + + # Count parameters + param_count = sum(p.data.size for p in model.parameters()) + console.print(f"[green]โœ“[/green] Model created: {param_count:,} parameters") + console.print(f"[dim] {CONFIG['n_layer']} layers, {CONFIG['n_head']} heads, {CONFIG['n_embd']} embed dim[/dim]") + + # Setup training + optimizer = Adam(model.parameters(), lr=CONFIG["learning_rate"]) + criterion = CrossEntropyLoss() + + console.print() + console.print("[bold green]Starting training with live dashboard...[/bold green]") + console.print("[dim]Press Ctrl+C to stop early[/dim]") + console.print() + time.sleep(1) + + # Storage for responses and stats + responses = {} + stats = { + "params": param_count, + "vocab_size": dataset.tokenizer.vocab_size, + "tokens_per_sec": 0, + "batch_time_ms": 0, + "memory_mb": param_count * 4 / (1024 * 1024), # Rough estimate + } + + # Create layout + layout = make_layout() + + # Training loop with live display + start_time = time.time() + current_loss = 0.0 + total_tokens = 0 + + try: + with Live(layout, console=console, refresh_per_second=4) as live: + for epoch in range(1, CONFIG["epochs"] + 1): + epoch_loss = 0.0 + + for batch_idx in range(1, CONFIG["batches_per_epoch"] + 1): + batch_start = time.time() + + # Get batch + inputs, targets = dataset.get_batch(CONFIG["batch_size"]) + + # Forward pass + logits = model.forward(inputs) + + # Reshape for loss + batch_size, seq_len, vocab_size = logits.shape + logits_flat = logits.reshape(batch_size * seq_len, vocab_size) + targets_flat = targets.reshape(-1) + + # Compute loss + loss = criterion(logits_flat, targets_flat) + + # Backward pass + loss.backward() + + # Update + optimizer.step() + optimizer.zero_grad() + + # Track loss and stats + batch_loss = float(loss.data) + epoch_loss += batch_loss + current_loss = epoch_loss / batch_idx + + # Update systems stats + batch_time = time.time() - batch_start + tokens_in_batch = batch_size * seq_len + total_tokens += tokens_in_batch + elapsed = time.time() - start_time + + stats["batch_time_ms"] = batch_time * 1000 + stats["tokens_per_sec"] = total_tokens / elapsed if elapsed > 0 else 0 + + # Update dashboard + layout["header"].update(make_header()) + layout["progress"].update(make_progress_panel( + epoch, CONFIG["epochs"], + batch_idx, CONFIG["batches_per_epoch"], + current_loss, elapsed + )) + layout["stats"].update(make_stats_panel(stats)) + layout["outputs"].update(make_outputs_panel(responses, epoch)) + layout["footer"].update(make_footer()) + + # End of epoch - generate sample responses + for i, prompt in enumerate(TEST_PROMPTS): + response = generate_response(model, dataset.tokenizer, prompt) + responses[f"epoch_{epoch}_{i}"] = response + + # Update display with new responses + layout["outputs"].update(make_outputs_panel(responses, epoch)) + + # Show epoch completion message + layout["footer"].update(make_footer( + f"Epoch {epoch} complete! Loss: {current_loss:.3f}" + )) + + # Training complete + total_time = time.time() - start_time + + console.print() + console.print(Panel.fit( + f"[bold green]Training Complete![/bold green]\n\n" + f"Total time: {total_time:.1f} seconds\n" + f"Final loss: {current_loss:.3f}\n" + f"Epochs: {CONFIG['epochs']}\n\n" + "[cyan]Watch how your transformer learned to talk![/cyan]", + title="Success", + border_style="green", + )) + + # Show learning progression for all prompts + console.print() + console.print("[bold]Full Learning Progression:[/bold]") + console.print() + + for i, prompt in enumerate(TEST_PROMPTS): + q = prompt.split('\n')[0] + table = Table(box=box.ROUNDED, title=q) + table.add_column("Epoch", style="cyan") + table.add_column("Response", style="white") + + for epoch in range(1, CONFIG["epochs"] + 1): + key = f"epoch_{epoch}_{i}" + resp = responses.get(key, "...") + table.add_row(str(epoch), resp) + + console.print(table) + console.print() + + except KeyboardInterrupt: + console.print("\n[yellow]Training stopped by user[/yellow]") + + +if __name__ == "__main__": + main() diff --git a/tito/commands/milestone.py b/tito/commands/milestone.py index 5c343c3e..0dcc0c27 100644 --- a/tito/commands/milestone.py +++ b/tito/commands/milestone.py @@ -59,9 +59,9 @@ MILESTONE_SCRIPTS = { "name": "MLP Revival (1986)", "year": 1986, "title": "Backpropagation Breakthrough", - "script": "milestones/03_1986_mlp/02_rumelhart_mnist.py", + "script": "milestones/03_1986_mlp/01_rumelhart_tinydigits.py", "required_modules": [1, 2, 3, 4, 5, 6, 7], - "description": "Train deep networks on MNIST", + "description": "Train deep networks on TinyDigits", "historical_context": "Rumelhart, Hinton & Williams (Nature, 1986)", "emoji": "๐ŸŽ“" }, @@ -81,7 +81,7 @@ MILESTONE_SCRIPTS = { "name": "Transformer Era (2017)", "year": 2017, "title": "Attention is All You Need", - "script": "milestones/05_2017_transformer/01_vaswani_generation.py", + "script": "milestones/05_2017_transformer/03_quickdemo.py", "required_modules": list(range(1, 14)), "description": "Build transformer with self-attention", "historical_context": "Vaswani et al. revolutionized NLP", @@ -946,43 +946,55 @@ class MilestoneCommand(BaseCommand): )) return 1 - # Check prerequisites (unless skipped) - completed_modules = [] + # Check prerequisites and validate exports/tests (unless skipped) if not args.skip_checks: console.print(f"\n[bold cyan]๐Ÿ” Checking prerequisites for Milestone {milestone_id}...[/bold cyan]\n") - # Load module completion status - progress_file = Path(".tito") / "progress.json" - if progress_file.exists(): - try: - with open(progress_file, 'r') as f: - progress_data = json.load(f) - completed_modules = progress_data.get("completed_modules", []) - except (json.JSONDecodeError, IOError): - pass - - # Check each required module - missing_modules = [] - for module_num in milestone["required_modules"]: - if module_num in completed_modules: - console.print(f" [green]โœ“[/green] Module {module_num:02d} - complete") - else: - console.print(f" [red]โœ—[/red] Module {module_num:02d} - NOT complete") - missing_modules.append(module_num) - - if missing_modules: + # Use unified progress tracker + from ..core.progress_tracker import ProgressTracker + from ..core.milestone_validator import MilestoneValidator + from .export import ExportCommand + from .test import TestCommand + + tracker = ProgressTracker(self.config.project_root) + validator = MilestoneValidator(self.config.project_root, console) + export_cmd = ExportCommand(self.config) + test_cmd = TestCommand(self.config) + + # Check if modules are completed + all_complete, missing_modules = validator.check_prerequisites(milestone_id, tracker) + + if not all_complete: console.print(Panel( f"[bold yellow]โŒ Missing Required Modules[/bold yellow]\n\n" f"[yellow]Milestone {milestone_id} requires modules: {', '.join(f'{m:02d}' for m in milestone['required_modules'])}[/yellow]\n" f"[red]Missing: {', '.join(f'{m:02d}' for m in missing_modules)}[/red]\n\n" f"[cyan]Complete the missing modules first:[/cyan]\n" + - "\n".join(f"[dim] tito module start {m:02d}[/dim]" for m in missing_modules[:3]), + "\n".join(f"[dim] tito module complete {m:02d}[/dim]" for m in missing_modules[:3]), title="Prerequisites Not Met", border_style="yellow" )) return 1 - - console.print(f"\n[green]โœ… All prerequisites met![/green]\n") + + console.print(f"[green]โœ… All required modules completed![/green]\n") + + # Validate that all modules are exported and tested + console.print(f"[bold cyan]๐Ÿ”ง Validating exports and tests...[/bold cyan]\n") + success, failed = validator.validate_and_export_modules(milestone_id, export_cmd, test_cmd) + + if not success: + console.print(Panel( + f"[bold red]โŒ Validation Failed[/bold red]\n\n" + f"[yellow]Some required modules failed export or tests:[/yellow]\n" + f"[red]{', '.join(failed)}[/red]\n\n" + f"[cyan]Fix the issues and try again:[/cyan]\n" + f"[dim] tito module complete XX[/dim] - Re-export and test modules", + title="Validation Failed", + border_style="red" + )) + return 1 + + console.print(f"\n[green]โœ… All modules exported and tested! Ready to run milestone.[/green]\n") # Test imports work console.print("[bold cyan]๐Ÿงช Testing YOUR implementations...[/bold cyan]\n") @@ -1034,7 +1046,11 @@ class MilestoneCommand(BaseCommand): padding=(1, 2) )) - input("\n[yellow]Press Enter to begin...[/yellow] ") + try: + input("\n[yellow]Press Enter to begin...[/yellow] ") + except EOFError: + # Non-interactive mode, proceed automatically + pass # Run the milestone script console.print(f"\n[bold green]๐Ÿš€ Starting Milestone {milestone_id}...[/bold green]\n") @@ -1052,6 +1068,14 @@ class MilestoneCommand(BaseCommand): if result.returncode == 0: # Success! Mark milestone as complete self._mark_milestone_complete(milestone_id) + + # Also update unified progress tracker + try: + from ..core.progress_tracker import ProgressTracker + tracker = ProgressTracker(self.config.project_root) + tracker.mark_milestone_completed(milestone_id) + except Exception: + pass # Non-critical console.print(Panel( f"[bold green]๐Ÿ† MILESTONE ACHIEVED![/bold green]\n\n" @@ -1074,6 +1098,21 @@ class MilestoneCommand(BaseCommand): console.print(f"\n[bold yellow]๐ŸŽฏ What's Next:[/bold yellow]") console.print(f"[dim]Milestone {next_id}: {next_milestone['name']} ({next_milestone['year']})[/dim]") + # Get completed modules for checking next milestone + progress_file = Path(".tito") / "progress.json" + completed_modules = [] + if progress_file.exists(): + try: + with open(progress_file, 'r') as f: + progress_data = json.load(f) + for mod in progress_data.get("completed_modules", []): + try: + completed_modules.append(int(mod.split("_")[0])) + except (ValueError, IndexError): + pass + except (json.JSONDecodeError, IOError): + pass + # Check if unlocked missing = [m for m in next_milestone["required_modules"] if m not in completed_modules] if missing: