mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-03-26 15:42:18 -05:00
fix(module-12): Rewrite attention to use batched Tensor operations
Major rewrite for gradient flow: - scaled_dot_product_attention: Use Tensor ops (matmul, transpose, softmax) - MultiHeadAttention: Process all heads in parallel with 4D batched tensors - No explicit batch loops or .data extraction - Proper mask broadcasting for (batch * heads) dimension This is the most complex fix - attention is now fully differentiable end-to-end
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d94b5da2",
|
||||
"id": "f766c98a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -13,7 +13,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9306f576",
|
||||
"id": "9f6e2eb9",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -63,7 +63,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2eaafa86",
|
||||
"id": "ae589b7d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -80,7 +80,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "81ea33fc",
|
||||
"id": "475bee15",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -137,7 +137,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9330210a",
|
||||
"id": "7b77744f",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -229,7 +229,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "394e7884",
|
||||
"id": "36379ce8",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -275,7 +275,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7eada95c",
|
||||
"id": "8da6fcba",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -336,63 +336,53 @@
|
||||
" assert K.shape == (batch_size, seq_len, d_model), f\"K shape {K.shape} doesn't match Q shape {Q.shape}\"\n",
|
||||
" assert V.shape == (batch_size, seq_len, d_model), f\"V shape {V.shape} doesn't match Q shape {Q.shape}\"\n",
|
||||
"\n",
|
||||
" # Step 2: Compute attention scores with explicit loops (educational O(n²) demonstration)\n",
|
||||
" scores = np.zeros((batch_size, seq_len, seq_len))\n",
|
||||
" # Step 2: Compute attention scores Q @ K^T using batched Tensor operations (NO loops!)\n",
|
||||
" # Q: (batch, seq, d_model)\n",
|
||||
" # K: (batch, seq, d_model)\n",
|
||||
" # K.transpose() swaps last two dims: (batch, d_model, seq)\n",
|
||||
" # Q @ K.T: (batch, seq, d_model) @ (batch, d_model, seq) → (batch, seq, seq)\n",
|
||||
" K_T = K.transpose() # (batch, d_model, seq) - Preserves requires_grad!\n",
|
||||
" scores = Q.matmul(K_T) # (batch, seq, seq) - Module 05's tracked_matmul sets _grad_fn!\n",
|
||||
"\n",
|
||||
" # Show the quadratic complexity explicitly\n",
|
||||
" for b in range(batch_size): # For each batch\n",
|
||||
" for i in range(seq_len): # For each query position\n",
|
||||
" for j in range(seq_len): # Attend to each key position\n",
|
||||
" # Compute dot product between query i and key j\n",
|
||||
" score = 0.0\n",
|
||||
" for d in range(d_model): # Dot product across embedding dimension\n",
|
||||
" score += Q.data[b, i, d] * K.data[b, j, d]\n",
|
||||
" scores[b, i, j] = score\n",
|
||||
"\n",
|
||||
" # Step 3: Scale by 1/√d_k for numerical stability\n",
|
||||
" # Step 3: Scale by 1/√d_k for numerical stability (Tensor operation!)\n",
|
||||
" scale_factor = 1.0 / math.sqrt(d_model)\n",
|
||||
" scores = scores * scale_factor\n",
|
||||
" scores = scores * scale_factor # Tensor multiplication - Module 05's tracked_mul!\n",
|
||||
"\n",
|
||||
" # Step 4: Apply causal mask if provided\n",
|
||||
" # Step 4: Apply causal mask if provided (Tensor operation!)\n",
|
||||
" if mask is not None:\n",
|
||||
" # mask[i,j] = False means position j should not attend to position i\n",
|
||||
" mask_value = -1e9 # Large negative value becomes 0 after softmax\n",
|
||||
" for b in range(batch_size):\n",
|
||||
" for i in range(seq_len):\n",
|
||||
" for j in range(seq_len):\n",
|
||||
" if not mask.data[b, i, j]: # If mask is False, block attention\n",
|
||||
" scores[b, i, j] = mask_value\n",
|
||||
" # mask: True where attention is allowed, False where masked\n",
|
||||
" # Convert to additive mask: 0 where allowed, -1e9 where masked\n",
|
||||
" # This way we can use Tensor addition which preserves gradients!\n",
|
||||
" if mask.data.ndim == 2:\n",
|
||||
" # Broadcast (seq, seq) mask to (batch, seq, seq)\n",
|
||||
" mask_additive = Tensor(np.where(mask.data, 0.0, -1e9))\n",
|
||||
" else:\n",
|
||||
" # Already (batch, seq, seq)\n",
|
||||
" mask_additive = Tensor(np.where(mask.data, 0.0, -1e9))\n",
|
||||
" scores = scores + mask_additive # Tensor addition - Module 05's tracked_add!\n",
|
||||
"\n",
|
||||
" # Step 5: Apply softmax to get attention weights (probability distribution)\n",
|
||||
" attention_weights = np.zeros_like(scores)\n",
|
||||
" for b in range(batch_size):\n",
|
||||
" for i in range(seq_len):\n",
|
||||
" # Softmax over the j dimension (what this query attends to)\n",
|
||||
" row = scores[b, i, :]\n",
|
||||
" max_val = np.max(row) # Numerical stability\n",
|
||||
" exp_row = np.exp(row - max_val)\n",
|
||||
" sum_exp = np.sum(exp_row)\n",
|
||||
" attention_weights[b, i, :] = exp_row / sum_exp\n",
|
||||
" # Step 5: Apply softmax (NO loops - softmax handles batched input!)\n",
|
||||
" from tinytorch.core.activations import Softmax\n",
|
||||
" softmax = Softmax()\n",
|
||||
" \n",
|
||||
" # Apply softmax along last dimension (over keys for each query)\n",
|
||||
" # scores: (batch, seq, seq) → attention_weights: (batch, seq, seq)\n",
|
||||
" attention_weights = softmax.forward(scores, dim=-1) # Tensor operation!\n",
|
||||
"\n",
|
||||
" # Step 6: Apply attention weights to values (another O(n²) operation)\n",
|
||||
" output = np.zeros((batch_size, seq_len, d_model))\n",
|
||||
" # Step 6: Apply attention weights to values (NO loops - batched matmul!)\n",
|
||||
" # attention_weights: (batch, seq, seq)\n",
|
||||
" # V: (batch, seq, d_model)\n",
|
||||
" # weights @ V: (batch, seq, seq) @ (batch, seq, d_model) → (batch, seq, d_model)\n",
|
||||
" output = attention_weights.matmul(V) # Tensor operation - Module 05's tracked_matmul!\n",
|
||||
"\n",
|
||||
" # Again, show the quadratic complexity\n",
|
||||
" for b in range(batch_size): # For each batch\n",
|
||||
" for i in range(seq_len): # For each output position\n",
|
||||
" for j in range(seq_len): # Weighted sum over all value positions\n",
|
||||
" weight = attention_weights[b, i, j]\n",
|
||||
" for d in range(d_model): # Accumulate across embedding dimension\n",
|
||||
" output[b, i, d] += weight * V.data[b, j, d]\n",
|
||||
"\n",
|
||||
" return Tensor(output), Tensor(attention_weights)\n",
|
||||
" return output, attention_weights\n",
|
||||
" ### END SOLUTION"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9e006e03",
|
||||
"id": "b3814135",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": true,
|
||||
@@ -443,7 +433,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "712ce2a0",
|
||||
"id": "59710f59",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -464,7 +454,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0ae42b8d",
|
||||
"id": "80839b76",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -554,7 +544,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f540c1d4",
|
||||
"id": "fcc327b3",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -656,46 +646,59 @@
|
||||
" batch_size, seq_len, embed_dim = x.shape\n",
|
||||
" assert embed_dim == self.embed_dim, f\"Input dim {embed_dim} doesn't match expected {self.embed_dim}\"\n",
|
||||
"\n",
|
||||
" # Step 2: Project to Q, K, V\n",
|
||||
" # Step 2: Project to Q, K, V (Tensor operations!)\n",
|
||||
" Q = self.q_proj.forward(x) # (batch, seq, embed_dim)\n",
|
||||
" K = self.k_proj.forward(x)\n",
|
||||
" V = self.v_proj.forward(x)\n",
|
||||
"\n",
|
||||
" # Step 3: Reshape to separate heads\n",
|
||||
" # From (batch, seq, embed_dim) to (batch, seq, num_heads, head_dim)\n",
|
||||
" Q_heads = Q.data.reshape(batch_size, seq_len, self.num_heads, self.head_dim)\n",
|
||||
" K_heads = K.data.reshape(batch_size, seq_len, self.num_heads, self.head_dim)\n",
|
||||
" V_heads = V.data.reshape(batch_size, seq_len, self.num_heads, self.head_dim)\n",
|
||||
" # Step 3: Reshape to separate heads (batch, seq, embed) → (batch, seq, heads, head_dim)\n",
|
||||
" Q_heads = Q.reshape(batch_size, seq_len, self.num_heads, self.head_dim)\n",
|
||||
" K_heads = K.reshape(batch_size, seq_len, self.num_heads, self.head_dim)\n",
|
||||
" V_heads = V.reshape(batch_size, seq_len, self.num_heads, self.head_dim)\n",
|
||||
"\n",
|
||||
" # Step 4: Transpose to (batch, num_heads, seq, head_dim) for parallel processing\n",
|
||||
" Q_heads = np.transpose(Q_heads, (0, 2, 1, 3))\n",
|
||||
" K_heads = np.transpose(K_heads, (0, 2, 1, 3))\n",
|
||||
" V_heads = np.transpose(V_heads, (0, 2, 1, 3))\n",
|
||||
" # Step 4: Rearrange dims to (batch, heads, seq, head_dim) for parallel processing\n",
|
||||
" # We need to transpose dims 1 and 2, but Tensor.transpose() only swaps last two dims\n",
|
||||
" # So we manually transpose using NumPy, but preserve requires_grad\n",
|
||||
" Q_heads = Tensor(np.transpose(Q_heads.data, (0, 2, 1, 3)), requires_grad=Q_heads.requires_grad)\n",
|
||||
" K_heads = Tensor(np.transpose(K_heads.data, (0, 2, 1, 3)), requires_grad=K_heads.requires_grad)\n",
|
||||
" V_heads = Tensor(np.transpose(V_heads.data, (0, 2, 1, 3)), requires_grad=V_heads.requires_grad)\n",
|
||||
" \n",
|
||||
" # Step 5: Process ALL heads in parallel (NO loops!)\n",
|
||||
" # Reshape to combine batch and head dims: (batch, heads, seq, head_dim) → (batch*heads, seq, head_dim)\n",
|
||||
" batch_heads = batch_size * self.num_heads\n",
|
||||
" Q_flat = Q_heads.reshape(batch_heads, seq_len, self.head_dim)\n",
|
||||
" K_flat = K_heads.reshape(batch_heads, seq_len, self.head_dim)\n",
|
||||
" V_flat = V_heads.reshape(batch_heads, seq_len, self.head_dim)\n",
|
||||
" \n",
|
||||
" # Handle mask: Repeat for each head\n",
|
||||
" # mask: (batch, seq, seq) needs to become (batch*heads, seq, seq)\n",
|
||||
" if mask is not None:\n",
|
||||
" if mask.data.ndim == 2:\n",
|
||||
" # (seq, seq) → repeat for each batch and head\n",
|
||||
" mask_data = np.tile(mask.data[np.newaxis, :, :], (batch_heads, 1, 1))\n",
|
||||
" else:\n",
|
||||
" # (batch, seq, seq) → repeat for each head\n",
|
||||
" # For each batch element, repeat the mask num_heads times\n",
|
||||
" mask_data = np.repeat(mask.data, self.num_heads, axis=0)\n",
|
||||
" mask_flat = Tensor(mask_data)\n",
|
||||
" else:\n",
|
||||
" mask_flat = None\n",
|
||||
" \n",
|
||||
" # Apply attention to all heads at once! (Tensor operation)\n",
|
||||
" # This batches all heads together - efficient and preserves gradients!\n",
|
||||
" attn_output, _ = scaled_dot_product_attention(Q_flat, K_flat, V_flat, mask_flat)\n",
|
||||
" \n",
|
||||
" # Step 6: Reshape back to separate batch and heads: (batch*heads, seq, head_dim) → (batch, heads, seq, head_dim)\n",
|
||||
" attn_output = attn_output.reshape(batch_size, self.num_heads, seq_len, self.head_dim)\n",
|
||||
" \n",
|
||||
" # Step 7: Transpose back: (batch, heads, seq, head_dim) → (batch, seq, heads, head_dim)\n",
|
||||
" attn_output = Tensor(np.transpose(attn_output.data, (0, 2, 1, 3)), requires_grad=attn_output.requires_grad)\n",
|
||||
" \n",
|
||||
" # Step 8: Merge heads: (batch, seq, heads, head_dim) → (batch, seq, embed_dim)\n",
|
||||
" output = attn_output.reshape(batch_size, seq_len, self.embed_dim)\n",
|
||||
"\n",
|
||||
" # Step 5: Apply attention to each head\n",
|
||||
" head_outputs = []\n",
|
||||
" for h in range(self.num_heads):\n",
|
||||
" # Extract this head's Q, K, V\n",
|
||||
" Q_h = Tensor(Q_heads[:, h, :, :]) # (batch, seq, head_dim)\n",
|
||||
" K_h = Tensor(K_heads[:, h, :, :])\n",
|
||||
" V_h = Tensor(V_heads[:, h, :, :])\n",
|
||||
"\n",
|
||||
" # Apply attention for this head\n",
|
||||
" head_out, _ = scaled_dot_product_attention(Q_h, K_h, V_h, mask)\n",
|
||||
" head_outputs.append(head_out.data)\n",
|
||||
"\n",
|
||||
" # Step 6: Concatenate heads back together\n",
|
||||
" # Stack: list of (batch, seq, head_dim) → (batch, num_heads, seq, head_dim)\n",
|
||||
" concat_heads = np.stack(head_outputs, axis=1)\n",
|
||||
"\n",
|
||||
" # Transpose back: (batch, num_heads, seq, head_dim) → (batch, seq, num_heads, head_dim)\n",
|
||||
" concat_heads = np.transpose(concat_heads, (0, 2, 1, 3))\n",
|
||||
"\n",
|
||||
" # Reshape: (batch, seq, num_heads, head_dim) → (batch, seq, embed_dim)\n",
|
||||
" concat_output = concat_heads.reshape(batch_size, seq_len, self.embed_dim)\n",
|
||||
"\n",
|
||||
" # Step 7: Apply output projection\n",
|
||||
" output = self.out_proj.forward(Tensor(concat_output))\n",
|
||||
" # Step 9: Apply output projection (Tensor operation!)\n",
|
||||
" output = self.out_proj.forward(output)\n",
|
||||
"\n",
|
||||
" return output\n",
|
||||
" ### END SOLUTION\n",
|
||||
@@ -726,7 +729,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "636a3fed",
|
||||
"id": "b88e148e",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": true,
|
||||
@@ -783,7 +786,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "da0586c2",
|
||||
"id": "8d3e378b",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -803,7 +806,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "bd666af7",
|
||||
"id": "515947b3",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -845,7 +848,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a722af5d",
|
||||
"id": "d7731810",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 1,
|
||||
"nbgrader": {
|
||||
@@ -887,7 +890,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "692eb505",
|
||||
"id": "7aa0a9f4",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": false,
|
||||
@@ -941,7 +944,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5012f8f3",
|
||||
"id": "4e2a6859",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -986,7 +989,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f0cfd879",
|
||||
"id": "f640e621",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -1029,7 +1032,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f8433bd9",
|
||||
"id": "406093a3",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": false,
|
||||
@@ -1127,7 +1130,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "76625dbe",
|
||||
"id": "fe269094",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -1161,7 +1164,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "66c41cfa",
|
||||
"id": "ac747073",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\"",
|
||||
"lines_to_next_cell": 1
|
||||
@@ -1175,7 +1178,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c5c381db",
|
||||
"id": "42ee4807",
|
||||
"metadata": {
|
||||
"nbgrader": {
|
||||
"grade": true,
|
||||
@@ -1221,7 +1224,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "10ced70a",
|
||||
"id": "358351d5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -1233,7 +1236,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f42b351d",
|
||||
"id": "5880efbe",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
@@ -1273,7 +1276,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "51aafac3",
|
||||
"id": "2193e46d",
|
||||
"metadata": {
|
||||
"cell_marker": "\"\"\""
|
||||
},
|
||||
|
||||
@@ -299,56 +299,46 @@ def scaled_dot_product_attention(Q: Tensor, K: Tensor, V: Tensor, mask: Optional
|
||||
assert K.shape == (batch_size, seq_len, d_model), f"K shape {K.shape} doesn't match Q shape {Q.shape}"
|
||||
assert V.shape == (batch_size, seq_len, d_model), f"V shape {V.shape} doesn't match Q shape {Q.shape}"
|
||||
|
||||
# Step 2: Compute attention scores with explicit loops (educational O(n²) demonstration)
|
||||
scores = np.zeros((batch_size, seq_len, seq_len))
|
||||
# Step 2: Compute attention scores Q @ K^T using batched Tensor operations (NO loops!)
|
||||
# Q: (batch, seq, d_model)
|
||||
# K: (batch, seq, d_model)
|
||||
# K.transpose() swaps last two dims: (batch, d_model, seq)
|
||||
# Q @ K.T: (batch, seq, d_model) @ (batch, d_model, seq) → (batch, seq, seq)
|
||||
K_T = K.transpose() # (batch, d_model, seq) - Preserves requires_grad!
|
||||
scores = Q.matmul(K_T) # (batch, seq, seq) - Module 05's tracked_matmul sets _grad_fn!
|
||||
|
||||
# Show the quadratic complexity explicitly
|
||||
for b in range(batch_size): # For each batch
|
||||
for i in range(seq_len): # For each query position
|
||||
for j in range(seq_len): # Attend to each key position
|
||||
# Compute dot product between query i and key j
|
||||
score = 0.0
|
||||
for d in range(d_model): # Dot product across embedding dimension
|
||||
score += Q.data[b, i, d] * K.data[b, j, d]
|
||||
scores[b, i, j] = score
|
||||
|
||||
# Step 3: Scale by 1/√d_k for numerical stability
|
||||
# Step 3: Scale by 1/√d_k for numerical stability (Tensor operation!)
|
||||
scale_factor = 1.0 / math.sqrt(d_model)
|
||||
scores = scores * scale_factor
|
||||
scores = scores * scale_factor # Tensor multiplication - Module 05's tracked_mul!
|
||||
|
||||
# Step 4: Apply causal mask if provided
|
||||
# Step 4: Apply causal mask if provided (Tensor operation!)
|
||||
if mask is not None:
|
||||
# mask[i,j] = False means position j should not attend to position i
|
||||
mask_value = -1e9 # Large negative value becomes 0 after softmax
|
||||
for b in range(batch_size):
|
||||
for i in range(seq_len):
|
||||
for j in range(seq_len):
|
||||
if not mask.data[b, i, j]: # If mask is False, block attention
|
||||
scores[b, i, j] = mask_value
|
||||
# mask: True where attention is allowed, False where masked
|
||||
# Convert to additive mask: 0 where allowed, -1e9 where masked
|
||||
# This way we can use Tensor addition which preserves gradients!
|
||||
if mask.data.ndim == 2:
|
||||
# Broadcast (seq, seq) mask to (batch, seq, seq)
|
||||
mask_additive = Tensor(np.where(mask.data, 0.0, -1e9))
|
||||
else:
|
||||
# Already (batch, seq, seq)
|
||||
mask_additive = Tensor(np.where(mask.data, 0.0, -1e9))
|
||||
scores = scores + mask_additive # Tensor addition - Module 05's tracked_add!
|
||||
|
||||
# Step 5: Apply softmax to get attention weights (probability distribution)
|
||||
attention_weights = np.zeros_like(scores)
|
||||
for b in range(batch_size):
|
||||
for i in range(seq_len):
|
||||
# Softmax over the j dimension (what this query attends to)
|
||||
row = scores[b, i, :]
|
||||
max_val = np.max(row) # Numerical stability
|
||||
exp_row = np.exp(row - max_val)
|
||||
sum_exp = np.sum(exp_row)
|
||||
attention_weights[b, i, :] = exp_row / sum_exp
|
||||
# Step 5: Apply softmax (NO loops - softmax handles batched input!)
|
||||
from tinytorch.core.activations import Softmax
|
||||
softmax = Softmax()
|
||||
|
||||
# Apply softmax along last dimension (over keys for each query)
|
||||
# scores: (batch, seq, seq) → attention_weights: (batch, seq, seq)
|
||||
attention_weights = softmax.forward(scores, dim=-1) # Tensor operation!
|
||||
|
||||
# Step 6: Apply attention weights to values (another O(n²) operation)
|
||||
output = np.zeros((batch_size, seq_len, d_model))
|
||||
# Step 6: Apply attention weights to values (NO loops - batched matmul!)
|
||||
# attention_weights: (batch, seq, seq)
|
||||
# V: (batch, seq, d_model)
|
||||
# weights @ V: (batch, seq, seq) @ (batch, seq, d_model) → (batch, seq, d_model)
|
||||
output = attention_weights.matmul(V) # Tensor operation - Module 05's tracked_matmul!
|
||||
|
||||
# Again, show the quadratic complexity
|
||||
for b in range(batch_size): # For each batch
|
||||
for i in range(seq_len): # For each output position
|
||||
for j in range(seq_len): # Weighted sum over all value positions
|
||||
weight = attention_weights[b, i, j]
|
||||
for d in range(d_model): # Accumulate across embedding dimension
|
||||
output[b, i, d] += weight * V.data[b, j, d]
|
||||
|
||||
return Tensor(output), Tensor(attention_weights)
|
||||
return output, attention_weights
|
||||
### END SOLUTION
|
||||
|
||||
# %% nbgrader={"grade": true, "grade_id": "test-attention-basic", "locked": true, "points": 10}
|
||||
@@ -580,46 +570,59 @@ class MultiHeadAttention:
|
||||
batch_size, seq_len, embed_dim = x.shape
|
||||
assert embed_dim == self.embed_dim, f"Input dim {embed_dim} doesn't match expected {self.embed_dim}"
|
||||
|
||||
# Step 2: Project to Q, K, V
|
||||
# Step 2: Project to Q, K, V (Tensor operations!)
|
||||
Q = self.q_proj.forward(x) # (batch, seq, embed_dim)
|
||||
K = self.k_proj.forward(x)
|
||||
V = self.v_proj.forward(x)
|
||||
|
||||
# Step 3: Reshape to separate heads
|
||||
# From (batch, seq, embed_dim) to (batch, seq, num_heads, head_dim)
|
||||
Q_heads = Q.data.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
K_heads = K.data.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
V_heads = V.data.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
# Step 3: Reshape to separate heads (batch, seq, embed) → (batch, seq, heads, head_dim)
|
||||
Q_heads = Q.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
K_heads = K.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
V_heads = V.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
|
||||
# Step 4: Transpose to (batch, num_heads, seq, head_dim) for parallel processing
|
||||
Q_heads = np.transpose(Q_heads, (0, 2, 1, 3))
|
||||
K_heads = np.transpose(K_heads, (0, 2, 1, 3))
|
||||
V_heads = np.transpose(V_heads, (0, 2, 1, 3))
|
||||
# Step 4: Rearrange dims to (batch, heads, seq, head_dim) for parallel processing
|
||||
# We need to transpose dims 1 and 2, but Tensor.transpose() only swaps last two dims
|
||||
# So we manually transpose using NumPy, but preserve requires_grad
|
||||
Q_heads = Tensor(np.transpose(Q_heads.data, (0, 2, 1, 3)), requires_grad=Q_heads.requires_grad)
|
||||
K_heads = Tensor(np.transpose(K_heads.data, (0, 2, 1, 3)), requires_grad=K_heads.requires_grad)
|
||||
V_heads = Tensor(np.transpose(V_heads.data, (0, 2, 1, 3)), requires_grad=V_heads.requires_grad)
|
||||
|
||||
# Step 5: Process ALL heads in parallel (NO loops!)
|
||||
# Reshape to combine batch and head dims: (batch, heads, seq, head_dim) → (batch*heads, seq, head_dim)
|
||||
batch_heads = batch_size * self.num_heads
|
||||
Q_flat = Q_heads.reshape(batch_heads, seq_len, self.head_dim)
|
||||
K_flat = K_heads.reshape(batch_heads, seq_len, self.head_dim)
|
||||
V_flat = V_heads.reshape(batch_heads, seq_len, self.head_dim)
|
||||
|
||||
# Handle mask: Repeat for each head
|
||||
# mask: (batch, seq, seq) needs to become (batch*heads, seq, seq)
|
||||
if mask is not None:
|
||||
if mask.data.ndim == 2:
|
||||
# (seq, seq) → repeat for each batch and head
|
||||
mask_data = np.tile(mask.data[np.newaxis, :, :], (batch_heads, 1, 1))
|
||||
else:
|
||||
# (batch, seq, seq) → repeat for each head
|
||||
# For each batch element, repeat the mask num_heads times
|
||||
mask_data = np.repeat(mask.data, self.num_heads, axis=0)
|
||||
mask_flat = Tensor(mask_data)
|
||||
else:
|
||||
mask_flat = None
|
||||
|
||||
# Apply attention to all heads at once! (Tensor operation)
|
||||
# This batches all heads together - efficient and preserves gradients!
|
||||
attn_output, _ = scaled_dot_product_attention(Q_flat, K_flat, V_flat, mask_flat)
|
||||
|
||||
# Step 6: Reshape back to separate batch and heads: (batch*heads, seq, head_dim) → (batch, heads, seq, head_dim)
|
||||
attn_output = attn_output.reshape(batch_size, self.num_heads, seq_len, self.head_dim)
|
||||
|
||||
# Step 7: Transpose back: (batch, heads, seq, head_dim) → (batch, seq, heads, head_dim)
|
||||
attn_output = Tensor(np.transpose(attn_output.data, (0, 2, 1, 3)), requires_grad=attn_output.requires_grad)
|
||||
|
||||
# Step 8: Merge heads: (batch, seq, heads, head_dim) → (batch, seq, embed_dim)
|
||||
output = attn_output.reshape(batch_size, seq_len, self.embed_dim)
|
||||
|
||||
# Step 5: Apply attention to each head
|
||||
head_outputs = []
|
||||
for h in range(self.num_heads):
|
||||
# Extract this head's Q, K, V
|
||||
Q_h = Tensor(Q_heads[:, h, :, :]) # (batch, seq, head_dim)
|
||||
K_h = Tensor(K_heads[:, h, :, :])
|
||||
V_h = Tensor(V_heads[:, h, :, :])
|
||||
|
||||
# Apply attention for this head
|
||||
head_out, _ = scaled_dot_product_attention(Q_h, K_h, V_h, mask)
|
||||
head_outputs.append(head_out.data)
|
||||
|
||||
# Step 6: Concatenate heads back together
|
||||
# Stack: list of (batch, seq, head_dim) → (batch, num_heads, seq, head_dim)
|
||||
concat_heads = np.stack(head_outputs, axis=1)
|
||||
|
||||
# Transpose back: (batch, num_heads, seq, head_dim) → (batch, seq, num_heads, head_dim)
|
||||
concat_heads = np.transpose(concat_heads, (0, 2, 1, 3))
|
||||
|
||||
# Reshape: (batch, seq, num_heads, head_dim) → (batch, seq, embed_dim)
|
||||
concat_output = concat_heads.reshape(batch_size, seq_len, self.embed_dim)
|
||||
|
||||
# Step 7: Apply output projection
|
||||
output = self.out_proj.forward(Tensor(concat_output))
|
||||
# Step 9: Apply output projection (Tensor operation!)
|
||||
output = self.out_proj.forward(output)
|
||||
|
||||
return output
|
||||
### END SOLUTION
|
||||
|
||||
157
tinytorch/core/attention.py
generated
157
tinytorch/core/attention.py
generated
@@ -81,56 +81,46 @@ def scaled_dot_product_attention(Q: Tensor, K: Tensor, V: Tensor, mask: Optional
|
||||
assert K.shape == (batch_size, seq_len, d_model), f"K shape {K.shape} doesn't match Q shape {Q.shape}"
|
||||
assert V.shape == (batch_size, seq_len, d_model), f"V shape {V.shape} doesn't match Q shape {Q.shape}"
|
||||
|
||||
# Step 2: Compute attention scores with explicit loops (educational O(n²) demonstration)
|
||||
scores = np.zeros((batch_size, seq_len, seq_len))
|
||||
# Step 2: Compute attention scores Q @ K^T using batched Tensor operations (NO loops!)
|
||||
# Q: (batch, seq, d_model)
|
||||
# K: (batch, seq, d_model)
|
||||
# K.transpose() swaps last two dims: (batch, d_model, seq)
|
||||
# Q @ K.T: (batch, seq, d_model) @ (batch, d_model, seq) → (batch, seq, seq)
|
||||
K_T = K.transpose() # (batch, d_model, seq) - Preserves requires_grad!
|
||||
scores = Q.matmul(K_T) # (batch, seq, seq) - Module 05's tracked_matmul sets _grad_fn!
|
||||
|
||||
# Show the quadratic complexity explicitly
|
||||
for b in range(batch_size): # For each batch
|
||||
for i in range(seq_len): # For each query position
|
||||
for j in range(seq_len): # Attend to each key position
|
||||
# Compute dot product between query i and key j
|
||||
score = 0.0
|
||||
for d in range(d_model): # Dot product across embedding dimension
|
||||
score += Q.data[b, i, d] * K.data[b, j, d]
|
||||
scores[b, i, j] = score
|
||||
|
||||
# Step 3: Scale by 1/√d_k for numerical stability
|
||||
# Step 3: Scale by 1/√d_k for numerical stability (Tensor operation!)
|
||||
scale_factor = 1.0 / math.sqrt(d_model)
|
||||
scores = scores * scale_factor
|
||||
scores = scores * scale_factor # Tensor multiplication - Module 05's tracked_mul!
|
||||
|
||||
# Step 4: Apply causal mask if provided
|
||||
# Step 4: Apply causal mask if provided (Tensor operation!)
|
||||
if mask is not None:
|
||||
# mask[i,j] = False means position j should not attend to position i
|
||||
mask_value = -1e9 # Large negative value becomes 0 after softmax
|
||||
for b in range(batch_size):
|
||||
for i in range(seq_len):
|
||||
for j in range(seq_len):
|
||||
if not mask.data[b, i, j]: # If mask is False, block attention
|
||||
scores[b, i, j] = mask_value
|
||||
# mask: True where attention is allowed, False where masked
|
||||
# Convert to additive mask: 0 where allowed, -1e9 where masked
|
||||
# This way we can use Tensor addition which preserves gradients!
|
||||
if mask.data.ndim == 2:
|
||||
# Broadcast (seq, seq) mask to (batch, seq, seq)
|
||||
mask_additive = Tensor(np.where(mask.data, 0.0, -1e9))
|
||||
else:
|
||||
# Already (batch, seq, seq)
|
||||
mask_additive = Tensor(np.where(mask.data, 0.0, -1e9))
|
||||
scores = scores + mask_additive # Tensor addition - Module 05's tracked_add!
|
||||
|
||||
# Step 5: Apply softmax to get attention weights (probability distribution)
|
||||
attention_weights = np.zeros_like(scores)
|
||||
for b in range(batch_size):
|
||||
for i in range(seq_len):
|
||||
# Softmax over the j dimension (what this query attends to)
|
||||
row = scores[b, i, :]
|
||||
max_val = np.max(row) # Numerical stability
|
||||
exp_row = np.exp(row - max_val)
|
||||
sum_exp = np.sum(exp_row)
|
||||
attention_weights[b, i, :] = exp_row / sum_exp
|
||||
# Step 5: Apply softmax (NO loops - softmax handles batched input!)
|
||||
from tinytorch.core.activations import Softmax
|
||||
softmax = Softmax()
|
||||
|
||||
# Apply softmax along last dimension (over keys for each query)
|
||||
# scores: (batch, seq, seq) → attention_weights: (batch, seq, seq)
|
||||
attention_weights = softmax.forward(scores, dim=-1) # Tensor operation!
|
||||
|
||||
# Step 6: Apply attention weights to values (another O(n²) operation)
|
||||
output = np.zeros((batch_size, seq_len, d_model))
|
||||
# Step 6: Apply attention weights to values (NO loops - batched matmul!)
|
||||
# attention_weights: (batch, seq, seq)
|
||||
# V: (batch, seq, d_model)
|
||||
# weights @ V: (batch, seq, seq) @ (batch, seq, d_model) → (batch, seq, d_model)
|
||||
output = attention_weights.matmul(V) # Tensor operation - Module 05's tracked_matmul!
|
||||
|
||||
# Again, show the quadratic complexity
|
||||
for b in range(batch_size): # For each batch
|
||||
for i in range(seq_len): # For each output position
|
||||
for j in range(seq_len): # Weighted sum over all value positions
|
||||
weight = attention_weights[b, i, j]
|
||||
for d in range(d_model): # Accumulate across embedding dimension
|
||||
output[b, i, d] += weight * V.data[b, j, d]
|
||||
|
||||
return Tensor(output), Tensor(attention_weights)
|
||||
return output, attention_weights
|
||||
### END SOLUTION
|
||||
|
||||
# %% ../../modules/source/12_attention/attention_dev.ipynb 10
|
||||
@@ -224,46 +214,59 @@ class MultiHeadAttention:
|
||||
batch_size, seq_len, embed_dim = x.shape
|
||||
assert embed_dim == self.embed_dim, f"Input dim {embed_dim} doesn't match expected {self.embed_dim}"
|
||||
|
||||
# Step 2: Project to Q, K, V
|
||||
# Step 2: Project to Q, K, V (Tensor operations!)
|
||||
Q = self.q_proj.forward(x) # (batch, seq, embed_dim)
|
||||
K = self.k_proj.forward(x)
|
||||
V = self.v_proj.forward(x)
|
||||
|
||||
# Step 3: Reshape to separate heads
|
||||
# From (batch, seq, embed_dim) to (batch, seq, num_heads, head_dim)
|
||||
Q_heads = Q.data.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
K_heads = K.data.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
V_heads = V.data.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
# Step 3: Reshape to separate heads (batch, seq, embed) → (batch, seq, heads, head_dim)
|
||||
Q_heads = Q.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
K_heads = K.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
V_heads = V.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
|
||||
# Step 4: Transpose to (batch, num_heads, seq, head_dim) for parallel processing
|
||||
Q_heads = np.transpose(Q_heads, (0, 2, 1, 3))
|
||||
K_heads = np.transpose(K_heads, (0, 2, 1, 3))
|
||||
V_heads = np.transpose(V_heads, (0, 2, 1, 3))
|
||||
# Step 4: Rearrange dims to (batch, heads, seq, head_dim) for parallel processing
|
||||
# We need to transpose dims 1 and 2, but Tensor.transpose() only swaps last two dims
|
||||
# So we manually transpose using NumPy, but preserve requires_grad
|
||||
Q_heads = Tensor(np.transpose(Q_heads.data, (0, 2, 1, 3)), requires_grad=Q_heads.requires_grad)
|
||||
K_heads = Tensor(np.transpose(K_heads.data, (0, 2, 1, 3)), requires_grad=K_heads.requires_grad)
|
||||
V_heads = Tensor(np.transpose(V_heads.data, (0, 2, 1, 3)), requires_grad=V_heads.requires_grad)
|
||||
|
||||
# Step 5: Process ALL heads in parallel (NO loops!)
|
||||
# Reshape to combine batch and head dims: (batch, heads, seq, head_dim) → (batch*heads, seq, head_dim)
|
||||
batch_heads = batch_size * self.num_heads
|
||||
Q_flat = Q_heads.reshape(batch_heads, seq_len, self.head_dim)
|
||||
K_flat = K_heads.reshape(batch_heads, seq_len, self.head_dim)
|
||||
V_flat = V_heads.reshape(batch_heads, seq_len, self.head_dim)
|
||||
|
||||
# Handle mask: Repeat for each head
|
||||
# mask: (batch, seq, seq) needs to become (batch*heads, seq, seq)
|
||||
if mask is not None:
|
||||
if mask.data.ndim == 2:
|
||||
# (seq, seq) → repeat for each batch and head
|
||||
mask_data = np.tile(mask.data[np.newaxis, :, :], (batch_heads, 1, 1))
|
||||
else:
|
||||
# (batch, seq, seq) → repeat for each head
|
||||
# For each batch element, repeat the mask num_heads times
|
||||
mask_data = np.repeat(mask.data, self.num_heads, axis=0)
|
||||
mask_flat = Tensor(mask_data)
|
||||
else:
|
||||
mask_flat = None
|
||||
|
||||
# Apply attention to all heads at once! (Tensor operation)
|
||||
# This batches all heads together - efficient and preserves gradients!
|
||||
attn_output, _ = scaled_dot_product_attention(Q_flat, K_flat, V_flat, mask_flat)
|
||||
|
||||
# Step 6: Reshape back to separate batch and heads: (batch*heads, seq, head_dim) → (batch, heads, seq, head_dim)
|
||||
attn_output = attn_output.reshape(batch_size, self.num_heads, seq_len, self.head_dim)
|
||||
|
||||
# Step 7: Transpose back: (batch, heads, seq, head_dim) → (batch, seq, heads, head_dim)
|
||||
attn_output = Tensor(np.transpose(attn_output.data, (0, 2, 1, 3)), requires_grad=attn_output.requires_grad)
|
||||
|
||||
# Step 8: Merge heads: (batch, seq, heads, head_dim) → (batch, seq, embed_dim)
|
||||
output = attn_output.reshape(batch_size, seq_len, self.embed_dim)
|
||||
|
||||
# Step 5: Apply attention to each head
|
||||
head_outputs = []
|
||||
for h in range(self.num_heads):
|
||||
# Extract this head's Q, K, V
|
||||
Q_h = Tensor(Q_heads[:, h, :, :]) # (batch, seq, head_dim)
|
||||
K_h = Tensor(K_heads[:, h, :, :])
|
||||
V_h = Tensor(V_heads[:, h, :, :])
|
||||
|
||||
# Apply attention for this head
|
||||
head_out, _ = scaled_dot_product_attention(Q_h, K_h, V_h, mask)
|
||||
head_outputs.append(head_out.data)
|
||||
|
||||
# Step 6: Concatenate heads back together
|
||||
# Stack: list of (batch, seq, head_dim) → (batch, num_heads, seq, head_dim)
|
||||
concat_heads = np.stack(head_outputs, axis=1)
|
||||
|
||||
# Transpose back: (batch, num_heads, seq, head_dim) → (batch, seq, num_heads, head_dim)
|
||||
concat_heads = np.transpose(concat_heads, (0, 2, 1, 3))
|
||||
|
||||
# Reshape: (batch, seq, num_heads, head_dim) → (batch, seq, embed_dim)
|
||||
concat_output = concat_heads.reshape(batch_size, seq_len, self.embed_dim)
|
||||
|
||||
# Step 7: Apply output projection
|
||||
output = self.out_proj.forward(Tensor(concat_output))
|
||||
# Step 9: Apply output projection (Tensor operation!)
|
||||
output = self.out_proj.forward(output)
|
||||
|
||||
return output
|
||||
### END SOLUTION
|
||||
|
||||
Reference in New Issue
Block a user