mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-04-29 18:42:33 -05:00
fix(module-05): Add SubBackward and DivBackward for autograd
- Implement gradient functions for subtraction and division operations - Patch Tensor.__sub__ and Tensor.__truediv__ in enable_autograd() - Required for LayerNorm (x - mean) and (normalized / std) operations These operations are used extensively in normalization layers
This commit is contained in:
123
tinytorch/core/autograd.py
generated
123
tinytorch/core/autograd.py
generated
@@ -15,8 +15,8 @@
|
||||
# ║ happens! The tinytorch/ directory is just the compiled output. ║
|
||||
# ╚═══════════════════════════════════════════════════════════════════════════════╝
|
||||
# %% auto 0
|
||||
__all__ = ['Function', 'AddBackward', 'MulBackward', 'MatmulBackward', 'SumBackward', 'ReLUBackward', 'SigmoidBackward',
|
||||
'MSEBackward', 'BCEBackward', 'CrossEntropyBackward', 'enable_autograd']
|
||||
__all__ = ['Function', 'AddBackward', 'MulBackward', 'SubBackward', 'DivBackward', 'MatmulBackward', 'SumBackward',
|
||||
'ReLUBackward', 'SigmoidBackward', 'MSEBackward', 'BCEBackward', 'CrossEntropyBackward', 'enable_autograd']
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 1
|
||||
import numpy as np
|
||||
@@ -164,6 +164,65 @@ class MulBackward(Function):
|
||||
return grad_a, grad_b
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 13
|
||||
class SubBackward(Function):
|
||||
"""
|
||||
Gradient computation for tensor subtraction.
|
||||
|
||||
**Mathematical Rule:** If z = a - b, then ∂z/∂a = 1 and ∂z/∂b = -1
|
||||
"""
|
||||
|
||||
def apply(self, grad_output):
|
||||
"""
|
||||
Compute gradients for subtraction.
|
||||
|
||||
Returns:
|
||||
Tuple of (grad_a, grad_b) where grad_b is negated
|
||||
"""
|
||||
a, b = self.saved_tensors
|
||||
grad_a = grad_b = None
|
||||
|
||||
if isinstance(a, Tensor) and a.requires_grad:
|
||||
grad_a = grad_output # ∂(a-b)/∂a = 1
|
||||
|
||||
if isinstance(b, Tensor) and b.requires_grad:
|
||||
grad_b = -grad_output # ∂(a-b)/∂b = -1 (note the negative!)
|
||||
|
||||
return grad_a, grad_b
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 15
|
||||
class DivBackward(Function):
|
||||
"""
|
||||
Gradient computation for tensor division.
|
||||
|
||||
**Mathematical Rule:** If z = a / b, then:
|
||||
- ∂z/∂a = 1/b
|
||||
- ∂z/∂b = -a/b²
|
||||
"""
|
||||
|
||||
def apply(self, grad_output):
|
||||
"""
|
||||
Compute gradients for division using quotient rule.
|
||||
|
||||
Returns:
|
||||
Tuple of (grad_a, grad_b)
|
||||
"""
|
||||
a, b = self.saved_tensors
|
||||
grad_a = grad_b = None
|
||||
|
||||
if isinstance(a, Tensor) and a.requires_grad:
|
||||
# ∂(a/b)/∂a = 1/b
|
||||
if isinstance(b, Tensor):
|
||||
grad_a = grad_output / b.data
|
||||
else:
|
||||
grad_a = grad_output / b
|
||||
|
||||
if isinstance(b, Tensor) and b.requires_grad:
|
||||
# ∂(a/b)/∂b = -a/b²
|
||||
grad_b = -grad_output * a.data / (b.data ** 2)
|
||||
|
||||
return grad_a, grad_b
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 17
|
||||
class MatmulBackward(Function):
|
||||
"""
|
||||
Gradient computation for matrix multiplication.
|
||||
@@ -206,7 +265,7 @@ class MatmulBackward(Function):
|
||||
|
||||
return grad_a, grad_b
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 15
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 19
|
||||
class SumBackward(Function):
|
||||
"""
|
||||
Gradient computation for tensor sum.
|
||||
@@ -240,7 +299,7 @@ class SumBackward(Function):
|
||||
return np.ones_like(tensor.data) * grad_output,
|
||||
return None,
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 20
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 24
|
||||
class ReLUBackward(Function):
|
||||
"""
|
||||
Gradient computation for ReLU activation.
|
||||
@@ -263,7 +322,7 @@ class ReLUBackward(Function):
|
||||
return grad_output * relu_grad,
|
||||
return None,
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 21
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 25
|
||||
class SigmoidBackward(Function):
|
||||
"""
|
||||
Gradient computation for sigmoid activation.
|
||||
@@ -293,7 +352,7 @@ class SigmoidBackward(Function):
|
||||
return grad_output * sigmoid_grad,
|
||||
return None,
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 22
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 26
|
||||
class MSEBackward(Function):
|
||||
"""
|
||||
Gradient computation for Mean Squared Error Loss.
|
||||
@@ -319,7 +378,7 @@ class MSEBackward(Function):
|
||||
return grad * grad_output,
|
||||
return None,
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 23
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 27
|
||||
class BCEBackward(Function):
|
||||
"""
|
||||
Gradient computation for Binary Cross-Entropy Loss.
|
||||
@@ -349,7 +408,7 @@ class BCEBackward(Function):
|
||||
return grad * grad_output,
|
||||
return None,
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 24
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 28
|
||||
class CrossEntropyBackward(Function):
|
||||
"""
|
||||
Gradient computation for Cross-Entropy Loss.
|
||||
@@ -394,7 +453,7 @@ class CrossEntropyBackward(Function):
|
||||
return grad * grad_output,
|
||||
return None,
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 25
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 29
|
||||
def enable_autograd():
|
||||
"""
|
||||
Enable gradient tracking for all Tensor operations.
|
||||
@@ -431,7 +490,9 @@ def enable_autograd():
|
||||
|
||||
# Store original operations
|
||||
_original_add = Tensor.__add__
|
||||
_original_sub = Tensor.__sub__
|
||||
_original_mul = Tensor.__mul__
|
||||
_original_div = Tensor.__truediv__
|
||||
_original_matmul = Tensor.matmul if hasattr(Tensor, 'matmul') else None
|
||||
|
||||
# Enhanced operations that track gradients
|
||||
@@ -499,6 +560,48 @@ def enable_autograd():
|
||||
|
||||
return result
|
||||
|
||||
def tracked_sub(self, other):
|
||||
"""
|
||||
Subtraction with gradient tracking.
|
||||
|
||||
Enhances the original __sub__ method to build computation graphs
|
||||
when requires_grad=True for any input.
|
||||
"""
|
||||
# Convert scalar to Tensor if needed
|
||||
if not isinstance(other, Tensor):
|
||||
other = Tensor(other)
|
||||
|
||||
# Call original operation
|
||||
result = _original_sub(self, other)
|
||||
|
||||
# Track gradient if needed
|
||||
if self.requires_grad or other.requires_grad:
|
||||
result.requires_grad = True
|
||||
result._grad_fn = SubBackward(self, other)
|
||||
|
||||
return result
|
||||
|
||||
def tracked_div(self, other):
|
||||
"""
|
||||
Division with gradient tracking.
|
||||
|
||||
Enhances the original __truediv__ method to build computation graphs
|
||||
when requires_grad=True for any input.
|
||||
"""
|
||||
# Convert scalar to Tensor if needed
|
||||
if not isinstance(other, Tensor):
|
||||
other = Tensor(other)
|
||||
|
||||
# Call original operation
|
||||
result = _original_div(self, other)
|
||||
|
||||
# Track gradient if needed
|
||||
if self.requires_grad or other.requires_grad:
|
||||
result.requires_grad = True
|
||||
result._grad_fn = DivBackward(self, other)
|
||||
|
||||
return result
|
||||
|
||||
def sum_op(self, axis=None, keepdims=False):
|
||||
"""
|
||||
Sum operation with gradient tracking.
|
||||
@@ -587,7 +690,9 @@ def enable_autograd():
|
||||
|
||||
# Install enhanced operations
|
||||
Tensor.__add__ = tracked_add
|
||||
Tensor.__sub__ = tracked_sub
|
||||
Tensor.__mul__ = tracked_mul
|
||||
Tensor.__truediv__ = tracked_div
|
||||
Tensor.matmul = tracked_matmul
|
||||
Tensor.sum = sum_op
|
||||
Tensor.backward = backward
|
||||
|
||||
Reference in New Issue
Block a user