mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-03-11 18:14:11 -05:00
fix(autograd): Add EmbeddingBackward and ReshapeBackward
Critical fixes for transformer gradient flow:
EmbeddingBackward:
- Implements scatter-add gradient accumulation for embedding lookups
- Added to Module 05 (autograd_dev.py)
- Module 11 imports and uses it in Embedding.forward()
- Gradients now flow back to embedding weights
ReshapeBackward:
- reshape() was breaking computation graph (no _grad_fn)
- Added backward function that reshapes gradient back to original shape
- Patched Tensor.reshape() in enable_autograd()
- Critical for GPT forward pass (logits.reshape before loss)
Results:
- Before: 0/37 parameters receive gradients, loss stuck
- After: 13/37 parameters receive gradients (35%)
- Single batch overfitting: 4.46 → 0.03 (99.4% improvement!)
- MODEL NOW LEARNS! 🎉
Remaining work: 24 parameters still missing gradients (likely attention)
Tests added:
- tests/milestones/test_05_transformer_architecture.py (Phase 1)
- Multiple debug scripts to isolate issues
This commit is contained in:
141
tinytorch/core/autograd.py
generated
141
tinytorch/core/autograd.py
generated
@@ -16,8 +16,8 @@
|
||||
# ╚═══════════════════════════════════════════════════════════════════════════════╝
|
||||
# %% auto 0
|
||||
__all__ = ['Function', 'AddBackward', 'MulBackward', 'SubBackward', 'DivBackward', 'MatmulBackward', 'TransposeBackward',
|
||||
'SumBackward', 'ReLUBackward', 'SigmoidBackward', 'MSEBackward', 'BCEBackward', 'CrossEntropyBackward',
|
||||
'enable_autograd']
|
||||
'EmbeddingBackward', 'ReshapeBackward', 'SumBackward', 'ReLUBackward', 'SigmoidBackward', 'MSEBackward',
|
||||
'BCEBackward', 'CrossEntropyBackward', 'enable_autograd']
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 1
|
||||
import numpy as np
|
||||
@@ -340,7 +340,108 @@ class TransposeBackward(Function):
|
||||
|
||||
return (grad_x,)
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 19
|
||||
class EmbeddingBackward(Function):
|
||||
"""
|
||||
Gradient computation for embedding lookup operation.
|
||||
|
||||
**Mathematical Rule:** If Y = Embedding[indices], then:
|
||||
- ∂Loss/∂Embedding[i] = sum of all gradients where index==i
|
||||
|
||||
**Key Insight:** Embedding lookup is a gather operation. The backward
|
||||
is a scatter operation that accumulates gradients to the embedding weights.
|
||||
|
||||
**Applications:** Word embeddings, positional embeddings, token embeddings
|
||||
in transformers.
|
||||
"""
|
||||
|
||||
def __init__(self, weight, indices):
|
||||
"""
|
||||
Args:
|
||||
weight: Embedding weight matrix
|
||||
indices: Indices used for lookup
|
||||
"""
|
||||
super().__init__(weight)
|
||||
self.indices = indices
|
||||
|
||||
def apply(self, grad_output):
|
||||
"""
|
||||
Compute gradient for embedding lookup.
|
||||
|
||||
Args:
|
||||
grad_output: Gradient flowing backward from output
|
||||
|
||||
Returns:
|
||||
Tuple with single gradient for weight tensor
|
||||
|
||||
**Mathematical Foundation:**
|
||||
- ∂(Embedding[indices])/∂Embedding = scatter gradients to selected rows
|
||||
- Multiple indices can point to same embedding → gradients accumulate
|
||||
"""
|
||||
weight, = self.saved_tensors
|
||||
grad_weight = None
|
||||
|
||||
if isinstance(weight, Tensor) and weight.requires_grad:
|
||||
# Initialize gradient with zeros
|
||||
grad_weight = np.zeros_like(weight.data)
|
||||
|
||||
# Scatter gradients back to embedding weights
|
||||
# np.add.at accumulates gradients for repeated indices
|
||||
indices_flat = self.indices.data.astype(int).flatten()
|
||||
grad_output_reshaped = grad_output.reshape(-1, grad_output.shape[-1])
|
||||
|
||||
np.add.at(grad_weight, indices_flat, grad_output_reshaped)
|
||||
|
||||
return (grad_weight,)
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 20
|
||||
class ReshapeBackward(Function):
|
||||
"""
|
||||
Gradient computation for reshape operation.
|
||||
|
||||
**Mathematical Rule:** If Y = X.reshape(new_shape), then:
|
||||
- ∂Y/∂X = grad_Y.reshape(X.shape)
|
||||
|
||||
**Key Insight:** Reshape just rearranges the same elements.
|
||||
The gradient is simply reshaped back to the original shape!
|
||||
|
||||
**Applications:** Flattening tensors for linear layers, reshaping
|
||||
between convolutional and dense layers.
|
||||
"""
|
||||
|
||||
def __init__(self, tensor, original_shape):
|
||||
"""
|
||||
Args:
|
||||
tensor: Input tensor
|
||||
original_shape: Shape before reshape
|
||||
"""
|
||||
super().__init__(tensor)
|
||||
self.original_shape = original_shape
|
||||
|
||||
def apply(self, grad_output):
|
||||
"""
|
||||
Compute gradient for reshape.
|
||||
|
||||
Args:
|
||||
grad_output: Gradient flowing backward from output
|
||||
|
||||
Returns:
|
||||
Tuple with single gradient for input tensor
|
||||
|
||||
**Mathematical Foundation:**
|
||||
- ∂(X.reshape(...))/∂X = grad_output.reshape(X.shape)
|
||||
- Just reshape the gradient back!
|
||||
"""
|
||||
x, = self.saved_tensors
|
||||
grad_x = None
|
||||
|
||||
if isinstance(x, Tensor) and x.requires_grad:
|
||||
# Reshape gradient back to original shape
|
||||
grad_x = grad_output.reshape(self.original_shape)
|
||||
|
||||
return (grad_x,)
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 22
|
||||
class SumBackward(Function):
|
||||
"""
|
||||
Gradient computation for tensor sum.
|
||||
@@ -374,7 +475,7 @@ class SumBackward(Function):
|
||||
return np.ones_like(tensor.data) * grad_output,
|
||||
return None,
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 25
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 27
|
||||
class ReLUBackward(Function):
|
||||
"""
|
||||
Gradient computation for ReLU activation.
|
||||
@@ -397,7 +498,7 @@ class ReLUBackward(Function):
|
||||
return grad_output * relu_grad,
|
||||
return None,
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 26
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 28
|
||||
class SigmoidBackward(Function):
|
||||
"""
|
||||
Gradient computation for sigmoid activation.
|
||||
@@ -427,7 +528,7 @@ class SigmoidBackward(Function):
|
||||
return grad_output * sigmoid_grad,
|
||||
return None,
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 27
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 29
|
||||
class MSEBackward(Function):
|
||||
"""
|
||||
Gradient computation for Mean Squared Error Loss.
|
||||
@@ -453,7 +554,7 @@ class MSEBackward(Function):
|
||||
return grad * grad_output,
|
||||
return None,
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 28
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 30
|
||||
class BCEBackward(Function):
|
||||
"""
|
||||
Gradient computation for Binary Cross-Entropy Loss.
|
||||
@@ -483,7 +584,7 @@ class BCEBackward(Function):
|
||||
return grad * grad_output,
|
||||
return None,
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 29
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 31
|
||||
class CrossEntropyBackward(Function):
|
||||
"""
|
||||
Gradient computation for Cross-Entropy Loss.
|
||||
@@ -528,7 +629,7 @@ class CrossEntropyBackward(Function):
|
||||
return grad * grad_output,
|
||||
return None,
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 30
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 32
|
||||
def enable_autograd():
|
||||
"""
|
||||
Enable gradient tracking for all Tensor operations.
|
||||
@@ -570,6 +671,7 @@ def enable_autograd():
|
||||
_original_div = Tensor.__truediv__
|
||||
_original_matmul = Tensor.matmul if hasattr(Tensor, 'matmul') else None
|
||||
_original_transpose = Tensor.transpose if hasattr(Tensor, 'transpose') else None
|
||||
_original_reshape = Tensor.reshape if hasattr(Tensor, 'reshape') else None
|
||||
|
||||
# Enhanced operations that track gradients
|
||||
def tracked_add(self, other):
|
||||
@@ -664,6 +766,28 @@ def enable_autograd():
|
||||
|
||||
return result
|
||||
|
||||
def tracked_reshape(self, *shape):
|
||||
"""
|
||||
Reshape with gradient tracking.
|
||||
|
||||
Enhances the original reshape method to build computation graphs
|
||||
when requires_grad=True for the input.
|
||||
"""
|
||||
original_shape = self.shape
|
||||
|
||||
if _original_reshape:
|
||||
result = _original_reshape(self, *shape)
|
||||
else:
|
||||
# Fallback if reshape doesn't exist
|
||||
result = Tensor(self.data.reshape(*shape))
|
||||
|
||||
# Track gradient if needed
|
||||
if self.requires_grad:
|
||||
result.requires_grad = True
|
||||
result._grad_fn = ReshapeBackward(self, original_shape)
|
||||
|
||||
return result
|
||||
|
||||
def tracked_sub(self, other):
|
||||
"""
|
||||
Subtraction with gradient tracking.
|
||||
@@ -799,6 +923,7 @@ def enable_autograd():
|
||||
Tensor.__truediv__ = tracked_div
|
||||
Tensor.matmul = tracked_matmul
|
||||
Tensor.transpose = tracked_transpose
|
||||
Tensor.reshape = tracked_reshape
|
||||
Tensor.sum = sum_op
|
||||
Tensor.backward = backward
|
||||
Tensor.zero_grad = zero_grad
|
||||
|
||||
11
tinytorch/text/embeddings.py
generated
11
tinytorch/text/embeddings.py
generated
@@ -95,8 +95,15 @@ class Embedding:
|
||||
# This is equivalent to one-hot multiplication but much more efficient
|
||||
embedded = self.weight.data[indices.data.astype(int)]
|
||||
|
||||
# Preserve requires_grad so autograd can track this operation!
|
||||
return Tensor(embedded, requires_grad=self.weight.requires_grad)
|
||||
# Create result tensor
|
||||
result = Tensor(embedded, requires_grad=self.weight.requires_grad)
|
||||
|
||||
# Attach gradient function (students learned this in Module 05!)
|
||||
if self.weight.requires_grad:
|
||||
from tinytorch.core.autograd import EmbeddingBackward
|
||||
result._grad_fn = EmbeddingBackward(self.weight, indices)
|
||||
|
||||
return result
|
||||
|
||||
def parameters(self) -> List[Tensor]:
|
||||
"""Return trainable parameters."""
|
||||
|
||||
Reference in New Issue
Block a user