mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-05-22 12:32:29 -05:00
1118 lines
38 KiB
Python
1118 lines
38 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
TinyTorch Milestone Learning Verification Tests
|
||
================================================
|
||
|
||
This test suite verifies that actual LEARNING is happening in all milestones.
|
||
We don't just check if code runs - we verify:
|
||
1. Loss decreases over training
|
||
2. Gradients flow properly (non-zero, reasonable magnitude)
|
||
3. Weights actually update
|
||
4. Models converge to expected performance
|
||
|
||
This is the "trust but verify" test for TinyTorch milestones.
|
||
"""
|
||
|
||
import sys
|
||
import os
|
||
import numpy as np
|
||
from pathlib import Path
|
||
|
||
# Add project root to path
|
||
project_root = Path(__file__).parent.parent.parent
|
||
sys.path.insert(0, str(project_root))
|
||
|
||
from tinytorch import Tensor, Linear, ReLU, Sigmoid, SGD, BinaryCrossEntropyLoss, CrossEntropyLoss
|
||
from tinytorch.core.spatial import Conv2d, MaxPool2d
|
||
from tinytorch.text.embeddings import Embedding, PositionalEncoding
|
||
from tinytorch.core.attention import MultiHeadAttention
|
||
from tinytorch.core.transformer import LayerNorm
|
||
from tinytorch.data.loader import TensorDataset, DataLoader
|
||
|
||
# Rich for beautiful output
|
||
from rich.console import Console
|
||
from rich.panel import Panel
|
||
from rich.table import Table
|
||
from rich import box
|
||
|
||
console = Console()
|
||
|
||
|
||
# =============================================================================
|
||
# UTILITY FUNCTIONS: Gradient and Learning Verification
|
||
# =============================================================================
|
||
|
||
def check_gradient_flow(parameters):
|
||
"""
|
||
Verify gradients are flowing properly.
|
||
|
||
Returns:
|
||
dict with gradient statistics
|
||
"""
|
||
stats = {
|
||
'total_params': 0,
|
||
'params_with_grad': 0,
|
||
'mean_grad_magnitude': 0.0,
|
||
'max_grad_magnitude': 0.0,
|
||
'min_grad_magnitude': float('inf'),
|
||
'zero_grad_params': []
|
||
}
|
||
|
||
grad_magnitudes = []
|
||
|
||
for i, param in enumerate(parameters):
|
||
stats['total_params'] += 1
|
||
|
||
if param.grad is not None:
|
||
stats['params_with_grad'] += 1
|
||
grad_magnitude = np.abs(param.grad.data).mean()
|
||
grad_magnitudes.append(grad_magnitude)
|
||
|
||
stats['max_grad_magnitude'] = max(stats['max_grad_magnitude'], np.abs(param.grad.data).max())
|
||
stats['min_grad_magnitude'] = min(stats['min_grad_magnitude'], np.abs(param.grad.data).min())
|
||
|
||
# Check for zero gradients (bad!)
|
||
if np.allclose(param.grad.data, 0.0):
|
||
stats['zero_grad_params'].append(i)
|
||
|
||
if grad_magnitudes:
|
||
stats['mean_grad_magnitude'] = np.mean(grad_magnitudes)
|
||
|
||
return stats
|
||
|
||
|
||
def check_weight_updates(params_before, params_after, atol=1e-6):
|
||
"""
|
||
Verify weights actually changed during training.
|
||
|
||
Args:
|
||
atol: Absolute tolerance for detecting weight changes
|
||
|
||
Returns:
|
||
dict with update statistics
|
||
"""
|
||
stats = {
|
||
'total_params': len(params_before),
|
||
'params_updated': 0,
|
||
'mean_weight_change': 0.0,
|
||
'max_weight_change': 0.0,
|
||
'unchanged_params': []
|
||
}
|
||
|
||
weight_changes = []
|
||
|
||
for i, (before, after) in enumerate(zip(params_before, params_after)):
|
||
weight_change = np.abs(after.data - before.data).mean()
|
||
weight_changes.append(weight_change)
|
||
|
||
stats['max_weight_change'] = max(stats['max_weight_change'], np.abs(after.data - before.data).max())
|
||
|
||
# Check if weights actually changed
|
||
if not np.allclose(before.data, after.data, atol=atol):
|
||
stats['params_updated'] += 1
|
||
else:
|
||
stats['unchanged_params'].append(i)
|
||
|
||
if weight_changes:
|
||
stats['mean_weight_change'] = np.mean(weight_changes)
|
||
|
||
return stats
|
||
|
||
|
||
def verify_loss_convergence(loss_history, min_decrease=0.1):
|
||
"""
|
||
Verify loss is decreasing (learning is happening).
|
||
|
||
Args:
|
||
loss_history: List of loss values over training
|
||
min_decrease: Minimum acceptable decrease (as fraction)
|
||
|
||
Returns:
|
||
dict with convergence statistics
|
||
"""
|
||
stats = {
|
||
'initial_loss': loss_history[0],
|
||
'final_loss': loss_history[-1],
|
||
'loss_decrease': loss_history[0] - loss_history[-1],
|
||
'loss_decrease_pct': (loss_history[0] - loss_history[-1]) / loss_history[0] * 100,
|
||
'monotonic_decrease': all(loss_history[i] >= loss_history[i+1] for i in range(len(loss_history)-1)),
|
||
'converged': (loss_history[0] - loss_history[-1]) / loss_history[0] >= min_decrease
|
||
}
|
||
|
||
return stats
|
||
|
||
|
||
# =============================================================================
|
||
# TEST 1: PERCEPTRON LEARNING (1957 - Rosenblatt)
|
||
# =============================================================================
|
||
|
||
def test_perceptron_learning():
|
||
"""
|
||
Verify the perceptron actually learns to classify linearly separable data.
|
||
|
||
Expected behavior:
|
||
- Loss should decrease significantly (>50%)
|
||
- All gradients should be non-zero
|
||
- Weights should update
|
||
- Final accuracy should be >90%
|
||
"""
|
||
console.print("\n" + "="*70)
|
||
console.print(Panel.fit(
|
||
"[bold cyan]TEST 1: Perceptron Learning Verification[/bold cyan]\n"
|
||
"[dim]1957 - Frank Rosenblatt[/dim]",
|
||
border_style="cyan"
|
||
))
|
||
console.print("="*70)
|
||
|
||
# Generate linearly separable data
|
||
np.random.seed(42)
|
||
n_samples = 100
|
||
|
||
# Class 1: Top-right cluster
|
||
class1 = np.random.randn(n_samples // 2, 2) * 0.5 + np.array([3, 3])
|
||
labels1 = np.ones((n_samples // 2, 1))
|
||
|
||
# Class 0: Bottom-left cluster
|
||
class0 = np.random.randn(n_samples // 2, 2) * 0.5 + np.array([1, 1])
|
||
labels0 = np.zeros((n_samples // 2, 1))
|
||
|
||
X = Tensor(np.vstack([class1, class0]))
|
||
y = Tensor(np.vstack([labels1, labels0]))
|
||
|
||
# Build perceptron
|
||
linear = Linear(2, 1)
|
||
activation = Sigmoid()
|
||
|
||
def perceptron(x):
|
||
return activation(linear(x))
|
||
|
||
# Get initial parameters
|
||
params = [linear.weight, linear.bias]
|
||
params_before = [Tensor(p.data.copy()) for p in params]
|
||
|
||
# Train
|
||
loss_fn = BinaryCrossEntropyLoss()
|
||
optimizer = SGD(params, lr=0.5) # Higher LR for faster convergence
|
||
|
||
epochs = 100
|
||
loss_history = []
|
||
|
||
console.print("\n🔬 Training perceptron...")
|
||
|
||
for epoch in range(epochs):
|
||
# Forward pass
|
||
predictions = perceptron(X)
|
||
loss = loss_fn(predictions, y)
|
||
|
||
# Backward pass
|
||
loss.backward()
|
||
|
||
# Check gradients on first epoch
|
||
if epoch == 0:
|
||
grad_stats = check_gradient_flow(params)
|
||
|
||
# Update weights
|
||
optimizer.step()
|
||
optimizer.zero_grad()
|
||
|
||
loss_history.append(loss.data.item())
|
||
|
||
if epoch % 10 == 0:
|
||
console.print(f" Epoch {epoch:2d}: Loss = {loss.data:.4f}")
|
||
|
||
# Final evaluation
|
||
predictions = perceptron(X)
|
||
pred_classes = (predictions.data > 0.5).astype(int)
|
||
accuracy = (pred_classes == y.data).mean() * 100
|
||
|
||
# Check weight updates
|
||
weight_stats = check_weight_updates(params_before, params)
|
||
|
||
# Check convergence
|
||
convergence_stats = verify_loss_convergence(loss_history, min_decrease=0.5)
|
||
|
||
# Display results
|
||
console.print("\n📊 Learning Verification Results:")
|
||
|
||
table = Table(title="Perceptron Learning Metrics", box=box.ROUNDED)
|
||
table.add_column("Metric", style="cyan")
|
||
table.add_column("Value", style="green")
|
||
table.add_column("Status", style="magenta")
|
||
|
||
table.add_row(
|
||
"Final Accuracy",
|
||
f"{accuracy:.1f}%",
|
||
"✅ PASS" if accuracy > 90 else "❌ FAIL"
|
||
)
|
||
table.add_row(
|
||
"Loss Decrease",
|
||
f"{convergence_stats['loss_decrease_pct']:.1f}%",
|
||
"✅ PASS" if convergence_stats['converged'] else "❌ FAIL"
|
||
)
|
||
table.add_row(
|
||
"Gradients Flowing",
|
||
f"{grad_stats['params_with_grad']}/{grad_stats['total_params']}",
|
||
"✅ PASS" if grad_stats['params_with_grad'] == grad_stats['total_params'] else "❌ FAIL"
|
||
)
|
||
table.add_row(
|
||
"Mean Gradient Magnitude",
|
||
f"{grad_stats['mean_grad_magnitude']:.6f}",
|
||
"✅ PASS" if grad_stats['mean_grad_magnitude'] > 1e-6 else "❌ FAIL"
|
||
)
|
||
table.add_row(
|
||
"Weights Updated",
|
||
f"{weight_stats['params_updated']}/{weight_stats['total_params']}",
|
||
"✅ PASS" if weight_stats['params_updated'] == weight_stats['total_params'] else "❌ FAIL"
|
||
)
|
||
table.add_row(
|
||
"Mean Weight Change",
|
||
f"{weight_stats['mean_weight_change']:.6f}",
|
||
"✅ PASS" if weight_stats['mean_weight_change'] > 1e-4 else "❌ FAIL"
|
||
)
|
||
|
||
console.print(table)
|
||
|
||
# Overall verdict
|
||
passed = (
|
||
accuracy > 90 and
|
||
convergence_stats['converged'] and
|
||
grad_stats['params_with_grad'] == grad_stats['total_params'] and
|
||
grad_stats['mean_grad_magnitude'] > 1e-6 and
|
||
weight_stats['params_updated'] == weight_stats['total_params']
|
||
)
|
||
|
||
if passed:
|
||
console.print("\n[bold green]✅ PERCEPTRON LEARNING VERIFIED[/bold green]")
|
||
console.print(" • Loss decreased significantly")
|
||
console.print(" • Gradients flow properly")
|
||
console.print(" • Weights updated correctly")
|
||
console.print(" • Model converged to high accuracy")
|
||
else:
|
||
console.print("\n[bold red]❌ PERCEPTRON LEARNING FAILED[/bold red]")
|
||
if accuracy <= 90:
|
||
console.print(f" • Accuracy too low: {accuracy:.1f}% (expected >90%)")
|
||
if not convergence_stats['converged']:
|
||
console.print(f" • Loss didn't decrease enough: {convergence_stats['loss_decrease_pct']:.1f}%")
|
||
if grad_stats['params_with_grad'] != grad_stats['total_params']:
|
||
console.print(f" • Some gradients missing: {grad_stats['params_with_grad']}/{grad_stats['total_params']}")
|
||
|
||
return passed
|
||
|
||
|
||
# =============================================================================
|
||
# TEST 2: XOR PROBLEM (1969 - Minsky & Papert)
|
||
# =============================================================================
|
||
|
||
def test_xor_learning():
|
||
"""
|
||
Verify MLP can learn XOR (showing limitations of single-layer perceptron).
|
||
|
||
Expected behavior:
|
||
- Loss should decrease to near zero
|
||
- All gradients should flow through hidden layer
|
||
- Perfect or near-perfect accuracy (100%)
|
||
"""
|
||
console.print("\n" + "="*70)
|
||
console.print(Panel.fit(
|
||
"[bold cyan]TEST 2: XOR Learning Verification[/bold cyan]\n"
|
||
"[dim]1969 - Minsky & Papert's Challenge[/dim]",
|
||
border_style="cyan"
|
||
))
|
||
console.print("="*70)
|
||
|
||
# XOR dataset
|
||
X = Tensor(np.array([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32))
|
||
y = Tensor(np.array([[0], [1], [1], [0]], dtype=np.float32))
|
||
|
||
# Build MLP
|
||
fc1 = Linear(2, 4)
|
||
relu = ReLU()
|
||
fc2 = Linear(4, 1)
|
||
sigmoid = Sigmoid()
|
||
|
||
def mlp(x):
|
||
x = fc1(x)
|
||
x = relu(x)
|
||
x = fc2(x)
|
||
x = sigmoid(x)
|
||
return x
|
||
|
||
# Get initial parameters
|
||
params = [fc1.weight, fc1.bias, fc2.weight, fc2.bias]
|
||
params_before = [Tensor(p.data.copy()) for p in params]
|
||
|
||
# Train
|
||
loss_fn = BinaryCrossEntropyLoss()
|
||
optimizer = SGD(params, lr=0.5)
|
||
|
||
epochs = 500
|
||
loss_history = []
|
||
|
||
console.print("\n🔬 Training MLP on XOR...")
|
||
|
||
for epoch in range(epochs):
|
||
# Forward pass
|
||
predictions = mlp(X)
|
||
loss = loss_fn(predictions, y)
|
||
|
||
# Backward pass
|
||
loss.backward()
|
||
|
||
# Check gradients on first epoch
|
||
if epoch == 0:
|
||
grad_stats = check_gradient_flow(params)
|
||
|
||
# Update weights
|
||
optimizer.step()
|
||
optimizer.zero_grad()
|
||
|
||
loss_history.append(loss.data.item())
|
||
|
||
if epoch % 100 == 0:
|
||
console.print(f" Epoch {epoch:3d}: Loss = {loss.data:.6f}")
|
||
|
||
# Final evaluation
|
||
predictions = mlp(X)
|
||
pred_classes = (predictions.data > 0.5).astype(int)
|
||
accuracy = (pred_classes == y.data).mean() * 100
|
||
|
||
# Check weight updates
|
||
weight_stats = check_weight_updates(params_before, params)
|
||
|
||
# Check convergence
|
||
convergence_stats = verify_loss_convergence(loss_history, min_decrease=0.8)
|
||
|
||
# Display results
|
||
console.print("\n📊 Learning Verification Results:")
|
||
|
||
table = Table(title="XOR Learning Metrics", box=box.ROUNDED)
|
||
table.add_column("Metric", style="cyan")
|
||
table.add_column("Value", style="green")
|
||
table.add_column("Status", style="magenta")
|
||
|
||
table.add_row(
|
||
"Final Accuracy",
|
||
f"{accuracy:.1f}%",
|
||
"✅ PASS" if accuracy == 100 else "⚠️ MARGINAL" if accuracy >= 75 else "❌ FAIL"
|
||
)
|
||
table.add_row(
|
||
"Final Loss",
|
||
f"{loss_history[-1]:.6f}",
|
||
"✅ PASS" if loss_history[-1] < 0.1 else "❌ FAIL"
|
||
)
|
||
table.add_row(
|
||
"Loss Decrease",
|
||
f"{convergence_stats['loss_decrease_pct']:.1f}%",
|
||
"✅ PASS" if convergence_stats['converged'] else "❌ FAIL"
|
||
)
|
||
table.add_row(
|
||
"Gradients Flowing",
|
||
f"{grad_stats['params_with_grad']}/{grad_stats['total_params']}",
|
||
"✅ PASS" if grad_stats['params_with_grad'] == grad_stats['total_params'] else "❌ FAIL"
|
||
)
|
||
table.add_row(
|
||
"Weights Updated",
|
||
f"{weight_stats['params_updated']}/{weight_stats['total_params']}",
|
||
"✅ PASS" if weight_stats['params_updated'] == weight_stats['total_params'] else "❌ FAIL"
|
||
)
|
||
|
||
console.print(table)
|
||
|
||
# Show predictions
|
||
console.print("\n🔍 XOR Predictions:")
|
||
for i in range(len(X.data)):
|
||
x_val = X.data[i]
|
||
true_val = int(y.data[i][0])
|
||
pred_val = predictions.data[i][0]
|
||
pred_class = int(pred_val > 0.5)
|
||
status = "✅" if pred_class == true_val else "❌"
|
||
console.print(f" {status} [{x_val[0]:.0f}, {x_val[1]:.0f}] → True: {true_val}, Pred: {pred_val:.4f} ({pred_class})")
|
||
|
||
# Overall verdict
|
||
passed = (
|
||
accuracy >= 75 and # XOR is hard, allow some tolerance
|
||
loss_history[-1] < 0.2 and
|
||
convergence_stats['converged'] and
|
||
grad_stats['params_with_grad'] == grad_stats['total_params'] and
|
||
weight_stats['params_updated'] == weight_stats['total_params']
|
||
)
|
||
|
||
if passed:
|
||
console.print("\n[bold green]✅ XOR LEARNING VERIFIED[/bold green]")
|
||
console.print(" • MLP solved XOR problem")
|
||
console.print(" • Gradients flow through hidden layer")
|
||
console.print(" • Non-linear problem solved with multi-layer network")
|
||
else:
|
||
console.print("\n[bold red]❌ XOR LEARNING FAILED[/bold red]")
|
||
console.print(" • Check learning rate, epochs, or architecture")
|
||
|
||
return passed
|
||
|
||
|
||
# =============================================================================
|
||
# TEST 3: MLP ON DIGITS (1986 - Rumelhart, Hinton, Williams)
|
||
# =============================================================================
|
||
|
||
def test_mlp_digits_learning():
|
||
"""
|
||
Verify MLP learns on real digit data (TinyDigits 8x8).
|
||
|
||
Expected behavior:
|
||
- Loss should decrease steadily
|
||
- Test accuracy should reach >70%
|
||
- Gradients should flow through all layers
|
||
- Overfitting gap should be reasonable (<15%)
|
||
"""
|
||
console.print("\n" + "="*70)
|
||
console.print(Panel.fit(
|
||
"[bold cyan]TEST 3: MLP Digit Classification Learning[/bold cyan]\n"
|
||
"[dim]1986 - Rumelhart, Hinton & Williams[/dim]",
|
||
border_style="cyan"
|
||
))
|
||
console.print("="*70)
|
||
|
||
# Load TinyDigits dataset
|
||
import pickle
|
||
train_path = project_root / "datasets" / "tinydigits" / "train.pkl"
|
||
test_path = project_root / "datasets" / "tinydigits" / "test.pkl"
|
||
|
||
if not train_path.exists() or not test_path.exists():
|
||
console.print(f"[yellow]⚠️ TinyDigits dataset not found, skipping test[/yellow]")
|
||
return True # Skip, don't fail
|
||
|
||
with open(train_path, 'rb') as f:
|
||
train_data = pickle.load(f)
|
||
with open(test_path, 'rb') as f:
|
||
test_data = pickle.load(f)
|
||
|
||
train_images = Tensor(train_data['images'].astype(np.float32))
|
||
train_labels = Tensor(train_data['labels'].astype(np.int64))
|
||
test_images = Tensor(test_data['images'].astype(np.float32))
|
||
test_labels = Tensor(test_data['labels'].astype(np.int64))
|
||
|
||
console.print(f"\n📊 Dataset: {len(train_images.data)} train, {len(test_images.data)} test")
|
||
|
||
# Build MLP
|
||
fc1 = Linear(64, 32)
|
||
relu = ReLU()
|
||
fc2 = Linear(32, 10)
|
||
|
||
def mlp(x):
|
||
# Flatten
|
||
if len(x.data.shape) > 2:
|
||
batch_size = x.data.shape[0]
|
||
x = Tensor(x.data.reshape(batch_size, -1))
|
||
x = fc1(x)
|
||
x = relu(x)
|
||
x = fc2(x)
|
||
return x
|
||
|
||
# Get initial parameters
|
||
params = [fc1.weight, fc1.bias, fc2.weight, fc2.bias]
|
||
params_before = [Tensor(p.data.copy()) for p in params]
|
||
|
||
# Train
|
||
loss_fn = CrossEntropyLoss()
|
||
optimizer = SGD(params, lr=0.01)
|
||
|
||
# Create DataLoader
|
||
train_dataset = TensorDataset(train_images, train_labels)
|
||
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
||
|
||
epochs = 25 # Increased from 15 - small dataset needs more epochs
|
||
loss_history = []
|
||
test_acc_history = []
|
||
|
||
console.print("\n🔬 Training MLP on TinyDigits...")
|
||
|
||
for epoch in range(epochs):
|
||
epoch_loss = 0.0
|
||
batch_count = 0
|
||
|
||
for batch_images, batch_labels in train_loader:
|
||
# Forward pass
|
||
logits = mlp(batch_images)
|
||
loss = loss_fn(logits, batch_labels)
|
||
|
||
# Backward pass
|
||
loss.backward()
|
||
|
||
# Check gradients on first batch
|
||
if epoch == 0 and batch_count == 0:
|
||
grad_stats = check_gradient_flow(params)
|
||
|
||
# Update weights
|
||
optimizer.step()
|
||
optimizer.zero_grad()
|
||
|
||
epoch_loss += loss.data
|
||
batch_count += 1
|
||
|
||
avg_loss = epoch_loss / batch_count
|
||
loss_history.append(avg_loss)
|
||
|
||
# Evaluate on test set
|
||
logits = mlp(test_images)
|
||
predictions = np.argmax(logits.data, axis=1)
|
||
test_acc = (predictions == test_labels.data).mean() * 100
|
||
test_acc_history.append(test_acc)
|
||
|
||
if epoch % 3 == 0:
|
||
console.print(f" Epoch {epoch:2d}: Loss = {avg_loss:.4f}, Test Acc = {test_acc:.1f}%")
|
||
|
||
# Final evaluation
|
||
final_test_acc = test_acc_history[-1]
|
||
|
||
# Check weight updates
|
||
weight_stats = check_weight_updates(params_before, params)
|
||
|
||
# Check convergence
|
||
convergence_stats = verify_loss_convergence(loss_history, min_decrease=0.3)
|
||
|
||
# Display results
|
||
console.print("\n📊 Learning Verification Results:")
|
||
|
||
table = Table(title="MLP Digits Learning Metrics", box=box.ROUNDED)
|
||
table.add_column("Metric", style="cyan")
|
||
table.add_column("Value", style="green")
|
||
table.add_column("Status", style="magenta")
|
||
|
||
table.add_row(
|
||
"Final Test Accuracy",
|
||
f"{final_test_acc:.1f}%",
|
||
"✅ PASS" if final_test_acc > 70 else "❌ FAIL"
|
||
)
|
||
table.add_row(
|
||
"Loss Decrease",
|
||
f"{convergence_stats['loss_decrease_pct']:.1f}%",
|
||
"✅ PASS" if convergence_stats['converged'] else "❌ FAIL"
|
||
)
|
||
table.add_row(
|
||
"Gradients Flowing",
|
||
f"{grad_stats['params_with_grad']}/{grad_stats['total_params']}",
|
||
"✅ PASS" if grad_stats['params_with_grad'] == grad_stats['total_params'] else "❌ FAIL"
|
||
)
|
||
table.add_row(
|
||
"Weights Updated",
|
||
f"{weight_stats['params_updated']}/{weight_stats['total_params']}",
|
||
"✅ PASS" if weight_stats['params_updated'] == weight_stats['total_params'] else "❌ FAIL"
|
||
)
|
||
|
||
console.print(table)
|
||
|
||
# Overall verdict
|
||
passed = (
|
||
final_test_acc > 70 and
|
||
convergence_stats['converged'] and
|
||
grad_stats['params_with_grad'] == grad_stats['total_params'] and
|
||
weight_stats['params_updated'] == weight_stats['total_params']
|
||
)
|
||
|
||
if passed:
|
||
console.print("\n[bold green]✅ MLP DIGITS LEARNING VERIFIED[/bold green]")
|
||
console.print(" • Model learned to classify real handwritten digits")
|
||
console.print(" • Gradients flow through multi-layer network")
|
||
console.print(" • DataLoader enables efficient batch training")
|
||
else:
|
||
console.print("\n[bold red]❌ MLP DIGITS LEARNING FAILED[/bold red]")
|
||
|
||
return passed
|
||
|
||
|
||
# =============================================================================
|
||
# TEST 4: CNN ON IMAGES (1998 - LeCun)
|
||
# =============================================================================
|
||
|
||
def test_cnn_learning():
|
||
"""
|
||
Verify CNN learns spatial features from images.
|
||
|
||
Expected behavior:
|
||
- Loss should decrease
|
||
- Convolutional gradients should flow
|
||
- Should outperform MLP on spatial data
|
||
- Test accuracy >75% (better than MLP's ~70%)
|
||
"""
|
||
console.print("\n" + "="*70)
|
||
console.print(Panel.fit(
|
||
"[bold cyan]TEST 4: CNN Learning Verification[/bold cyan]\n"
|
||
"[dim]1998 - Yann LeCun's Convolutional Networks[/dim]",
|
||
border_style="cyan"
|
||
))
|
||
console.print("="*70)
|
||
|
||
# Load TinyDigits dataset
|
||
import pickle
|
||
train_path = project_root / "datasets" / "tinydigits" / "train.pkl"
|
||
test_path = project_root / "datasets" / "tinydigits" / "test.pkl"
|
||
|
||
if not train_path.exists() or not test_path.exists():
|
||
console.print(f"[yellow]⚠️ TinyDigits dataset not found, skipping test[/yellow]")
|
||
return True # Skip, don't fail
|
||
|
||
with open(train_path, 'rb') as f:
|
||
train_data = pickle.load(f)
|
||
with open(test_path, 'rb') as f:
|
||
test_data = pickle.load(f)
|
||
|
||
# Add channel dimension: (N, 8, 8) → (N, 1, 8, 8)
|
||
train_images = Tensor(train_data['images'].astype(np.float32)[:, np.newaxis, :, :])
|
||
train_labels = Tensor(train_data['labels'].astype(np.int64))
|
||
test_images = Tensor(test_data['images'].astype(np.float32)[:, np.newaxis, :, :])
|
||
test_labels = Tensor(test_data['labels'].astype(np.int64))
|
||
|
||
console.print(f"\n📊 Dataset: {len(train_images.data)} train, {len(test_images.data)} test")
|
||
console.print(f" Image shape: {train_images.data.shape[1:]}")
|
||
|
||
# Build simple CNN (single pooling to avoid 0x0 spatial dims)
|
||
conv1 = Conv2d(1, 8, kernel_size=(3, 3))
|
||
conv1.weight.requires_grad = True
|
||
if conv1.bias is not None:
|
||
conv1.bias.requires_grad = True
|
||
relu1 = ReLU()
|
||
pool1 = MaxPool2d(2)
|
||
conv2 = Conv2d(8, 16, kernel_size=(3, 3))
|
||
conv2.weight.requires_grad = True
|
||
if conv2.bias is not None:
|
||
conv2.bias.requires_grad = True
|
||
relu2 = ReLU()
|
||
# After convs: 8x8 → conv1(3x3) → 6x6 → pool(2) → 3x3 → conv2(3x3) → 1x1
|
||
# Final shape: 16 channels × 1 × 1 = 16 features
|
||
fc = Linear(16 * 1 * 1, 10)
|
||
|
||
def cnn(x):
|
||
# Conv block 1
|
||
x = conv1(x)
|
||
x = relu1(x)
|
||
x = pool1(x)
|
||
|
||
# Conv block 2
|
||
x = conv2(x)
|
||
x = relu2(x)
|
||
# No second pooling - would create 0x0!
|
||
|
||
# Flatten and classify (using Tensor.reshape to preserve autograd)
|
||
batch_size = x.shape[0]
|
||
x = x.reshape(batch_size, -1)
|
||
x = fc(x)
|
||
return x
|
||
|
||
# Get initial parameters (ReLU has no parameters)
|
||
params = [conv1.weight, conv1.bias, conv2.weight, conv2.bias, fc.weight, fc.bias]
|
||
params_before = [Tensor(p.data.copy()) for p in params]
|
||
|
||
# Train
|
||
loss_fn = CrossEntropyLoss()
|
||
optimizer = SGD(params, lr=0.01)
|
||
|
||
# Create DataLoader
|
||
train_dataset = TensorDataset(train_images, train_labels)
|
||
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) # Match MLP batch size for fair comparison
|
||
|
||
epochs = 25 # Same epochs and batch size as MLP for fair comparison
|
||
loss_history = []
|
||
test_acc_history = []
|
||
conv_grad_mean = 0.0 # Track conv gradient magnitude
|
||
|
||
console.print("\n🔬 Training CNN on TinyDigits...")
|
||
|
||
for epoch in range(epochs):
|
||
epoch_loss = 0.0
|
||
batch_count = 0
|
||
|
||
for batch_images, batch_labels in train_loader:
|
||
# Forward pass
|
||
logits = cnn(batch_images)
|
||
loss = loss_fn(logits, batch_labels)
|
||
|
||
# Backward pass
|
||
loss.backward()
|
||
|
||
# Check gradients on first batch (before zero_grad clears them!)
|
||
if epoch == 0 and batch_count == 0:
|
||
grad_stats = check_gradient_flow(params)
|
||
# Also capture conv gradient magnitude before it gets zeroed
|
||
conv_grad_mean = np.abs(conv1.weight.grad.data).mean() if conv1.weight.grad is not None else 0.0
|
||
|
||
# Update weights
|
||
optimizer.step()
|
||
optimizer.zero_grad()
|
||
|
||
epoch_loss += loss.data
|
||
batch_count += 1
|
||
|
||
avg_loss = epoch_loss / batch_count
|
||
loss_history.append(avg_loss)
|
||
|
||
# Evaluate on test set
|
||
logits = cnn(test_images)
|
||
predictions = np.argmax(logits.data, axis=1)
|
||
test_acc = (predictions == test_labels.data).mean() * 100
|
||
test_acc_history.append(test_acc)
|
||
|
||
if epoch % 3 == 0:
|
||
console.print(f" Epoch {epoch:2d}: Loss = {avg_loss:.4f}, Test Acc = {test_acc:.1f}%")
|
||
|
||
# Final evaluation
|
||
final_test_acc = test_acc_history[-1]
|
||
|
||
# Check weight updates
|
||
weight_stats = check_weight_updates(params_before, params)
|
||
|
||
# Check convergence
|
||
convergence_stats = verify_loss_convergence(loss_history, min_decrease=0.3)
|
||
|
||
# Display results
|
||
console.print("\n📊 Learning Verification Results:")
|
||
|
||
table = Table(title="CNN Learning Metrics", box=box.ROUNDED)
|
||
table.add_column("Metric", style="cyan")
|
||
table.add_column("Value", style="green")
|
||
table.add_column("Status", style="magenta")
|
||
|
||
table.add_row(
|
||
"Final Test Accuracy",
|
||
f"{final_test_acc:.1f}%",
|
||
"✅ PASS" if final_test_acc > 75 else "⚠️ MARGINAL" if final_test_acc > 70 else "❌ FAIL"
|
||
)
|
||
table.add_row(
|
||
"Loss Decrease",
|
||
f"{convergence_stats['loss_decrease_pct']:.1f}%",
|
||
"✅ PASS" if convergence_stats['converged'] else "❌ FAIL"
|
||
)
|
||
table.add_row(
|
||
"Gradients Flowing",
|
||
f"{grad_stats['params_with_grad']}/{grad_stats['total_params']}",
|
||
"✅ PASS" if grad_stats['params_with_grad'] == grad_stats['total_params'] else "❌ FAIL"
|
||
)
|
||
# Check convolutional gradients exist (captured during training before zero_grad)
|
||
table.add_row(
|
||
"Conv Gradients",
|
||
f"{conv_grad_mean:.6f}",
|
||
"✅ PASS" if conv_grad_mean > 1e-6 else "❌ FAIL"
|
||
)
|
||
table.add_row(
|
||
"Weights Updated",
|
||
f"{weight_stats['params_updated']}/{weight_stats['total_params']}",
|
||
"✅ PASS" if weight_stats['params_updated'] == weight_stats['total_params'] else "❌ FAIL"
|
||
)
|
||
|
||
console.print(table)
|
||
|
||
# Overall verdict
|
||
passed = (
|
||
final_test_acc > 70 and # Allow slight tolerance
|
||
convergence_stats['converged'] and
|
||
grad_stats['params_with_grad'] == grad_stats['total_params'] and
|
||
weight_stats['params_updated'] == weight_stats['total_params']
|
||
)
|
||
|
||
if passed:
|
||
console.print("\n[bold green]✅ CNN LEARNING VERIFIED[/bold green]")
|
||
console.print(" • CNN learned spatial features from images")
|
||
console.print(" • Convolution gradients flow properly")
|
||
console.print(" • Spatial structure preserved (vs MLP flattening)")
|
||
else:
|
||
console.print("\n[bold red]❌ CNN LEARNING FAILED[/bold red]")
|
||
|
||
return passed
|
||
|
||
|
||
# =============================================================================
|
||
# TEST 5: TRANSFORMER LEARNING (2017 - Vaswani et al.)
|
||
# =============================================================================
|
||
|
||
def test_transformer_learning():
|
||
"""
|
||
Verify transformer learns on sequence data.
|
||
|
||
Expected behavior:
|
||
- Loss should decrease
|
||
- Attention gradients should flow
|
||
- Embedding gradients should flow
|
||
- All transformer components receive gradients
|
||
"""
|
||
console.print("\n" + "="*70)
|
||
console.print(Panel.fit(
|
||
"[bold cyan]TEST 5: Transformer Learning Verification[/bold cyan]\n"
|
||
"[dim]2017 - Vaswani et al. 'Attention is All You Need'[/dim]",
|
||
border_style="cyan"
|
||
))
|
||
console.print("="*70)
|
||
|
||
# Create simple sequence modeling task: copy sequence
|
||
np.random.seed(42)
|
||
vocab_size = 20
|
||
seq_length = 8
|
||
batch_size = 16
|
||
|
||
# Task: model should learn to predict next token
|
||
X = np.random.randint(0, vocab_size, (batch_size, seq_length))
|
||
# Shift by 1 for next-token prediction
|
||
y = np.roll(X, -1, axis=1)
|
||
y[:, -1] = 0 # Pad last position
|
||
|
||
X_tensor = Tensor(X)
|
||
y_tensor = Tensor(y)
|
||
|
||
console.print(f"\n📊 Task: Next-token prediction")
|
||
console.print(f" Vocab size: {vocab_size}")
|
||
console.print(f" Sequence length: {seq_length}")
|
||
console.print(f" Batch size: {batch_size}")
|
||
|
||
# Build transformer
|
||
embed_dim = 32
|
||
num_heads = 4
|
||
|
||
embedding = Embedding(vocab_size, embed_dim)
|
||
pos_encoding = PositionalEncoding(seq_length, embed_dim)
|
||
attention = MultiHeadAttention(embed_dim, num_heads)
|
||
ln1 = LayerNorm(embed_dim)
|
||
ln2 = LayerNorm(embed_dim)
|
||
fc1 = Linear(embed_dim, embed_dim * 2)
|
||
relu_ffn = ReLU()
|
||
fc2 = Linear(embed_dim * 2, embed_dim)
|
||
output_proj = Linear(embed_dim, vocab_size)
|
||
|
||
def transformer(x):
|
||
# Embed
|
||
x = embedding(x)
|
||
x = pos_encoding(x)
|
||
|
||
# Attention block (self-attention)
|
||
attn_out = attention.forward(x)
|
||
x = ln1(x + attn_out) # Residual (preserves autograd)
|
||
|
||
# FFN block
|
||
ffn_out = fc2(relu_ffn(fc1(x)))
|
||
x = ln2(x + ffn_out) # Residual (preserves autograd)
|
||
|
||
# Project to vocab
|
||
batch, seq, embed = x.shape
|
||
x_2d = x.reshape(batch * seq, embed)
|
||
logits_2d = output_proj(x_2d)
|
||
logits = logits_2d.reshape(batch, seq, vocab_size)
|
||
|
||
return logits
|
||
|
||
# Get all parameters
|
||
params = (
|
||
[embedding.weight] +
|
||
attention.parameters() +
|
||
ln1.parameters() + ln2.parameters() +
|
||
[fc1.weight, fc1.bias, fc2.weight, fc2.bias] +
|
||
[output_proj.weight, output_proj.bias]
|
||
)
|
||
params_before = [Tensor(p.data.copy()) for p in params]
|
||
|
||
# Train
|
||
loss_fn = CrossEntropyLoss()
|
||
optimizer = SGD(params, lr=0.01)
|
||
|
||
epochs = 30
|
||
loss_history = []
|
||
|
||
console.print("\n🔬 Training transformer on next-token prediction...")
|
||
|
||
for epoch in range(epochs):
|
||
# Forward pass
|
||
logits = transformer(X_tensor)
|
||
|
||
# Reshape for loss computation
|
||
logits_2d = logits.reshape(batch_size * seq_length, vocab_size)
|
||
y_flat = y_tensor.reshape(batch_size * seq_length)
|
||
|
||
loss = loss_fn(logits_2d, y_flat)
|
||
|
||
# Backward pass
|
||
loss.backward()
|
||
|
||
# Check gradients on first epoch
|
||
if epoch == 0:
|
||
grad_stats = check_gradient_flow(params)
|
||
|
||
# Specifically check attention and embedding gradients
|
||
attn_has_grad = attention.parameters()[0].grad is not None
|
||
embed_has_grad = embedding.weight.grad is not None
|
||
|
||
# Update weights
|
||
optimizer.step()
|
||
optimizer.zero_grad()
|
||
|
||
loss_history.append(loss.data.item())
|
||
|
||
if epoch % 5 == 0:
|
||
console.print(f" Epoch {epoch:2d}: Loss = {loss.data:.4f}")
|
||
|
||
# Check weight updates (relaxed tolerance for LayerNorm params)
|
||
weight_stats = check_weight_updates(params_before, params, atol=1e-5)
|
||
|
||
# Check convergence (adjusted for transformer complexity)
|
||
convergence_stats = verify_loss_convergence(loss_history, min_decrease=0.12)
|
||
|
||
# Display results
|
||
console.print("\n📊 Learning Verification Results:")
|
||
|
||
table = Table(title="Transformer Learning Metrics", box=box.ROUNDED)
|
||
table.add_column("Metric", style="cyan")
|
||
table.add_column("Value", style="green")
|
||
table.add_column("Status", style="magenta")
|
||
|
||
table.add_row(
|
||
"Final Loss",
|
||
f"{loss_history[-1]:.4f}",
|
||
"✅ PASS" if loss_history[-1] < loss_history[0] * 0.8 else "❌ FAIL"
|
||
)
|
||
table.add_row(
|
||
"Loss Decrease",
|
||
f"{convergence_stats['loss_decrease_pct']:.1f}%",
|
||
"✅ PASS" if convergence_stats['converged'] else "❌ FAIL"
|
||
)
|
||
table.add_row(
|
||
"Gradients Flowing",
|
||
f"{grad_stats['params_with_grad']}/{grad_stats['total_params']}",
|
||
"✅ PASS" if grad_stats['params_with_grad'] == grad_stats['total_params'] else "❌ FAIL"
|
||
)
|
||
table.add_row(
|
||
"Attention Gradients",
|
||
"Yes" if attn_has_grad else "No",
|
||
"✅ PASS" if attn_has_grad else "❌ FAIL"
|
||
)
|
||
table.add_row(
|
||
"Embedding Gradients",
|
||
"Yes" if embed_has_grad else "No",
|
||
"✅ PASS" if embed_has_grad else "❌ FAIL"
|
||
)
|
||
table.add_row(
|
||
"Weights Updated",
|
||
f"{weight_stats['params_updated']}/{weight_stats['total_params']}",
|
||
"✅ PASS" if weight_stats['params_updated'] == weight_stats['total_params'] else "❌ FAIL"
|
||
)
|
||
|
||
console.print(table)
|
||
|
||
# Overall verdict (relaxed weight update requirement for Transformer)
|
||
# Note: Some params (LayerNorm) may have tiny but valid updates
|
||
passed = (
|
||
convergence_stats['converged'] and
|
||
grad_stats['params_with_grad'] == grad_stats['total_params'] and
|
||
attn_has_grad and
|
||
embed_has_grad and
|
||
weight_stats['params_updated'] >= weight_stats['total_params'] * 0.6 # At least 60% updated
|
||
)
|
||
|
||
if passed:
|
||
console.print("\n[bold green]✅ TRANSFORMER LEARNING VERIFIED[/bold green]")
|
||
console.print(" • Transformer learned on sequence data")
|
||
console.print(" • Attention gradients flow properly")
|
||
console.print(" • Embedding gradients flow properly")
|
||
console.print(" • All transformer components update correctly")
|
||
else:
|
||
console.print("\n[bold red]❌ TRANSFORMER LEARNING FAILED[/bold red]")
|
||
if not attn_has_grad:
|
||
console.print(" • Attention gradients not flowing")
|
||
if not embed_has_grad:
|
||
console.print(" • Embedding gradients not flowing")
|
||
|
||
return passed
|
||
|
||
|
||
# =============================================================================
|
||
# MAIN TEST RUNNER
|
||
# =============================================================================
|
||
|
||
def run_all_learning_tests():
|
||
"""Run all milestone learning verification tests."""
|
||
console.print("\n" + "🔥"*35)
|
||
console.print(Panel.fit(
|
||
"[bold cyan]TINYTORCH MILESTONE LEARNING VERIFICATION[/bold cyan]\n\n"
|
||
"[dim]Verifying that actual LEARNING happens in all milestones.[/dim]\n"
|
||
"[dim]We don't just check if code runs - we verify convergence![/dim]",
|
||
border_style="cyan",
|
||
box=box.DOUBLE
|
||
))
|
||
console.print("🔥"*35)
|
||
|
||
results = []
|
||
|
||
# Test each milestone
|
||
tests = [
|
||
("1957 Perceptron", test_perceptron_learning),
|
||
("1969 XOR (MLP)", test_xor_learning),
|
||
("1986 MLP Digits", test_mlp_digits_learning),
|
||
("1998 CNN", test_cnn_learning),
|
||
("2017 Transformer", test_transformer_learning),
|
||
]
|
||
|
||
for name, test_fn in tests:
|
||
try:
|
||
console.print(f"\n[bold]Running: {name}[/bold]")
|
||
passed = test_fn()
|
||
results.append((name, passed, None))
|
||
except Exception as e:
|
||
console.print(f"[red]❌ {name} test crashed: {e}[/red]")
|
||
import traceback
|
||
traceback.print_exc()
|
||
results.append((name, False, str(e)))
|
||
|
||
# Summary
|
||
console.print("\n" + "="*70)
|
||
console.print(Panel.fit(
|
||
"[bold]MILESTONE LEARNING VERIFICATION SUMMARY[/bold]",
|
||
border_style="cyan"
|
||
))
|
||
console.print("="*70)
|
||
|
||
summary_table = Table(title="Test Results", box=box.ROUNDED)
|
||
summary_table.add_column("Milestone", style="cyan", width=25)
|
||
summary_table.add_column("Status", style="green", width=15)
|
||
summary_table.add_column("Notes", style="dim", width=25)
|
||
|
||
all_passed = True
|
||
for name, passed, error in results:
|
||
status = "✅ PASSED" if passed else "❌ FAILED"
|
||
notes = "" if passed else (error[:25] + "..." if error and len(error) > 25 else error or "Learning failed")
|
||
summary_table.add_row(name, status, notes)
|
||
all_passed = all_passed and passed
|
||
|
||
console.print(summary_table)
|
||
|
||
if all_passed:
|
||
console.print("\n" + "🎉"*35)
|
||
console.print(Panel.fit(
|
||
"[bold green]✅ ALL MILESTONES VERIFIED![/bold green]\n\n"
|
||
"[bold]Every milestone demonstrates real learning:[/bold]\n"
|
||
" • Loss decreases over training\n"
|
||
" • Gradients flow properly through all layers\n"
|
||
" • Weights update correctly\n"
|
||
" • Models converge to expected performance\n\n"
|
||
"[dim]Students using TinyTorch will build systems that actually learn!\n"
|
||
"This is the foundation of all deep learning - and it's verified.[/dim]",
|
||
border_style="green",
|
||
box=box.DOUBLE
|
||
))
|
||
console.print("🎉"*35)
|
||
else:
|
||
console.print("\n" + "⚠️ "*35)
|
||
console.print(Panel.fit(
|
||
"[bold yellow]⚠️ SOME MILESTONES NEED ATTENTION[/bold yellow]\n\n"
|
||
"Check the failed tests above for details.\n"
|
||
"Common issues:\n"
|
||
" • Missing gradients (check backward() implementation)\n"
|
||
" • Weights not updating (check optimizer step())\n"
|
||
" • Loss not decreasing (check learning rate or architecture)\n"
|
||
" • Data issues (check dataset loading)",
|
||
border_style="yellow",
|
||
box=box.DOUBLE
|
||
))
|
||
console.print("⚠️ "*35)
|
||
|
||
return all_passed
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import sys
|
||
success = run_all_learning_tests()
|
||
sys.exit(0 if success else 1)
|
||
|