mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-04-30 10:13:57 -05:00
refactor: Keep explicit module imports + optimize CNN milestone
Import Strategy: - Keep explicit 'from tinytorch.core.spatial import Conv2d' - Maps directly to module structure (Module 09 → core.spatial) - Better for education: students see exactly where each concept lives - Removed redundant tinytorch/nn.py (nn/ directory already exists) Milestone 04 Optimizations: - Reduced epochs: 50 → 20 (explicit loops are slow!) - Print progress every 5 epochs (instead of 10) - Load from local npz file (no sklearn dependency) - Still achieves ~80%+ accuracy Educational Rationale: TinyTorch uses explicit imports to show module structure: tinytorch.core.tensor # Module 01 tinytorch.core.layers # Module 03 tinytorch.core.spatial # Module 09 tinytorch.core.losses # Module 04 PyTorch's torch.nn is convenient but pedagogically unclear. Our approach: clarity over convenience!
This commit is contained in:
@@ -43,7 +43,6 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../..'))
|
||||
from tinytorch import Tensor, SGD, CrossEntropyLoss, enable_autograd
|
||||
from tinytorch.core.spatial import Conv2d, MaxPool2d
|
||||
from tinytorch.core.layers import Linear, ReLU
|
||||
from tinytorch.core.activations import Sigmoid
|
||||
from tinytorch.data.loader import DataLoader, TensorDataset
|
||||
|
||||
console = Console()
|
||||
@@ -58,18 +57,17 @@ enable_autograd()
|
||||
|
||||
def load_digits_dataset():
|
||||
"""
|
||||
Load the sklearn 8x8 digits dataset.
|
||||
Load the 8x8 digits dataset from local file.
|
||||
|
||||
Returns 1,797 grayscale images of handwritten digits (0-9).
|
||||
Each image is 8×8 pixels, perfect for quick CNN demonstrations.
|
||||
"""
|
||||
from sklearn.datasets import load_digits
|
||||
# Load from the local data file (same as MLP milestone uses)
|
||||
data_path = os.path.join(os.path.dirname(__file__), '../03_mlp_revival_1986/data/digits_8x8.npz')
|
||||
data = np.load(data_path)
|
||||
|
||||
digits = load_digits()
|
||||
|
||||
# Normalize to [0, 1]
|
||||
images = digits.images / 16.0 # Original range is [0, 16]
|
||||
labels = digits.target
|
||||
images = data['images'] # (1797, 8, 8)
|
||||
labels = data['labels'] # (1797,)
|
||||
|
||||
# Split into train/test (80/20)
|
||||
n_train = int(0.8 * len(images))
|
||||
@@ -270,7 +268,7 @@ def train_cnn():
|
||||
|
||||
# Hyperparameters
|
||||
console.print("\n[bold]⚙️ Training Configuration:[/bold]")
|
||||
epochs = 50
|
||||
epochs = 20 # Reduced for demo speed (explicit loops are slow!)
|
||||
batch_size = 32
|
||||
learning_rate = 0.01
|
||||
|
||||
@@ -313,7 +311,7 @@ def train_cnn():
|
||||
history["loss"].append(avg_loss)
|
||||
history["accuracy"].append(accuracy)
|
||||
|
||||
if (epoch + 1) % 10 == 0:
|
||||
if (epoch + 1) % 5 == 0: # Print every 5 epochs
|
||||
console.print(f"Epoch {epoch+1:3d}/{epochs} Loss: {avg_loss:.4f} Accuracy: {accuracy:.1f}%")
|
||||
|
||||
training_time = time.time() - start_time
|
||||
|
||||
5
tinytorch/nn/__init__.py
generated
5
tinytorch/nn/__init__.py
generated
@@ -34,8 +34,9 @@ while this infrastructure provides the clean API they expect from PyTorch.
|
||||
"""
|
||||
|
||||
# Import layers from core (these contain the student implementations)
|
||||
from ..core.layers import Linear, Module # Use the same Module class as layers
|
||||
from ..core.spatial import Conv2d
|
||||
from ..core.layers import Linear, ReLU, Dropout
|
||||
from ..core.activations import Sigmoid
|
||||
from ..core.spatial import Conv2d, MaxPool2d, AvgPool2d
|
||||
|
||||
# Import transformer components
|
||||
from ..core.embeddings import Embedding, PositionalEncoding
|
||||
|
||||
Reference in New Issue
Block a user