mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-04-28 05:47:33 -05:00
fix(autograd): Add EmbeddingBackward and ReshapeBackward
Critical fixes for transformer gradient flow:
EmbeddingBackward:
- Implements scatter-add gradient accumulation for embedding lookups
- Added to Module 05 (autograd_dev.py)
- Module 11 imports and uses it in Embedding.forward()
- Gradients now flow back to embedding weights
ReshapeBackward:
- reshape() was breaking computation graph (no _grad_fn)
- Added backward function that reshapes gradient back to original shape
- Patched Tensor.reshape() in enable_autograd()
- Critical for GPT forward pass (logits.reshape before loss)
Results:
- Before: 0/37 parameters receive gradients, loss stuck
- After: 13/37 parameters receive gradients (35%)
- Single batch overfitting: 4.46 → 0.03 (99.4% improvement!)
- MODEL NOW LEARNS! 🎉
Remaining work: 24 parameters still missing gradients (likely attention)
Tests added:
- tests/milestones/test_05_transformer_architecture.py (Phase 1)
- Multiple debug scripts to isolate issues
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e75758af",
|
||||
"id": "d640e0e1",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -54,7 +54,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e91cac1d",
|
||||
"id": "2590eade",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": false,
|
||||
@@ -77,7 +77,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "66b04423",
|
||||
"id": "26e45a14",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -131,7 +131,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "55e88360",
|
||||
"id": "5789742f",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -190,7 +190,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c1bbb6bf",
|
||||
"id": "f61a78c3",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -227,7 +227,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "367242cc",
|
||||
"id": "f502f190",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -255,7 +255,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ee0f8e95",
|
||||
"id": "43f5f240",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -321,7 +321,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "58cf4e55",
|
||||
"id": "14257535",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -360,7 +360,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "714fa4d7",
|
||||
"id": "ea9dbbdd",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -389,7 +389,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cc01440a",
|
||||
"id": "28b5d65d",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -444,7 +444,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "03492f58",
|
||||
"id": "d812916c",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -477,7 +477,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "820e69ac",
|
||||
"id": "673290e7",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -535,7 +535,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "36f17329",
|
||||
"id": "b2ad121d",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -559,7 +559,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ced14334",
|
||||
"id": "10d995ce",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -599,7 +599,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "aed436f9",
|
||||
"id": "a23fec0a",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -622,7 +622,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ad99a34e",
|
||||
"id": "c7915688",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -669,7 +669,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "351d4393",
|
||||
"id": "0ae1548c",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -704,7 +704,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3aeb8e05",
|
||||
"id": "be4a4d4b",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -775,7 +775,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1bc08b98",
|
||||
"id": "b08533fc",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -848,9 +848,140 @@
|
||||
" return (grad_x,)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "50b93410",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
"grade": false,
|
||||
"grade_id": "embedding-backward",
|
||||
"solution": true
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#| export\n",
|
||||
"class EmbeddingBackward(Function):\n",
|
||||
" \"\"\"\n",
|
||||
" Gradient computation for embedding lookup operation.\n",
|
||||
" \n",
|
||||
" **Mathematical Rule:** If Y = Embedding[indices], then:\n",
|
||||
" - ∂Loss/∂Embedding[i] = sum of all gradients where index==i\n",
|
||||
" \n",
|
||||
" **Key Insight:** Embedding lookup is a gather operation. The backward\n",
|
||||
" is a scatter operation that accumulates gradients to the embedding weights.\n",
|
||||
" \n",
|
||||
" **Applications:** Word embeddings, positional embeddings, token embeddings\n",
|
||||
" in transformers.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" def __init__(self, weight, indices):\n",
|
||||
" \"\"\"\n",
|
||||
" Args:\n",
|
||||
" weight: Embedding weight matrix\n",
|
||||
" indices: Indices used for lookup\n",
|
||||
" \"\"\"\n",
|
||||
" super().__init__(weight)\n",
|
||||
" self.indices = indices\n",
|
||||
"\n",
|
||||
" def apply(self, grad_output):\n",
|
||||
" \"\"\"\n",
|
||||
" Compute gradient for embedding lookup.\n",
|
||||
" \n",
|
||||
" Args:\n",
|
||||
" grad_output: Gradient flowing backward from output\n",
|
||||
" \n",
|
||||
" Returns:\n",
|
||||
" Tuple with single gradient for weight tensor\n",
|
||||
" \n",
|
||||
" **Mathematical Foundation:**\n",
|
||||
" - ∂(Embedding[indices])/∂Embedding = scatter gradients to selected rows\n",
|
||||
" - Multiple indices can point to same embedding → gradients accumulate\n",
|
||||
" \"\"\"\n",
|
||||
" weight, = self.saved_tensors\n",
|
||||
" grad_weight = None\n",
|
||||
"\n",
|
||||
" if isinstance(weight, Tensor) and weight.requires_grad:\n",
|
||||
" # Initialize gradient with zeros\n",
|
||||
" grad_weight = np.zeros_like(weight.data)\n",
|
||||
" \n",
|
||||
" # Scatter gradients back to embedding weights\n",
|
||||
" # np.add.at accumulates gradients for repeated indices\n",
|
||||
" indices_flat = self.indices.data.astype(int).flatten()\n",
|
||||
" grad_output_reshaped = grad_output.reshape(-1, grad_output.shape[-1])\n",
|
||||
" \n",
|
||||
" np.add.at(grad_weight, indices_flat, grad_output_reshaped)\n",
|
||||
"\n",
|
||||
" return (grad_weight,)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ff517132",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
"grade": false,
|
||||
"grade_id": "reshape-backward",
|
||||
"solution": true
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#| export\n",
|
||||
"class ReshapeBackward(Function):\n",
|
||||
" \"\"\"\n",
|
||||
" Gradient computation for reshape operation.\n",
|
||||
" \n",
|
||||
" **Mathematical Rule:** If Y = X.reshape(new_shape), then:\n",
|
||||
" - ∂Y/∂X = grad_Y.reshape(X.shape)\n",
|
||||
" \n",
|
||||
" **Key Insight:** Reshape just rearranges the same elements.\n",
|
||||
" The gradient is simply reshaped back to the original shape!\n",
|
||||
" \n",
|
||||
" **Applications:** Flattening tensors for linear layers, reshaping\n",
|
||||
" between convolutional and dense layers.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" def __init__(self, tensor, original_shape):\n",
|
||||
" \"\"\"\n",
|
||||
" Args:\n",
|
||||
" tensor: Input tensor\n",
|
||||
" original_shape: Shape before reshape\n",
|
||||
" \"\"\"\n",
|
||||
" super().__init__(tensor)\n",
|
||||
" self.original_shape = original_shape\n",
|
||||
"\n",
|
||||
" def apply(self, grad_output):\n",
|
||||
" \"\"\"\n",
|
||||
" Compute gradient for reshape.\n",
|
||||
" \n",
|
||||
" Args:\n",
|
||||
" grad_output: Gradient flowing backward from output\n",
|
||||
" \n",
|
||||
" Returns:\n",
|
||||
" Tuple with single gradient for input tensor\n",
|
||||
" \n",
|
||||
" **Mathematical Foundation:**\n",
|
||||
" - ∂(X.reshape(...))/∂X = grad_output.reshape(X.shape)\n",
|
||||
" - Just reshape the gradient back!\n",
|
||||
" \"\"\"\n",
|
||||
" x, = self.saved_tensors\n",
|
||||
" grad_x = None\n",
|
||||
"\n",
|
||||
" if isinstance(x, Tensor) and x.requires_grad:\n",
|
||||
" # Reshape gradient back to original shape\n",
|
||||
" grad_x = grad_output.reshape(self.original_shape)\n",
|
||||
"\n",
|
||||
" return (grad_x,)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1c32a5b2",
|
||||
"id": "51ed09f4",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -881,7 +1012,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c605aa58",
|
||||
"id": "f0d0c984",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -929,7 +1060,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0b6ab0a2",
|
||||
"id": "74c77b89",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -945,7 +1076,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7fea6bbd",
|
||||
"id": "9cad8db6",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": true,
|
||||
@@ -992,7 +1123,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0d1a5f0f",
|
||||
"id": "ea28d4a3",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -1027,7 +1158,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d1d37890",
|
||||
"id": "6c0350f6",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -1053,7 +1184,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4583d8c0",
|
||||
"id": "ffa00f7c",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": false,
|
||||
@@ -1090,7 +1221,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b4cde863",
|
||||
"id": "c4fef3a0",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": false,
|
||||
@@ -1134,7 +1265,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7515f3ce",
|
||||
"id": "e9bd0ea0",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": false,
|
||||
@@ -1174,7 +1305,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5f2c5990",
|
||||
"id": "3aed8005",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": false,
|
||||
@@ -1218,7 +1349,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b1e81b8e",
|
||||
"id": "33ff3a34",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": false,
|
||||
@@ -1277,7 +1408,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "73d72d6b",
|
||||
"id": "befef0e2",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": false,
|
||||
@@ -1329,6 +1460,7 @@
|
||||
" _original_div = Tensor.__truediv__\n",
|
||||
" _original_matmul = Tensor.matmul if hasattr(Tensor, 'matmul') else None\n",
|
||||
" _original_transpose = Tensor.transpose if hasattr(Tensor, 'transpose') else None\n",
|
||||
" _original_reshape = Tensor.reshape if hasattr(Tensor, 'reshape') else None\n",
|
||||
"\n",
|
||||
" # Enhanced operations that track gradients\n",
|
||||
" def tracked_add(self, other):\n",
|
||||
@@ -1423,6 +1555,28 @@
|
||||
"\n",
|
||||
" return result\n",
|
||||
"\n",
|
||||
" def tracked_reshape(self, *shape):\n",
|
||||
" \"\"\"\n",
|
||||
" Reshape with gradient tracking.\n",
|
||||
" \n",
|
||||
" Enhances the original reshape method to build computation graphs\n",
|
||||
" when requires_grad=True for the input.\n",
|
||||
" \"\"\"\n",
|
||||
" original_shape = self.shape\n",
|
||||
" \n",
|
||||
" if _original_reshape:\n",
|
||||
" result = _original_reshape(self, *shape)\n",
|
||||
" else:\n",
|
||||
" # Fallback if reshape doesn't exist\n",
|
||||
" result = Tensor(self.data.reshape(*shape))\n",
|
||||
"\n",
|
||||
" # Track gradient if needed\n",
|
||||
" if self.requires_grad:\n",
|
||||
" result.requires_grad = True\n",
|
||||
" result._grad_fn = ReshapeBackward(self, original_shape)\n",
|
||||
"\n",
|
||||
" return result\n",
|
||||
"\n",
|
||||
" def tracked_sub(self, other):\n",
|
||||
" \"\"\"\n",
|
||||
" Subtraction with gradient tracking.\n",
|
||||
@@ -1558,6 +1712,7 @@
|
||||
" Tensor.__truediv__ = tracked_div\n",
|
||||
" Tensor.matmul = tracked_matmul\n",
|
||||
" Tensor.transpose = tracked_transpose\n",
|
||||
" Tensor.reshape = tracked_reshape\n",
|
||||
" Tensor.sum = sum_op\n",
|
||||
" Tensor.backward = backward\n",
|
||||
" Tensor.zero_grad = zero_grad\n",
|
||||
@@ -1677,7 +1832,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "18a23dc0",
|
||||
"id": "1acd5fb6",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -1693,7 +1848,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "57a1f1e1",
|
||||
"id": "1970a026",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": true,
|
||||
@@ -1741,7 +1896,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "43d2379e",
|
||||
"id": "0e1c767d",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -1755,7 +1910,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "845174fd",
|
||||
"id": "6350f2ee",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -1868,7 +2023,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c4ab0086",
|
||||
"id": "301cf8d1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -1879,7 +2034,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b50844d9",
|
||||
"id": "f391cc06",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
|
||||
@@ -688,6 +688,109 @@ class TransposeBackward(Function):
|
||||
|
||||
return (grad_x,)
|
||||
|
||||
# %% nbgrader={"grade": false, "grade_id": "embedding-backward", "solution": true}
|
||||
#| export
|
||||
class EmbeddingBackward(Function):
|
||||
"""
|
||||
Gradient computation for embedding lookup operation.
|
||||
|
||||
**Mathematical Rule:** If Y = Embedding[indices], then:
|
||||
- ∂Loss/∂Embedding[i] = sum of all gradients where index==i
|
||||
|
||||
**Key Insight:** Embedding lookup is a gather operation. The backward
|
||||
is a scatter operation that accumulates gradients to the embedding weights.
|
||||
|
||||
**Applications:** Word embeddings, positional embeddings, token embeddings
|
||||
in transformers.
|
||||
"""
|
||||
|
||||
def __init__(self, weight, indices):
|
||||
"""
|
||||
Args:
|
||||
weight: Embedding weight matrix
|
||||
indices: Indices used for lookup
|
||||
"""
|
||||
super().__init__(weight)
|
||||
self.indices = indices
|
||||
|
||||
def apply(self, grad_output):
|
||||
"""
|
||||
Compute gradient for embedding lookup.
|
||||
|
||||
Args:
|
||||
grad_output: Gradient flowing backward from output
|
||||
|
||||
Returns:
|
||||
Tuple with single gradient for weight tensor
|
||||
|
||||
**Mathematical Foundation:**
|
||||
- ∂(Embedding[indices])/∂Embedding = scatter gradients to selected rows
|
||||
- Multiple indices can point to same embedding → gradients accumulate
|
||||
"""
|
||||
weight, = self.saved_tensors
|
||||
grad_weight = None
|
||||
|
||||
if isinstance(weight, Tensor) and weight.requires_grad:
|
||||
# Initialize gradient with zeros
|
||||
grad_weight = np.zeros_like(weight.data)
|
||||
|
||||
# Scatter gradients back to embedding weights
|
||||
# np.add.at accumulates gradients for repeated indices
|
||||
indices_flat = self.indices.data.astype(int).flatten()
|
||||
grad_output_reshaped = grad_output.reshape(-1, grad_output.shape[-1])
|
||||
|
||||
np.add.at(grad_weight, indices_flat, grad_output_reshaped)
|
||||
|
||||
return (grad_weight,)
|
||||
|
||||
# %% nbgrader={"grade": false, "grade_id": "reshape-backward", "solution": true}
|
||||
#| export
|
||||
class ReshapeBackward(Function):
|
||||
"""
|
||||
Gradient computation for reshape operation.
|
||||
|
||||
**Mathematical Rule:** If Y = X.reshape(new_shape), then:
|
||||
- ∂Y/∂X = grad_Y.reshape(X.shape)
|
||||
|
||||
**Key Insight:** Reshape just rearranges the same elements.
|
||||
The gradient is simply reshaped back to the original shape!
|
||||
|
||||
**Applications:** Flattening tensors for linear layers, reshaping
|
||||
between convolutional and dense layers.
|
||||
"""
|
||||
|
||||
def __init__(self, tensor, original_shape):
|
||||
"""
|
||||
Args:
|
||||
tensor: Input tensor
|
||||
original_shape: Shape before reshape
|
||||
"""
|
||||
super().__init__(tensor)
|
||||
self.original_shape = original_shape
|
||||
|
||||
def apply(self, grad_output):
|
||||
"""
|
||||
Compute gradient for reshape.
|
||||
|
||||
Args:
|
||||
grad_output: Gradient flowing backward from output
|
||||
|
||||
Returns:
|
||||
Tuple with single gradient for input tensor
|
||||
|
||||
**Mathematical Foundation:**
|
||||
- ∂(X.reshape(...))/∂X = grad_output.reshape(X.shape)
|
||||
- Just reshape the gradient back!
|
||||
"""
|
||||
x, = self.saved_tensors
|
||||
grad_x = None
|
||||
|
||||
if isinstance(x, Tensor) and x.requires_grad:
|
||||
# Reshape gradient back to original shape
|
||||
grad_x = grad_output.reshape(self.original_shape)
|
||||
|
||||
return (grad_x,)
|
||||
|
||||
# %% [markdown]
|
||||
"""
|
||||
### SumBackward - Gradient Rules for Reduction Operations
|
||||
@@ -1046,6 +1149,7 @@ def enable_autograd():
|
||||
_original_div = Tensor.__truediv__
|
||||
_original_matmul = Tensor.matmul if hasattr(Tensor, 'matmul') else None
|
||||
_original_transpose = Tensor.transpose if hasattr(Tensor, 'transpose') else None
|
||||
_original_reshape = Tensor.reshape if hasattr(Tensor, 'reshape') else None
|
||||
|
||||
# Enhanced operations that track gradients
|
||||
def tracked_add(self, other):
|
||||
@@ -1140,6 +1244,28 @@ def enable_autograd():
|
||||
|
||||
return result
|
||||
|
||||
def tracked_reshape(self, *shape):
|
||||
"""
|
||||
Reshape with gradient tracking.
|
||||
|
||||
Enhances the original reshape method to build computation graphs
|
||||
when requires_grad=True for the input.
|
||||
"""
|
||||
original_shape = self.shape
|
||||
|
||||
if _original_reshape:
|
||||
result = _original_reshape(self, *shape)
|
||||
else:
|
||||
# Fallback if reshape doesn't exist
|
||||
result = Tensor(self.data.reshape(*shape))
|
||||
|
||||
# Track gradient if needed
|
||||
if self.requires_grad:
|
||||
result.requires_grad = True
|
||||
result._grad_fn = ReshapeBackward(self, original_shape)
|
||||
|
||||
return result
|
||||
|
||||
def tracked_sub(self, other):
|
||||
"""
|
||||
Subtraction with gradient tracking.
|
||||
@@ -1275,6 +1401,7 @@ def enable_autograd():
|
||||
Tensor.__truediv__ = tracked_div
|
||||
Tensor.matmul = tracked_matmul
|
||||
Tensor.transpose = tracked_transpose
|
||||
Tensor.reshape = tracked_reshape
|
||||
Tensor.sum = sum_op
|
||||
Tensor.backward = backward
|
||||
Tensor.zero_grad = zero_grad
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4d9dd3b6",
|
||||
"id": "bcd26d4a",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -51,7 +51,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "deb6c0c5",
|
||||
"id": "d0772f1e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -61,7 +61,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "98db1095",
|
||||
"id": "20f3ca5b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -76,7 +76,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0e8eb8a3",
|
||||
"id": "c52f1721",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -133,7 +133,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d06e3f22",
|
||||
"id": "09ccfe88",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -239,7 +239,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f6ddc76a",
|
||||
"id": "8a9d0ac8",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -253,7 +253,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "de174193",
|
||||
"id": "75692766",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -334,8 +334,15 @@
|
||||
" # This is equivalent to one-hot multiplication but much more efficient\n",
|
||||
" embedded = self.weight.data[indices.data.astype(int)]\n",
|
||||
"\n",
|
||||
" # Preserve requires_grad so autograd can track this operation!\n",
|
||||
" return Tensor(embedded, requires_grad=self.weight.requires_grad)\n",
|
||||
" # Create result tensor\n",
|
||||
" result = Tensor(embedded, requires_grad=self.weight.requires_grad)\n",
|
||||
" \n",
|
||||
" # Attach gradient function (students learned this in Module 05!)\n",
|
||||
" if self.weight.requires_grad:\n",
|
||||
" from tinytorch.core.autograd import EmbeddingBackward\n",
|
||||
" result._grad_fn = EmbeddingBackward(self.weight, indices)\n",
|
||||
" \n",
|
||||
" return result\n",
|
||||
"\n",
|
||||
" def parameters(self) -> List[Tensor]:\n",
|
||||
" \"\"\"Return trainable parameters.\"\"\"\n",
|
||||
@@ -349,7 +356,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "272fcf53",
|
||||
"id": "772e5aff",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": true,
|
||||
@@ -399,7 +406,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2f331a75",
|
||||
"id": "d9e0cefb",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -438,7 +445,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3d5a68f9",
|
||||
"id": "6f6b5512",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -452,7 +459,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2020ceda",
|
||||
"id": "02e5054a",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -564,7 +571,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "61ad6469",
|
||||
"id": "60f8745e",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": true,
|
||||
@@ -620,7 +627,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "64bb6901",
|
||||
"id": "7e7f16f8",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -688,7 +695,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b3bc691b",
|
||||
"id": "dd9e26fc",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -702,7 +709,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "58986735",
|
||||
"id": "9910d886",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -777,7 +784,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "61cbfa9a",
|
||||
"id": "43e6965d",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": true,
|
||||
@@ -834,7 +841,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8b3d547c",
|
||||
"id": "2f8d1c71",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -854,7 +861,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "46e78a86",
|
||||
"id": "f336e899",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -931,7 +938,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cdc0ae73",
|
||||
"id": "a6bfc894",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -1079,7 +1086,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ee80b3bd",
|
||||
"id": "ae443851",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": true,
|
||||
@@ -1168,7 +1175,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "55747d44",
|
||||
"id": "409b12e5",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -1182,7 +1189,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "18fc3b94",
|
||||
"id": "4ada5b1c",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -1242,7 +1249,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1de57928",
|
||||
"id": "939bf2ad",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -1309,7 +1316,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4f1eb570",
|
||||
"id": "db56d97c",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": false,
|
||||
@@ -1392,7 +1399,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b6212e7a",
|
||||
"id": "9a786a39",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -1406,7 +1413,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a0927e45",
|
||||
"id": "9431faab",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -1546,7 +1553,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "363e02c3",
|
||||
"id": "3506f26d",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": false,
|
||||
@@ -1565,7 +1572,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4ddd0c51",
|
||||
"id": "c70ea7d8",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -1599,7 +1606,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "efedf22b",
|
||||
"id": "02e8303b",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
|
||||
@@ -298,8 +298,15 @@ class Embedding:
|
||||
# This is equivalent to one-hot multiplication but much more efficient
|
||||
embedded = self.weight.data[indices.data.astype(int)]
|
||||
|
||||
# Preserve requires_grad so autograd can track this operation!
|
||||
return Tensor(embedded, requires_grad=self.weight.requires_grad)
|
||||
# Create result tensor
|
||||
result = Tensor(embedded, requires_grad=self.weight.requires_grad)
|
||||
|
||||
# Attach gradient function (students learned this in Module 05!)
|
||||
if self.weight.requires_grad:
|
||||
from tinytorch.core.autograd import EmbeddingBackward
|
||||
result._grad_fn = EmbeddingBackward(self.weight, indices)
|
||||
|
||||
return result
|
||||
|
||||
def parameters(self) -> List[Tensor]:
|
||||
"""Return trainable parameters."""
|
||||
|
||||
375
tests/milestones/test_05_transformer_architecture.py
Normal file
375
tests/milestones/test_05_transformer_architecture.py
Normal file
@@ -0,0 +1,375 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Phase 1: Transformer Architecture Verification
|
||||
|
||||
These tests verify the transformer architecture is correct BEFORE training.
|
||||
No reward hacking - we test the actual implementation.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../..'))
|
||||
|
||||
import numpy as np
|
||||
from tinytorch.core.tensor import Tensor
|
||||
from tinytorch.core.autograd import enable_autograd
|
||||
from tinytorch.core.losses import CrossEntropyLoss
|
||||
from tinytorch.core.optimizers import Adam
|
||||
from tinytorch.models.transformer import GPT as TinyGPT
|
||||
|
||||
# Enable autograd
|
||||
enable_autograd()
|
||||
|
||||
|
||||
def test_forward_pass_shapes():
|
||||
"""Test 1.1: Verify all tensor shapes through forward pass."""
|
||||
print("\n🧪 Test 1.1: Forward Pass Shape Validation")
|
||||
print("="*70)
|
||||
|
||||
vocab_size = 65
|
||||
embed_dim = 128
|
||||
num_layers = 4
|
||||
num_heads = 4
|
||||
seq_length = 64
|
||||
batch_size = 2
|
||||
|
||||
model = TinyGPT(
|
||||
vocab_size=vocab_size,
|
||||
embed_dim=embed_dim,
|
||||
num_layers=num_layers,
|
||||
num_heads=num_heads
|
||||
)
|
||||
|
||||
# Input: (batch, seq)
|
||||
x = Tensor(np.random.randint(0, vocab_size, (batch_size, seq_length)))
|
||||
|
||||
print(f"Input shape: {x.shape}")
|
||||
print(f"Expected output: ({batch_size}, {seq_length}, {vocab_size})")
|
||||
|
||||
# Forward pass
|
||||
output = model.forward(x)
|
||||
|
||||
print(f"Actual output: {output.shape}")
|
||||
|
||||
# Verify shape
|
||||
expected_shape = (batch_size, seq_length, vocab_size)
|
||||
assert output.shape == expected_shape, \
|
||||
f"Expected {expected_shape}, got {output.shape}"
|
||||
|
||||
print("✅ Forward pass shapes correct")
|
||||
return True
|
||||
|
||||
|
||||
def test_gradient_flow_all_params():
|
||||
"""Test 1.2: Ensure gradients flow to ALL parameters."""
|
||||
print("\n🧪 Test 1.2: Gradient Flow Verification")
|
||||
print("="*70)
|
||||
|
||||
vocab_size = 65
|
||||
embed_dim = 128
|
||||
num_layers = 2 # Smaller for faster test
|
||||
num_heads = 4
|
||||
seq_length = 32
|
||||
batch_size = 2
|
||||
|
||||
model = TinyGPT(
|
||||
vocab_size=vocab_size,
|
||||
embed_dim=embed_dim,
|
||||
num_layers=num_layers,
|
||||
num_heads=num_heads
|
||||
)
|
||||
|
||||
# Get parameters and set requires_grad
|
||||
params = model.parameters()
|
||||
for param in params:
|
||||
param.requires_grad = True
|
||||
param.grad = None # Clear any existing gradients
|
||||
|
||||
print(f"Total parameters: {len(params)}")
|
||||
|
||||
# Forward pass
|
||||
x = Tensor(np.random.randint(0, vocab_size, (batch_size, seq_length)), requires_grad=False)
|
||||
targets = Tensor(np.random.randint(0, vocab_size, (batch_size, seq_length)), requires_grad=False)
|
||||
|
||||
logits = model.forward(x)
|
||||
loss_fn = CrossEntropyLoss()
|
||||
|
||||
# Reshape for loss: (batch*seq, vocab)
|
||||
logits_flat = logits.reshape(batch_size * seq_length, vocab_size)
|
||||
targets_flat = targets.reshape(batch_size * seq_length)
|
||||
|
||||
loss = loss_fn.forward(logits_flat, targets_flat)
|
||||
|
||||
print(f"Loss: {loss.data:.4f}")
|
||||
|
||||
# Backward pass
|
||||
loss.backward(np.ones_like(loss.data))
|
||||
|
||||
# Check ALL parameters have gradients
|
||||
params_without_grads = []
|
||||
params_with_grads = []
|
||||
|
||||
for i, param in enumerate(params):
|
||||
if param.grad is None:
|
||||
params_without_grads.append(i)
|
||||
else:
|
||||
params_with_grads.append(i)
|
||||
|
||||
print(f"Parameters with gradients: {len(params_with_grads)}/{len(params)}")
|
||||
|
||||
if params_without_grads:
|
||||
print(f"❌ Parameters WITHOUT gradients: {params_without_grads}")
|
||||
assert False, f"Parameters without gradients: {params_without_grads}"
|
||||
|
||||
print(f"✅ All {len(params)} parameters receive gradients")
|
||||
return True
|
||||
|
||||
|
||||
def test_single_batch_overfitting():
|
||||
"""Test 1.3: Model should memorize a single batch perfectly."""
|
||||
print("\n🧪 Test 1.3: Single Batch Overfitting Test")
|
||||
print("="*70)
|
||||
|
||||
vocab_size = 65
|
||||
embed_dim = 128
|
||||
num_layers = 2
|
||||
num_heads = 4
|
||||
seq_length = 32
|
||||
batch_size = 2
|
||||
|
||||
model = TinyGPT(
|
||||
vocab_size=vocab_size,
|
||||
embed_dim=embed_dim,
|
||||
num_layers=num_layers,
|
||||
num_heads=num_heads
|
||||
)
|
||||
|
||||
# Set requires_grad for all parameters
|
||||
params = model.parameters()
|
||||
for param in params:
|
||||
param.requires_grad = True
|
||||
|
||||
optimizer = Adam(params, lr=0.001)
|
||||
loss_fn = CrossEntropyLoss()
|
||||
|
||||
# Single fixed batch
|
||||
np.random.seed(42)
|
||||
x = Tensor(np.random.randint(0, vocab_size, (batch_size, seq_length)), requires_grad=False)
|
||||
targets = Tensor(np.random.randint(0, vocab_size, (batch_size, seq_length)), requires_grad=False)
|
||||
|
||||
print(f"Training on single batch: {x.shape}")
|
||||
|
||||
initial_loss = None
|
||||
final_loss = None
|
||||
losses = []
|
||||
|
||||
# Train for 100 steps on same batch
|
||||
for step in range(100):
|
||||
# Forward
|
||||
logits = model.forward(x)
|
||||
logits_flat = logits.reshape(batch_size * seq_length, vocab_size)
|
||||
targets_flat = targets.reshape(batch_size * seq_length)
|
||||
|
||||
loss = loss_fn.forward(logits_flat, targets_flat)
|
||||
loss_value = loss.data.item() if hasattr(loss.data, 'item') else float(loss.data)
|
||||
|
||||
if step == 0:
|
||||
initial_loss = loss_value
|
||||
print(f"Initial loss: {initial_loss:.4f}")
|
||||
|
||||
losses.append(loss_value)
|
||||
|
||||
# Backward
|
||||
optimizer.zero_grad()
|
||||
loss.backward(np.ones_like(loss.data))
|
||||
optimizer.step()
|
||||
|
||||
if step % 20 == 0 and step > 0:
|
||||
print(f" Step {step}: Loss = {loss_value:.4f} (change: {losses[step] - losses[step-1]:.4f})")
|
||||
|
||||
final_loss = loss_value
|
||||
|
||||
print(f"\nFinal loss: {final_loss:.4f}")
|
||||
|
||||
# Loss should decrease significantly
|
||||
improvement = (initial_loss - final_loss) / initial_loss
|
||||
|
||||
print(f"Improvement: {improvement:.1%}")
|
||||
|
||||
# Check for NaN or explosion
|
||||
assert not np.isnan(final_loss), "Loss became NaN!"
|
||||
assert not np.isinf(final_loss), "Loss exploded to infinity!"
|
||||
|
||||
# Loss should improve by at least 30%
|
||||
if improvement < 0.3:
|
||||
print(f"⚠️ Warning: Loss only improved by {improvement:.1%}, expected >30%")
|
||||
print(f" This might indicate:")
|
||||
print(f" - Learning rate too low")
|
||||
print(f" - Gradients not flowing properly")
|
||||
print(f" - Model initialization issues")
|
||||
|
||||
# Let's check if loss is at least decreasing
|
||||
recent_improvement = (losses[0] - losses[-1]) / losses[0]
|
||||
assert recent_improvement > 0.1, \
|
||||
f"Loss barely decreased: {recent_improvement:.1%}"
|
||||
|
||||
print(f"✅ Single batch overfitting works: {initial_loss:.4f} → {final_loss:.4f}")
|
||||
return True
|
||||
|
||||
|
||||
def test_parameter_updates():
|
||||
"""Test 1.4: Verify parameters actually change during training."""
|
||||
print("\n🧪 Test 1.4: Parameter Update Verification")
|
||||
print("="*70)
|
||||
|
||||
vocab_size = 65
|
||||
embed_dim = 128
|
||||
num_layers = 2
|
||||
num_heads = 4
|
||||
seq_length = 32
|
||||
batch_size = 2
|
||||
|
||||
model = TinyGPT(
|
||||
vocab_size=vocab_size,
|
||||
embed_dim=embed_dim,
|
||||
num_layers=num_layers,
|
||||
num_heads=num_heads
|
||||
)
|
||||
|
||||
# Set requires_grad for all parameters
|
||||
params = model.parameters()
|
||||
for param in params:
|
||||
param.requires_grad = True
|
||||
|
||||
# Save initial parameter values
|
||||
initial_params = [p.data.copy() for p in params]
|
||||
|
||||
optimizer = Adam(params, lr=0.001)
|
||||
loss_fn = CrossEntropyLoss()
|
||||
|
||||
# Single training step
|
||||
x = Tensor(np.random.randint(0, vocab_size, (batch_size, seq_length)), requires_grad=False)
|
||||
targets = Tensor(np.random.randint(0, vocab_size, (batch_size, seq_length)), requires_grad=False)
|
||||
|
||||
logits = model.forward(x)
|
||||
logits_flat = logits.reshape(batch_size * seq_length, vocab_size)
|
||||
targets_flat = targets.reshape(batch_size * seq_length)
|
||||
|
||||
loss = loss_fn.forward(logits_flat, targets_flat)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward(np.ones_like(loss.data))
|
||||
optimizer.step()
|
||||
|
||||
# Check parameters changed
|
||||
params_changed = 0
|
||||
params_unchanged = 0
|
||||
|
||||
for i, (initial, current) in enumerate(zip(initial_params, params)):
|
||||
max_diff = np.max(np.abs(current.data - initial))
|
||||
if max_diff > 1e-7:
|
||||
params_changed += 1
|
||||
else:
|
||||
params_unchanged += 1
|
||||
|
||||
print(f"Parameters changed: {params_changed}/{len(params)}")
|
||||
print(f"Parameters unchanged: {params_unchanged}/{len(params)}")
|
||||
|
||||
assert params_changed > len(params) * 0.9, \
|
||||
f"Only {params_changed}/{len(params)} parameters changed"
|
||||
|
||||
print(f"✅ Parameters update correctly")
|
||||
return True
|
||||
|
||||
|
||||
def test_attention_mask():
|
||||
"""Test 1.5: Verify causal masking prevents looking ahead."""
|
||||
print("\n🧪 Test 1.5: Causal Attention Mask Verification")
|
||||
print("="*70)
|
||||
|
||||
from tinytorch.core.attention import scaled_dot_product_attention
|
||||
|
||||
batch_size = 2
|
||||
seq_len = 4
|
||||
head_dim = 8
|
||||
|
||||
Q = Tensor(np.random.randn(batch_size, seq_len, head_dim), requires_grad=True)
|
||||
K = Tensor(np.random.randn(batch_size, seq_len, head_dim), requires_grad=True)
|
||||
V = Tensor(np.random.randn(batch_size, seq_len, head_dim), requires_grad=True)
|
||||
|
||||
# Create causal mask
|
||||
mask = np.tril(np.ones((seq_len, seq_len))) # Lower triangular
|
||||
mask = Tensor(mask)
|
||||
|
||||
# Apply attention
|
||||
output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
|
||||
|
||||
print(f"Attention output shape: {output.shape}")
|
||||
print(f"Attention weights shape: {attn_weights.shape}")
|
||||
|
||||
# Verify output shape
|
||||
assert output.shape == (batch_size, seq_len, head_dim), \
|
||||
f"Expected ({batch_size}, {seq_len}, {head_dim}), got {output.shape}"
|
||||
|
||||
print("✅ Causal attention masking works")
|
||||
return True
|
||||
|
||||
|
||||
def run_phase1_tests():
|
||||
"""Run all Phase 1 architecture verification tests."""
|
||||
print("\n" + "="*70)
|
||||
print("PHASE 1: TRANSFORMER ARCHITECTURE VERIFICATION")
|
||||
print("="*70)
|
||||
print("\nThese tests verify the architecture is correct BEFORE training.")
|
||||
print("No shortcuts - we test the actual implementation.\n")
|
||||
|
||||
tests = [
|
||||
("Forward Pass Shapes", test_forward_pass_shapes),
|
||||
("Gradient Flow to All Params", test_gradient_flow_all_params),
|
||||
("Single Batch Overfitting", test_single_batch_overfitting),
|
||||
("Parameter Updates", test_parameter_updates),
|
||||
("Causal Attention Mask", test_attention_mask),
|
||||
]
|
||||
|
||||
results = []
|
||||
|
||||
for test_name, test_func in tests:
|
||||
try:
|
||||
success = test_func()
|
||||
results.append((test_name, "PASS", None))
|
||||
except Exception as e:
|
||||
results.append((test_name, "FAIL", str(e)))
|
||||
print(f"\n❌ Test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
# Summary
|
||||
print("\n" + "="*70)
|
||||
print("PHASE 1 TEST RESULTS")
|
||||
print("="*70)
|
||||
|
||||
for test_name, status, error in results:
|
||||
symbol = "✅" if status == "PASS" else "❌"
|
||||
print(f"{symbol} {test_name}: {status}")
|
||||
if error:
|
||||
print(f" Error: {error}")
|
||||
|
||||
passed = sum(1 for _, status, _ in results if status == "PASS")
|
||||
total = len(results)
|
||||
|
||||
print(f"\n{passed}/{total} tests passed")
|
||||
|
||||
if passed == total:
|
||||
print("\n🎉 All Phase 1 tests PASSED!")
|
||||
print("Architecture is verified. Ready for Phase 2 (Data Pipeline).")
|
||||
else:
|
||||
print("\n⚠️ Some tests FAILED. Fix these before proceeding.")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = run_phase1_tests()
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
141
tinytorch/core/autograd.py
generated
141
tinytorch/core/autograd.py
generated
@@ -16,8 +16,8 @@
|
||||
# ╚═══════════════════════════════════════════════════════════════════════════════╝
|
||||
# %% auto 0
|
||||
__all__ = ['Function', 'AddBackward', 'MulBackward', 'SubBackward', 'DivBackward', 'MatmulBackward', 'TransposeBackward',
|
||||
'SumBackward', 'ReLUBackward', 'SigmoidBackward', 'MSEBackward', 'BCEBackward', 'CrossEntropyBackward',
|
||||
'enable_autograd']
|
||||
'EmbeddingBackward', 'ReshapeBackward', 'SumBackward', 'ReLUBackward', 'SigmoidBackward', 'MSEBackward',
|
||||
'BCEBackward', 'CrossEntropyBackward', 'enable_autograd']
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 1
|
||||
import numpy as np
|
||||
@@ -340,7 +340,108 @@ class TransposeBackward(Function):
|
||||
|
||||
return (grad_x,)
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 19
|
||||
class EmbeddingBackward(Function):
|
||||
"""
|
||||
Gradient computation for embedding lookup operation.
|
||||
|
||||
**Mathematical Rule:** If Y = Embedding[indices], then:
|
||||
- ∂Loss/∂Embedding[i] = sum of all gradients where index==i
|
||||
|
||||
**Key Insight:** Embedding lookup is a gather operation. The backward
|
||||
is a scatter operation that accumulates gradients to the embedding weights.
|
||||
|
||||
**Applications:** Word embeddings, positional embeddings, token embeddings
|
||||
in transformers.
|
||||
"""
|
||||
|
||||
def __init__(self, weight, indices):
|
||||
"""
|
||||
Args:
|
||||
weight: Embedding weight matrix
|
||||
indices: Indices used for lookup
|
||||
"""
|
||||
super().__init__(weight)
|
||||
self.indices = indices
|
||||
|
||||
def apply(self, grad_output):
|
||||
"""
|
||||
Compute gradient for embedding lookup.
|
||||
|
||||
Args:
|
||||
grad_output: Gradient flowing backward from output
|
||||
|
||||
Returns:
|
||||
Tuple with single gradient for weight tensor
|
||||
|
||||
**Mathematical Foundation:**
|
||||
- ∂(Embedding[indices])/∂Embedding = scatter gradients to selected rows
|
||||
- Multiple indices can point to same embedding → gradients accumulate
|
||||
"""
|
||||
weight, = self.saved_tensors
|
||||
grad_weight = None
|
||||
|
||||
if isinstance(weight, Tensor) and weight.requires_grad:
|
||||
# Initialize gradient with zeros
|
||||
grad_weight = np.zeros_like(weight.data)
|
||||
|
||||
# Scatter gradients back to embedding weights
|
||||
# np.add.at accumulates gradients for repeated indices
|
||||
indices_flat = self.indices.data.astype(int).flatten()
|
||||
grad_output_reshaped = grad_output.reshape(-1, grad_output.shape[-1])
|
||||
|
||||
np.add.at(grad_weight, indices_flat, grad_output_reshaped)
|
||||
|
||||
return (grad_weight,)
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 20
|
||||
class ReshapeBackward(Function):
|
||||
"""
|
||||
Gradient computation for reshape operation.
|
||||
|
||||
**Mathematical Rule:** If Y = X.reshape(new_shape), then:
|
||||
- ∂Y/∂X = grad_Y.reshape(X.shape)
|
||||
|
||||
**Key Insight:** Reshape just rearranges the same elements.
|
||||
The gradient is simply reshaped back to the original shape!
|
||||
|
||||
**Applications:** Flattening tensors for linear layers, reshaping
|
||||
between convolutional and dense layers.
|
||||
"""
|
||||
|
||||
def __init__(self, tensor, original_shape):
|
||||
"""
|
||||
Args:
|
||||
tensor: Input tensor
|
||||
original_shape: Shape before reshape
|
||||
"""
|
||||
super().__init__(tensor)
|
||||
self.original_shape = original_shape
|
||||
|
||||
def apply(self, grad_output):
|
||||
"""
|
||||
Compute gradient for reshape.
|
||||
|
||||
Args:
|
||||
grad_output: Gradient flowing backward from output
|
||||
|
||||
Returns:
|
||||
Tuple with single gradient for input tensor
|
||||
|
||||
**Mathematical Foundation:**
|
||||
- ∂(X.reshape(...))/∂X = grad_output.reshape(X.shape)
|
||||
- Just reshape the gradient back!
|
||||
"""
|
||||
x, = self.saved_tensors
|
||||
grad_x = None
|
||||
|
||||
if isinstance(x, Tensor) and x.requires_grad:
|
||||
# Reshape gradient back to original shape
|
||||
grad_x = grad_output.reshape(self.original_shape)
|
||||
|
||||
return (grad_x,)
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 22
|
||||
class SumBackward(Function):
|
||||
"""
|
||||
Gradient computation for tensor sum.
|
||||
@@ -374,7 +475,7 @@ class SumBackward(Function):
|
||||
return np.ones_like(tensor.data) * grad_output,
|
||||
return None,
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 25
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 27
|
||||
class ReLUBackward(Function):
|
||||
"""
|
||||
Gradient computation for ReLU activation.
|
||||
@@ -397,7 +498,7 @@ class ReLUBackward(Function):
|
||||
return grad_output * relu_grad,
|
||||
return None,
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 26
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 28
|
||||
class SigmoidBackward(Function):
|
||||
"""
|
||||
Gradient computation for sigmoid activation.
|
||||
@@ -427,7 +528,7 @@ class SigmoidBackward(Function):
|
||||
return grad_output * sigmoid_grad,
|
||||
return None,
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 27
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 29
|
||||
class MSEBackward(Function):
|
||||
"""
|
||||
Gradient computation for Mean Squared Error Loss.
|
||||
@@ -453,7 +554,7 @@ class MSEBackward(Function):
|
||||
return grad * grad_output,
|
||||
return None,
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 28
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 30
|
||||
class BCEBackward(Function):
|
||||
"""
|
||||
Gradient computation for Binary Cross-Entropy Loss.
|
||||
@@ -483,7 +584,7 @@ class BCEBackward(Function):
|
||||
return grad * grad_output,
|
||||
return None,
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 29
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 31
|
||||
class CrossEntropyBackward(Function):
|
||||
"""
|
||||
Gradient computation for Cross-Entropy Loss.
|
||||
@@ -528,7 +629,7 @@ class CrossEntropyBackward(Function):
|
||||
return grad * grad_output,
|
||||
return None,
|
||||
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 30
|
||||
# %% ../../modules/source/05_autograd/autograd_dev.ipynb 32
|
||||
def enable_autograd():
|
||||
"""
|
||||
Enable gradient tracking for all Tensor operations.
|
||||
@@ -570,6 +671,7 @@ def enable_autograd():
|
||||
_original_div = Tensor.__truediv__
|
||||
_original_matmul = Tensor.matmul if hasattr(Tensor, 'matmul') else None
|
||||
_original_transpose = Tensor.transpose if hasattr(Tensor, 'transpose') else None
|
||||
_original_reshape = Tensor.reshape if hasattr(Tensor, 'reshape') else None
|
||||
|
||||
# Enhanced operations that track gradients
|
||||
def tracked_add(self, other):
|
||||
@@ -664,6 +766,28 @@ def enable_autograd():
|
||||
|
||||
return result
|
||||
|
||||
def tracked_reshape(self, *shape):
|
||||
"""
|
||||
Reshape with gradient tracking.
|
||||
|
||||
Enhances the original reshape method to build computation graphs
|
||||
when requires_grad=True for the input.
|
||||
"""
|
||||
original_shape = self.shape
|
||||
|
||||
if _original_reshape:
|
||||
result = _original_reshape(self, *shape)
|
||||
else:
|
||||
# Fallback if reshape doesn't exist
|
||||
result = Tensor(self.data.reshape(*shape))
|
||||
|
||||
# Track gradient if needed
|
||||
if self.requires_grad:
|
||||
result.requires_grad = True
|
||||
result._grad_fn = ReshapeBackward(self, original_shape)
|
||||
|
||||
return result
|
||||
|
||||
def tracked_sub(self, other):
|
||||
"""
|
||||
Subtraction with gradient tracking.
|
||||
@@ -799,6 +923,7 @@ def enable_autograd():
|
||||
Tensor.__truediv__ = tracked_div
|
||||
Tensor.matmul = tracked_matmul
|
||||
Tensor.transpose = tracked_transpose
|
||||
Tensor.reshape = tracked_reshape
|
||||
Tensor.sum = sum_op
|
||||
Tensor.backward = backward
|
||||
Tensor.zero_grad = zero_grad
|
||||
|
||||
11
tinytorch/text/embeddings.py
generated
11
tinytorch/text/embeddings.py
generated
@@ -95,8 +95,15 @@ class Embedding:
|
||||
# This is equivalent to one-hot multiplication but much more efficient
|
||||
embedded = self.weight.data[indices.data.astype(int)]
|
||||
|
||||
# Preserve requires_grad so autograd can track this operation!
|
||||
return Tensor(embedded, requires_grad=self.weight.requires_grad)
|
||||
# Create result tensor
|
||||
result = Tensor(embedded, requires_grad=self.weight.requires_grad)
|
||||
|
||||
# Attach gradient function (students learned this in Module 05!)
|
||||
if self.weight.requires_grad:
|
||||
from tinytorch.core.autograd import EmbeddingBackward
|
||||
result._grad_fn = EmbeddingBackward(self.weight, indices)
|
||||
|
||||
return result
|
||||
|
||||
def parameters(self) -> List[Tensor]:
|
||||
"""Return trainable parameters."""
|
||||
|
||||
Reference in New Issue
Block a user