mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-06-04 00:27:19 -05:00
fix(module-11): Fix Embedding and PositionalEncoding gradient flow
- Embedding.forward() now preserves requires_grad from weight tensor - PositionalEncoding.forward() uses Tensor addition (x + pos) instead of .data - Critical for transformer input embeddings to have gradients Both changes ensure gradient flows from loss back to embedding weights
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7d9a210d",
|
||||
"id": "4d9dd3b6",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -51,7 +51,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9e8440ef",
|
||||
"id": "deb6c0c5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -61,7 +61,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "70d81596",
|
||||
"id": "98db1095",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -76,7 +76,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1ddbc881",
|
||||
"id": "0e8eb8a3",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -133,7 +133,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f1e4bdc9",
|
||||
"id": "d06e3f22",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -239,7 +239,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b0e68de7",
|
||||
"id": "f6ddc76a",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -253,7 +253,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "19fc003f",
|
||||
"id": "de174193",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -334,7 +334,8 @@
|
||||
" # This is equivalent to one-hot multiplication but much more efficient\n",
|
||||
" embedded = self.weight.data[indices.data.astype(int)]\n",
|
||||
"\n",
|
||||
" return Tensor(embedded)\n",
|
||||
" # Preserve requires_grad so autograd can track this operation!\n",
|
||||
" return Tensor(embedded, requires_grad=self.weight.requires_grad)\n",
|
||||
"\n",
|
||||
" def parameters(self) -> List[Tensor]:\n",
|
||||
" \"\"\"Return trainable parameters.\"\"\"\n",
|
||||
@@ -348,7 +349,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "17498acf",
|
||||
"id": "272fcf53",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": true,
|
||||
@@ -398,7 +399,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2aeb2910",
|
||||
"id": "2f331a75",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -437,7 +438,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2d67a811",
|
||||
"id": "3d5a68f9",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -451,7 +452,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b82481b4",
|
||||
"id": "2020ceda",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -537,16 +538,19 @@
|
||||
" f\"Embedding dimension mismatch: expected {self.embed_dim}, got {embed_dim}\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Get position embeddings for this sequence length\n",
|
||||
" pos_embeddings = self.position_embeddings.data[:seq_len] # (seq_len, embed_dim)\n",
|
||||
" # Get position embeddings for this sequence length (slice using .data for efficiency)\n",
|
||||
" pos_embeddings_data = self.position_embeddings.data[:seq_len] # (seq_len, embed_dim)\n",
|
||||
"\n",
|
||||
" # Broadcast to match batch dimension: (1, seq_len, embed_dim)\n",
|
||||
" pos_embeddings = pos_embeddings[np.newaxis, :, :]\n",
|
||||
" pos_embeddings_data = pos_embeddings_data[np.newaxis, :, :]\n",
|
||||
" \n",
|
||||
" # Wrap in Tensor to preserve requires_grad\n",
|
||||
" pos_embeddings = Tensor(pos_embeddings_data, requires_grad=self.position_embeddings.requires_grad)\n",
|
||||
"\n",
|
||||
" # Add positional information to input embeddings\n",
|
||||
" result = x.data + pos_embeddings\n",
|
||||
" # Add positional information using Tensor operation to preserve gradients!\n",
|
||||
" result = x + pos_embeddings\n",
|
||||
"\n",
|
||||
" return Tensor(result)\n",
|
||||
" return result\n",
|
||||
"\n",
|
||||
" def parameters(self) -> List[Tensor]:\n",
|
||||
" \"\"\"Return trainable parameters.\"\"\"\n",
|
||||
@@ -560,7 +564,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c9f01de4",
|
||||
"id": "61ad6469",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": true,
|
||||
@@ -616,7 +620,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "809099dc",
|
||||
"id": "64bb6901",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -684,7 +688,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6fa0a064",
|
||||
"id": "b3bc691b",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -698,7 +702,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "825a0299",
|
||||
"id": "58986735",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -773,7 +777,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "440952af",
|
||||
"id": "61cbfa9a",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": true,
|
||||
@@ -830,7 +834,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6ffbaced",
|
||||
"id": "8b3d547c",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -850,7 +854,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e25bf119",
|
||||
"id": "46e78a86",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -927,7 +931,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2f033b57",
|
||||
"id": "cdc0ae73",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -1075,7 +1079,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "72670535",
|
||||
"id": "ee80b3bd",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": true,
|
||||
@@ -1164,7 +1168,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "27c9db3b",
|
||||
"id": "55747d44",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -1178,7 +1182,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d05c51c1",
|
||||
"id": "18fc3b94",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -1238,7 +1242,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1b96f57b",
|
||||
"id": "1de57928",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -1305,7 +1309,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "62edc85a",
|
||||
"id": "4f1eb570",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": false,
|
||||
@@ -1388,7 +1392,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ffcbbfe8",
|
||||
"id": "b6212e7a",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -1402,7 +1406,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9a0587e9",
|
||||
"id": "a0927e45",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -1542,7 +1546,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0ed96c77",
|
||||
"id": "363e02c3",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": false,
|
||||
@@ -1561,7 +1565,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a6aff95f",
|
||||
"id": "4ddd0c51",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -1595,7 +1599,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6b073262",
|
||||
"id": "efedf22b",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
|
||||
@@ -298,7 +298,8 @@ class Embedding:
|
||||
# This is equivalent to one-hot multiplication but much more efficient
|
||||
embedded = self.weight.data[indices.data.astype(int)]
|
||||
|
||||
return Tensor(embedded)
|
||||
# Preserve requires_grad so autograd can track this operation!
|
||||
return Tensor(embedded, requires_grad=self.weight.requires_grad)
|
||||
|
||||
def parameters(self) -> List[Tensor]:
|
||||
"""Return trainable parameters."""
|
||||
@@ -462,16 +463,19 @@ class PositionalEncoding:
|
||||
f"Embedding dimension mismatch: expected {self.embed_dim}, got {embed_dim}"
|
||||
)
|
||||
|
||||
# Get position embeddings for this sequence length
|
||||
pos_embeddings = self.position_embeddings.data[:seq_len] # (seq_len, embed_dim)
|
||||
# Get position embeddings for this sequence length (slice using .data for efficiency)
|
||||
pos_embeddings_data = self.position_embeddings.data[:seq_len] # (seq_len, embed_dim)
|
||||
|
||||
# Broadcast to match batch dimension: (1, seq_len, embed_dim)
|
||||
pos_embeddings = pos_embeddings[np.newaxis, :, :]
|
||||
pos_embeddings_data = pos_embeddings_data[np.newaxis, :, :]
|
||||
|
||||
# Wrap in Tensor to preserve requires_grad
|
||||
pos_embeddings = Tensor(pos_embeddings_data, requires_grad=self.position_embeddings.requires_grad)
|
||||
|
||||
# Add positional information to input embeddings
|
||||
result = x.data + pos_embeddings
|
||||
# Add positional information using Tensor operation to preserve gradients!
|
||||
result = x + pos_embeddings
|
||||
|
||||
return Tensor(result)
|
||||
return result
|
||||
|
||||
def parameters(self) -> List[Tensor]:
|
||||
"""Return trainable parameters."""
|
||||
|
||||
18
tinytorch/text/embeddings.py
generated
18
tinytorch/text/embeddings.py
generated
@@ -95,7 +95,8 @@ class Embedding:
|
||||
# This is equivalent to one-hot multiplication but much more efficient
|
||||
embedded = self.weight.data[indices.data.astype(int)]
|
||||
|
||||
return Tensor(embedded)
|
||||
# Preserve requires_grad so autograd can track this operation!
|
||||
return Tensor(embedded, requires_grad=self.weight.requires_grad)
|
||||
|
||||
def parameters(self) -> List[Tensor]:
|
||||
"""Return trainable parameters."""
|
||||
@@ -180,16 +181,19 @@ class PositionalEncoding:
|
||||
f"Embedding dimension mismatch: expected {self.embed_dim}, got {embed_dim}"
|
||||
)
|
||||
|
||||
# Get position embeddings for this sequence length
|
||||
pos_embeddings = self.position_embeddings.data[:seq_len] # (seq_len, embed_dim)
|
||||
# Get position embeddings for this sequence length (slice using .data for efficiency)
|
||||
pos_embeddings_data = self.position_embeddings.data[:seq_len] # (seq_len, embed_dim)
|
||||
|
||||
# Broadcast to match batch dimension: (1, seq_len, embed_dim)
|
||||
pos_embeddings = pos_embeddings[np.newaxis, :, :]
|
||||
pos_embeddings_data = pos_embeddings_data[np.newaxis, :, :]
|
||||
|
||||
# Wrap in Tensor to preserve requires_grad
|
||||
pos_embeddings = Tensor(pos_embeddings_data, requires_grad=self.position_embeddings.requires_grad)
|
||||
|
||||
# Add positional information to input embeddings
|
||||
result = x.data + pos_embeddings
|
||||
# Add positional information using Tensor operation to preserve gradients!
|
||||
result = x + pos_embeddings
|
||||
|
||||
return Tensor(result)
|
||||
return result
|
||||
|
||||
def parameters(self) -> List[Tensor]:
|
||||
"""Return trainable parameters."""
|
||||
|
||||
Reference in New Issue
Block a user