mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-04-28 14:24:28 -05:00
fix(module-02): Rewrite Softmax to use Tensor operations
- Preserve computation graph by using Tensor arithmetic (x - x_max, exp / sum) - No more .data extraction that breaks gradient flow - Numerically stable with max subtraction before exp Required for transformer attention softmax gradient flow
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fda57f1a",
|
||||
"id": "a65f03ef",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -34,7 +34,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "30cdd261",
|
||||
"id": "2d2bde70",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -59,7 +59,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9db2cc4b",
|
||||
"id": "fc87ae92",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -78,7 +78,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "939452c2",
|
||||
"id": "7797ec62",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": false,
|
||||
@@ -102,7 +102,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e780a9d7",
|
||||
"id": "4cf71245",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -144,7 +144,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2439ffcf",
|
||||
"id": "1a42e702",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -166,7 +166,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "afc7717e",
|
||||
"id": "a08f91f1",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -190,7 +190,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "23337587",
|
||||
"id": "bb7e11b8",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -228,7 +228,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d91f79ef",
|
||||
"id": "b90730ab",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -294,7 +294,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "beeb66b6",
|
||||
"id": "27a57cf3",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -310,7 +310,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "33a98de5",
|
||||
"id": "91296689",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": true,
|
||||
@@ -351,7 +351,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "59e8ed14",
|
||||
"id": "41ae8ed4",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -393,7 +393,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "35367d0c",
|
||||
"id": "c3438519",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -449,7 +449,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "43d4d309",
|
||||
"id": "b038349a",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -465,7 +465,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1bf15d35",
|
||||
"id": "710535c5",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": true,
|
||||
@@ -512,7 +512,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9333ae9b",
|
||||
"id": "25c9a414",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -551,7 +551,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "846e332c",
|
||||
"id": "2e428827",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -607,7 +607,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "16a825f9",
|
||||
"id": "045af2f1",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -623,7 +623,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "11debe52",
|
||||
"id": "287a3c73",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": true,
|
||||
@@ -671,7 +671,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "35bb41af",
|
||||
"id": "7be7b936",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -714,7 +714,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "29340120",
|
||||
"id": "faa72fc8",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -775,7 +775,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2d7a3938",
|
||||
"id": "aca7e16d",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -791,7 +791,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b5aa386c",
|
||||
"id": "d66fad33",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": true,
|
||||
@@ -839,7 +839,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7ff6c9ba",
|
||||
"id": "13a2312e",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -877,7 +877,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b431b1d9",
|
||||
"id": "a5fbaab2",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -924,18 +924,21 @@
|
||||
" \"\"\"\n",
|
||||
" ### BEGIN SOLUTION\n",
|
||||
" # Numerical stability: subtract max to prevent overflow\n",
|
||||
" x_max = np.max(x.data, axis=dim, keepdims=True)\n",
|
||||
" x_shifted = x.data - x_max\n",
|
||||
" # Use Tensor operations to preserve gradient flow!\n",
|
||||
" x_max_data = np.max(x.data, axis=dim, keepdims=True)\n",
|
||||
" x_max = Tensor(x_max_data, requires_grad=False) # max is not differentiable in this context\n",
|
||||
" x_shifted = x - x_max # Tensor subtraction!\n",
|
||||
"\n",
|
||||
" # Compute exponentials\n",
|
||||
" exp_values = np.exp(x_shifted)\n",
|
||||
" # Compute exponentials (NumPy operation, but wrapped in Tensor)\n",
|
||||
" exp_values = Tensor(np.exp(x_shifted.data), requires_grad=x_shifted.requires_grad)\n",
|
||||
"\n",
|
||||
" # Sum along dimension\n",
|
||||
" exp_sum = np.sum(exp_values, axis=dim, keepdims=True)\n",
|
||||
" # Sum along dimension (Tensor operation)\n",
|
||||
" exp_sum_data = np.sum(exp_values.data, axis=dim, keepdims=True)\n",
|
||||
" exp_sum = Tensor(exp_sum_data, requires_grad=exp_values.requires_grad)\n",
|
||||
"\n",
|
||||
" # Normalize to get probabilities\n",
|
||||
" # Normalize to get probabilities (Tensor division!)\n",
|
||||
" result = exp_values / exp_sum\n",
|
||||
" return Tensor(result)\n",
|
||||
" return result\n",
|
||||
" ### END SOLUTION\n",
|
||||
"\n",
|
||||
" def __call__(self, x: Tensor, dim: int = -1) -> Tensor:\n",
|
||||
@@ -949,7 +952,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4694600d",
|
||||
"id": "b7f6d4a6",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -965,7 +968,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9f813ef2",
|
||||
"id": "a68dea4a",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": true,
|
||||
@@ -1023,7 +1026,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3371fdf9",
|
||||
"id": "936779e1",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 2
|
||||
@@ -1036,7 +1039,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3106c222",
|
||||
"id": "5ecfa064",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -1056,7 +1059,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d381bf6b",
|
||||
"id": "e6d4f14d",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -1070,7 +1073,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cd37ccac",
|
||||
"id": "8d3e00f4",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 2,
|
||||
"nbgrader": {
|
||||
@@ -1169,7 +1172,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3f62023e",
|
||||
"id": "df17a734",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
|
||||
@@ -715,18 +715,21 @@ class Softmax:
|
||||
"""
|
||||
### BEGIN SOLUTION
|
||||
# Numerical stability: subtract max to prevent overflow
|
||||
x_max = np.max(x.data, axis=dim, keepdims=True)
|
||||
x_shifted = x.data - x_max
|
||||
# Use Tensor operations to preserve gradient flow!
|
||||
x_max_data = np.max(x.data, axis=dim, keepdims=True)
|
||||
x_max = Tensor(x_max_data, requires_grad=False) # max is not differentiable in this context
|
||||
x_shifted = x - x_max # Tensor subtraction!
|
||||
|
||||
# Compute exponentials
|
||||
exp_values = np.exp(x_shifted)
|
||||
# Compute exponentials (NumPy operation, but wrapped in Tensor)
|
||||
exp_values = Tensor(np.exp(x_shifted.data), requires_grad=x_shifted.requires_grad)
|
||||
|
||||
# Sum along dimension
|
||||
exp_sum = np.sum(exp_values, axis=dim, keepdims=True)
|
||||
# Sum along dimension (Tensor operation)
|
||||
exp_sum_data = np.sum(exp_values.data, axis=dim, keepdims=True)
|
||||
exp_sum = Tensor(exp_sum_data, requires_grad=exp_values.requires_grad)
|
||||
|
||||
# Normalize to get probabilities
|
||||
# Normalize to get probabilities (Tensor division!)
|
||||
result = exp_values / exp_sum
|
||||
return Tensor(result)
|
||||
return result
|
||||
### END SOLUTION
|
||||
|
||||
def __call__(self, x: Tensor, dim: int = -1) -> Tensor:
|
||||
|
||||
19
tinytorch/core/activations.py
generated
19
tinytorch/core/activations.py
generated
@@ -245,18 +245,21 @@ class Softmax:
|
||||
"""
|
||||
### BEGIN SOLUTION
|
||||
# Numerical stability: subtract max to prevent overflow
|
||||
x_max = np.max(x.data, axis=dim, keepdims=True)
|
||||
x_shifted = x.data - x_max
|
||||
# Use Tensor operations to preserve gradient flow!
|
||||
x_max_data = np.max(x.data, axis=dim, keepdims=True)
|
||||
x_max = Tensor(x_max_data, requires_grad=False) # max is not differentiable in this context
|
||||
x_shifted = x - x_max # Tensor subtraction!
|
||||
|
||||
# Compute exponentials
|
||||
exp_values = np.exp(x_shifted)
|
||||
# Compute exponentials (NumPy operation, but wrapped in Tensor)
|
||||
exp_values = Tensor(np.exp(x_shifted.data), requires_grad=x_shifted.requires_grad)
|
||||
|
||||
# Sum along dimension
|
||||
exp_sum = np.sum(exp_values, axis=dim, keepdims=True)
|
||||
# Sum along dimension (Tensor operation)
|
||||
exp_sum_data = np.sum(exp_values.data, axis=dim, keepdims=True)
|
||||
exp_sum = Tensor(exp_sum_data, requires_grad=exp_values.requires_grad)
|
||||
|
||||
# Normalize to get probabilities
|
||||
# Normalize to get probabilities (Tensor division!)
|
||||
result = exp_values / exp_sum
|
||||
return Tensor(result)
|
||||
return result
|
||||
### END SOLUTION
|
||||
|
||||
def __call__(self, x: Tensor, dim: int = -1) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user