mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-03-11 22:33:36 -05:00
Module improvements: Core modules (01-08)
- Update tensor module notebook - Enhance activations module - Expand layers module functionality - Improve autograd implementation - Add optimizers enhancements - Update training module - Refine dataloader notebook
This commit is contained in:
@@ -445,6 +445,75 @@ class SGD(Optimizer):
|
||||
self.momentum_buffers = [None for _ in self.params]
|
||||
### END SOLUTION
|
||||
|
||||
def has_momentum(self) -> bool:
|
||||
"""
|
||||
Check if this optimizer uses momentum.
|
||||
|
||||
This explicit API method replaces the need for hasattr() checks
|
||||
in checkpointing code (Module 07).
|
||||
|
||||
Returns:
|
||||
bool: True if momentum is enabled (momentum > 0), False otherwise
|
||||
|
||||
Example:
|
||||
>>> optimizer = SGD(params, lr=0.01, momentum=0.9)
|
||||
>>> optimizer.has_momentum()
|
||||
True
|
||||
"""
|
||||
return self.momentum > 0
|
||||
|
||||
def get_momentum_state(self) -> Optional[List]:
|
||||
"""
|
||||
Get momentum buffers for checkpointing.
|
||||
|
||||
This explicit API method provides safe access to momentum buffers
|
||||
without using hasattr(), making the API contract clear.
|
||||
|
||||
Returns:
|
||||
Optional[List]: List of momentum buffers if momentum is enabled,
|
||||
None otherwise
|
||||
|
||||
Example:
|
||||
>>> optimizer = SGD(params, lr=0.01, momentum=0.9)
|
||||
>>> optimizer.step() # Initialize buffers
|
||||
>>> state = optimizer.get_momentum_state()
|
||||
>>> # Later: optimizer.set_momentum_state(state)
|
||||
"""
|
||||
if not self.has_momentum():
|
||||
return None
|
||||
return [buf.copy() if buf is not None else None
|
||||
for buf in self.momentum_buffers]
|
||||
|
||||
def set_momentum_state(self, state: Optional[List]) -> None:
|
||||
"""
|
||||
Restore momentum buffers from checkpointing.
|
||||
|
||||
This explicit API method provides safe restoration of momentum state
|
||||
without using hasattr().
|
||||
|
||||
Args:
|
||||
state: List of momentum buffers or None
|
||||
|
||||
Example:
|
||||
>>> optimizer = SGD(params, lr=0.01, momentum=0.9)
|
||||
>>> state = optimizer.get_momentum_state()
|
||||
>>> # Training interruption...
|
||||
>>> new_optimizer = SGD(params, lr=0.01, momentum=0.9)
|
||||
>>> new_optimizer.set_momentum_state(state)
|
||||
"""
|
||||
if state is None or not self.has_momentum():
|
||||
return
|
||||
|
||||
if len(state) != len(self.momentum_buffers):
|
||||
raise ValueError(
|
||||
f"State length {len(state)} doesn't match "
|
||||
f"optimizer parameters {len(self.momentum_buffers)}"
|
||||
)
|
||||
|
||||
for i, buf in enumerate(state):
|
||||
if buf is not None:
|
||||
self.momentum_buffers[i] = buf.copy()
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
Perform SGD update step with momentum.
|
||||
|
||||
Reference in New Issue
Block a user