Files
TinyTorch/examples/lenet_1998/train_mlp.py
Vijay Janapa Reddi 6491a7512e Clean up repository: remove temp files, organize modules, prepare for PyPI publication
- Removed temporary test files and audit reports
- Deleted backup and temp_holding directories
- Reorganized module structure (07->09 spatial, 09->07 dataloader)
- Added new modules: 11-14 (tokenization, embeddings, attention, transformers)
- Updated examples with historical ML milestones
- Cleaned up documentation structure
2025-09-24 10:13:37 -04:00

103 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Clean MNIST Example - What Students Built
=========================================
After completing modules 02-07, students can classify handwritten digits.
This demonstrates how multi-layer perceptrons solve real vision tasks.
MODULES EXERCISED IN THIS EXAMPLE:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Module 02 (Tensor) : Data structure with gradient tracking + basic autograd
Module 03 (Activations) : ReLU activation function
Module 04 (Layers) : Linear layers + Module base + Flatten operation
Module 05 (Loss) : CrossEntropy loss for multi-class classification
Module 06 (Optimizers) : Adam optimizer with adaptive learning
Module 07 (Training) : Complete training loops and evaluation
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
MLP Architecture:
┌─────────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐
│ Input Image │ │ Flatten │ │ Dense │ │ Dense │ │ Output │
│ (28×28) │───▶│ (784) │───▶│ (128) │───▶│ (64) │───▶│ (10) │
│ Pixels │ │ Module │ │ Linear │ │ Linear │ │ Classes │
└─────────────┘ │ 04 │ │ +ReLU │ │ +ReLU │ │Module 04│
└─────────┘ │Module 04│ │Module 04│ └─────────┘
└─────────┘ └─────────┘
Key Insight: Simple MLPs can achieve 95%+ accuracy on MNIST digits
Hidden layers learn hierarchical feature representations
"""
from tinytorch import nn, optim
from tinytorch.core.tensor import Tensor
from tinytorch.core.training import CrossEntropyLoss
import numpy as np
class MNISTMLP(nn.Module):
def __init__(self):
super().__init__() # Module 04: You built Module base class!
self.fc1 = nn.Linear(784, 128) # Module 04: You built Linear layers!
self.fc2 = nn.Linear(128, 64) # Module 04: You built weight matrices!
self.fc3 = nn.Linear(64, 10) # Module 04: Your output layer!
def forward(self, x):
x = nn.F.flatten(x, start_dim=1) # Module 04: You built flatten!
x = self.fc1(x) # Module 04: Your Linear.forward()!
x = nn.F.relu(x) # Module 03: You built ReLU activation!
x = self.fc2(x) # Module 04: Your hidden layer!
x = nn.F.relu(x) # Module 03: Your non-linearity!
return self.fc3(x) # Module 04: Your classification layer!
def main():
# Generate MNIST-like data (real MNIST would use DataLoader)
batch_size, num_samples = 32, 1000
X = np.random.randn(num_samples, 28, 28).astype(np.float32) # 28×28 images
y = np.random.randint(0, 10, (num_samples,)).astype(np.int64) # 10 digit classes
model = MNISTMLP() # Module 04: Your neural network!
optimizer = optim.Adam(model.parameters(), learning_rate=0.001) # Module 06: You built Adam!
loss_fn = CrossEntropyLoss() # Module 05: You built cross-entropy loss!
print("🔢 Training MNIST Digit Classifier")
print(" Architecture: Input(784) → Dense(128) → Dense(64) → Output(10)")
print(f" Parameters: {sum(p.data.size for p in model.parameters())} trainable weights")
print(f" Dataset: {num_samples} handwritten digit images")
print()
# What students built: Complete digit classification pipeline
for epoch in range(10):
total_loss = 0
num_batches = 0
for i in range(0, num_samples, batch_size):
# Mini-batch processing
batch_X = X[i:i+batch_size]
batch_y = y[i:i+batch_size]
inputs = Tensor(batch_X) # Module 02: You built Tensor with gradients!
targets = Tensor(batch_y) # Module 02: Your data structure!
outputs = model(inputs) # Modules 03+04: Your forward pass!
loss = loss_fn(outputs, targets) # Module 05: You built CrossEntropy!
loss.backward() # Module 02: You built autodiff!
optimizer.step() # Module 06: You built Adam updates!
optimizer.zero_grad() # Module 06: Your gradient clearing!
loss_value = loss.data.item() if hasattr(loss.data, 'item') else float(loss.data)
total_loss += loss_value
num_batches += 1
avg_loss = total_loss / num_batches
print(f" Epoch {epoch+1:2d}: Loss = {avg_loss:.4f}")
print("\n✅ Success! MLP trained on digit classification")
print("\n🎯 What You Learned by Building:")
print(" • How dense layers transform high-dimensional inputs")
print(" • Why multiple hidden layers improve representation")
print(" • How cross-entropy loss handles multi-class problems")
print(" • Complete vision pipeline from pixels to predictions")
if __name__ == "__main__":
main()