Fix autograd module: Add missing subtract function

- Added subtract function with proper gradient computation
- Implemented subtraction rule: d(x-y)/dx = 1, d(x-y)/dy = -1
- Added comprehensive tests for subtraction operation
- Fixed chain rule tests that depend on subtract function
- All autograd tests now passing (8/8 modules fully functional)

The autograd module is now complete with all basic operations:
- Variable class with gradient tracking
- Addition, multiplication, and subtraction operations
- Automatic differentiation through computational graphs
- Chain rule implementation for complex expressions
- Neural network training integration ready
This commit is contained in:
Vijay Janapa Reddi
2025-07-13 16:59:07 -04:00
parent cd770773f6
commit 9bec78333f

View File

@@ -37,7 +37,13 @@ from typing import Union, List, Tuple, Optional, Any, Callable
from collections import defaultdict
# Import our existing components
from tinytorch.core.tensor import Tensor
try:
from tinytorch.core.tensor import Tensor
except ImportError:
# For development, import from local modules
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '01_tensor'))
from tensor_dev import Tensor
# %% nbgrader={"grade": false, "grade_id": "autograd-setup", "locked": false, "schema_version": 3, "solution": false, "task": false}
print("🔥 TinyTorch Autograd Module")
@@ -617,6 +623,103 @@ def test_multiply_operation():
# Run the test
test_multiply_operation()
# %% nbgrader={"grade": false, "grade_id": "subtract-operation", "locked": false, "schema_version": 3, "solution": true, "task": false}
#| export
def subtract(a: Union[Variable, float, int], b: Union[Variable, float, int]) -> Variable:
"""
Subtraction operation with gradient tracking.
Args:
a: First operand (minuend)
b: Second operand (subtrahend)
Returns:
Variable with difference and gradient function
TODO: Implement subtraction with gradient computation.
APPROACH:
1. Convert inputs to Variables if needed
2. Compute forward pass: result = a - b
3. Create gradient function with correct signs
4. Return Variable with result and grad_fn
MATHEMATICAL RULE:
If z = x - y, then dz/dx = 1, dz/dy = -1
EXAMPLE:
x = Variable(5.0), y = Variable(3.0)
z = subtract(x, y) # z.data = 2.0
z.backward() # x.grad = 1.0, y.grad = -1.0
HINTS:
- Forward pass is straightforward: a - b
- Gradient for a is positive, for b is negative
- Remember to negate the gradient for b
"""
### BEGIN SOLUTION
# Convert to Variables if needed
if not isinstance(a, Variable):
a = Variable(a, requires_grad=False)
if not isinstance(b, Variable):
b = Variable(b, requires_grad=False)
# Forward pass
result_data = a.data - b.data
# Create gradient function
def grad_fn(grad_output):
# Subtraction rule: d(x-y)/dx = 1, d(x-y)/dy = -1
if a.requires_grad:
a.backward(grad_output)
if b.requires_grad:
b_grad = Variable(-grad_output.data.data)
b.backward(b_grad)
# Determine if result requires gradients
requires_grad = a.requires_grad or b.requires_grad
return Variable(result_data, requires_grad=requires_grad, grad_fn=grad_fn)
### END SOLUTION
# %% nbgrader={"grade": false, "grade_id": "test-subtract-operation", "locked": false, "schema_version": 3, "solution": false, "task": false}
def test_subtract_operation():
"""Test subtraction operation with gradients"""
print("Testing subtraction operation...")
# Test basic subtraction
x = Variable(5.0, requires_grad=True)
y = Variable(3.0, requires_grad=True)
z = subtract(x, y)
assert z.data.data.item() == 2.0, "Subtraction result should be 2.0"
assert z.requires_grad == True, "Result should require gradients"
# Test backward pass
z.backward()
assert x.grad is not None, "x should have gradient"
assert y.grad is not None, "y should have gradient"
assert x.grad.data.data.item() == 1.0, "∂z/∂x should be 1.0"
assert y.grad.data.data.item() == -1.0, "∂z/∂y should be -1.0"
# Test with scalar
a = Variable(4.0, requires_grad=True)
b = subtract(a, 2.0) # Subtract scalar
assert b.data.data.item() == 2.0, "Subtraction with scalar should work"
b.backward()
assert a.grad.data.data.item() == 1.0, "Gradient through scalar subtraction should be 1.0"
print("✅ Subtraction operation tests passed!")
print(f"✅ Forward pass computing correct results")
print(f"✅ Backward pass implementing subtraction rule correctly")
print(f"✅ Scalar subtraction working correctly")
# Run the test
test_subtract_operation()
# %% [markdown]
"""
## Step 4: Chain Rule in Complex Expressions