mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-04-30 00:47:31 -05:00
Reset package and export modules 01-07 only (skip broken spatial module)
This commit is contained in:
18
tinytorch/core/activations.py
generated
18
tinytorch/core/activations.py
generated
@@ -1,5 +1,19 @@
|
||||
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../modules/source/02_activations/activations_dev.ipynb.
|
||||
|
||||
# ╔═══════════════════════════════════════════════════════════════════════════════╗
|
||||
# ║ 🚨 CRITICAL WARNING 🚨 ║
|
||||
# ║ AUTOGENERATED! DO NOT EDIT! ║
|
||||
# ║ ║
|
||||
# ║ This file is AUTOMATICALLY GENERATED from source modules. ║
|
||||
# ║ ANY CHANGES MADE HERE WILL BE LOST when modules are re-exported! ║
|
||||
# ║ ║
|
||||
# ║ ✅ TO EDIT: modules/source/03_activations/activations_dev.py ║
|
||||
# ║ ✅ TO EXPORT: Run 'tito module complete <module_name>' ║
|
||||
# ║ ║
|
||||
# ║ 🛡️ STUDENT PROTECTION: This file contains optimized implementations. ║
|
||||
# ║ Editing it directly may break module functionality and training. ║
|
||||
# ║ ║
|
||||
# ║ 🎓 LEARNING TIP: Work in modules/source/ - that's where real development ║
|
||||
# ║ happens! The tinytorch/ directory is just the compiled output. ║
|
||||
# ╚═══════════════════════════════════════════════════════════════════════════════╝
|
||||
# %% auto 0
|
||||
__all__ = ['Sigmoid', 'ReLU', 'Tanh', 'GELU', 'Softmax']
|
||||
|
||||
|
||||
8
tinytorch/core/attention.py
generated
8
tinytorch/core/attention.py
generated
@@ -1,8 +0,0 @@
|
||||
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../modules/source/12_attention/attention_dev.ipynb.
|
||||
|
||||
# %% auto 0
|
||||
__all__ = []
|
||||
|
||||
# %% ../../modules/source/12_attention/attention_dev.ipynb 0
|
||||
#| default_exp core.attention
|
||||
#| export
|
||||
22
tinytorch/core/layers.py
generated
22
tinytorch/core/layers.py
generated
@@ -1,5 +1,19 @@
|
||||
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../modules/source/03_layers/layers_dev.ipynb.
|
||||
|
||||
# ╔═══════════════════════════════════════════════════════════════════════════════╗
|
||||
# ║ 🚨 CRITICAL WARNING 🚨 ║
|
||||
# ║ AUTOGENERATED! DO NOT EDIT! ║
|
||||
# ║ ║
|
||||
# ║ This file is AUTOMATICALLY GENERATED from source modules. ║
|
||||
# ║ ANY CHANGES MADE HERE WILL BE LOST when modules are re-exported! ║
|
||||
# ║ ║
|
||||
# ║ ✅ TO EDIT: modules/source/04_layers/layers_dev.py ║
|
||||
# ║ ✅ TO EXPORT: Run 'tito module complete <module_name>' ║
|
||||
# ║ ║
|
||||
# ║ 🛡️ STUDENT PROTECTION: This file contains optimized implementations. ║
|
||||
# ║ Editing it directly may break module functionality and training. ║
|
||||
# ║ ║
|
||||
# ║ 🎓 LEARNING TIP: Work in modules/source/ - that's where real development ║
|
||||
# ║ happens! The tinytorch/ directory is just the compiled output. ║
|
||||
# ╚═══════════════════════════════════════════════════════════════════════════════╝
|
||||
# %% auto 0
|
||||
__all__ = ['Linear', 'Dropout']
|
||||
|
||||
@@ -194,6 +208,10 @@ class Dropout:
|
||||
return Tensor(output_data)
|
||||
### END SOLUTION
|
||||
|
||||
def __call__(self, x, training=True):
|
||||
"""Allows the layer to be called like a function."""
|
||||
return self.forward(x, training)
|
||||
|
||||
def parameters(self):
|
||||
"""Dropout has no parameters."""
|
||||
return []
|
||||
|
||||
18
tinytorch/core/losses.py
generated
18
tinytorch/core/losses.py
generated
@@ -1,5 +1,19 @@
|
||||
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../modules/source/04_losses/losses_dev.ipynb.
|
||||
|
||||
# ╔═══════════════════════════════════════════════════════════════════════════════╗
|
||||
# ║ 🚨 CRITICAL WARNING 🚨 ║
|
||||
# ║ AUTOGENERATED! DO NOT EDIT! ║
|
||||
# ║ ║
|
||||
# ║ This file is AUTOMATICALLY GENERATED from source modules. ║
|
||||
# ║ ANY CHANGES MADE HERE WILL BE LOST when modules are re-exported! ║
|
||||
# ║ ║
|
||||
# ║ ✅ TO EDIT: modules/source/XX_losses/losses_dev.py ║
|
||||
# ║ ✅ TO EXPORT: Run 'tito module complete <module_name>' ║
|
||||
# ║ ║
|
||||
# ║ 🛡️ STUDENT PROTECTION: This file contains optimized implementations. ║
|
||||
# ║ Editing it directly may break module functionality and training. ║
|
||||
# ║ ║
|
||||
# ║ 🎓 LEARNING TIP: Work in modules/source/ - that's where real development ║
|
||||
# ║ happens! The tinytorch/ directory is just the compiled output. ║
|
||||
# ╚═══════════════════════════════════════════════════════════════════════════════╝
|
||||
# %% auto 0
|
||||
__all__ = ['import_previous_module', 'MSELoss', 'CrossEntropyLoss', 'BinaryCrossEntropyLoss']
|
||||
|
||||
|
||||
53
tinytorch/core/optimizers.py
generated
53
tinytorch/core/optimizers.py
generated
@@ -1,5 +1,19 @@
|
||||
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../modules/source/06_optimizers/optimizers_dev.ipynb.
|
||||
|
||||
# ╔═══════════════════════════════════════════════════════════════════════════════╗
|
||||
# ║ 🚨 CRITICAL WARNING 🚨 ║
|
||||
# ║ AUTOGENERATED! DO NOT EDIT! ║
|
||||
# ║ ║
|
||||
# ║ This file is AUTOMATICALLY GENERATED from source modules. ║
|
||||
# ║ ANY CHANGES MADE HERE WILL BE LOST when modules are re-exported! ║
|
||||
# ║ ║
|
||||
# ║ ✅ TO EDIT: modules/source/10_optimizers/optimizers_dev.py ║
|
||||
# ║ ✅ TO EXPORT: Run 'tito module complete <module_name>' ║
|
||||
# ║ ║
|
||||
# ║ 🛡️ STUDENT PROTECTION: This file contains optimized implementations. ║
|
||||
# ║ Editing it directly may break module functionality and training. ║
|
||||
# ║ ║
|
||||
# ║ 🎓 LEARNING TIP: Work in modules/source/ - that's where real development ║
|
||||
# ║ happens! The tinytorch/ directory is just the compiled output. ║
|
||||
# ╚═══════════════════════════════════════════════════════════════════════════════╝
|
||||
# %% auto 0
|
||||
__all__ = ['Optimizer', 'SGD', 'Adam', 'AdamW']
|
||||
|
||||
@@ -7,10 +21,10 @@ __all__ = ['Optimizer', 'SGD', 'Adam', 'AdamW']
|
||||
import numpy as np
|
||||
from typing import List, Union, Optional, Dict, Any
|
||||
|
||||
# Import Tensor from Module 01
|
||||
from tinytorch.core.tensor import Tensor
|
||||
# Import Tensor from Module 01 (now with gradient support from Module 05)
|
||||
from .tensor import Tensor
|
||||
|
||||
# %% Base Optimizer class
|
||||
# %% ../../modules/source/06_optimizers/optimizers_dev.ipynb 5
|
||||
class Optimizer:
|
||||
"""
|
||||
Base class for all optimizers.
|
||||
@@ -37,6 +51,7 @@ class Optimizer:
|
||||
|
||||
HINT: Check that each parameter has requires_grad=True
|
||||
"""
|
||||
### BEGIN SOLUTION
|
||||
# Validate and store parameters
|
||||
if not isinstance(params, list):
|
||||
params = list(params)
|
||||
@@ -50,6 +65,7 @@ class Optimizer:
|
||||
|
||||
self.params = params
|
||||
self.step_count = 0 # For algorithms that need step counting
|
||||
### END SOLUTION
|
||||
|
||||
def zero_grad(self):
|
||||
"""
|
||||
@@ -67,8 +83,10 @@ class Optimizer:
|
||||
|
||||
WHY: Gradients accumulate by default, so we need to clear them between batches
|
||||
"""
|
||||
### BEGIN SOLUTION
|
||||
for param in self.params:
|
||||
param.grad = None
|
||||
### END SOLUTION
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
@@ -78,9 +96,7 @@ class Optimizer:
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement step()")
|
||||
|
||||
|
||||
|
||||
# %% SGD Optimizer
|
||||
# %% ../../modules/source/06_optimizers/optimizers_dev.ipynb 9
|
||||
class SGD(Optimizer):
|
||||
"""
|
||||
Stochastic Gradient Descent with momentum.
|
||||
@@ -108,6 +124,7 @@ class SGD(Optimizer):
|
||||
- Momentum buffers should be initialized as None
|
||||
- They'll be created lazily on first step
|
||||
"""
|
||||
### BEGIN SOLUTION
|
||||
super().__init__(params)
|
||||
|
||||
self.lr = lr
|
||||
@@ -116,6 +133,7 @@ class SGD(Optimizer):
|
||||
|
||||
# Initialize momentum buffers (created lazily)
|
||||
self.momentum_buffers = [None for _ in self.params]
|
||||
### END SOLUTION
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
@@ -139,6 +157,7 @@ class SGD(Optimizer):
|
||||
- Initialize momentum buffers on first use
|
||||
- Use in-place operations to save memory
|
||||
"""
|
||||
### BEGIN SOLUTION
|
||||
for i, param in enumerate(self.params):
|
||||
if param.grad is None:
|
||||
continue
|
||||
@@ -165,10 +184,9 @@ class SGD(Optimizer):
|
||||
|
||||
# Increment step counter
|
||||
self.step_count += 1
|
||||
### END SOLUTION
|
||||
|
||||
|
||||
|
||||
# %% Adam Optimizer
|
||||
# %% ../../modules/source/06_optimizers/optimizers_dev.ipynb 13
|
||||
class Adam(Optimizer):
|
||||
"""
|
||||
Adam optimizer with adaptive learning rates.
|
||||
@@ -198,6 +216,7 @@ class Adam(Optimizer):
|
||||
EXAMPLE:
|
||||
>>> optimizer = Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))
|
||||
"""
|
||||
### BEGIN SOLUTION
|
||||
super().__init__(params)
|
||||
|
||||
self.lr = lr
|
||||
@@ -208,6 +227,7 @@ class Adam(Optimizer):
|
||||
# Initialize moment buffers (created lazily)
|
||||
self.m_buffers = [None for _ in self.params] # First moment (mean)
|
||||
self.v_buffers = [None for _ in self.params] # Second moment (variance)
|
||||
### END SOLUTION
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
@@ -235,6 +255,7 @@ class Adam(Optimizer):
|
||||
- Use step_count for bias correction
|
||||
- Square gradients element-wise for second moment
|
||||
"""
|
||||
### BEGIN SOLUTION
|
||||
# Increment step counter first (needed for bias correction)
|
||||
self.step_count += 1
|
||||
|
||||
@@ -270,10 +291,9 @@ class Adam(Optimizer):
|
||||
|
||||
# Update parameter
|
||||
param.data = param.data - self.lr * m_hat / (np.sqrt(v_hat) + self.eps)
|
||||
### END SOLUTION
|
||||
|
||||
|
||||
|
||||
# %% AdamW Optimizer
|
||||
# %% ../../modules/source/06_optimizers/optimizers_dev.ipynb 17
|
||||
class AdamW(Optimizer):
|
||||
"""
|
||||
AdamW optimizer with decoupled weight decay.
|
||||
@@ -301,6 +321,7 @@ class AdamW(Optimizer):
|
||||
EXAMPLE:
|
||||
>>> optimizer = AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
|
||||
"""
|
||||
### BEGIN SOLUTION
|
||||
super().__init__(params)
|
||||
|
||||
self.lr = lr
|
||||
@@ -311,6 +332,7 @@ class AdamW(Optimizer):
|
||||
# Initialize moment buffers (same as Adam)
|
||||
self.m_buffers = [None for _ in self.params]
|
||||
self.v_buffers = [None for _ in self.params]
|
||||
### END SOLUTION
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
@@ -336,6 +358,7 @@ class AdamW(Optimizer):
|
||||
|
||||
HINT: Apply weight decay after gradient update for proper decoupling
|
||||
"""
|
||||
### BEGIN SOLUTION
|
||||
# Increment step counter first
|
||||
self.step_count += 1
|
||||
|
||||
@@ -369,4 +392,4 @@ class AdamW(Optimizer):
|
||||
# Apply decoupled weight decay
|
||||
if self.weight_decay != 0:
|
||||
param.data = param.data * (1 - self.lr * self.weight_decay)
|
||||
|
||||
### END SOLUTION
|
||||
|
||||
64
tinytorch/core/spatial.py
generated
64
tinytorch/core/spatial.py
generated
@@ -1,64 +0,0 @@
|
||||
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../modules/source/09_spatial/spatial_dev.ipynb.
|
||||
|
||||
# %% auto 0
|
||||
__all__ = []
|
||||
|
||||
# %% ../../modules/source/09_spatial/spatial_dev.ipynb 1
|
||||
import numpy as np
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
|
||||
# Import dependencies from other modules
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '01_tensor'))
|
||||
from tensor_dev import Tensor
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '03_layers'))
|
||||
from layers_dev import Module
|
||||
|
||||
# Note: Keeping simplified implementations for reference during development
|
||||
class _SimplifiedTensor:
|
||||
"""Simplified tensor for spatial operations development."""
|
||||
|
||||
def __init__(self, data, requires_grad=False):
|
||||
self.data = np.array(data, dtype=np.float32)
|
||||
self.shape = self.data.shape
|
||||
self.requires_grad = requires_grad
|
||||
self.grad = None
|
||||
|
||||
def __repr__(self):
|
||||
return f"Tensor(shape={self.shape}, data=\n{self.data})"
|
||||
|
||||
def __add__(self, other):
|
||||
if isinstance(other, Tensor):
|
||||
return Tensor(self.data + other.data)
|
||||
return Tensor(self.data + other)
|
||||
|
||||
def __mul__(self, other):
|
||||
if isinstance(other, Tensor):
|
||||
return Tensor(self.data * other.data)
|
||||
return Tensor(self.data * other)
|
||||
|
||||
def sum(self):
|
||||
return Tensor(np.sum(self.data))
|
||||
|
||||
def mean(self):
|
||||
return Tensor(np.mean(self.data))
|
||||
|
||||
# Create a simple Module base class for inheritance
|
||||
class Module:
|
||||
"""Simple base class for neural network modules."""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def forward(self, x):
|
||||
raise NotImplementedError("Subclasses must implement forward()")
|
||||
|
||||
def parameters(self):
|
||||
"""Return list of parameters for this module."""
|
||||
params = []
|
||||
for attr_name in dir(self):
|
||||
attr = getattr(self, attr_name)
|
||||
if hasattr(attr, 'data') and hasattr(attr, 'requires_grad'):
|
||||
params.append(attr)
|
||||
return params
|
||||
18
tinytorch/core/tensor.py
generated
18
tinytorch/core/tensor.py
generated
@@ -1,5 +1,19 @@
|
||||
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../modules/source/01_tensor/tensor_dev.ipynb.
|
||||
|
||||
# ╔═══════════════════════════════════════════════════════════════════════════════╗
|
||||
# ║ 🚨 CRITICAL WARNING 🚨 ║
|
||||
# ║ AUTOGENERATED! DO NOT EDIT! ║
|
||||
# ║ ║
|
||||
# ║ This file is AUTOMATICALLY GENERATED from source modules. ║
|
||||
# ║ ANY CHANGES MADE HERE WILL BE LOST when modules are re-exported! ║
|
||||
# ║ ║
|
||||
# ║ ✅ TO EDIT: modules/source/02_tensor/tensor_dev.py ║
|
||||
# ║ ✅ TO EXPORT: Run 'tito module complete <module_name>' ║
|
||||
# ║ ║
|
||||
# ║ 🛡️ STUDENT PROTECTION: This file contains optimized implementations. ║
|
||||
# ║ Editing it directly may break module functionality and training. ║
|
||||
# ║ ║
|
||||
# ║ 🎓 LEARNING TIP: Work in modules/source/ - that's where real development ║
|
||||
# ║ happens! The tinytorch/ directory is just the compiled output. ║
|
||||
# ╚═══════════════════════════════════════════════════════════════════════════════╝
|
||||
# %% auto 0
|
||||
__all__ = ['Tensor']
|
||||
|
||||
|
||||
332
tinytorch/core/training.py
generated
332
tinytorch/core/training.py
generated
@@ -1,7 +1,21 @@
|
||||
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../modules/source/07_training/training_dev.ipynb.
|
||||
|
||||
# ╔═══════════════════════════════════════════════════════════════════════════════╗
|
||||
# ║ 🚨 CRITICAL WARNING 🚨 ║
|
||||
# ║ AUTOGENERATED! DO NOT EDIT! ║
|
||||
# ║ ║
|
||||
# ║ This file is AUTOMATICALLY GENERATED from source modules. ║
|
||||
# ║ ANY CHANGES MADE HERE WILL BE LOST when modules are re-exported! ║
|
||||
# ║ ║
|
||||
# ║ ✅ TO EDIT: modules/source/11_training/training_dev.py ║
|
||||
# ║ ✅ TO EXPORT: Run 'tito module complete <module_name>' ║
|
||||
# ║ ║
|
||||
# ║ 🛡️ STUDENT PROTECTION: This file contains optimized implementations. ║
|
||||
# ║ Editing it directly may break module functionality and training. ║
|
||||
# ║ ║
|
||||
# ║ 🎓 LEARNING TIP: Work in modules/source/ - that's where real development ║
|
||||
# ║ happens! The tinytorch/ directory is just the compiled output. ║
|
||||
# ╚═══════════════════════════════════════════════════════════════════════════════╝
|
||||
# %% auto 0
|
||||
__all__ = []
|
||||
__all__ = ['CosineSchedule', 'Trainer']
|
||||
|
||||
# %% ../../modules/source/07_training/training_dev.ipynb 1
|
||||
import numpy as np
|
||||
@@ -13,14 +27,310 @@ import sys
|
||||
import os
|
||||
|
||||
# Import dependencies from other modules
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '01_tensor'))
|
||||
from tensor_dev import Tensor
|
||||
from .tensor import Tensor
|
||||
from .layers import Linear
|
||||
from .losses import MSELoss, CrossEntropyLoss
|
||||
from .optimizers import SGD, AdamW
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '03_layers'))
|
||||
from layers_dev import Linear
|
||||
# %% ../../modules/source/07_training/training_dev.ipynb 6
|
||||
class CosineSchedule:
|
||||
"""
|
||||
Cosine annealing learning rate schedule.
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '04_losses'))
|
||||
from losses_dev import MSELoss, CrossEntropyLoss
|
||||
Starts at max_lr, decreases following a cosine curve to min_lr over T epochs.
|
||||
This provides aggressive learning initially, then fine-tuning at the end.
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '06_optimizers'))
|
||||
from optimizers_dev import SGD, AdamW
|
||||
TODO: Implement cosine annealing schedule
|
||||
|
||||
APPROACH:
|
||||
1. Store max_lr, min_lr, and total_epochs
|
||||
2. In get_lr(), compute cosine factor: (1 + cos(π * epoch / total_epochs)) / 2
|
||||
3. Interpolate: min_lr + (max_lr - min_lr) * cosine_factor
|
||||
|
||||
EXAMPLE:
|
||||
>>> schedule = CosineSchedule(max_lr=0.1, min_lr=0.01, total_epochs=100)
|
||||
>>> print(schedule.get_lr(0)) # Start: 0.1
|
||||
>>> print(schedule.get_lr(50)) # Middle: ~0.055
|
||||
>>> print(schedule.get_lr(100)) # End: 0.01
|
||||
|
||||
HINT: Use np.cos() and np.pi for the cosine calculation
|
||||
"""
|
||||
### BEGIN SOLUTION
|
||||
def __init__(self, max_lr: float = 0.1, min_lr: float = 0.01, total_epochs: int = 100):
|
||||
self.max_lr = max_lr
|
||||
self.min_lr = min_lr
|
||||
self.total_epochs = total_epochs
|
||||
|
||||
def get_lr(self, epoch: int) -> float:
|
||||
"""Get learning rate for current epoch."""
|
||||
if epoch >= self.total_epochs:
|
||||
return self.min_lr
|
||||
|
||||
# Cosine annealing formula
|
||||
cosine_factor = (1 + np.cos(np.pi * epoch / self.total_epochs)) / 2
|
||||
return self.min_lr + (self.max_lr - self.min_lr) * cosine_factor
|
||||
### END SOLUTION
|
||||
|
||||
# %% ../../modules/source/07_training/training_dev.ipynb 14
|
||||
class Trainer:
|
||||
"""
|
||||
Complete training orchestrator for neural networks.
|
||||
|
||||
Handles the full training lifecycle: forward pass, loss computation,
|
||||
backward pass, optimization, scheduling, checkpointing, and evaluation.
|
||||
|
||||
This is the central class that brings together all the components
|
||||
you've built in previous modules.
|
||||
|
||||
TODO: Implement complete Trainer class
|
||||
|
||||
APPROACH:
|
||||
1. Store model, optimizer, loss function, and optional scheduler
|
||||
2. train_epoch(): Loop through data, compute loss, update parameters
|
||||
3. evaluate(): Similar loop but without gradient updates
|
||||
4. save/load_checkpoint(): Persist training state for resumption
|
||||
|
||||
DESIGN PATTERNS:
|
||||
- Context managers for train/eval modes
|
||||
- Gradient accumulation for effective large batch sizes
|
||||
- Progress tracking for monitoring
|
||||
- Flexible scheduling integration
|
||||
"""
|
||||
### BEGIN SOLUTION
|
||||
def __init__(self, model, optimizer, loss_fn, scheduler=None, grad_clip_norm=None):
|
||||
"""
|
||||
Initialize trainer with model and training components.
|
||||
|
||||
Args:
|
||||
model: Neural network to train
|
||||
optimizer: Parameter update strategy (SGD, Adam, etc.)
|
||||
loss_fn: Loss function (CrossEntropy, MSE, etc.)
|
||||
scheduler: Optional learning rate scheduler
|
||||
grad_clip_norm: Optional gradient clipping threshold
|
||||
"""
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
self.loss_fn = loss_fn
|
||||
self.scheduler = scheduler
|
||||
self.grad_clip_norm = grad_clip_norm
|
||||
|
||||
# Training state
|
||||
self.epoch = 0
|
||||
self.step = 0
|
||||
self.training_mode = True
|
||||
|
||||
# History tracking
|
||||
self.history = {
|
||||
'train_loss': [],
|
||||
'eval_loss': [],
|
||||
'learning_rates': []
|
||||
}
|
||||
|
||||
def train_epoch(self, dataloader, accumulation_steps=1):
|
||||
"""
|
||||
Train for one epoch through the dataset.
|
||||
|
||||
Args:
|
||||
dataloader: Iterable yielding (inputs, targets) batches
|
||||
accumulation_steps: Number of batches to accumulate before update
|
||||
|
||||
Returns:
|
||||
Average loss for the epoch
|
||||
"""
|
||||
self.model.training = True
|
||||
self.training_mode = True
|
||||
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
accumulated_loss = 0.0
|
||||
|
||||
for batch_idx, (inputs, targets) in enumerate(dataloader):
|
||||
# Forward pass
|
||||
outputs = self.model.forward(inputs)
|
||||
loss = self.loss_fn.forward(outputs, targets)
|
||||
|
||||
# Scale loss for accumulation
|
||||
scaled_loss = loss.data / accumulation_steps
|
||||
accumulated_loss += scaled_loss
|
||||
|
||||
# Backward pass
|
||||
if hasattr(loss, 'backward'):
|
||||
loss.backward()
|
||||
|
||||
# Update parameters every accumulation_steps
|
||||
if (batch_idx + 1) % accumulation_steps == 0:
|
||||
# Gradient clipping
|
||||
if self.grad_clip_norm is not None:
|
||||
params = []
|
||||
if hasattr(self.model, 'parameters'):
|
||||
params = self.model.parameters()
|
||||
clip_grad_norm(params, self.grad_clip_norm)
|
||||
|
||||
# Optimizer step
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
total_loss += accumulated_loss
|
||||
accumulated_loss = 0.0
|
||||
num_batches += 1
|
||||
self.step += 1
|
||||
|
||||
# Handle remaining accumulated gradients
|
||||
if accumulated_loss > 0:
|
||||
if self.grad_clip_norm is not None:
|
||||
params = []
|
||||
if hasattr(self.model, 'parameters'):
|
||||
params = self.model.parameters()
|
||||
clip_grad_norm(params, self.grad_clip_norm)
|
||||
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
total_loss += accumulated_loss
|
||||
num_batches += 1
|
||||
|
||||
avg_loss = total_loss / max(num_batches, 1)
|
||||
self.history['train_loss'].append(avg_loss)
|
||||
|
||||
# Update scheduler
|
||||
if self.scheduler is not None:
|
||||
current_lr = self.scheduler.get_lr(self.epoch)
|
||||
# Update optimizer learning rate
|
||||
if hasattr(self.optimizer, 'lr'):
|
||||
self.optimizer.lr = current_lr
|
||||
self.history['learning_rates'].append(current_lr)
|
||||
|
||||
self.epoch += 1
|
||||
return avg_loss
|
||||
|
||||
def evaluate(self, dataloader):
|
||||
"""
|
||||
Evaluate model on dataset without updating parameters.
|
||||
|
||||
Args:
|
||||
dataloader: Iterable yielding (inputs, targets) batches
|
||||
|
||||
Returns:
|
||||
Average loss and accuracy
|
||||
"""
|
||||
self.model.training = False
|
||||
self.training_mode = False
|
||||
|
||||
total_loss = 0.0
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
for inputs, targets in dataloader:
|
||||
# Forward pass only
|
||||
outputs = self.model.forward(inputs)
|
||||
loss = self.loss_fn.forward(outputs, targets)
|
||||
|
||||
total_loss += loss.data
|
||||
|
||||
# Calculate accuracy (for classification)
|
||||
if hasattr(outputs, 'data') and hasattr(targets, 'data'):
|
||||
if len(outputs.data.shape) > 1: # Multi-class
|
||||
predictions = np.argmax(outputs.data, axis=1)
|
||||
if len(targets.data.shape) == 1: # Integer targets
|
||||
correct += np.sum(predictions == targets.data)
|
||||
else: # One-hot targets
|
||||
correct += np.sum(predictions == np.argmax(targets.data, axis=1))
|
||||
total += len(predictions)
|
||||
|
||||
avg_loss = total_loss / len(dataloader) if len(dataloader) > 0 else 0.0
|
||||
accuracy = correct / total if total > 0 else 0.0
|
||||
|
||||
self.history['eval_loss'].append(avg_loss)
|
||||
|
||||
return avg_loss, accuracy
|
||||
|
||||
def save_checkpoint(self, path: str):
|
||||
"""
|
||||
Save complete training state for resumption.
|
||||
|
||||
Args:
|
||||
path: File path to save checkpoint
|
||||
"""
|
||||
checkpoint = {
|
||||
'epoch': self.epoch,
|
||||
'step': self.step,
|
||||
'model_state': self._get_model_state(),
|
||||
'optimizer_state': self._get_optimizer_state(),
|
||||
'scheduler_state': self._get_scheduler_state(),
|
||||
'history': self.history,
|
||||
'training_mode': self.training_mode
|
||||
}
|
||||
|
||||
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, 'wb') as f:
|
||||
pickle.dump(checkpoint, f)
|
||||
|
||||
def load_checkpoint(self, path: str):
|
||||
"""
|
||||
Load training state from checkpoint.
|
||||
|
||||
Args:
|
||||
path: File path to load checkpoint from
|
||||
"""
|
||||
with open(path, 'rb') as f:
|
||||
checkpoint = pickle.load(f)
|
||||
|
||||
self.epoch = checkpoint['epoch']
|
||||
self.step = checkpoint['step']
|
||||
self.history = checkpoint['history']
|
||||
self.training_mode = checkpoint['training_mode']
|
||||
|
||||
# Restore states (simplified for educational purposes)
|
||||
if 'model_state' in checkpoint:
|
||||
self._set_model_state(checkpoint['model_state'])
|
||||
if 'optimizer_state' in checkpoint:
|
||||
self._set_optimizer_state(checkpoint['optimizer_state'])
|
||||
if 'scheduler_state' in checkpoint:
|
||||
self._set_scheduler_state(checkpoint['scheduler_state'])
|
||||
|
||||
def _get_model_state(self):
|
||||
"""Extract model parameters for checkpointing."""
|
||||
if hasattr(self.model, 'parameters'):
|
||||
return {i: param.data.copy() for i, param in enumerate(self.model.parameters())}
|
||||
return {}
|
||||
|
||||
def _set_model_state(self, state):
|
||||
"""Restore model parameters from checkpoint."""
|
||||
if hasattr(self.model, 'parameters'):
|
||||
for i, param in enumerate(self.model.parameters()):
|
||||
if i in state:
|
||||
param.data = state[i].copy()
|
||||
|
||||
def _get_optimizer_state(self):
|
||||
"""Extract optimizer state for checkpointing."""
|
||||
state = {}
|
||||
if hasattr(self.optimizer, 'lr'):
|
||||
state['lr'] = self.optimizer.lr
|
||||
if hasattr(self.optimizer, 'momentum_buffers'):
|
||||
state['momentum_buffers'] = self.optimizer.momentum_buffers.copy()
|
||||
return state
|
||||
|
||||
def _set_optimizer_state(self, state):
|
||||
"""Restore optimizer state from checkpoint."""
|
||||
if 'lr' in state and hasattr(self.optimizer, 'lr'):
|
||||
self.optimizer.lr = state['lr']
|
||||
if 'momentum_buffers' in state and hasattr(self.optimizer, 'momentum_buffers'):
|
||||
self.optimizer.momentum_buffers = state['momentum_buffers']
|
||||
|
||||
def _get_scheduler_state(self):
|
||||
"""Extract scheduler state for checkpointing."""
|
||||
if self.scheduler is None:
|
||||
return None
|
||||
return {
|
||||
'max_lr': getattr(self.scheduler, 'max_lr', None),
|
||||
'min_lr': getattr(self.scheduler, 'min_lr', None),
|
||||
'total_epochs': getattr(self.scheduler, 'total_epochs', None)
|
||||
}
|
||||
|
||||
def _set_scheduler_state(self, state):
|
||||
"""Restore scheduler state from checkpoint."""
|
||||
if state is None or self.scheduler is None:
|
||||
return
|
||||
for key, value in state.items():
|
||||
if hasattr(self.scheduler, key):
|
||||
setattr(self.scheduler, key, value)
|
||||
### END SOLUTION
|
||||
|
||||
Reference in New Issue
Block a user