mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-03-11 22:03:34 -05:00
Update autograd module with latest changes
This commit is contained in:
216
tinytorch/core/autograd.py
generated
216
tinytorch/core/autograd.py
generated
@@ -1,7 +1,22 @@
|
||||
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../modules/source/05_autograd/autograd_dev.ipynb.
|
||||
|
||||
# ╔═══════════════════════════════════════════════════════════════════════════════╗
|
||||
# ║ 🚨 CRITICAL WARNING 🚨 ║
|
||||
# ║ AUTOGENERATED! DO NOT EDIT! ║
|
||||
# ║ ║
|
||||
# ║ This file is AUTOMATICALLY GENERATED from source modules. ║
|
||||
# ║ ANY CHANGES MADE HERE WILL BE LOST when modules are re-exported! ║
|
||||
# ║ ║
|
||||
# ║ ✅ TO EDIT: modules/source/09_autograd/autograd_dev.py ║
|
||||
# ║ ✅ TO EXPORT: Run 'tito module complete <module_name>' ║
|
||||
# ║ ║
|
||||
# ║ 🛡️ STUDENT PROTECTION: This file contains optimized implementations. ║
|
||||
# ║ Editing it directly may break module functionality and training. ║
|
||||
# ║ ║
|
||||
# ║ 🎓 LEARNING TIP: Work in modules/source/ - that's where real development ║
|
||||
# ║ happens! The tinytorch/ directory is just the compiled output. ║
|
||||
# ╚═══════════════════════════════════════════════════════════════════════════════╝
|
||||
# %% auto 0
|
||||
__all__ = ['Function', 'AddBackward', 'MulBackward', 'MatmulBackward', 'SumBackward', 'SigmoidBackward', 'BCEBackward']
|
||||
__all__ = ['Function', 'AddBackward', 'MulBackward', 'MatmulBackward', 'SumBackward', 'SigmoidBackward', 'BCEBackward',
|
||||
'enable_autograd']
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 1
|
||||
import numpy as np
|
||||
@@ -284,3 +299,198 @@ class BCEBackward(Function):
|
||||
|
||||
return grad * grad_output,
|
||||
return None,
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 22
|
||||
def enable_autograd():
|
||||
"""
|
||||
Enable gradient tracking for all Tensor operations.
|
||||
|
||||
This function enhances the existing Tensor class with autograd capabilities.
|
||||
Call this once to activate gradients globally.
|
||||
|
||||
**What it does:**
|
||||
- Replaces Tensor operations with gradient-tracking versions
|
||||
- Adds backward() method for reverse-mode differentiation
|
||||
- Enables computation graph building
|
||||
- Maintains full backward compatibility
|
||||
|
||||
**After calling this:**
|
||||
- Tensor operations will track computation graphs
|
||||
- backward() method becomes available
|
||||
- Gradients will flow through operations
|
||||
- requires_grad=True enables tracking per tensor
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
enable_autograd() # Call once
|
||||
x = Tensor([2.0], requires_grad=True)
|
||||
y = x * 3
|
||||
y.backward()
|
||||
print(x.grad) # [3.0]
|
||||
```
|
||||
"""
|
||||
|
||||
# Check if already enabled
|
||||
if hasattr(Tensor, '_autograd_enabled'):
|
||||
print("⚠️ Autograd already enabled")
|
||||
return
|
||||
|
||||
# Store original operations
|
||||
_original_add = Tensor.__add__
|
||||
_original_mul = Tensor.__mul__
|
||||
_original_matmul = Tensor.matmul if hasattr(Tensor, 'matmul') else None
|
||||
|
||||
# Enhanced operations that track gradients
|
||||
def tracked_add(self, other):
|
||||
"""
|
||||
Addition with gradient tracking.
|
||||
|
||||
Enhances the original __add__ method to build computation graphs
|
||||
when requires_grad=True for any input.
|
||||
"""
|
||||
# Convert scalar to Tensor if needed
|
||||
if not isinstance(other, Tensor):
|
||||
other = Tensor(other)
|
||||
|
||||
# Call original operation
|
||||
result = _original_add(self, other)
|
||||
|
||||
# Track gradient if needed
|
||||
if self.requires_grad or other.requires_grad:
|
||||
result.requires_grad = True
|
||||
result._grad_fn = AddBackward(self, other)
|
||||
|
||||
return result
|
||||
|
||||
def tracked_mul(self, other):
|
||||
"""
|
||||
Multiplication with gradient tracking.
|
||||
|
||||
Enhances the original __mul__ method to build computation graphs
|
||||
when requires_grad=True for any input.
|
||||
"""
|
||||
# Convert scalar to Tensor if needed for consistency
|
||||
if not isinstance(other, Tensor):
|
||||
other_tensor = Tensor(other)
|
||||
else:
|
||||
other_tensor = other
|
||||
|
||||
# Call original operation
|
||||
result = _original_mul(self, other)
|
||||
|
||||
# Track gradient if needed
|
||||
if self.requires_grad or (isinstance(other, Tensor) and other.requires_grad):
|
||||
result.requires_grad = True
|
||||
result._grad_fn = MulBackward(self, other)
|
||||
|
||||
return result
|
||||
|
||||
def tracked_matmul(self, other):
|
||||
"""
|
||||
Matrix multiplication with gradient tracking.
|
||||
|
||||
Enhances the original matmul method to build computation graphs
|
||||
when requires_grad=True for any input.
|
||||
"""
|
||||
if _original_matmul:
|
||||
result = _original_matmul(self, other)
|
||||
else:
|
||||
# Fallback if matmul doesn't exist
|
||||
result = Tensor(np.dot(self.data, other.data))
|
||||
|
||||
# Track gradient if needed
|
||||
if self.requires_grad or other.requires_grad:
|
||||
result.requires_grad = True
|
||||
result._grad_fn = MatmulBackward(self, other)
|
||||
|
||||
return result
|
||||
|
||||
def sum_op(self, axis=None, keepdims=False):
|
||||
"""
|
||||
Sum operation with gradient tracking.
|
||||
|
||||
Creates a new sum method that builds computation graphs
|
||||
when requires_grad=True.
|
||||
"""
|
||||
result_data = np.sum(self.data, axis=axis, keepdims=keepdims)
|
||||
result = Tensor(result_data)
|
||||
|
||||
if self.requires_grad:
|
||||
result.requires_grad = True
|
||||
result._grad_fn = SumBackward(self)
|
||||
|
||||
return result
|
||||
|
||||
def backward(self, gradient=None):
|
||||
"""
|
||||
Compute gradients via backpropagation.
|
||||
|
||||
This is the key method that makes training possible!
|
||||
It implements reverse-mode automatic differentiation.
|
||||
|
||||
**Algorithm:**
|
||||
1. Initialize gradient if not provided (for scalar outputs)
|
||||
2. Accumulate gradient in self.grad
|
||||
3. If this tensor has a _grad_fn, call it to propagate gradients
|
||||
4. Recursively call backward() on parent tensors
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
x = Tensor([2.0], requires_grad=True)
|
||||
y = x * 3
|
||||
y.backward() # Computes gradients for x
|
||||
print(x.grad) # [3.0]
|
||||
```
|
||||
"""
|
||||
# Only compute gradients if required
|
||||
if not self.requires_grad:
|
||||
return
|
||||
|
||||
# Initialize gradient if not provided (for scalar outputs)
|
||||
if gradient is None:
|
||||
if self.data.size == 1:
|
||||
gradient = np.ones_like(self.data)
|
||||
else:
|
||||
raise ValueError("backward() requires gradient for non-scalar outputs")
|
||||
|
||||
# Initialize or accumulate gradient
|
||||
if self.grad is None:
|
||||
self.grad = np.zeros_like(self.data)
|
||||
self.grad += gradient
|
||||
|
||||
# Propagate gradients through computation graph
|
||||
if hasattr(self, '_grad_fn') and self._grad_fn:
|
||||
grads = self._grad_fn.apply(gradient)
|
||||
|
||||
# Recursively call backward on parent tensors
|
||||
for tensor, grad in zip(self._grad_fn.saved_tensors, grads):
|
||||
if isinstance(tensor, Tensor) and tensor.requires_grad and grad is not None:
|
||||
tensor.backward(grad)
|
||||
|
||||
def zero_grad(self):
|
||||
"""
|
||||
Reset gradients to zero.
|
||||
|
||||
Call this before each backward pass to prevent gradient accumulation
|
||||
from previous iterations.
|
||||
"""
|
||||
self.grad = None
|
||||
|
||||
# Install enhanced operations
|
||||
Tensor.__add__ = tracked_add
|
||||
Tensor.__mul__ = tracked_mul
|
||||
Tensor.matmul = tracked_matmul
|
||||
Tensor.sum = sum_op
|
||||
Tensor.backward = backward
|
||||
Tensor.zero_grad = zero_grad
|
||||
|
||||
# Mark as enabled
|
||||
Tensor._autograd_enabled = True
|
||||
|
||||
print("✅ Autograd enabled! Tensors now track gradients.")
|
||||
print(" - Operations build computation graphs")
|
||||
print(" - backward() computes gradients")
|
||||
print(" - requires_grad=True enables tracking")
|
||||
|
||||
# Auto-enable when module is imported
|
||||
enable_autograd()
|
||||
|
||||
Reference in New Issue
Block a user