Files
TinyTorch/tinytorch/core/training.py
Vijay Janapa Reddi 96880b3133 Update tinytorch and tito with module exports
Re-exported all modules after restructuring:
- Updated _modidx.py with new module locations
- Removed outdated autogeneration headers
- Updated all core modules (tensor, autograd, layers, etc.)
- Updated optimization modules (quantization, compression, etc.)
- Updated TITO commands for new structure

Changes include:
- 24 tinytorch/ module files
- 24 tito/ command and core files
- Updated references from modules/source/ to modules/

All modules re-exported via nbdev from their new locations.
2025-11-10 19:42:03 -05:00

416 lines
14 KiB
Python
Generated

# AUTOGENERATED! DO NOT EDIT! File to edit: ../../modules/source/07_training/training_dev.ipynb.
# %% auto 0
__all__ = ['CosineSchedule', 'save_checkpoint', 'load_checkpoint', 'Trainer']
# %% ../../modules/source/07_training/training_dev.ipynb 1
import numpy as np
import pickle
import time
from typing import Dict, List, Optional, Tuple, Any, Callable
from pathlib import Path
import sys
import os
# Import dependencies from other modules
from .tensor import Tensor
from .layers import Linear
from .losses import MSELoss, CrossEntropyLoss
from .optimizers import SGD, AdamW
# %% ../../modules/source/07_training/training_dev.ipynb 6
class CosineSchedule:
"""
Cosine annealing learning rate schedule.
Starts at max_lr, decreases following a cosine curve to min_lr over T epochs.
This provides aggressive learning initially, then fine-tuning at the end.
TODO: Implement cosine annealing schedule
APPROACH:
1. Store max_lr, min_lr, and total_epochs
2. In get_lr(), compute cosine factor: (1 + cos(π * epoch / total_epochs)) / 2
3. Interpolate: min_lr + (max_lr - min_lr) * cosine_factor
EXAMPLE:
>>> schedule = CosineSchedule(max_lr=0.1, min_lr=0.01, total_epochs=100)
>>> print(schedule.get_lr(0)) # Start: 0.1
>>> print(schedule.get_lr(50)) # Middle: ~0.055
>>> print(schedule.get_lr(100)) # End: 0.01
HINT: Use np.cos() and np.pi for the cosine calculation
"""
### BEGIN SOLUTION
def __init__(self, max_lr: float = 0.1, min_lr: float = 0.01, total_epochs: int = 100):
self.max_lr = max_lr
self.min_lr = min_lr
self.total_epochs = total_epochs
def get_lr(self, epoch: int) -> float:
"""Get learning rate for current epoch."""
if epoch >= self.total_epochs:
return self.min_lr
# Cosine annealing formula
cosine_factor = (1 + np.cos(np.pi * epoch / self.total_epochs)) / 2
return self.min_lr + (self.max_lr - self.min_lr) * cosine_factor
### END SOLUTION
# %% ../../modules/source/07_training/training_dev.ipynb 14
def save_checkpoint(checkpoint_dict: Dict[str, Any], path: str):
"""
Save checkpoint dictionary to disk using pickle.
This is a low-level utility for saving model state. Use this when you have
a custom training loop and want to save just what you need (model params,
config, metadata).
For complete training state with optimizer and scheduler, use
Trainer.save_checkpoint() instead.
TODO: Implement checkpoint saving with pickle
APPROACH:
1. Create parent directory if it doesn't exist (Path(path).parent.mkdir)
2. Open file in binary write mode ('wb')
3. Use pickle.dump() to serialize the checkpoint dictionary
4. Print confirmation message
EXAMPLE:
>>> model = SimpleModel()
>>> checkpoint = {
... 'model_params': [p.data.copy() for p in model.parameters()],
... 'config': {'embed_dim': 32, 'num_layers': 2},
... 'metadata': {'final_loss': 0.089, 'training_steps': 5000}
... }
>>> save_checkpoint(checkpoint, 'checkpoints/model.pkl')
✓ Checkpoint saved: checkpoints/model.pkl
HINTS:
- Use Path(path).parent.mkdir(parents=True, exist_ok=True)
- pickle.dump(obj, file) writes the object to file
- Always print a success message so users know it worked
"""
### BEGIN SOLUTION
# Create parent directory if needed
Path(path).parent.mkdir(parents=True, exist_ok=True)
# Save checkpoint using pickle
with open(path, 'wb') as f:
pickle.dump(checkpoint_dict, f)
print(f"✓ Checkpoint saved: {path}")
### END SOLUTION
# %% ../../modules/source/07_training/training_dev.ipynb 15
def load_checkpoint(path: str) -> Dict[str, Any]:
"""
Load checkpoint dictionary from disk using pickle.
Companion function to save_checkpoint(). Restores the checkpoint dictionary
so you can rebuild your model, resume training, or inspect saved metadata.
TODO: Implement checkpoint loading with pickle
APPROACH:
1. Open file in binary read mode ('rb')
2. Use pickle.load() to deserialize the checkpoint
3. Print confirmation message
4. Return the loaded dictionary
EXAMPLE:
>>> checkpoint = load_checkpoint('checkpoints/model.pkl')
✓ Checkpoint loaded: checkpoints/model.pkl
>>> print(checkpoint['metadata']['final_loss'])
0.089
>>> model_params = checkpoint['model_params']
>>> # Now restore model: for param, data in zip(model.parameters(), model_params)...
HINTS:
- pickle.load(file) reads and deserializes the object
- Return the loaded dictionary
- Print a success message for user feedback
"""
### BEGIN SOLUTION
# Load checkpoint using pickle
with open(path, 'rb') as f:
checkpoint = pickle.load(f)
print(f"✓ Checkpoint loaded: {path}")
return checkpoint
### END SOLUTION
# %% ../../modules/source/07_training/training_dev.ipynb 19
class Trainer:
"""
Complete training orchestrator for neural networks.
Handles the full training lifecycle: forward pass, loss computation,
backward pass, optimization, scheduling, checkpointing, and evaluation.
This is the central class that brings together all the components
you've built in previous modules.
TODO: Implement complete Trainer class
APPROACH:
1. Store model, optimizer, loss function, and optional scheduler
2. train_epoch(): Loop through data, compute loss, update parameters
3. evaluate(): Similar loop but without gradient updates
4. save/load_checkpoint(): Persist training state for resumption
DESIGN PATTERNS:
- Context managers for train/eval modes
- Gradient accumulation for effective large batch sizes
- Progress tracking for monitoring
- Flexible scheduling integration
"""
### BEGIN SOLUTION
def __init__(self, model, optimizer, loss_fn, scheduler=None, grad_clip_norm=None):
"""
Initialize trainer with model and training components.
Args:
model: Neural network to train
optimizer: Parameter update strategy (SGD, Adam, etc.)
loss_fn: Loss function (CrossEntropy, MSE, etc.)
scheduler: Optional learning rate scheduler
grad_clip_norm: Optional gradient clipping threshold
"""
self.model = model
self.optimizer = optimizer
self.loss_fn = loss_fn
self.scheduler = scheduler
self.grad_clip_norm = grad_clip_norm
# Training state
self.epoch = 0
self.step = 0
self.training_mode = True
# History tracking
self.history = {
'train_loss': [],
'eval_loss': [],
'learning_rates': []
}
def train_epoch(self, dataloader, accumulation_steps=1):
"""
Train for one epoch through the dataset.
Args:
dataloader: Iterable yielding (inputs, targets) batches
accumulation_steps: Number of batches to accumulate before update
Returns:
Average loss for the epoch
"""
self.model.training = True
self.training_mode = True
total_loss = 0.0
num_batches = 0
accumulated_loss = 0.0
for batch_idx, (inputs, targets) in enumerate(dataloader):
# Forward pass
outputs = self.model.forward(inputs)
loss = self.loss_fn.forward(outputs, targets)
# Scale loss for accumulation
scaled_loss = loss.data / accumulation_steps
accumulated_loss += scaled_loss
# Backward pass
if hasattr(loss, 'backward'):
loss.backward()
# Update parameters every accumulation_steps
if (batch_idx + 1) % accumulation_steps == 0:
# Gradient clipping
if self.grad_clip_norm is not None:
params = []
if hasattr(self.model, 'parameters'):
params = self.model.parameters()
clip_grad_norm(params, self.grad_clip_norm)
# Optimizer step
self.optimizer.step()
self.optimizer.zero_grad()
total_loss += accumulated_loss
accumulated_loss = 0.0
num_batches += 1
self.step += 1
# Handle remaining accumulated gradients
if accumulated_loss > 0:
if self.grad_clip_norm is not None:
params = []
if hasattr(self.model, 'parameters'):
params = self.model.parameters()
clip_grad_norm(params, self.grad_clip_norm)
self.optimizer.step()
self.optimizer.zero_grad()
total_loss += accumulated_loss
num_batches += 1
avg_loss = total_loss / max(num_batches, 1)
self.history['train_loss'].append(avg_loss)
# Update scheduler
if self.scheduler is not None:
current_lr = self.scheduler.get_lr(self.epoch)
# Update optimizer learning rate
if hasattr(self.optimizer, 'lr'):
self.optimizer.lr = current_lr
self.history['learning_rates'].append(current_lr)
self.epoch += 1
return avg_loss
def evaluate(self, dataloader):
"""
Evaluate model on dataset without updating parameters.
Args:
dataloader: Iterable yielding (inputs, targets) batches
Returns:
Average loss and accuracy
"""
self.model.training = False
self.training_mode = False
total_loss = 0.0
correct = 0
total = 0
for inputs, targets in dataloader:
# Forward pass only
outputs = self.model.forward(inputs)
loss = self.loss_fn.forward(outputs, targets)
total_loss += loss.data
# Calculate accuracy (for classification)
if hasattr(outputs, 'data') and hasattr(targets, 'data'):
if len(outputs.data.shape) > 1: # Multi-class
predictions = np.argmax(outputs.data, axis=1)
if len(targets.data.shape) == 1: # Integer targets
correct += np.sum(predictions == targets.data)
else: # One-hot targets
correct += np.sum(predictions == np.argmax(targets.data, axis=1))
total += len(predictions)
avg_loss = total_loss / len(dataloader) if len(dataloader) > 0 else 0.0
accuracy = correct / total if total > 0 else 0.0
self.history['eval_loss'].append(avg_loss)
return avg_loss, accuracy
def save_checkpoint(self, path: str):
"""
Save complete training state for resumption.
This high-level method saves everything needed to resume training:
model parameters, optimizer state, scheduler state, and training history.
Uses the low-level save_checkpoint() function internally.
Args:
path: File path to save checkpoint
"""
checkpoint = {
'epoch': self.epoch,
'step': self.step,
'model_state': self._get_model_state(),
'optimizer_state': self._get_optimizer_state(),
'scheduler_state': self._get_scheduler_state(),
'history': self.history,
'training_mode': self.training_mode
}
# Use the standalone save_checkpoint function
save_checkpoint(checkpoint, path)
def load_checkpoint(self, path: str):
"""
Load training state from checkpoint.
This high-level method restores complete training state including
model parameters, optimizer state, scheduler state, and history.
Uses the low-level load_checkpoint() function internally.
Args:
path: File path to load checkpoint from
"""
# Use the standalone load_checkpoint function
checkpoint = load_checkpoint(path)
self.epoch = checkpoint['epoch']
self.step = checkpoint['step']
self.history = checkpoint['history']
self.training_mode = checkpoint['training_mode']
# Restore states (simplified for educational purposes)
if 'model_state' in checkpoint:
self._set_model_state(checkpoint['model_state'])
if 'optimizer_state' in checkpoint:
self._set_optimizer_state(checkpoint['optimizer_state'])
if 'scheduler_state' in checkpoint:
self._set_scheduler_state(checkpoint['scheduler_state'])
def _get_model_state(self):
"""Extract model parameters for checkpointing."""
if hasattr(self.model, 'parameters'):
return {i: param.data.copy() for i, param in enumerate(self.model.parameters())}
return {}
def _set_model_state(self, state):
"""Restore model parameters from checkpoint."""
if hasattr(self.model, 'parameters'):
for i, param in enumerate(self.model.parameters()):
if i in state:
param.data = state[i].copy()
def _get_optimizer_state(self):
"""Extract optimizer state for checkpointing."""
state = {}
if hasattr(self.optimizer, 'lr'):
state['lr'] = self.optimizer.lr
if hasattr(self.optimizer, 'momentum_buffers'):
state['momentum_buffers'] = self.optimizer.momentum_buffers.copy()
return state
def _set_optimizer_state(self, state):
"""Restore optimizer state from checkpoint."""
if 'lr' in state and hasattr(self.optimizer, 'lr'):
self.optimizer.lr = state['lr']
if 'momentum_buffers' in state and hasattr(self.optimizer, 'momentum_buffers'):
self.optimizer.momentum_buffers = state['momentum_buffers']
def _get_scheduler_state(self):
"""Extract scheduler state for checkpointing."""
if self.scheduler is None:
return None
return {
'max_lr': getattr(self.scheduler, 'max_lr', None),
'min_lr': getattr(self.scheduler, 'min_lr', None),
'total_epochs': getattr(self.scheduler, 'total_epochs', None)
}
def _set_scheduler_state(self, state):
"""Restore scheduler state from checkpoint."""
if state is None or self.scheduler is None:
return
for key, value in state.items():
if hasattr(self.scheduler, key):
setattr(self.scheduler, key, value)
### END SOLUTION