From 01d9118cc8d927d47371f5532f2b7d8cd8eab01c Mon Sep 17 00:00:00 2001 From: Vijay Janapa Reddi Date: Tue, 28 Oct 2025 15:42:47 -0400 Subject: [PATCH] feat(milestones): Add monitored training script with early stopping - train_monitored.py: Smart training with early stopping and progress monitoring - MONITORED_TRAINING.md: Complete usage guide - Features: Test mode (10 epochs) and full mode (30 epochs) - Automatically stops training if loss doesn't improve - Saves time by killing bad experiments early --- .../05_2017_transformer/MONITORED_TRAINING.md | 168 +++++++++ .../05_2017_transformer/train_monitored.py | 336 ++++++++++++++++++ 2 files changed, 504 insertions(+) create mode 100644 milestones/05_2017_transformer/MONITORED_TRAINING.md create mode 100755 milestones/05_2017_transformer/train_monitored.py diff --git a/milestones/05_2017_transformer/MONITORED_TRAINING.md b/milestones/05_2017_transformer/MONITORED_TRAINING.md new file mode 100644 index 00000000..9d3f12c1 --- /dev/null +++ b/milestones/05_2017_transformer/MONITORED_TRAINING.md @@ -0,0 +1,168 @@ +# Monitored Training for TinyTalks + +## Problem +Training transformers can take a long time (30+ minutes), and we don't want to waste time on experiments that aren't learning. + +## Solution +`train_monitored.py` provides: +- **Early Stopping**: Automatically kills training if loss doesn't improve +- **Continuous Monitoring**: Shows progress every N batches +- **Two Modes**: Quick test (10 epochs) vs full training (30 epochs) + +## Usage + +### Quick Test (Recommended First!) +```bash +cd /Users/VJ/GitHub/TinyTorch +PYTHONPATH=/Users/VJ/GitHub/TinyTorch:$PYTHONPATH \ + .venv/bin/python milestones/05_2017_transformer/train_monitored.py --mode test +``` + +**What it does:** +- Trains for 10 epochs (or until early stop) +- Checks progress every 50 batches +- Stops if no improvement for 5 checks +- Takes ~15-20 minutes +- Shows if the config is working + +### Full Training (After test passes) +```bash +PYTHONPATH=/Users/VJ/GitHub/TinyTorch:$PYTHONPATH \ + .venv/bin/python milestones/05_2017_transformer/train_monitored.py --mode full +``` + +**What it does:** +- Trains for 30 epochs (or until early stop) +- Same monitoring as test mode +- Takes ~45-60 minutes +- Only run if test mode shows good learning + +## Parameters + +### Early Stopping +```bash +--patience 5 # Stop after 5 checks without improvement (default) +--min-delta 0.01 # Minimum loss decrease to count (default: 0.01) +``` + +### Monitoring +```bash +--check-interval 50 # Check every N batches (default: 50) +``` + +## Example Output + +``` +═══════════════════════════════════════════════════ + Monitored TinyTalks Training - Option C +═══════════════════════════════════════════════════ + +┌──────────────────────┬─────────────────────────┐ +│ Parameter │ Value │ +├──────────────────────┼─────────────────────────┤ +│ Mode │ TEST (Quick Validation) │ +│ Epochs │ 10 │ +│ Batch Size │ 32 │ +│ Learning Rate │ 0.001 │ +│ Model Size │ 128d, 6L, 8H │ +│ Early Stopping Patience │ 5 │ +│ Min Delta │ 0.01 │ +│ Check Interval │ Every 50 batches │ +└──────────────────────┴─────────────────────────┘ + +Starting Training with Monitoring + Check interval: Every 50 batches + Early stopping: 5 checks without improvement + +Epoch 1/10 + Batch 50 | Loss: 3.2145 | ✓ Loss improved by 0.8234 | Time: 12.3s + Batch 100 | Loss: 2.8912 | ✓ Loss improved by 0.3233 | Time: 24.1s + → Epoch 1 complete: Avg Loss = 2.7234 | Time: 48.2s + +Epoch 2/10 + Batch 150 | Loss: 2.3456 | ✓ Loss improved by 0.5456 | Time: 36.5s + ... +``` + +## Interpreting Results + +### Success Messages +- `✓ Loss improved by X`: Training is working! +- `✓ EXCELLENT: Model is learning well!`: Loss decreased >50% +- `⚠ MODERATE: Model is learning but slowly`: Loss decreased 20-50% + +### Warning Messages +- `⚠ No improvement (2/5)`: Still OK, but being watched +- `✗ POOR: Model not learning effectively`: Loss decreased <20% +- `✗ FAILED: Training stopped early`: No improvement, try different config + +## Typical Results + +### Good Run (Continue to full training) +``` +Initial Loss: 4.2345 +Final Loss: 1.5678 +Total Decrease: 2.6667 (62.9%) +Status: ✓ SUCCESS +``` + +### Bad Run (Stop and tune) +``` +Initial Loss: 4.2345 +Final Loss: 4.1234 +Total Decrease: 0.1111 (2.6%) +Status: ✗ EARLY STOP +``` + +## Workflow + +1. **Start with test mode**: `--mode test` +2. **Monitor console output**: Watch for loss improvement +3. **Check summary**: Look at decrease percentage +4. **Decision**: + - If >50% decrease → Run full training + - If 20-50% decrease → Consider tuning, then full training + - If <20% decrease → Tune hyperparameters, retest + - If early stop → Major changes needed + +## Hyperparameter Tuning + +If test mode shows poor learning: + +### Try Higher Learning Rate +```bash +# Edit config in train_monitored.py: +'lr': 0.003 # Up from 0.001 +``` + +### Try Smaller Model (Faster iteration) +```bash +'embed_dim': 96, # Down from 128 +'num_layers': 4, # Down from 6 +'num_heads': 4, # Down from 8 +``` + +### Try Larger Batch Size +```bash +'batch_size': 64, # Up from 32 +``` + +## Time Estimates + +- **Test mode**: 15-20 minutes +- **Full mode**: 45-60 minutes +- **Early stop**: 5-10 minutes (saves you 40+ minutes!) + +## Files Created + +- `/tmp/training_log.txt`: Complete training log +- Console output: Real-time progress + +## When to Use + +- ✓ First time training a model +- ✓ Testing new hyperparameters +- ✓ Limited time available +- ✓ Unsure if config will work +- ✗ Config already validated (use regular training) + diff --git a/milestones/05_2017_transformer/train_monitored.py b/milestones/05_2017_transformer/train_monitored.py new file mode 100755 index 00000000..3dd95495 --- /dev/null +++ b/milestones/05_2017_transformer/train_monitored.py @@ -0,0 +1,336 @@ +#!/usr/bin/env python3 +""" +Monitored Training Script for TinyTalks +======================================== + +Features: +- Early stopping if loss doesn't improve +- Continuous progress monitoring +- Automatic experiment termination for bad runs +- Clear feedback on learning progress + +Usage: + python train_monitored.py --mode test # 10 epochs, quick validation + python train_monitored.py --mode full # 30 epochs, full training +""" + +import sys +import os +import argparse +import time +import numpy as np +from pathlib import Path + +# Add project root to path +project_root = Path(__file__).parent.parent.parent +sys.path.insert(0, str(project_root)) + +from rich.console import Console +from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn +from rich.table import Table +from rich import box + +# Import TinyTorch components +from tinytorch.core.tensor import Tensor +from tinytorch.core.autograd import enable_autograd +from tinytorch.core.losses import CrossEntropyLoss +from tinytorch.core.optimizers import Adam +from tinytorch.text.tokenization import CharTokenizer + +console = Console() + +# Import TinyGPT and dataset classes +exec(open(project_root / "milestones/05_2017_transformer/tinytalks_gpt.py").read()) + + +class TrainingMonitor: + """Monitor training progress and implement early stopping""" + + def __init__(self, patience=5, min_delta=0.01): + """ + Args: + patience: Number of checks without improvement before stopping + min_delta: Minimum change in loss to count as improvement + """ + self.patience = patience + self.min_delta = min_delta + self.best_loss = float('inf') + self.checks_without_improvement = 0 + self.losses = [] + + def check(self, current_loss): + """ + Check if training should continue + + Returns: + (should_continue, message) + """ + self.losses.append(current_loss) + + # Calculate improvement + improvement = self.best_loss - current_loss + + if improvement > self.min_delta: + # Significant improvement + self.best_loss = current_loss + self.checks_without_improvement = 0 + return True, f"✓ Loss improved by {improvement:.4f}" + else: + # No significant improvement + self.checks_without_improvement += 1 + + if self.checks_without_improvement >= self.patience: + return False, f"✗ No improvement for {self.patience} checks. Stopping." + else: + return True, f"⚠ No improvement ({self.checks_without_improvement}/{self.patience})" + + def summary(self): + """Get training summary""" + if len(self.losses) < 2: + return "Not enough data" + + initial = self.losses[0] + final = self.losses[-1] + best = min(self.losses) + decrease = initial - final + decrease_pct = (decrease / initial) * 100 if initial > 0 else 0 + + return { + 'initial_loss': initial, + 'final_loss': final, + 'best_loss': best, + 'total_decrease': decrease, + 'decrease_percent': decrease_pct, + 'num_checks': len(self.losses) + } + + +def train_with_monitoring(model, dataset, optimizer, criterion, config, monitor): + """ + Train with continuous monitoring and early stopping + + Args: + model: TinyGPT model + dataset: TinyTalksDataset + optimizer: Adam optimizer + criterion: CrossEntropyLoss + config: Training configuration dict + monitor: TrainingMonitor instance + + Returns: + success: True if training completed successfully + """ + epochs = config['epochs'] + batch_size = config['batch_size'] + check_interval = config.get('check_interval', 50) # Check every N batches + + console.print(f"\n[bold cyan]Starting Training with Monitoring[/bold cyan]") + console.print(f" Check interval: Every {check_interval} batches") + console.print(f" Early stopping: {monitor.patience} checks without improvement\n") + + total_batches_processed = 0 + start_time = time.time() + + for epoch in range(epochs): + epoch_start = time.time() + epoch_loss = 0.0 + batch_count = 0 + + console.print(f"[bold]Epoch {epoch+1}/{epochs}[/bold]") + + # Create batches + num_sequences = len(dataset) + indices = np.random.permutation(num_sequences) + + for batch_start in range(0, num_sequences, batch_size): + batch_end = min(batch_start + batch_size, num_sequences) + batch_indices = indices[batch_start:batch_end] + + # Get batch data + batch_inputs = [] + batch_targets = [] + for idx in batch_indices: + input_seq, target_seq = dataset[idx] + batch_inputs.append(input_seq) + batch_targets.append(target_seq) + + # Convert to tensors + batch_input = Tensor(np.array(batch_inputs)) + batch_target = Tensor(np.array(batch_targets)) + + # Forward pass + logits = model.forward(batch_input) + + # Reshape for loss + batch_size_actual, seq_length, vocab_size = logits.shape + logits_2d = logits.reshape(batch_size_actual * seq_length, vocab_size) + targets_1d = batch_target.reshape(-1) + + # Compute loss + loss = criterion.forward(logits_2d, targets_1d) + + # Backward and optimize + loss.backward() + optimizer.step() + optimizer.zero_grad() + + # Track loss + loss_value = float(loss.data) + epoch_loss += loss_value + batch_count += 1 + total_batches_processed += 1 + + # Monitor progress at check intervals + if total_batches_processed % check_interval == 0: + avg_loss = epoch_loss / batch_count + should_continue, message = monitor.check(avg_loss) + + elapsed = time.time() - start_time + console.print(f" Batch {total_batches_processed} | Loss: {avg_loss:.4f} | {message} | Time: {elapsed:.1f}s") + + if not should_continue: + console.print(f"\n[yellow]Early stopping triggered at epoch {epoch+1}, batch {batch_count}[/yellow]") + return False + + # Epoch summary + avg_epoch_loss = epoch_loss / batch_count + epoch_time = time.time() - epoch_start + console.print(f" → Epoch {epoch+1} complete: Avg Loss = {avg_epoch_loss:.4f} | Time: {epoch_time:.1f}s\n") + + console.print(f"[green]✓ Training completed successfully![/green]\n") + return True + + +def main(): + parser = argparse.ArgumentParser(description='Monitored TinyTalks Training') + parser.add_argument('--mode', choices=['test', 'full'], default='test', + help='Training mode: test (10 epochs) or full (30 epochs)') + parser.add_argument('--patience', type=int, default=5, + help='Early stopping patience (checks without improvement)') + parser.add_argument('--min-delta', type=float, default=0.01, + help='Minimum loss decrease to count as improvement') + parser.add_argument('--check-interval', type=int, default=50, + help='Check progress every N batches') + + args = parser.parse_args() + + # Enable autograd + enable_autograd() + + # Configuration based on mode + if args.mode == 'test': + config = { + 'epochs': 10, + 'batch_size': 32, + 'lr': 0.001, + 'embed_dim': 128, + 'num_layers': 6, + 'num_heads': 8, + 'check_interval': args.check_interval, + 'mode': 'TEST (Quick Validation)' + } + else: # full + config = { + 'epochs': 30, + 'batch_size': 32, + 'lr': 0.001, + 'embed_dim': 128, + 'num_layers': 6, + 'num_heads': 8, + 'check_interval': args.check_interval, + 'mode': 'FULL (Complete Training)' + } + + # Display configuration + console.print("\n[bold cyan]═══════════════════════════════════════════════════[/bold cyan]") + console.print("[bold cyan] Monitored TinyTalks Training - Option C [/bold cyan]") + console.print("[bold cyan]═══════════════════════════════════════════════════[/bold cyan]\n") + + table = Table(box=box.ROUNDED) + table.add_column("Parameter", style="cyan") + table.add_column("Value", style="yellow") + + table.add_row("Mode", config['mode']) + table.add_row("Epochs", str(config['epochs'])) + table.add_row("Batch Size", str(config['batch_size'])) + table.add_row("Learning Rate", str(config['lr'])) + table.add_row("Model Size", f"{config['embed_dim']}d, {config['num_layers']}L, {config['num_heads']}H") + table.add_row("Early Stopping Patience", str(args.patience)) + table.add_row("Min Delta", str(args.min_delta)) + table.add_row("Check Interval", f"Every {args.check_interval} batches") + + console.print(table) + console.print() + + # Load dataset + console.print("[bold]Loading TinyTalks dataset...[/bold]") + dataset_path = project_root / "datasets/tinytalks/splits/train.txt" + with open(dataset_path, 'r') as f: + text = f.read() + + dataset = TinyTalksDataset(text, seq_length=64) + console.print(f" ✓ Loaded: {len(text):,} chars, {dataset.tokenizer.vocab_size} vocab\n") + + # Initialize model + console.print("[bold]Initializing model...[/bold]") + model = TinyGPT( + vocab_size=dataset.tokenizer.vocab_size, + embed_dim=config['embed_dim'], + num_layers=config['num_layers'], + num_heads=config['num_heads'], + max_seq_len=64 + ) + + params = model.parameters() + param_count = sum(p.data.size for p in params) + console.print(f" ✓ Model initialized: {param_count:,} parameters\n") + + # Initialize training components + optimizer = Adam(params, lr=config['lr']) + criterion = CrossEntropyLoss() + monitor = TrainingMonitor(patience=args.patience, min_delta=args.min_delta) + + # Train + console.print("[bold]Starting training...[/bold]\n") + start_time = time.time() + + success = train_with_monitoring(model, dataset, optimizer, criterion, config, monitor) + + total_time = time.time() - start_time + + # Summary + console.print("\n[bold cyan]═══════════════════════════════════════════════════[/bold cyan]") + console.print("[bold cyan] Training Summary [/bold cyan]") + console.print("[bold cyan]═══════════════════════════════════════════════════[/bold cyan]\n") + + summary = monitor.summary() + + result_table = Table(box=box.ROUNDED) + result_table.add_column("Metric", style="cyan") + result_table.add_column("Value", style="yellow") + + result_table.add_row("Status", "✓ SUCCESS" if success else "⚠ EARLY STOP") + result_table.add_row("Total Time", f"{total_time/60:.1f} minutes") + result_table.add_row("Initial Loss", f"{summary['initial_loss']:.4f}") + result_table.add_row("Final Loss", f"{summary['final_loss']:.4f}") + result_table.add_row("Best Loss", f"{summary['best_loss']:.4f}") + result_table.add_row("Total Decrease", f"{summary['total_decrease']:.4f} ({summary['decrease_percent']:.1f}%)") + result_table.add_row("Checks Performed", str(summary['num_checks'])) + + console.print(result_table) + console.print() + + # Recommendation + if success and summary['decrease_percent'] > 50: + console.print("[bold green]✓ EXCELLENT: Model is learning well! Continue with full training.[/bold green]") + elif success and summary['decrease_percent'] > 20: + console.print("[bold yellow]⚠ MODERATE: Model is learning but slowly. Consider tuning hyperparameters.[/bold yellow]") + elif success: + console.print("[bold red]✗ POOR: Model not learning effectively. Needs hyperparameter adjustment.[/bold red]") + else: + console.print("[bold red]✗ FAILED: Training stopped early. Try different hyperparameters.[/bold red]") + + +if __name__ == "__main__": + main() +