mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-06-02 08:32:31 -05:00
Enhance autograd module with comprehensive computational graph theory
- Added detailed explanation of gradient computation challenges at scale - Enhanced computational graph theory with forward/backward pass details - Included mathematical foundation of chain rule and differentiation modes - Comprehensive real-world impact examples (deep learning revolution) - Performance considerations and optimization strategies - Connection to neural network training and modern AI applications - Better explanation of why autograd is revolutionary for ML systems
This commit is contained in:
@@ -73,83 +73,235 @@ from tinytorch.core.activations import ReLU, Sigmoid, Tanh
|
||||
### Definition
|
||||
**Automatic differentiation (autograd)** is a technique that automatically computes derivatives of functions represented as computational graphs. It's the magic that makes neural network training possible.
|
||||
|
||||
### Why Autograd Matters in ML
|
||||
Without autograd, we'd have to manually compute gradients for every operation:
|
||||
- **Manual gradients**: Error-prone, time-consuming, doesn't scale
|
||||
- **Numerical gradients**: Slow, imprecise, unstable
|
||||
- **Automatic gradients**: Fast, precise, scalable to any complexity
|
||||
### The Fundamental Challenge: Computing Gradients at Scale
|
||||
|
||||
### The Key Insight: Computational Graphs
|
||||
Every mathematical expression can be represented as a graph:
|
||||
```
|
||||
Expression: f(x, y) = (x + y) * (x - y)
|
||||
Graph: x ──┐ ┌── add ──┐
|
||||
│ │ │
|
||||
├─────┤ ├── multiply ── output
|
||||
│ │ │
|
||||
y ──┘ └── sub ──┘
|
||||
#### **The Problem**
|
||||
Neural networks have millions or billions of parameters. To train them, we need to compute the gradient of the loss function with respect to every single parameter:
|
||||
|
||||
```python
|
||||
# For a neural network with parameters θ = [w1, w2, ..., wn, b1, b2, ..., bm]
|
||||
# We need to compute: ∇θ L = [∂L/∂w1, ∂L/∂w2, ..., ∂L/∂wn, ∂L/∂b1, ∂L/∂b2, ..., ∂L/∂bm]
|
||||
```
|
||||
|
||||
### Forward vs Backward Pass
|
||||
- **Forward pass**: Compute the function value
|
||||
- **Backward pass**: Compute gradients using the chain rule
|
||||
#### **Why Manual Differentiation Fails**
|
||||
- **Complexity**: Neural networks are compositions of thousands of operations
|
||||
- **Error-prone**: Manual computation is extremely difficult and error-prone
|
||||
- **Inflexible**: Every architecture change requires re-deriving gradients
|
||||
- **Inefficient**: Manual computation doesn't exploit computational structure
|
||||
|
||||
### Real-World Examples
|
||||
- **Neural networks**: Backpropagation through layers
|
||||
- **Optimization**: Gradient descent for parameter updates
|
||||
- **Scientific computing**: Sensitivity analysis, inverse problems
|
||||
- **Machine learning**: Any gradient-based optimization
|
||||
|
||||
Let's start building our autograd system!
|
||||
"""
|
||||
|
||||
# %% [markdown]
|
||||
"""
|
||||
## 🧠 The Mathematical Foundation
|
||||
|
||||
### Chain Rule: The Heart of Backpropagation
|
||||
The chain rule is what makes automatic differentiation possible:
|
||||
|
||||
```
|
||||
If z = f(g(x)), then dz/dx = (dz/df) * (df/dx)
|
||||
#### **Why Numerical Differentiation is Inadequate**
|
||||
```python
|
||||
# Numerical differentiation: f'(x) ≈ (f(x + h) - f(x)) / h
|
||||
def numerical_gradient(f, x, h=1e-5):
|
||||
return (f(x + h) - f(x)) / h
|
||||
```
|
||||
|
||||
### Computational Graph Perspective
|
||||
For a graph with nodes and edges:
|
||||
- **Nodes**: Variables and operations
|
||||
- **Edges**: Data flow and dependencies
|
||||
- **Forward pass**: Compute values following edges
|
||||
- **Backward pass**: Compute gradients following edges in reverse
|
||||
Problems:
|
||||
- **Slow**: Requires 2 function evaluations per parameter
|
||||
- **Imprecise**: Numerical errors accumulate
|
||||
- **Unstable**: Sensitive to choice of h
|
||||
- **Expensive**: O(n) cost for n parameters
|
||||
|
||||
### Example: Simple Expression
|
||||
```
|
||||
f(x, y) = x * y + sin(x)
|
||||
### The Solution: Computational Graphs
|
||||
|
||||
Forward:
|
||||
x = 2, y = 3
|
||||
a = x * y = 6
|
||||
b = sin(x) = sin(2) ≈ 0.909
|
||||
f = a + b = 6.909
|
||||
#### **Key Insight: Every Computation is a Graph**
|
||||
Any mathematical expression can be represented as a directed acyclic graph (DAG):
|
||||
|
||||
Backward:
|
||||
df/df = 1
|
||||
df/da = 1, df/db = 1
|
||||
da/dx = y = 3, da/dy = x = 2
|
||||
db/dx = cos(x) = cos(2) ≈ -0.416
|
||||
df/dx = df/da * da/dx + df/db * db/dx = 1*3 + 1*(-0.416) = 2.584
|
||||
df/dy = df/da * da/dy = 1*2 = 2
|
||||
```python
|
||||
# Expression: f(x, y) = (x + y) * (x - y)
|
||||
# Graph representation:
|
||||
# x ──┐ ┌── add ──┐
|
||||
# │ │ │
|
||||
# ├─────┤ ├── multiply ── output
|
||||
# │ │ │
|
||||
# y ──┘ └── sub ──┘
|
||||
```
|
||||
|
||||
### Connection to Neural Networks
|
||||
- **Layers**: Nodes in the computational graph
|
||||
- **Weights**: Parameters with gradients
|
||||
- **Loss function**: Final output node
|
||||
- **Backpropagation**: Backward pass through the entire network
|
||||
#### **Forward Pass: Computing Values**
|
||||
Traverse the graph from inputs to outputs, computing values at each node:
|
||||
|
||||
```python
|
||||
# Forward pass for f(x, y) = (x + y) * (x - y)
|
||||
x = 3, y = 2
|
||||
add_result = x + y = 5
|
||||
sub_result = x - y = 1
|
||||
output = add_result * sub_result = 5
|
||||
```
|
||||
|
||||
#### **Backward Pass: Computing Gradients**
|
||||
Traverse the graph from outputs to inputs, computing gradients using the chain rule:
|
||||
|
||||
```python
|
||||
# Backward pass for f(x, y) = (x + y) * (x - y)
|
||||
# Starting from output gradient = 1
|
||||
∂output/∂multiply = 1
|
||||
∂output/∂add = ∂output/∂multiply * ∂multiply/∂add = 1 * sub_result = 1
|
||||
∂output/∂sub = ∂output/∂multiply * ∂multiply/∂sub = 1 * add_result = 5
|
||||
∂output/∂x = ∂output/∂add * ∂add/∂x + ∂output/∂sub * ∂sub/∂x = 1 * 1 + 5 * 1 = 6
|
||||
∂output/∂y = ∂output/∂add * ∂add/∂y + ∂output/∂sub * ∂sub/∂y = 1 * 1 + 5 * (-1) = -4
|
||||
```
|
||||
|
||||
### Mathematical Foundation: The Chain Rule
|
||||
|
||||
#### **Single Variable Chain Rule**
|
||||
For composite functions: If z = f(g(x)), then:
|
||||
```
|
||||
dz/dx = (dz/df) * (df/dx)
|
||||
```
|
||||
|
||||
#### **Multivariable Chain Rule**
|
||||
For functions of multiple variables: If z = f(x, y) where x = g(t) and y = h(t), then:
|
||||
```
|
||||
dz/dt = (∂z/∂x) * (dx/dt) + (∂z/∂y) * (dy/dt)
|
||||
```
|
||||
|
||||
#### **Chain Rule in Computational Graphs**
|
||||
For any path from input to output through intermediate nodes:
|
||||
```
|
||||
∂output/∂input = ∏(∂node_{i+1}/∂node_i) for all nodes in the path
|
||||
```
|
||||
|
||||
### Automatic Differentiation Modes
|
||||
|
||||
#### **Forward Mode (Forward Accumulation)**
|
||||
- **Process**: Compute derivatives alongside forward pass
|
||||
- **Efficiency**: Efficient when #inputs << #outputs
|
||||
- **Use case**: Jacobian-vector products, sensitivity analysis
|
||||
|
||||
#### **Reverse Mode (Backpropagation)**
|
||||
- **Process**: Compute derivatives in reverse pass after forward pass
|
||||
- **Efficiency**: Efficient when #outputs << #inputs
|
||||
- **Use case**: Neural network training (many parameters, few outputs)
|
||||
|
||||
#### **Why Reverse Mode Dominates ML**
|
||||
Neural networks typically have:
|
||||
- **Many inputs**: Millions of parameters
|
||||
- **Few outputs**: Single loss value or small output vector
|
||||
- **Reverse mode**: O(1) cost per parameter vs O(n) for forward mode
|
||||
|
||||
### The Computational Graph Abstraction
|
||||
|
||||
#### **Nodes: Operations and Variables**
|
||||
- **Variable nodes**: Store values and gradients
|
||||
- **Operation nodes**: Define how to compute forward and backward passes
|
||||
|
||||
#### **Edges: Data Dependencies**
|
||||
- **Forward edges**: Data flow from inputs to outputs
|
||||
- **Backward edges**: Gradient flow from outputs to inputs
|
||||
|
||||
#### **Dynamic vs Static Graphs**
|
||||
- **Static graphs**: Define once, execute many times (TensorFlow 1.x)
|
||||
- **Dynamic graphs**: Build graph during execution (PyTorch, TensorFlow 2.x)
|
||||
|
||||
### Real-World Impact: What Autograd Enables
|
||||
|
||||
#### **Deep Learning Revolution**
|
||||
```python
|
||||
# Before autograd: Manual gradient computation
|
||||
def manual_gradient(x, y, w1, w2, b1, b2):
|
||||
# Forward pass
|
||||
z1 = w1 * x + b1
|
||||
a1 = sigmoid(z1)
|
||||
z2 = w2 * a1 + b2
|
||||
a2 = sigmoid(z2)
|
||||
loss = (a2 - y) ** 2
|
||||
|
||||
# Backward pass (manual)
|
||||
dloss_da2 = 2 * (a2 - y)
|
||||
da2_dz2 = sigmoid_derivative(z2)
|
||||
dz2_dw2 = a1
|
||||
dz2_db2 = 1
|
||||
dz2_da1 = w2
|
||||
da1_dz1 = sigmoid_derivative(z1)
|
||||
dz1_dw1 = x
|
||||
dz1_db1 = 1
|
||||
|
||||
# Chain rule application
|
||||
dloss_dw2 = dloss_da2 * da2_dz2 * dz2_dw2
|
||||
dloss_db2 = dloss_da2 * da2_dz2 * dz2_db2
|
||||
dloss_dw1 = dloss_da2 * da2_dz2 * dz2_da1 * da1_dz1 * dz1_dw1
|
||||
dloss_db1 = dloss_da2 * da2_dz2 * dz2_da1 * da1_dz1 * dz1_db1
|
||||
|
||||
return dloss_dw1, dloss_db1, dloss_dw2, dloss_db2
|
||||
|
||||
# With autograd: Automatic gradient computation
|
||||
def autograd_gradient(x, y, w1, w2, b1, b2):
|
||||
# Forward pass with gradient tracking
|
||||
z1 = w1 * x + b1
|
||||
a1 = sigmoid(z1)
|
||||
z2 = w2 * a1 + b2
|
||||
a2 = sigmoid(z2)
|
||||
loss = (a2 - y) ** 2
|
||||
|
||||
# Backward pass (automatic)
|
||||
loss.backward()
|
||||
|
||||
return w1.grad, b1.grad, w2.grad, b2.grad
|
||||
```
|
||||
|
||||
#### **Scientific Computing**
|
||||
- **Optimization**: Gradient-based optimization algorithms
|
||||
- **Inverse problems**: Parameter estimation from observations
|
||||
- **Sensitivity analysis**: How outputs change with input perturbations
|
||||
|
||||
#### **Modern AI Applications**
|
||||
- **Neural architecture search**: Differentiable architecture optimization
|
||||
- **Meta-learning**: Learning to learn with gradient-based meta-algorithms
|
||||
- **Differentiable programming**: Entire programs as differentiable functions
|
||||
|
||||
### Performance Considerations
|
||||
- **Memory**: Store intermediate values for backward pass
|
||||
- **Computation**: Reuse computations where possible
|
||||
- **Numerical stability**: Handle edge cases and precision
|
||||
|
||||
#### **Memory Management**
|
||||
- **Intermediate storage**: Must store forward pass results for backward pass
|
||||
- **Memory optimization**: Checkpointing, gradient accumulation
|
||||
- **Trade-offs**: Memory vs computation time
|
||||
|
||||
#### **Computational Efficiency**
|
||||
- **Graph optimization**: Fuse operations, eliminate redundancy
|
||||
- **Parallelization**: Compute independent gradients simultaneously
|
||||
- **Hardware acceleration**: Specialized gradient computation on GPUs/TPUs
|
||||
|
||||
#### **Numerical Stability**
|
||||
- **Gradient clipping**: Prevent exploding gradients
|
||||
- **Numerical precision**: Balance between float16 and float32
|
||||
- **Accumulation order**: Minimize numerical errors
|
||||
|
||||
### Connection to Neural Network Training
|
||||
|
||||
#### **The Training Loop**
|
||||
```python
|
||||
for epoch in range(num_epochs):
|
||||
for batch in dataloader:
|
||||
# Forward pass
|
||||
predictions = model(batch.inputs)
|
||||
loss = criterion(predictions, batch.targets)
|
||||
|
||||
# Backward pass (autograd)
|
||||
loss.backward()
|
||||
|
||||
# Parameter update
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
#### **Gradient-Based Optimization**
|
||||
- **Stochastic Gradient Descent**: Use gradients to update parameters
|
||||
- **Adaptive methods**: Adam, RMSprop use gradient statistics
|
||||
- **Second-order methods**: Use gradient and Hessian information
|
||||
|
||||
### Why Autograd is Revolutionary
|
||||
|
||||
#### **Democratization of Deep Learning**
|
||||
- **Research acceleration**: Focus on architecture, not gradient computation
|
||||
- **Experimentation**: Easy to try new ideas and architectures
|
||||
- **Accessibility**: Researchers don't need to be differentiation experts
|
||||
|
||||
#### **Scalability**
|
||||
- **Large models**: Handle millions/billions of parameters automatically
|
||||
- **Complex architectures**: Support arbitrary computational graphs
|
||||
- **Distributed training**: Coordinate gradients across multiple devices
|
||||
|
||||
Let's implement the Variable class that makes this magic possible!
|
||||
"""
|
||||
|
||||
# %% [markdown]
|
||||
|
||||
Reference in New Issue
Block a user