diff --git a/tinytorch/src/01_tensor/01_tensor.py b/tinytorch/src/01_tensor/01_tensor.py index 3b55c17cb..34b8d5576 100644 --- a/tinytorch/src/01_tensor/01_tensor.py +++ b/tinytorch/src/01_tensor/01_tensor.py @@ -936,15 +936,15 @@ Matrix Multiplication Process: A (2×3) B (3×2) C (2×2) ┌ ┐ ┌ ┐ ┌ ┐ │ 1 2 3 │ │ 7 8 │ │ 1×7+2×9+3×1 │ ┌ ┐ - │ │ × │ 9 1 │ = │ │ = │ 28 13│ - │ 4 5 6 │ │ 1 2 │ │ 4×7+5×9+6×1 │ │ 79 37│ + │ │ × │ 9 1 │ = │ │ = │ 28 16│ + │ 4 5 6 │ │ 1 2 │ │ 4×7+5×9+6×1 │ │ 79 49│ └ ┘ └ ┘ └ ┘ └ ┘ Computation Breakdown: C[0,0] = A[0,:] · B[:,0] = [1,2,3] · [7,9,1] = 1×7 + 2×9 + 3×1 = 28 -C[0,1] = A[0,:] · B[:,1] = [1,2,3] · [8,1,2] = 1×8 + 2×1 + 3×2 = 13 +C[0,1] = A[0,:] · B[:,1] = [1,2,3] · [8,1,2] = 1×8 + 2×1 + 3×2 = 16 C[1,0] = A[1,:] · B[:,0] = [4,5,6] · [7,9,1] = 4×7 + 5×9 + 6×1 = 79 -C[1,1] = A[1,:] · B[:,1] = [4,5,6] · [8,1,2] = 4×8 + 5×1 + 6×2 = 37 +C[1,1] = A[1,:] · B[:,1] = [4,5,6] · [8,1,2] = 4×8 + 5×1 + 6×2 = 49 Key Rule: Inner dimensions must match! A(m,n) @ B(n,p) = C(m,p) @@ -1491,12 +1491,12 @@ Step 1: Matrix Multiply [[1, 2, 3]] @ [[0.1, 0.2]] = [[1×0.1+2×0.3+3×0.5, 1×0.2+2×0.4+3×0.6]] [[4, 5, 6]] [[0.3, 0.4]] [[4×0.1+5×0.3+6×0.5, 4×0.2+5×0.4+6×0.6]] [[0.5, 0.6]] - = [[1.6, 2.6], - [4.9, 6.8]] + = [[2.2, 2.8], + [4.9, 6.4]] Step 2: Add Bias (Broadcasting) -[[1.6, 2.6]] + [0.1, 0.2] = [[1.7, 2.8], - [4.9, 6.8]] [5.0, 7.0]] +[[2.2, 2.8]] + [0.1, 0.2] = [[2.3, 3.0], + [4.9, 6.4]] [5.0, 6.6]] This is the foundation of every neural network layer! ``` diff --git a/tinytorch/src/09_convolutions/09_convolutions.py b/tinytorch/src/09_convolutions/09_convolutions.py index da935e727..4495c841d 100644 --- a/tinytorch/src/09_convolutions/09_convolutions.py +++ b/tinytorch/src/09_convolutions/09_convolutions.py @@ -207,7 +207,7 @@ Max Pooling Example (2×2 window): Input: Output: ┌───────────────┐ ┌───────┐ │ 1 3 2 4 │ │ 6 8 │ ← max([1,3,5,6])=6, max([2,4,7,8])=8 -│ 5 6 7 8 │ │ 9 9 │ ← max([5,2,9,1])=9, max([7,4,9,3])=9 +│ 5 6 7 8 │ │ 9 9 │ ← max([2,9,0,1])=9, max([1,3,9,3])=9 │ 2 9 1 3 │ └───────┘ │ 0 1 9 3 │ └───────────────┘ @@ -215,7 +215,7 @@ Input: Output: Average Pooling (same window): ┌─────────────┐ │ 3.75 5.25 │ ← avg([1,3,5,6])=3.75, avg([2,4,7,8])=5.25 -│ 2.75 5.75 │ ← avg([5,2,9,1])=4.25, avg([7,4,9,3])=5.75 +│ 3.0 4.0 │ ← avg([2,9,0,1])=3.0, avg([1,3,9,3])=4.0 └─────────────┘ ```