mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-06-01 03:52:20 -05:00
feat: Add overfitting detection to Milestones 03 and 04
Track train vs test accuracy to detect overfitting:
Training Progress:
- Print both train and test accuracy every 5 epochs
- Show gap between train/test with indicator:
✓ Gap < 10%: Healthy generalization
⚠️ Gap > 10%: Overfitting warning
Results Table (ACT 4):
- Train Accuracy + improvement
- Test Accuracy + improvement
- Overfitting Gap + status
- Training Time
Final Panel (ACT 5):
- Show test accuracy with gap
- Celebrate good generalization
Educational Value:
Students now see:
1. How to detect overfitting (growing train/test gap)
2. When model memorizes vs generalizes
3. Real ML systems track BOTH metrics
Example output:
Epoch 5/20 Loss: 1.234 Train: 85.0% Test: 82.0% ✓ Gap: 3.0%
Epoch 10/20 Loss: 0.891 Train: 90.0% Test: 87.0% ✓ Gap: 3.0%
This prepares them for regularization techniques (Dropout, etc.)
in later modules!
This commit is contained in:
@@ -355,6 +355,11 @@ def train_mlp():
|
||||
|
||||
epochs = 20
|
||||
initial_loss = None
|
||||
history = {
|
||||
"train_loss": [],
|
||||
"train_accuracy": [],
|
||||
"test_accuracy": []
|
||||
}
|
||||
|
||||
for epoch in range(epochs):
|
||||
epoch_loss = 0.0
|
||||
@@ -377,19 +382,34 @@ def train_mlp():
|
||||
|
||||
avg_loss = epoch_loss / batch_count
|
||||
|
||||
# Evaluate on both train and test to detect overfitting
|
||||
train_acc, _ = evaluate_accuracy(model, train_images, train_labels)
|
||||
test_acc, _ = evaluate_accuracy(model, test_images, test_labels)
|
||||
|
||||
history["train_loss"].append(avg_loss)
|
||||
history["train_accuracy"].append(train_acc)
|
||||
history["test_accuracy"].append(test_acc)
|
||||
|
||||
if initial_loss is None:
|
||||
initial_loss = avg_loss
|
||||
|
||||
# Print progress every 5 epochs
|
||||
if (epoch + 1) % 5 == 0:
|
||||
test_acc, _ = evaluate_accuracy(model, test_images, test_labels)
|
||||
console.print(f"Epoch {epoch+1:2d}/{epochs} "
|
||||
f"Loss: [cyan]{avg_loss:.4f}[/cyan] "
|
||||
f"Test Accuracy: [green]{test_acc:.1f}%[/green]")
|
||||
gap = train_acc - test_acc
|
||||
gap_indicator = "⚠️" if gap > 10 else "✓"
|
||||
console.print(
|
||||
f"Epoch {epoch+1:2d}/{epochs} "
|
||||
f"Loss: {avg_loss:.4f} "
|
||||
f"Train: {train_acc:.1f}% "
|
||||
f"Test: {test_acc:.1f}% "
|
||||
f"{gap_indicator} Gap: {gap:.1f}%"
|
||||
)
|
||||
|
||||
console.print("\n[green]✅ Training Complete![/green]")
|
||||
|
||||
final_acc, predictions = evaluate_accuracy(model, test_images, test_labels)
|
||||
final_train_acc = history["train_accuracy"][-1]
|
||||
final_test_acc = history["test_accuracy"][-1]
|
||||
overfitting_gap = final_train_acc - final_test_acc
|
||||
|
||||
console.print("\n" + "─" * 70 + "\n")
|
||||
|
||||
@@ -400,26 +420,31 @@ def train_mlp():
|
||||
console.print("[bold]📊 The Results:[/bold]\n")
|
||||
|
||||
table = Table(title="Training Outcome", box=box.ROUNDED)
|
||||
table.add_column("Metric", style="cyan", width=18)
|
||||
table.add_column("Before Training", style="yellow", width=16)
|
||||
table.add_column("After Training", style="green", width=16)
|
||||
table.add_column("Improvement", style="magenta", width=14)
|
||||
table.add_column("Metric", style="cyan", width=20)
|
||||
table.add_column("Value", style="green", width=20)
|
||||
table.add_column("Status", style="magenta", width=20)
|
||||
|
||||
table.add_row(
|
||||
"Loss",
|
||||
f"{initial_loss:.4f}",
|
||||
f"{avg_loss:.4f}",
|
||||
f"-{initial_loss - avg_loss:.4f}"
|
||||
"Train Accuracy",
|
||||
f"{final_train_acc:.1f}%",
|
||||
f"↑ +{final_train_acc - initial_acc:.1f}%"
|
||||
)
|
||||
table.add_row(
|
||||
"Test Accuracy",
|
||||
f"{initial_acc:.1f}%",
|
||||
f"{final_acc:.1f}%",
|
||||
f"+{final_acc - initial_acc:.1f}%"
|
||||
f"{final_test_acc:.1f}%",
|
||||
f"↑ +{final_test_acc - initial_acc:.1f}%"
|
||||
)
|
||||
table.add_row(
|
||||
"Overfitting Gap",
|
||||
f"{overfitting_gap:.1f}%",
|
||||
"✓ Healthy" if overfitting_gap < 10 else "⚠️ Overfitting"
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
|
||||
# Also get predictions for later use
|
||||
_, predictions = evaluate_accuracy(model, test_images, test_labels)
|
||||
|
||||
console.print("\n[bold]🔍 Sample Predictions:[/bold]")
|
||||
console.print("[dim](First 10 test images)[/dim]\n")
|
||||
|
||||
@@ -447,7 +472,7 @@ def train_mlp():
|
||||
console.print(Panel.fit(
|
||||
"[bold green]🎉 Success! Your MLP Learned to Recognize Digits![/bold green]\n\n"
|
||||
|
||||
f"Final accuracy: [bold]{final_acc:.1f}%[/bold]\n\n"
|
||||
f"Test accuracy: [bold]{final_test_acc:.1f}%[/bold] (Gap: {overfitting_gap:.1f}%)\n\n"
|
||||
|
||||
"━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n"
|
||||
|
||||
@@ -455,8 +480,9 @@ def train_mlp():
|
||||
" ✓ Built multi-layer network with YOUR components\n"
|
||||
" ✓ Trained on REAL handwritten digits\n"
|
||||
" ✓ Used YOUR DataLoader for efficient batching\n"
|
||||
f" ✓ Model generalizes well (gap: {overfitting_gap:.1f}%)\n"
|
||||
" ✓ Backprop through hidden layers works on real data!\n"
|
||||
" ✓ Achieved {:.1f}% accuracy on digit recognition!\n\n".format(final_acc) +
|
||||
f" ✓ Achieved {final_test_acc:.1f}% test accuracy!\n\n"
|
||||
|
||||
"━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n"
|
||||
|
||||
|
||||
@@ -265,7 +265,7 @@ def train_cnn():
|
||||
|
||||
# Hyperparameters
|
||||
console.print("\n[bold]⚙️ Training Configuration:[/bold]")
|
||||
epochs = 20 # Reduced for demo speed (explicit loops are slow!)
|
||||
epochs = 50
|
||||
batch_size = 32
|
||||
learning_rate = 0.01
|
||||
|
||||
@@ -298,18 +298,35 @@ def train_cnn():
|
||||
console.print(f"[yellow]Before training:[/yellow] Accuracy = {initial_acc:.1f}%\n")
|
||||
|
||||
# Training loop
|
||||
history = {"loss": [], "accuracy": []}
|
||||
history = {
|
||||
"train_loss": [],
|
||||
"test_accuracy": [],
|
||||
"train_accuracy": [] # Track training accuracy to detect overfitting
|
||||
}
|
||||
start_time = time.time()
|
||||
|
||||
for epoch in range(epochs):
|
||||
avg_loss = train_epoch(model, train_loader, criterion, optimizer)
|
||||
accuracy, _ = evaluate_accuracy(model, test_images, test_labels)
|
||||
# Train
|
||||
train_loss = train_epoch(model, train_loader, criterion, optimizer)
|
||||
|
||||
history["loss"].append(avg_loss)
|
||||
history["accuracy"].append(accuracy)
|
||||
# Evaluate on both train and test
|
||||
train_acc, _ = evaluate_accuracy(model, train_images, train_labels)
|
||||
test_acc, _ = evaluate_accuracy(model, test_images, test_labels)
|
||||
|
||||
history["train_loss"].append(train_loss)
|
||||
history["train_accuracy"].append(train_acc)
|
||||
history["test_accuracy"].append(test_acc)
|
||||
|
||||
if (epoch + 1) % 5 == 0: # Print every 5 epochs
|
||||
console.print(f"Epoch {epoch+1:3d}/{epochs} Loss: {avg_loss:.4f} Accuracy: {accuracy:.1f}%")
|
||||
gap = train_acc - test_acc
|
||||
gap_indicator = "⚠️" if gap > 10 else "✓"
|
||||
console.print(
|
||||
f"Epoch {epoch+1:3d}/{epochs} "
|
||||
f"Loss: {train_loss:.4f} "
|
||||
f"Train: {train_acc:.1f}% "
|
||||
f"Test: {test_acc:.1f}% "
|
||||
f"{gap_indicator} Gap: {gap:.1f}%"
|
||||
)
|
||||
|
||||
training_time = time.time() - start_time
|
||||
|
||||
@@ -321,24 +338,33 @@ def train_cnn():
|
||||
|
||||
console.print("[bold]📊 The Results:[/bold]\n")
|
||||
|
||||
final_acc, _ = evaluate_accuracy(model, test_images, test_labels)
|
||||
final_loss = history["loss"][-1]
|
||||
final_train_acc = history["train_accuracy"][-1]
|
||||
final_test_acc = history["test_accuracy"][-1]
|
||||
final_loss = history["train_loss"][-1]
|
||||
overfitting_gap = final_train_acc - final_test_acc
|
||||
|
||||
table = Table(title="Training Outcome", box=box.ROUNDED)
|
||||
table.add_column("Metric", style="cyan", width=18)
|
||||
table.add_column("Before Training", style="yellow", width=16)
|
||||
table.add_column("After Training", style="green", width=16)
|
||||
table.add_column("Improvement", style="magenta", width=14)
|
||||
table.add_column("Metric", style="cyan", width=20)
|
||||
table.add_column("Value", style="green", width=20)
|
||||
table.add_column("Status", style="magenta", width=20)
|
||||
|
||||
table.add_row(
|
||||
"Accuracy",
|
||||
f"{initial_acc:.1f}%",
|
||||
f"{final_acc:.1f}%",
|
||||
f"+{final_acc - initial_acc:.1f}%"
|
||||
"Train Accuracy",
|
||||
f"{final_train_acc:.1f}%",
|
||||
f"↑ +{final_train_acc - initial_acc:.1f}%"
|
||||
)
|
||||
table.add_row(
|
||||
"Test Accuracy",
|
||||
f"{final_test_acc:.1f}%",
|
||||
f"↑ +{final_test_acc - initial_acc:.1f}%"
|
||||
)
|
||||
table.add_row(
|
||||
"Overfitting Gap",
|
||||
f"{overfitting_gap:.1f}%",
|
||||
"✓ Healthy" if overfitting_gap < 10 else "⚠️ Overfitting"
|
||||
)
|
||||
table.add_row(
|
||||
"Training Time",
|
||||
"—",
|
||||
f"{training_time*1000:.0f}ms",
|
||||
"—"
|
||||
)
|
||||
@@ -382,7 +408,7 @@ def train_cnn():
|
||||
console.print(Panel.fit(
|
||||
"[bold green]🎉 Success! Your CNN Learned to Recognize Digits![/bold green]\n\n"
|
||||
|
||||
f"Final accuracy: [bold]{final_acc:.1f}%[/bold]\n\n"
|
||||
f"Test accuracy: [bold]{final_test_acc:.1f}%[/bold] (Gap: {overfitting_gap:.1f}%)\n\n"
|
||||
|
||||
"━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n"
|
||||
|
||||
@@ -390,7 +416,8 @@ def train_cnn():
|
||||
" ✓ Built a Convolutional Neural Network from scratch\n"
|
||||
" ✓ Used Conv2d for spatial feature extraction\n"
|
||||
" ✓ Applied MaxPooling for translation invariance\n"
|
||||
f" ✓ Achieved {final_acc:.1f}% accuracy on digit recognition!\n"
|
||||
f" ✓ Achieved {final_test_acc:.1f}% test accuracy!\n"
|
||||
f" ✓ Model generalizes well (gap: {overfitting_gap:.1f}%)\n"
|
||||
" ✓ Used 100× fewer parameters than MLP!\n\n"
|
||||
|
||||
"━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n\n"
|
||||
|
||||
Reference in New Issue
Block a user