Module improvements: Core modules (01-08)

- Update tensor module notebook
- Enhance activations module
- Expand layers module functionality
- Improve autograd implementation
- Add optimizers enhancements
- Update training module
- Refine dataloader notebook
This commit is contained in:
Vijay Janapa Reddi
2025-11-11 19:05:00 -05:00
parent f445e133ac
commit c8555bdb78
7 changed files with 787 additions and 403 deletions

View File

@@ -445,6 +445,75 @@ class SGD(Optimizer):
self.momentum_buffers = [None for _ in self.params]
### END SOLUTION
def has_momentum(self) -> bool:
"""
Check if this optimizer uses momentum.
This explicit API method replaces the need for hasattr() checks
in checkpointing code (Module 07).
Returns:
bool: True if momentum is enabled (momentum > 0), False otherwise
Example:
>>> optimizer = SGD(params, lr=0.01, momentum=0.9)
>>> optimizer.has_momentum()
True
"""
return self.momentum > 0
def get_momentum_state(self) -> Optional[List]:
"""
Get momentum buffers for checkpointing.
This explicit API method provides safe access to momentum buffers
without using hasattr(), making the API contract clear.
Returns:
Optional[List]: List of momentum buffers if momentum is enabled,
None otherwise
Example:
>>> optimizer = SGD(params, lr=0.01, momentum=0.9)
>>> optimizer.step() # Initialize buffers
>>> state = optimizer.get_momentum_state()
>>> # Later: optimizer.set_momentum_state(state)
"""
if not self.has_momentum():
return None
return [buf.copy() if buf is not None else None
for buf in self.momentum_buffers]
def set_momentum_state(self, state: Optional[List]) -> None:
"""
Restore momentum buffers from checkpointing.
This explicit API method provides safe restoration of momentum state
without using hasattr().
Args:
state: List of momentum buffers or None
Example:
>>> optimizer = SGD(params, lr=0.01, momentum=0.9)
>>> state = optimizer.get_momentum_state()
>>> # Training interruption...
>>> new_optimizer = SGD(params, lr=0.01, momentum=0.9)
>>> new_optimizer.set_momentum_state(state)
"""
if state is None or not self.has_momentum():
return
if len(state) != len(self.momentum_buffers):
raise ValueError(
f"State length {len(state)} doesn't match "
f"optimizer parameters {len(self.momentum_buffers)}"
)
for i, buf in enumerate(state):
if buf is not None:
self.momentum_buffers[i] = buf.copy()
def step(self):
"""
Perform SGD update step with momentum.