mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-05-27 10:56:02 -05:00
337 lines
13 KiB
Python
Generated
337 lines
13 KiB
Python
Generated
# ╔═══════════════════════════════════════════════════════════════════════════════╗
|
|
# ║ 🚨 CRITICAL WARNING 🚨 ║
|
|
# ║ AUTOGENERATED! DO NOT EDIT! ║
|
|
# ║ ║
|
|
# ║ This file is AUTOMATICALLY GENERATED from source modules. ║
|
|
# ║ ANY CHANGES MADE HERE WILL BE LOST when modules are re-exported! ║
|
|
# ║ ║
|
|
# ║ ✅ TO EDIT: modules/source/11_training/training_dev.py ║
|
|
# ║ ✅ TO EXPORT: Run 'tito module complete <module_name>' ║
|
|
# ║ ║
|
|
# ║ 🛡️ STUDENT PROTECTION: This file contains optimized implementations. ║
|
|
# ║ Editing it directly may break module functionality and training. ║
|
|
# ║ ║
|
|
# ║ 🎓 LEARNING TIP: Work in modules/source/ - that's where real development ║
|
|
# ║ happens! The tinytorch/ directory is just the compiled output. ║
|
|
# ╚═══════════════════════════════════════════════════════════════════════════════╝
|
|
# %% auto 0
|
|
__all__ = ['CosineSchedule', 'Trainer']
|
|
|
|
# %% ../../modules/source/07_training/training_dev.ipynb 1
|
|
import numpy as np
|
|
import pickle
|
|
import time
|
|
from typing import Dict, List, Optional, Tuple, Any, Callable
|
|
from pathlib import Path
|
|
import sys
|
|
import os
|
|
|
|
# Import dependencies from other modules
|
|
from .tensor import Tensor
|
|
from .layers import Linear
|
|
from .losses import MSELoss, CrossEntropyLoss
|
|
from .optimizers import SGD, AdamW
|
|
|
|
# %% ../../modules/source/07_training/training_dev.ipynb 6
|
|
class CosineSchedule:
|
|
"""
|
|
Cosine annealing learning rate schedule.
|
|
|
|
Starts at max_lr, decreases following a cosine curve to min_lr over T epochs.
|
|
This provides aggressive learning initially, then fine-tuning at the end.
|
|
|
|
TODO: Implement cosine annealing schedule
|
|
|
|
APPROACH:
|
|
1. Store max_lr, min_lr, and total_epochs
|
|
2. In get_lr(), compute cosine factor: (1 + cos(π * epoch / total_epochs)) / 2
|
|
3. Interpolate: min_lr + (max_lr - min_lr) * cosine_factor
|
|
|
|
EXAMPLE:
|
|
>>> schedule = CosineSchedule(max_lr=0.1, min_lr=0.01, total_epochs=100)
|
|
>>> print(schedule.get_lr(0)) # Start: 0.1
|
|
>>> print(schedule.get_lr(50)) # Middle: ~0.055
|
|
>>> print(schedule.get_lr(100)) # End: 0.01
|
|
|
|
HINT: Use np.cos() and np.pi for the cosine calculation
|
|
"""
|
|
### BEGIN SOLUTION
|
|
def __init__(self, max_lr: float = 0.1, min_lr: float = 0.01, total_epochs: int = 100):
|
|
self.max_lr = max_lr
|
|
self.min_lr = min_lr
|
|
self.total_epochs = total_epochs
|
|
|
|
def get_lr(self, epoch: int) -> float:
|
|
"""Get learning rate for current epoch."""
|
|
if epoch >= self.total_epochs:
|
|
return self.min_lr
|
|
|
|
# Cosine annealing formula
|
|
cosine_factor = (1 + np.cos(np.pi * epoch / self.total_epochs)) / 2
|
|
return self.min_lr + (self.max_lr - self.min_lr) * cosine_factor
|
|
### END SOLUTION
|
|
|
|
# %% ../../modules/source/07_training/training_dev.ipynb 14
|
|
class Trainer:
|
|
"""
|
|
Complete training orchestrator for neural networks.
|
|
|
|
Handles the full training lifecycle: forward pass, loss computation,
|
|
backward pass, optimization, scheduling, checkpointing, and evaluation.
|
|
|
|
This is the central class that brings together all the components
|
|
you've built in previous modules.
|
|
|
|
TODO: Implement complete Trainer class
|
|
|
|
APPROACH:
|
|
1. Store model, optimizer, loss function, and optional scheduler
|
|
2. train_epoch(): Loop through data, compute loss, update parameters
|
|
3. evaluate(): Similar loop but without gradient updates
|
|
4. save/load_checkpoint(): Persist training state for resumption
|
|
|
|
DESIGN PATTERNS:
|
|
- Context managers for train/eval modes
|
|
- Gradient accumulation for effective large batch sizes
|
|
- Progress tracking for monitoring
|
|
- Flexible scheduling integration
|
|
"""
|
|
### BEGIN SOLUTION
|
|
def __init__(self, model, optimizer, loss_fn, scheduler=None, grad_clip_norm=None):
|
|
"""
|
|
Initialize trainer with model and training components.
|
|
|
|
Args:
|
|
model: Neural network to train
|
|
optimizer: Parameter update strategy (SGD, Adam, etc.)
|
|
loss_fn: Loss function (CrossEntropy, MSE, etc.)
|
|
scheduler: Optional learning rate scheduler
|
|
grad_clip_norm: Optional gradient clipping threshold
|
|
"""
|
|
self.model = model
|
|
self.optimizer = optimizer
|
|
self.loss_fn = loss_fn
|
|
self.scheduler = scheduler
|
|
self.grad_clip_norm = grad_clip_norm
|
|
|
|
# Training state
|
|
self.epoch = 0
|
|
self.step = 0
|
|
self.training_mode = True
|
|
|
|
# History tracking
|
|
self.history = {
|
|
'train_loss': [],
|
|
'eval_loss': [],
|
|
'learning_rates': []
|
|
}
|
|
|
|
def train_epoch(self, dataloader, accumulation_steps=1):
|
|
"""
|
|
Train for one epoch through the dataset.
|
|
|
|
Args:
|
|
dataloader: Iterable yielding (inputs, targets) batches
|
|
accumulation_steps: Number of batches to accumulate before update
|
|
|
|
Returns:
|
|
Average loss for the epoch
|
|
"""
|
|
self.model.training = True
|
|
self.training_mode = True
|
|
|
|
total_loss = 0.0
|
|
num_batches = 0
|
|
accumulated_loss = 0.0
|
|
|
|
for batch_idx, (inputs, targets) in enumerate(dataloader):
|
|
# Forward pass
|
|
outputs = self.model.forward(inputs)
|
|
loss = self.loss_fn.forward(outputs, targets)
|
|
|
|
# Scale loss for accumulation
|
|
scaled_loss = loss.data / accumulation_steps
|
|
accumulated_loss += scaled_loss
|
|
|
|
# Backward pass
|
|
if hasattr(loss, 'backward'):
|
|
loss.backward()
|
|
|
|
# Update parameters every accumulation_steps
|
|
if (batch_idx + 1) % accumulation_steps == 0:
|
|
# Gradient clipping
|
|
if self.grad_clip_norm is not None:
|
|
params = []
|
|
if hasattr(self.model, 'parameters'):
|
|
params = self.model.parameters()
|
|
clip_grad_norm(params, self.grad_clip_norm)
|
|
|
|
# Optimizer step
|
|
self.optimizer.step()
|
|
self.optimizer.zero_grad()
|
|
|
|
total_loss += accumulated_loss
|
|
accumulated_loss = 0.0
|
|
num_batches += 1
|
|
self.step += 1
|
|
|
|
# Handle remaining accumulated gradients
|
|
if accumulated_loss > 0:
|
|
if self.grad_clip_norm is not None:
|
|
params = []
|
|
if hasattr(self.model, 'parameters'):
|
|
params = self.model.parameters()
|
|
clip_grad_norm(params, self.grad_clip_norm)
|
|
|
|
self.optimizer.step()
|
|
self.optimizer.zero_grad()
|
|
total_loss += accumulated_loss
|
|
num_batches += 1
|
|
|
|
avg_loss = total_loss / max(num_batches, 1)
|
|
self.history['train_loss'].append(avg_loss)
|
|
|
|
# Update scheduler
|
|
if self.scheduler is not None:
|
|
current_lr = self.scheduler.get_lr(self.epoch)
|
|
# Update optimizer learning rate
|
|
if hasattr(self.optimizer, 'lr'):
|
|
self.optimizer.lr = current_lr
|
|
self.history['learning_rates'].append(current_lr)
|
|
|
|
self.epoch += 1
|
|
return avg_loss
|
|
|
|
def evaluate(self, dataloader):
|
|
"""
|
|
Evaluate model on dataset without updating parameters.
|
|
|
|
Args:
|
|
dataloader: Iterable yielding (inputs, targets) batches
|
|
|
|
Returns:
|
|
Average loss and accuracy
|
|
"""
|
|
self.model.training = False
|
|
self.training_mode = False
|
|
|
|
total_loss = 0.0
|
|
correct = 0
|
|
total = 0
|
|
|
|
for inputs, targets in dataloader:
|
|
# Forward pass only
|
|
outputs = self.model.forward(inputs)
|
|
loss = self.loss_fn.forward(outputs, targets)
|
|
|
|
total_loss += loss.data
|
|
|
|
# Calculate accuracy (for classification)
|
|
if hasattr(outputs, 'data') and hasattr(targets, 'data'):
|
|
if len(outputs.data.shape) > 1: # Multi-class
|
|
predictions = np.argmax(outputs.data, axis=1)
|
|
if len(targets.data.shape) == 1: # Integer targets
|
|
correct += np.sum(predictions == targets.data)
|
|
else: # One-hot targets
|
|
correct += np.sum(predictions == np.argmax(targets.data, axis=1))
|
|
total += len(predictions)
|
|
|
|
avg_loss = total_loss / len(dataloader) if len(dataloader) > 0 else 0.0
|
|
accuracy = correct / total if total > 0 else 0.0
|
|
|
|
self.history['eval_loss'].append(avg_loss)
|
|
|
|
return avg_loss, accuracy
|
|
|
|
def save_checkpoint(self, path: str):
|
|
"""
|
|
Save complete training state for resumption.
|
|
|
|
Args:
|
|
path: File path to save checkpoint
|
|
"""
|
|
checkpoint = {
|
|
'epoch': self.epoch,
|
|
'step': self.step,
|
|
'model_state': self._get_model_state(),
|
|
'optimizer_state': self._get_optimizer_state(),
|
|
'scheduler_state': self._get_scheduler_state(),
|
|
'history': self.history,
|
|
'training_mode': self.training_mode
|
|
}
|
|
|
|
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
|
with open(path, 'wb') as f:
|
|
pickle.dump(checkpoint, f)
|
|
|
|
def load_checkpoint(self, path: str):
|
|
"""
|
|
Load training state from checkpoint.
|
|
|
|
Args:
|
|
path: File path to load checkpoint from
|
|
"""
|
|
with open(path, 'rb') as f:
|
|
checkpoint = pickle.load(f)
|
|
|
|
self.epoch = checkpoint['epoch']
|
|
self.step = checkpoint['step']
|
|
self.history = checkpoint['history']
|
|
self.training_mode = checkpoint['training_mode']
|
|
|
|
# Restore states (simplified for educational purposes)
|
|
if 'model_state' in checkpoint:
|
|
self._set_model_state(checkpoint['model_state'])
|
|
if 'optimizer_state' in checkpoint:
|
|
self._set_optimizer_state(checkpoint['optimizer_state'])
|
|
if 'scheduler_state' in checkpoint:
|
|
self._set_scheduler_state(checkpoint['scheduler_state'])
|
|
|
|
def _get_model_state(self):
|
|
"""Extract model parameters for checkpointing."""
|
|
if hasattr(self.model, 'parameters'):
|
|
return {i: param.data.copy() for i, param in enumerate(self.model.parameters())}
|
|
return {}
|
|
|
|
def _set_model_state(self, state):
|
|
"""Restore model parameters from checkpoint."""
|
|
if hasattr(self.model, 'parameters'):
|
|
for i, param in enumerate(self.model.parameters()):
|
|
if i in state:
|
|
param.data = state[i].copy()
|
|
|
|
def _get_optimizer_state(self):
|
|
"""Extract optimizer state for checkpointing."""
|
|
state = {}
|
|
if hasattr(self.optimizer, 'lr'):
|
|
state['lr'] = self.optimizer.lr
|
|
if hasattr(self.optimizer, 'momentum_buffers'):
|
|
state['momentum_buffers'] = self.optimizer.momentum_buffers.copy()
|
|
return state
|
|
|
|
def _set_optimizer_state(self, state):
|
|
"""Restore optimizer state from checkpoint."""
|
|
if 'lr' in state and hasattr(self.optimizer, 'lr'):
|
|
self.optimizer.lr = state['lr']
|
|
if 'momentum_buffers' in state and hasattr(self.optimizer, 'momentum_buffers'):
|
|
self.optimizer.momentum_buffers = state['momentum_buffers']
|
|
|
|
def _get_scheduler_state(self):
|
|
"""Extract scheduler state for checkpointing."""
|
|
if self.scheduler is None:
|
|
return None
|
|
return {
|
|
'max_lr': getattr(self.scheduler, 'max_lr', None),
|
|
'min_lr': getattr(self.scheduler, 'min_lr', None),
|
|
'total_epochs': getattr(self.scheduler, 'total_epochs', None)
|
|
}
|
|
|
|
def _set_scheduler_state(self, state):
|
|
"""Restore scheduler state from checkpoint."""
|
|
if state is None or self.scheduler is None:
|
|
return
|
|
for key, value in state.items():
|
|
if hasattr(self.scheduler, key):
|
|
setattr(self.scheduler, key, value)
|
|
### END SOLUTION
|