mirror of
https://github.com/harvard-edge/cs249r_book.git
synced 2026-04-30 09:38:38 -05:00
Second pass catching ~37 additional instances missed in the initial cleanup, including prose in frameworks, glossary definitions, footnotes, fig-caps, fig-alts, table cells, and callout content. All remaining `Nx` patterns are now exclusively inside Python code blocks (comments, docstrings, f-strings) or are mathematical variable expressions (e.g., derivative = 2x), which are correct as-is.
4111 lines
354 KiB
Plaintext
4111 lines
354 KiB
Plaintext
---
|
||
quiz: footnote_context_quizzes.json
|
||
concepts: frameworks_concepts.yml
|
||
glossary: frameworks_glossary.json
|
||
engine: jupyter
|
||
---
|
||
|
||
# ML Frameworks {#sec-ml-frameworks}
|
||
|
||
::: {layout-narrow}
|
||
::: {.column-margin}
|
||
\chapterminitoc
|
||
:::
|
||
|
||
\noindent
|
||
{fig-alt="Colorful illustration showing machine learning framework concepts with icons representing TensorFlow, Keras, PyTorch, and other tools, featuring geometric shapes, charts, and training and inference workflow labels."}
|
||
|
||
:::
|
||
|
||
## Purpose {.unnumbered}
|
||
|
||
\begin{marginfigure}
|
||
\mlsysstack{35}{90}{40}{45}{15}{0}{0}{10}
|
||
\end{marginfigure}
|
||
|
||
_Why does the framework---the layer between your math and your hardware---silently constrain every decision that follows?_
|
||
|
||
Neural networks are defined by mathematics (matrix multiplications, gradient computations, activation functions), but mathematics does not execute itself. Between the equations and the silicon lies a translation layer that decides how operations are scheduled on hardware, how memory is allocated across the compute hierarchy, and how gradients flow backward through the computational graph. The framework *is* this translation layer, and the translation is not neutral. An eager-mode framework that prioritizes debugging flexibility sacrifices the graph-level optimizations that can halve inference latency. A framework lacking support for your target accelerator renders your hardware investment useless. A framework with a rich training API but no export path to edge devices means the model you built cannot reach the deployment target it was designed for. Architecture choices are at least visible: engineers debate model size, layer count, and attention mechanisms explicitly. Framework choices are more insidious because they operate below the level of daily attention, silently determining which optimizations are possible, which hardware is reachable, and which deployment paths exist. In the AI Triad (@sec-introduction), the framework is the invisible mediator between Algorithm and Machine, and its design choices---baked into its compilation stack, memory management, and operator libraries---are difficult to reverse. Migrating between frameworks invalidates data pipelines, serving infrastructure, model checkpoints, and team expertise, typically requiring months of engineering effort for production systems. Framework selection is therefore not a tooling preference but an infrastructure commitment that determines what your system *can* do on the hardware it *must* run on.
|
||
|
||
::: {.content-visible when-format="pdf"}
|
||
\newpage
|
||
:::
|
||
|
||
::: {.callout-tip title="Learning Objectives"}
|
||
|
||
- Explain how ML frameworks solve three core problems: execution (computational graphs), differentiation (automatic differentiation), and abstraction (hardware-optimized operations)
|
||
- Compare eager, static, and hybrid (JIT) execution strategies using the Compilation Continuum and Dispatch Overhead principles to determine when compilation benefits outweigh costs
|
||
- Describe the nn.Module abstraction pattern for hierarchical composition, automatic parameter discovery, and mode-dependent behavior
|
||
- Analyze how the memory wall drives framework optimization strategies including kernel fusion, mixed-precision training, and activation checkpointing
|
||
- Evaluate major framework architectures (TensorFlow, PyTorch, JAX) based on their execution models, differentiation approaches, and deployment trade-offs
|
||
- Evaluate framework selection trade-offs by matching model requirements, hardware constraints, and deployment targets across the cloud-to-edge spectrum
|
||
|
||
:::
|
||
|
||
```{python}
|
||
#| label: a100-specs-blas
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ A100 SPECS FOR BLAS FOOTNOTE
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: BLAS footnote discussing hardware-optimized implementations
|
||
# │
|
||
# │ Goal: Establish the foundation of hardware-optimized BLAS.
|
||
# │ Show: The ~312 TFLOPS dense capability of A100 tensor cores.
|
||
# │ How: Retrieve peak throughput from mlsys.constants.
|
||
# │
|
||
# │ Imports: mlsys.constants (A100_FLOPS_FP16_TENSOR), mlsys.formatting (fmt)
|
||
# │ Exports: a100_tflops_fp16_str, a100_tflops_sparse_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsys import Hardware
|
||
from mlsys.formatting import fmt, check
|
||
from mlsys.constants import TFLOPs, second
|
||
|
||
# ┌── P.I.C.O. ISOLATED SCENARIO ───────────────────────────────────────────────
|
||
class A100BLAS:
|
||
"""
|
||
Namespace for A100 BLAS Specs.
|
||
Scenario: Dense vs Sparse Tensor Core throughput.
|
||
"""
|
||
|
||
# ┌── 1. PARAMETERS (Inputs) ───────────────────────────────────────────────
|
||
gpu = Hardware.Cloud.A100
|
||
|
||
# ┌── 2. CALCULATION (The Physics) ─────────────────────────────────────────
|
||
dense_flops = gpu.peak_flops.to(TFLOPs/second).magnitude
|
||
sparse_flops = dense_flops * 2
|
||
|
||
# ┌── 4. OUTPUTS (Formatting) ──────────────────────────────────────────────
|
||
dense_tflops_str = fmt(dense_flops, precision=0, commas=False)
|
||
sparse_tflops_str = fmt(sparse_flops, precision=0, commas=False)
|
||
|
||
# Note: Use A100BLAS.dense_tflops_str directly.
|
||
```
|
||
|
||
## Three Framework Problems {#sec-ml-frameworks-three-problems-every-framework-must-solve-317d}
|
||
|
||
Two lines of code: `model = Transformer(...)` followed by `loss.backward()`. Between them, invisible to the programmer, the framework orchestrates billions of floating-point operations across memory hierarchies, computes exact gradients through millions of parameters using **automatic differentiation** (the systematic application of the chain rule[^fn-chain-rule] to compute derivatives), schedules thousands of GPU kernel launches, and manages gigabytes of intermediate state. The simplicity is an illusion. Those two lines trigger machinery as complex as a compiler, because that is exactly what a modern ML framework is.
|
||
|
||
[^fn-chain-rule]: **Chain Rule**: The foundational calculus rule for differentiating composite functions, independently discovered by Leibniz (1676) and formalized by Euler (1748). For a composition $f(g(x))$, the chain rule states $\frac{df}{dx} = \frac{df}{dg} \cdot \frac{dg}{dx}$—the derivative of the whole equals the product of derivatives of each part. This rule is the mathematical engine behind backpropagation: a neural network is a deeply nested composition of functions (layers), and the chain rule decomposes the gradient of the loss with respect to any parameter into a product of local derivatives along the path from output to that parameter. Robert Wengert's 1964 paper showed that this decomposition could be mechanized as a graph algorithm, laying the foundation for automatic differentiation systems in modern frameworks.
|
||
|
||
The architectures defined in @sec-network-architectures specify *what* computations neural networks perform, but knowing *what* to compute is entirely different from knowing *how* to compute it efficiently. A Transformer's attention mechanism (introduced in @sec-network-architectures) requires coordinating computation across memory hierarchies and accelerator cores in patterns that naive implementations would execute 100× slower than optimized ones. Implementing these operations from scratch for every model would make deep learning economically infeasible. ML frameworks exist to bridge this gap by translating high-level model definitions into hardware-specific execution plans that extract maximum performance from silicon.
|
||
|
||
\index{ML Framework!compiler analogy}
|
||
A framework is to machine learning *what* a compiler is to traditional programming. A C compiler translates human-readable code into optimized machine instructions, managing register allocation, instruction scheduling, and memory layout. An ML framework translates high-level model definitions into hardware-specific execution plans, managing operator fusion, memory reuse, and device placement. This analogy is more than metaphor: modern frameworks literally include compilers, as we will see throughout this chapter.
|
||
|
||
\index{Deferred Execution!graph construction}
|
||
Every ML framework, regardless of API or design philosophy, must solve three core problems\index{Framework!core problems (execution, differentiation, abstraction)}. First, the *execution problem*: when and how should computation happen? Should operations execute immediately as written (**eager execution**[^fn-eager-execution]), or should the framework build a complete description first---a **computational graph**[^fn-computational-graph] (a structured representation of operations and their dependencies)---and optimize before executing (graph execution)? This choice shapes debugging capability, optimization potential, and deployment flexibility. Second, the *differentiation problem*: how should the framework compute gradients automatically? As established in @sec-neural-computation, training (the complex orchestration detailed in @sec-model-training) requires derivatives of a loss function with respect to millions or billions of parameters, and manual differentiation is error-prone at this scale. Frameworks must implement **automatic differentiation** systems that compute exact gradients for arbitrary compositions of operations while managing the memory overhead of storing intermediate values. Third, the *hardware abstraction problem*: how should the framework target diverse hardware from a single interface? The same model definition should run on CPUs, GPUs, TPUs, and mobile devices, each with different memory constraints and optimal execution patterns.
|
||
|
||
[^fn-computational-graph]: **Computational Graph**\index{Computational Graph!etymology}: A directed acyclic graph (DAG) where nodes represent operations and edges represent data (tensors). The concept traces to Wengert's 1964 paper showing that graph-based computation enables systematic gradient evaluation via the chain rule. Theano [@al2016theano] was the first ML framework to compile computational graphs into optimized GPU code. The representation enables powerful optimizations---operation fusion, dead code elimination, memory planning---impossible when operations execute independently.
|
||
|
||
[^fn-eager-execution]: **Eager Execution**\index{Eager Execution!etymology}: Also called "define-by-run" or "imperative mode." Popularized by Chainer (2015) and PyTorch [@paszke2019pytorch] (2016), eager execution runs each operation immediately, producing concrete results that can be inspected and debugged with standard Python tools. TensorFlow 2.0 (2019) adopted it as default. The trade-off is fundamental: sacrificing the global view of computation for developer productivity. Modern hybrid approaches like `torch.compile` and JAX's `jit` attempt to recover graph-level optimizations while preserving the eager model.
|
||
|
||
These three problems are deeply interconnected. The execution model determines *when* differentiation occurs and *what* optimizations are possible. The abstraction layer must support both execution styles across all hardware targets. Solving any one problem in isolation leads to frameworks that excel in narrow contexts but fail in broader deployment. Because these problems are ultimately about translating mathematics into efficient hardware execution, a useful perspective is to view frameworks not as libraries but as compilers.
|
||
|
||
::: {.callout-perspective title="The ML Compiler"}
|
||
In the context of the **Iron Law** (@sec-introduction-iron-law-ml-systems-c32a), a framework is a **Compiler for the Silicon Contract**\index{Framework!compiler for the Silicon Contract}.
|
||
|
||
Your "Source Code" is the model architecture (the **$O$** term). The framework's job is to take this high-level math and compile it into a series of hardware-specific kernel launches that:
|
||
|
||
1. Minimize **Data Movement ($D_{vol}$)** through techniques like kernel fusion.
|
||
2. Maximize **Utilization ($\eta$)** by matching operations to specialized hardware units like Tensor Cores.
|
||
3. Minimize **Overhead ($L_{lat}$)** through efficient asynchronous dispatch and graph capture.
|
||
|
||
Choosing a framework means choosing the compiler that determines *how* efficiently your model utilizes hardware.
|
||
:::
|
||
|
||
With these three problems in mind, we can now define *what* a machine learning framework fundamentally is.
|
||
|
||
::: {.callout-definition title="Machine Learning Frameworks"}
|
||
|
||
***Machine Learning Frameworks***\index{ML Framework} are the **Compilers** for the **Silicon Contract**. They translate high-level mathematical definitions into hardware-specific execution plans, managing the **Abstraction Gap** between algorithmic logic and physical silicon constraints (memory layout, kernel dispatch, differentiation).
|
||
|
||
:::
|
||
|
||
The scale of this translation is not obvious from the API surface. A single call to `loss.backward()` triggers operation recording, memory allocation for gradients, reverse-order graph traversal, and hardware-optimized kernel dispatch---machinery that would require hundreds of lines of manual calculus for even a three-layer network. For a contemporary language model, the framework additionally orchestrates billions of floating-point operations across distributed hardware, coordinating memory hierarchies, communication protocols, and numerical precision. Building this from scratch would be economically prohibitive for most organizations, which is why the history of ML frameworks is a history of progressively automating these layers.
|
||
|
||
The three problems---execution, differentiation, and abstraction---did not emerge simultaneously. Each arose as a response to scaling limitations in the previous generation of tools. Tracing this evolution reveals why modern frameworks are designed as they are and why the particular trade-offs they embody were, in hindsight, inevitable.
|
||
|
||
## Framework Evolution {#sec-ml-frameworks-frameworks-evolved-ac68}
|
||
|
||
\index{ML Framework!historical evolution}
|
||
\index{NumPy!framework foundation}
|
||
Modern frameworks reflect decades of incremental abstraction, each generation solving problems that made the previous generation impractical. This evolution progressed through three distinct levels of abstraction:
|
||
|
||
1. **Hardware Primitives (1979–1992)**: The **Basic Linear Algebra Subprograms (BLAS)**\index{BLAS!historical foundation}[^fn-blas] established standardized, hardware-optimized implementations of operations like matrix multiplication (`GEMM`[^fn-gemm]) [@lawson1979blas]. **LAPACK**[^fn-lapack] extended this with higher-level solvers. These libraries remain the hidden foundation of every modern ML system; vendors like Intel (MKL) and NVIDIA (cuBLAS) provide highly tuned versions for their silicon, as the hardware acceleration hierarchy in @sec-hardware-acceleration explains.
|
||
|
||
[^fn-gemm]: **GEMM (General Matrix Multiply)**: The most performance-critical operation in deep learning, classified as a BLAS Level 3 routine. GEMM computes $C = \alpha AB + \beta C$ for matrices $A$, $B$, and $C$ with scalar coefficients $\alpha$ and $\beta$. The "General" in GEMM distinguishes it from specialized variants (SYMM for symmetric matrices, TRMM for triangular). Nearly every neural network layer—fully connected, convolutional (via im2col transformation), and attention—reduces to GEMM at the hardware level. A modern A100 GPU achieves 312 TFLOPS on FP16 GEMM, but only if matrix dimensions are multiples of 8 (for Tensor Core alignment). This single operation's efficiency determines whether a framework achieves 30% or 90% of peak hardware utilization, explaining why hardware vendors invest years optimizing their GEMM implementations.
|
||
|
||
[^fn-lapack]: **LAPACK (Linear Algebra PACKage)**: Developed at the University of Tennessee, Oak Ridge National Laboratory, and NAG Ltd., first released in 1992. LAPACK extends BLAS with higher-level linear algebra routines: solving systems of equations, eigenvalue decomposition, singular value decomposition (SVD), and least-squares problems. The "PACKage" suffix reflects its origins as a curated collection of Fortran subroutines. While ML frameworks rarely call LAPACK directly during training (GEMM dominates), LAPACK routines appear in data preprocessing (PCA via SVD), model initialization (orthogonal initialization via QR decomposition), and numerical stability analysis. NVIDIA's cuSOLVER provides GPU-accelerated LAPACK equivalents for these operations.
|
||
|
||
[^fn-blas]: **BLAS (Basic Linear Algebra Subprograms)**: A standardized specification (Level 1: vector operations; Level 2: matrix-vector; Level 3: matrix-matrix like GEMM) published in 1979 that defines portable APIs for dense linear algebra. Hardware vendors implement optimized BLAS libraries: Intel MKL achieves near-peak FLOPS on x86 CPUs through AVX-512 vectorization, while NVIDIA cuBLAS uses Tensor Cores for up to `{python} A100BLAS.dense_tflops_str` TFLOPS (FP16/BF16/TF32) on A100 GPUs, or `{python} A100BLAS.sparse_tflops_str` TFLOPS with structured sparsity. This 45-year-old interface remains the performance foundation of modern ML frameworks.
|
||
|
||
2. **Vectorized Productivity (2006)**: **NumPy**[^fn-numpy] made Python viable for numerical computing by delegating heavy computation to underlying C and Fortran BLAS libraries. This "vectorization" approach (writing code in high-level Python but executing it in low-level C) became the dominant pattern, drastically reducing the gap between research ideas and execution speed. @sec-benchmarking quantifies this trade-off.
|
||
|
||
[^fn-numpy]: **NumPy**: Contraction of "Numerical Python," created by Travis Oliphant in 2005 by merging two earlier projects (Numeric and Numarray). Released publicly in 2006, NumPy established the n-dimensional array as Python's standard numerical container. Its array-oriented computing model, borrowed from APL and MATLAB, remains the conceptual foundation for PyTorch tensors and TensorFlow arrays.
|
||
|
||
3. **Automatic Differentiation (2015–present)**: While NumPy required engineers to manually derive and code gradients, modern frameworks like **TensorFlow**\index{TensorFlow!historical introduction} [@abadi2016tensorflow] and **PyTorch**\index{PyTorch!historical introduction} automated this through the **computational graph**. This architectural shift, separating the *definition* of the model from the *computation* of its derivatives, enabled the scaling of deep learning.
|
||
|
||
This evolution highlights a critical engineering lesson: scaling ML development required turning the mathematical chain rule into a software primitive. The transition from manual gradients to static graphs (Theano\index{Theano!static graphs pioneer} [@al2016theano], TensorFlow 1.x), and eventually to dynamic graphs (PyTorch[^fn-pytorch-paszke]), reflects the industry's search for the optimal balance between performance and developer velocity.
|
||
|
||
[^fn-pytorch-paszke]: **PyTorch**: Originally developed by Adam Paszke while still an undergraduate at the University of Warsaw, building on the Lua-based Torch framework. Released by Meta AI (then Facebook AI Research) in 2016, PyTorch's "define-by-run" approach made debugging neural networks as natural as debugging Python code. Paszke's insight was that researcher productivity mattered more than raw execution speed---an engineering trade-off that proved correct as PyTorch became the dominant framework in academic research. The framework's success demonstrated that developer experience could be a decisive competitive advantage in infrastructure software. Trace this evolution in @fig-mlfm-timeline, where each generation builds upon its predecessor: BLAS established optimized primitives, NumPy made them accessible from Python, and modern frameworks added automatic differentiation on top of this foundation.
|
||
|
||
::: {#fig-mlfm-timeline fig-env="figure" fig-pos="htb" fig-cap="**Computational Library Evolution**: Modern machine learning frameworks build upon decades of numerical computing advancements, transitioning from low-level routines like BLAS and LAPACK to high-level abstractions in NumPy and SciPy, and finally to deep learning frameworks such as Theano [@bergstra2010theano], TensorFlow, and PyTorch. SciPy was first released in 2001; the 2007 entry marks the period when both SciPy's maturing ecosystem and Theano's introduction of computational graphs jointly established Python as the dominant language for scientific and machine learning computing." fig-alt="Horizontal timeline from 1979 to 2018 with colored boxes marking key years. Dashed arrows connect to milestones below: 1979 BLAS introduced, 1992 LAPACK extends BLAS, 2006 NumPy becomes Python's numerical backbone, 2007 SciPy and Theano introduce computational graphs, 2015 TensorFlow revolutionizes distributed ML, 2016 PyTorch introduces dynamic graphs, 2018 JAX introduces functional paradigms."}
|
||
```{.tikz}
|
||
\begin{tikzpicture}[node distance=1mm,outer sep=0pt,font=\small\usefont{T1}{phv}{m}{n}]
|
||
\tikzset{%
|
||
Line/.style={line width=1.0pt,black!50
|
||
},
|
||
Box/.style={inner xsep=1pt,
|
||
draw=none,
|
||
fill=#1,
|
||
anchor=west,
|
||
text width=27mm,align=flush center,
|
||
minimum width=28mm, minimum height=13mm
|
||
},
|
||
Box/.default=red
|
||
}
|
||
\definecolor{col1}{RGB}{128, 179, 255}
|
||
\definecolor{col2}{RGB}{255, 255, 128}
|
||
\definecolor{col3}{RGB}{204, 255, 204}
|
||
\definecolor{col4}{RGB}{230, 179, 255}
|
||
\definecolor{col5}{RGB}{255, 153, 204}
|
||
\definecolor{col6}{RGB}{245, 82, 102}
|
||
\definecolor{col7}{RGB}{255, 102, 102}
|
||
|
||
\node[Box={col1}](B1){1979};
|
||
\node[Box={col2!},right=of B1](B2){1992};
|
||
\node[Box={col3},right=of B2](B3){2006};
|
||
\node[Box={col4},right=of B3](B4){2007};
|
||
\node[Box={col5},right=of B4](B5){2015};
|
||
\node[Box={col6},right=of B5](B6){2016};
|
||
\node[Box={col7},right=of B6](B7){2018};
|
||
%%
|
||
\foreach \x in{1,2,...,7}
|
||
\draw[dashed,thick,-latex](B\x)--++(270:6);
|
||
|
||
\path[red]([yshift=-8mm]B1.south west)coordinate(P)-|coordinate(K)(B7.south east);
|
||
|
||
\draw[line width=2pt,-latex](P)--(K)--++(0:3mm);
|
||
|
||
\node[Box={col1!50},below=2 of B1](BB1){BLAS introduced};
|
||
\node[Box={col2!50},below=2 of B2](BB2){LAPACK extends BLAS};
|
||
\node[Box={col3!50},below=2 of B3](BB3){NumPy becomes Python's numerical backbone};
|
||
\node[Box={col4!50},below=2 of B4](BB4){SciPy adds advanced computations};
|
||
\node[Box={col4!50},below= 2mm of BB4](BBB4){Theano introduces computational graphs};
|
||
\node[Box={col5!50},below=2 of B5](BB5){TensorFlow revolutionizes distributed ML};
|
||
\node[Box={col6!50},below=2 of B6](BB6){PyTorch introduces dynamic graphs};
|
||
\node[Box={col7!50},below=2 of B7](BB7){JAX introduces functional paradigms};
|
||
\end{tikzpicture}
|
||
```
|
||
:::
|
||
|
||
Each generation abstracted away details that consumed engineering effort in the previous one, yet each abstraction introduced new trade-offs. BLAS hid assembly-level optimization but fixed the interface. NumPy hid memory management but required manual differentiation. Modern frameworks hide gradient computation but introduce the execution model choice we examine next.
|
||
|
||
All modern frameworks converge on the same three core problems: *how* to execute computation, *how* to differentiate it, and *how* to abstract across hardware. We begin with the most visible of these: the execution problem, because its resolution determines what optimizations the other two problems can exploit.
|
||
|
||
## Execution Problem {#sec-ml-frameworks-execution-problem-e1e1}
|
||
|
||
\index{Execution Problem!definition}
|
||
Consider two engineers writing the same neural network. The first debugs interactively, printing tensor shapes after each operation, inspecting intermediate values, and stepping through code with `pdb`. The second waits 30 seconds for compilation, then watches the model run 3× faster with no ability to inspect any intermediate state. Both are correct; they have simply made different choices about the execution problem, the question of whether operations should execute immediately as written or be recorded for later execution. This choice creates a cascade of engineering trade-offs that shape every aspect of framework behavior, from debugging workflows to deployment options to peak hardware utilization.
|
||
|
||
### Why Execution Strategy Matters: The Memory Wall {#sec-ml-frameworks-execution-strategy-matters-memory-wall-1ce8}
|
||
|
||
```{python}
|
||
#| label: a100-memory-wall
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ A100 SPECS FOR MEMORY WALL DISCUSSION
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: Introducing the Memory Wall concept with A100 hardware numbers
|
||
# │
|
||
# │ Goal: Demonstrate the fundamental compute-vs-bandwidth imbalance.
|
||
# │ Show: Why memory-bound ops achieve <1% compute utilization on A100.
|
||
# │ How: Calculate the ridge point using peak FLOPS and HBM bandwidth.
|
||
# │
|
||
# │ Imports: mlsys.constants (A100_FLOPS_FP16_TENSOR, A100_MEM_BW)
|
||
# │ Exports: a100_tflops_fp16_str, a100_bw_tbs_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsys import Hardware
|
||
from mlsys.constants import TFLOPs, TB, second
|
||
from mlsys.formatting import fmt, check
|
||
|
||
# ┌── P.I.C.O. ISOLATED SCENARIO ───────────────────────────────────────────────
|
||
class MemoryWallSpecs:
|
||
"""
|
||
Namespace for A100 Memory Wall Specs.
|
||
Scenario: Demonstrating the 150x gap between compute and bandwidth.
|
||
"""
|
||
|
||
# ┌── 1. PARAMETERS (Inputs) ───────────────────────────────────────────────
|
||
gpu = Hardware.Cloud.A100
|
||
|
||
flops_fp16 = gpu.peak_flops.to(TFLOPs/second).magnitude
|
||
bw_tbs = gpu.memory_bw.to(TB/second).magnitude
|
||
|
||
# ┌── 2. CALCULATION (The Physics) ─────────────────────────────────────────
|
||
# Arithmetic Intensity "Ridge Point" (Ops / Byte)
|
||
ridge_point = gpu.ridge_point().magnitude
|
||
|
||
# ┌── 3. INVARIANTS (Guardrails) ───────────────────────────────────────────
|
||
check(ridge_point >= 100, f"A100 ridge point ({ridge_point:.1f}) is too low to claim a 'Memory Wall'.")
|
||
|
||
# ┌── 4. OUTPUTS (Formatting) ──────────────────────────────────────────────
|
||
a100_tflops_fp16_str = fmt(flops_fp16, precision=0, commas=False)
|
||
a100_bw_tbs_str = fmt(bw_tbs, precision=1, commas=False)
|
||
|
||
# Note: Use MemoryWallSpecs.a100_tflops_fp16_str directly.
|
||
```
|
||
|
||
\index{Memory Wall!execution strategy impact}
|
||
To understand why execution strategy matters so much, consider the **Memory Wall**\index{Memory Wall!definition} (first introduced in @sec-neural-computation), the growing gap between processor computational speed and memory bandwidth. Modern GPUs can perform arithmetic far faster than they can fetch data. On an A100 GPU with `{python} A100BLAS.dense_tflops_str` TFLOPS of compute and `{python} MemoryWallSpecs.a100_bw_tbs_str` TB/s of memory bandwidth, element-wise operations like ReLU achieve less than 1% of peak compute capacity, not because the hardware is slow, but because they spend nearly all their time waiting for data. The Roofline Model (@sec-machine-foundations-roofline-model-2529) formalizes this trade-off, showing exactly when operations are memory-bound versus compute-bound.
|
||
|
||
This creates a critical classification: operations are either **compute-bound**\index{Compute-bound operations} (limited by arithmetic throughput, like large matrix multiplications) or **memory-bound**\index{Memory-bound operations} (limited by data movement, like activation functions and normalization). Most individual neural network operation types (activations, normalizations, element-wise operations) are memory-bound, though the large matrix multiplications that dominate total compute time can be compute-bound.
|
||
|
||
The key optimization for memory-bound operations is **kernel fusion**\index{Kernel Fusion!optimizing memory-bound ops}, combining multiple operations into a single GPU function (called a *kernel*)[^fn-kernel] to avoid intermediate memory traffic. Fusing a sequence of LayerNorm, Dropout, and ReLU into one kernel can yield 5× speedup by eliminating intermediate writes between operations. FlashAttention[^fn-flashattention-frameworks] fuses the entire attention computation, reducing HBM traffic by 10–20× and achieving 2–4× wall-clock speedup.
|
||
|
||
[^fn-kernel]: **Kernel**: From German "Kern" (core/nucleus), borrowed from operating systems where it denotes the core program with full hardware access. In GPU programming, a kernel is the function that executes in parallel across thousands of threads. The metaphor extends: just as an OS kernel mediates between software and hardware, a GPU kernel is the fundamental unit where algorithms meet silicon.
|
||
|
||
[^fn-flashattention-frameworks]: FlashAttention (introduced in @sec-network-architectures) exemplifies kernel fusion taken to its logical extreme, fusing the entire attention computation into a single kernel that tiles data to fit in SRAM. This demonstrates that frameworks enabling such fusions can transform memory-bound operations into compute-bound ones.
|
||
|
||
A framework can only fuse operations it can see together. If operations execute immediately one at a time (eager execution)\index{Eager Execution!optimization limitations}, the framework cannot fuse them. If operations are recorded first into a graph (deferred execution)\index{Graph Execution!optimization advantages}, the framework can analyze and optimize the entire computation. This is why execution strategy matters so much: it determines what optimizations are even possible.
|
||
|
||
### The Computational Graph {#sec-ml-frameworks-computational-graph-00f7}
|
||
|
||
Kernel fusion is the key optimization for memory-bound operations, but fusion requires seeing multiple operations together. How do frameworks represent computation in a way that makes this visibility possible? The answer is the **computational graph**\index{Computational Graph!definition (DAG)}, a directed acyclic graph (DAG)[^fn-dag] where nodes represent operations and edges represent data dependencies. This graph is the framework's internal model of the computation.
|
||
|
||
[^fn-dag]: **DAG (Directed Acyclic Graph)**: A graph where edges have direction (data flows from producer to consumer) and no cycles exist (no operation can depend on its own output). The "directed" property ensures that data flows in one direction through the computation, defining a clear execution order. The "acyclic" property guarantees that execution terminates—cycles would create infinite loops. DAGs have deep roots in computer science: they appear in build systems (Make, 1976), task scheduling, version control (Git), and compiler optimization. For ML frameworks, the DAG representation enables topological sorting (determining valid execution orders), dependency analysis (identifying independent operations for parallel execution), and memory lifetime analysis (determining when intermediate tensors can be freed).
|
||
|
||
To ground this abstraction, examine @fig-comp-graph: computing $z = x \times y$ maps onto two input nodes ($x$ and $y$), one operation node (multiplication), and one output node ($z$). The execution problem asks: *when* is this graph constructed, and *when* is it executed?
|
||
|
||
::: {#fig-comp-graph fig-env="figure" fig-pos="htb" fig-cap="**Simple Computational Graph.** The computation $z = x \\times y$ represented as a graph, where nodes define operations and edges specify data flow." fig-alt="Simple directed graph with nodes x and y flowing into function f(x,y) which outputs z."}
|
||
```{.tikz}
|
||
\begin{tikzpicture}[font=\usefont{T1}{phv}{m}{n}\small]
|
||
%
|
||
\tikzset{%
|
||
Line/.style={line width=1.0pt,black!50,rounded corners
|
||
},
|
||
Box/.style={align=flush center,
|
||
shape=circle,
|
||
inner xsep=1pt,
|
||
node distance=1.4,
|
||
draw=BlueLine,
|
||
line width=0.75pt,
|
||
fill=BlueL,
|
||
minimum width=8mm,
|
||
},
|
||
}
|
||
\node[Box,fill=GreenFill,draw=GreenLine,minimum width=13mm, ](B1){$f(x,y)$};
|
||
\node[Box,right=of B1,fill=GreenFill,draw=GreenLine](B2){$z$};
|
||
\node[Box,above left=0.1 and 2 of B1,fill=GreenFill,draw=GreenLine](B3){$x$};
|
||
\node[Box,below left=0.1 and 2 of B1,fill=GreenFill,draw=GreenLine](B4){y};
|
||
\draw[-latex,Line](B1)--(B2);
|
||
\draw[-latex,Line](B3)to[bend left=25](B1);
|
||
\draw[-latex,Line](B4)to[bend right=25](B1);
|
||
\end{tikzpicture}
|
||
```
|
||
:::
|
||
|
||
Real machine learning models require much more complex graph structures. @fig-mlfm-comp-graph extends this representation to show a neural network computation graph alongside the system components that reason about it. In the left panel, notice how data flows through six operation nodes in a directed acyclic graph---each node's output becomes the next node's input. The right panel reveals what the framework gains by having this graph: it can query the structure to plan memory allocation for each tensor's lifetime, and it can assign operations to devices based on data dependencies rather than execution order. The critical insight is that the graph exists independently of execution, enabling the framework to optimize *before* any arithmetic occurs.
|
||
|
||
::: {#fig-mlfm-comp-graph fig-env="figure" fig-pos="htb" fig-cap="**Computation Graph with System Interactions.** A neural network computation graph (left) alongside system components including memory management and device placement (right) that interact with the graph to optimize resource allocation before execution." fig-alt="Left side shows computational graph with 6 operation nodes connected by data flow edges. Right side shows system components box with Memory Management and Device Placement nodes that interact with the computational graph."}
|
||
```{.tikz}
|
||
\begin{tikzpicture}[font=\usefont{T1}{phv}{m}{n}\small]
|
||
%
|
||
\tikzset{Line/.style={line width=1.0pt,black!50
|
||
},
|
||
Box/.style={align=flush center,
|
||
inner xsep=2pt,
|
||
node distance=1.1,
|
||
draw=BlueLine,
|
||
line width=0.75pt,
|
||
fill=BlueL,
|
||
text width=26mm,
|
||
minimum width=26mm, minimum height=10mm
|
||
},
|
||
Text/.style={%
|
||
inner sep=3pt,
|
||
draw=none,
|
||
line width=0.75pt,
|
||
fill=TextColor!80,
|
||
text=black,
|
||
font=\usefont{T1}{phv}{m}{n}\footnotesize,
|
||
align=flush center,
|
||
minimum width=7mm, minimum height=5mm
|
||
}
|
||
}
|
||
\begin{scope}[local bounding box=scope1]
|
||
\node[Box,fill=BlueL,draw=BlueLine](B1){Operation Node 1};
|
||
\node[Box,fill=BlueL,draw=BlueLine,below=of B1](B2){Operation Node 2};
|
||
\node[Box,fill=BlueL,draw=BlueLine,below left=0.75 and 0.1 of B2](B3){Operation Node 3};
|
||
\node[Box,fill=BlueL,draw=BlueLine,below right=0.75 and 0.1 of B2](B4){Operation Node 4};
|
||
\node[Box,fill=BlueL,draw=BlueLine,below=of B3](B5){Operation Node 5};
|
||
\node[Box,fill=BlueL,draw=BlueLine,below=of B4](B6){Operation Node 6};
|
||
%
|
||
\scoped[on background layer]
|
||
\node[draw=BackLine,inner xsep=4mm,inner ysep=6mm,yshift=2mm,
|
||
fill=BackColor!80,fit=(B1)(B3)(B6),line width=0.75pt](BB1){};
|
||
\node[below=2pt of BB1.north east,anchor=north east]{Computational Graph};
|
||
\end{scope}
|
||
%
|
||
\begin{scope}[local bounding box=scope2, shift={($(scope1.east)+(45mm,10mm)$)}]
|
||
\node[Box,fill=OrangeL,draw=OrangeLine](2B1){Memory Management};
|
||
\node[Box,fill=OrangeL,draw=OrangeLine,below=of 2B1](2B2){Device Placement};
|
||
%
|
||
\scoped[on background layer]
|
||
\node[draw=BackLine,inner xsep=4mm,inner ysep=6mm,yshift=2mm,
|
||
fill=BackColor!50,fit=(2B1)(2B2),line width=0.75pt](2BB1){};
|
||
\node[below=2pt of 2BB1.north east,anchor=north east]{System Components};
|
||
\end{scope}
|
||
\draw[-latex,Line](B1)--node[Text,pos=0.45]{Data Flow}(B2);
|
||
\draw[-latex,Line](B3)--node[Text,pos=0.45]{Data Flow}(B5);
|
||
\draw[-latex,Line](B4)--node[Text,pos=0.45]{Data Flow}(B6);
|
||
\draw[-latex,Line](B2)-|node[Text,pos=0.45]{Data Flow}(B3);
|
||
\draw[-latex,Line](B2)-|node[Text,pos=0.45]{Data Flow}(B4);
|
||
\draw[latex-,Line](2B2) --node[Text,pos=0.55]{Interacts with} (scope1.east|-2B2);
|
||
\draw[latex-,Line](2B1) --node[Text,pos=0.55]{Interacts with} (scope1.east|-2B1);
|
||
\end{tikzpicture}
|
||
```
|
||
:::
|
||
|
||
This graph representation is more than a visualization; it is the data structure that enables both efficient execution and automatic differentiation. The answer to *when* this graph is constructed creates a design choice with cascading implications:
|
||
|
||
- **For debugging**: Can you print intermediate values? Step through code with a debugger?
|
||
- **For optimization**: Can the framework see multiple operations at once to fuse them?
|
||
- **For deployment**: Can the model run without a Python interpreter?
|
||
- **For flexibility**: Can control flow depend on computed tensor values?
|
||
|
||
No single execution model optimizes all these dimensions. Frameworks must choose their position in this trade-off space, and practitioners must understand these trade-offs to select appropriate tools and write efficient code. The following sections examine *how* different execution strategies navigate these constraints.
|
||
|
||
### Three Execution Strategies {#sec-ml-frameworks-three-execution-strategies-5934}
|
||
|
||
The computational graph representation is powerful, but it raises a critical design question: *when* should the framework build this graph? Consider a simple operation like `y = x * 2`. Two distinct approaches exist:
|
||
|
||
1. **Immediate execution**: Perform the multiplication right now, storing the result in `y`. Natural and debuggable, but the framework sees only one operation at a time.
|
||
|
||
2. **Deferred execution**: Record the intention to multiply, building a graph of operations. Execute later when explicitly requested. Less intuitive, but the framework sees the complete computation, enabling optimization.
|
||
|
||
Neither approach dominates; each embodies different trade-offs between **flexibility** and **optimization potential**. Modern frameworks have explored three primary execution strategies: eager execution with dynamic graphs, static computation graphs, and hybrid approaches that combine JIT compilation with eager development. We examine each through its systems implications.
|
||
|
||
#### Eager Execution with Dynamic Graphs {#sec-ml-frameworks-eager-execution-dynamic-graphs-29c2}
|
||
|
||
\index{Eager Execution!define-by-run model}
|
||
\index{Dynamic Graphs!eager execution}
|
||
|
||
::: {.callout-example title="Eager vs. Graph Execution Code Comparison"}
|
||
**PyTorch (Eager Execution)**:
|
||
```{.python}
|
||
import torch
|
||
|
||
x = torch.tensor([1.0, 2.0])
|
||
y = x * 2
|
||
print(f"Intermediate value: {y}") # Works immediately
|
||
z = y.sum()
|
||
```
|
||
|
||
**TensorFlow 1.x (Static Graph)**:
|
||
|
||
```{.python}
|
||
import tensorflow as tf
|
||
|
||
x = tf.placeholder(tf.float32)
|
||
y = x * 2
|
||
# print(y) -> Prints Tensor("mul:0"...), not value!
|
||
z = tf.reduce_sum(y)
|
||
|
||
with tf.Session() as sess:
|
||
result = sess.run(z, feed_dict={x: [1.0, 2.0]})
|
||
```
|
||
:::
|
||
|
||
Eager execution runs operations immediately as encountered, building the computation graph dynamically during execution. When a programmer writes `y = x * 2`, the multiplication happens instantly and the result is available for immediate use.
|
||
|
||
This provides the flexibility of normal programming: developers can print intermediate values, use conditionals based on computed results, and debug with standard tools. The framework records operations as they happen, constructing a **dynamic graph**\index{Dynamic Graphs!runtime construction} that reflects the actual execution path taken.
|
||
|
||
\index{Autograd Tape!definition}
|
||
For gradient computation, the framework records a history of operations in what's called an **autograd tape**\index{Autograd Tape!dynamic graph construction}[^fn-autograd-tape], a transient data structure built during execution. Each tensor operation creates a node that records: the operation performed, references to input tensors, and how to compute gradients. These nodes form a directed acyclic graph (DAG) of operations built **during** forward pass execution, not before.
|
||
|
||
[^fn-autograd-tape]: **Autograd Tape**: A dynamically constructed data structure recording operations during forward pass execution. Each operation adds a node to the tape containing: (1) the operation type, (2) references to input tensors, (3) saved intermediate values needed for gradient computation, and (4) the backward function implementing chain rule application. The tape is destroyed after backward pass to free memory.
|
||
|
||
Consider this example using PyTorch, which implements eager execution as its default mode. @lst-autograd-tape-example shows *how* operations are recorded as they execute.
|
||
|
||
::: {#lst-autograd-tape-example lst-cap="**Autograd Tape Construction**: Each operation executes immediately while recording a backward node to the autograd tape for later gradient computation."}
|
||
```{.python}
|
||
import torch
|
||
|
||
x = torch.tensor([1.0], requires_grad=True)
|
||
y = x * 2 # Executes immediately; records MulBackward node
|
||
z = y + 1 # Executes immediately; records AddBackward node
|
||
# The autograd tape exists NOW, built during execution
|
||
```
|
||
:::
|
||
|
||
After these two operations, the framework has constructed an autograd tape with two nodes: one for the multiplication and one for the addition. The tape records that `z` depends on `y`, and `y` depends on `x`.
|
||
|
||
Calling `z.backward()` traverses this tape in reverse topological order, applying the chain rule at each node:
|
||
|
||
1. Compute $\frac{\partial z}{\partial z} = 1$ (seed gradient)
|
||
2. Call `AddBackward0.backward()` $\rightarrow \frac{\partial z}{\partial y} = 1$
|
||
3. Call `MulBackward0.backward()` $\rightarrow \frac{\partial z}{\partial x} = 2$
|
||
4. Accumulate gradient in `x.grad`
|
||
|
||
After `backward()` completes, the autograd tape is **destroyed** to free memory. The next forward pass builds a completely new tape. This design enables memory-efficient training: the system pays for gradient computation storage only during the backward pass.
|
||
|
||
Follow this "define-by-run" execution model step by step in @fig-mlfm-dynamic-graph-flow. Notice the alternating pattern: define, execute, define, execute. Each operation completes entirely before the next begins, which is why standard Python debuggers work---a developer can set a breakpoint between any two operations and inspect the actual tensor values. This contrasts sharply with static graphs, where all operations must be defined before any execution occurs.
|
||
|
||
::: {#fig-mlfm-dynamic-graph-flow fig-env="figure" fig-pos="htb" fig-cap="**Dynamic Graph Execution Flow**: In eager execution, each operation is defined and immediately executed before the next operation begins. This define-by-run model enables natural debugging and data-dependent control flow at the cost of optimization opportunities." fig-alt="Flow diagram showing Start to Operation 1 to Operation 1 Executed to Operation 2 to Operation 2 Executed to End. Above arrows show Define Operation, Execute Operation, Define Next Operation, Execute Operation, Repeat Until Done."}
|
||
```{.tikz}
|
||
\begin{tikzpicture}[font=\usefont{T1}{phv}{m}{n}\small]
|
||
\tikzset{Line/.style={line width=1.0pt,black!50
|
||
},
|
||
Box/.style={align=flush center,
|
||
inner xsep=2pt,
|
||
node distance=1.0,
|
||
draw=BlueLine,
|
||
line width=0.75pt,
|
||
fill=BlueL,
|
||
text width=18mm,
|
||
minimum width=18mm,
|
||
minimum height=10mm
|
||
},
|
||
Text/.style={%
|
||
inner sep=4pt,
|
||
draw=none,
|
||
line width=0.75pt,
|
||
fill=TextColor!80,
|
||
text=black,
|
||
font=\usefont{T1}{phv}{m}{n}\footnotesize,
|
||
align=flush center,
|
||
minimum width=7mm, minimum height=5mm
|
||
},
|
||
}
|
||
\node[Box,text width=12mm,minimum width=14mm,
|
||
fill=OliveL!70,draw=OliveLine](B1){Start};
|
||
\node[Box,fill=VioletL,draw=VioletLine,right=of B1](B2){Operation 1};
|
||
\node[Box,fill=GreenL,draw=GreenLine,right=of B2,
|
||
minimum height=14mm](B3){Operation 1 Executed};
|
||
\node[Box,node distance=2.1,fill=VioletL,draw=VioletLine,right=of B3](B4){Operation 2};
|
||
\node[Box,fill=GreenL,draw=GreenLine,right=of B4,
|
||
minimum height=14mm](B5){Operation 2 Executed};
|
||
\node[Box,right=of B5,text width=12mm,minimum width=14mm,
|
||
fill=OliveL!70,draw=OliveLine](B6){End};
|
||
%%
|
||
\foreach \x/\y in{1/2,2/3,3/4,4/5,5/6}
|
||
\draw[-latex,Line](B\x)--(B\y);
|
||
\def\vi{15mm}
|
||
\draw[thick]($(B1.east)!0.5!(B2.west)$)--++(90:\vi)
|
||
node[Text]{Define\\ Operation};
|
||
\draw[thick]($(B2.east)!0.5!(B3.west)$)--++(90:\vi)
|
||
node[Text]{Execute\\ Operation};
|
||
\draw[thick]($(B3.east)!0.5!(B4.west)$)--++(90:\vi)
|
||
node[Text]{Define Next\\ Operation};
|
||
\draw[thick]($(B4.east)!0.5!(B5.west)$)--++(90:\vi)
|
||
node[Text]{Execute\\ Operation};
|
||
\draw[thick]($(B5.east)!0.5!(B6.west)$)--++(90:\vi)
|
||
node[Text](BB6){Repeat\\ Until Done};
|
||
\end{tikzpicture}
|
||
```
|
||
:::
|
||
|
||
##### Systems Implications: Flexibility {.unnumbered}
|
||
|
||
The dynamic autograd tape enables capabilities impossible with static graphs. Conditionals and loops can depend on tensor values computed during execution, enabling algorithms like beam search, dynamic RNN lengths, or adaptive computation that adjust their behavior based on intermediate results. Different iterations can process tensors of different sizes without redefining the computation---essential for natural language processing where sentence lengths vary. Most importantly, because operations execute immediately in standard Python, developers can print tensors, inspect values, and use standard debuggers (`pdb`, breakpoints) to diagnose errors in the same way they would debug any Python program.
|
||
|
||
##### Systems Implications: Overhead {.unnumbered}
|
||
|
||
This flexibility comes with performance costs that map directly to the Iron Law (@sec-introduction-iron-law-ml-systems-c32a). Each forward pass rebuilds the autograd tape from scratch, adding Python object creation, reference counting, and node linking overhead to $L_{lat}$ on every iteration. Every operation goes through Python dispatch---function lookup, argument parsing, type checking---costing ~10μs per operation, which becomes significant for models with thousands of operations. Because the graph is built during execution, the framework cannot see across operations to fuse kernels, so each operation launches its own GPU kernel, inflating both $O$ and $D_{vol}$. The autograd tape itself stores references to all intermediate tensors and `Function` nodes, increasing memory consumption by 2–3× compared to forward-only execution and adding pressure to $D_{vol}$. Together, these costs create a performance ceiling that becomes visible as models grow smaller and dispatch overhead dominates computation.
|
||
|
||
For a typical ResNet-50 forward pass, eager execution overhead adds approximately 5--10 ms compared to an optimized compiled version, with the majority spent in Python dispatch and tape construction rather than actual computation.
|
||
|
||
The overhead costs of eager execution raise a natural question: what if we could see the entire computation *before* executing any of it? This is precisely what static computation graphs provide.
|
||
|
||
#### Static Computation Graphs {#sec-ml-frameworks-static-computation-graphs-e100}
|
||
|
||
\index{Static Graphs!define-then-run model}
|
||
\index{Computational Graph!static}
|
||
Static graph execution defines the complete computational graph as a symbolic representation first, then executes it separately. This "define-then-run" execution model means the graph exists **before** any computation occurs, enabling aggressive ahead-of-time optimization. The key insight is that if the framework sees the entire computation before running it, the framework can analyze, transform, and optimize the graph globally---a visibility impossible when operations execute immediately one at a time.
|
||
|
||
##### Two-Phase Execution {.unnumbered}
|
||
|
||
Static graphs implement a clear separation between graph construction and execution. @lst-tf-static-graph illustrates the two phases using TensorFlow 1.x, which pioneered this approach: symbolic definition creates placeholders and operations without computation, while explicit execution triggers actual arithmetic:
|
||
|
||
::: {#lst-tf-static-graph lst-cap="**Static Graph Two-Phase Execution**: Graph construction (symbolic definition) is separated from execution (actual computation), enabling ahead-of-time optimization."}
|
||
```{.python}
|
||
# Phase 1: Graph Construction (symbolic, no computation)
|
||
import tensorflow.compat.v1 as tf
|
||
|
||
tf.disable_v2_behavior()
|
||
|
||
# Define graph symbolically
|
||
x = tf.placeholder(tf.float32, shape=[1]) # Just a placeholder
|
||
y = x * 2 # Not executed, just recorded
|
||
z = y + 1 # Still no execution
|
||
# At this point, nothing has been computed
|
||
|
||
# Phase 2: Graph Execution (actual computation)
|
||
with tf.Session() as sess:
|
||
result = sess.run(z, feed_dict={x: [1.0]})
|
||
# Now computation happens: result = [3.0]
|
||
```
|
||
:::
|
||
|
||
Compare this with the dynamic model by examining @fig-mlfm-static-graph. Notice the clear boundary between phases: in the definition phase (left), the framework builds a complete blueprint without touching any data; in the execution phase (right), data flows through an already-optimized graph. This separation enables the framework to answer questions during the definition phase that are impossible to answer operation-by-operation: "Which intermediate tensors can share memory?" "Which operations can fuse into a single kernel?" "What's the total memory footprint?" By the time execution begins, these optimizations are already baked in.
|
||
|
||
::: {#fig-mlfm-static-graph fig-env="figure" fig-pos="htb" fig-cap="**Static Graph: Define then Execute.** The two phases of static graph execution. The definition phase (left) declares operations and builds the graph. The execution phase (right) loads data, runs the optimized graph, and produces results." fig-alt="Flow diagram showing two phases. Definition Phase: Define Operations, Declare Variables, Build Graph. Execution Phase: Load Data, Run Graph, Get Results. Arrows connect boxes left to right."}
|
||
```{.tikz}
|
||
\begin{tikzpicture}[font=\usefont{T1}{phv}{m}{n}\small]
|
||
%
|
||
\tikzset{Line/.style={line width=1.0pt,black!50,rounded corners
|
||
},
|
||
Box/.style={align=flush center,
|
||
inner xsep=2pt,
|
||
node distance=0.7,
|
||
draw=BlueLine,
|
||
line width=0.75pt,
|
||
fill=BlueL,
|
||
text width=18mm,
|
||
minimum width=18mm, minimum height=10mm
|
||
},
|
||
}
|
||
\node[Box,fill=VioletL,draw=VioletLine](B1){Define Operations};
|
||
\node[Box,fill=VioletL,draw=VioletLine,right=of B1](B2){Declare Variables};
|
||
\node[Box,fill=VioletL,draw=VioletLine,right=of B2](B3){Build Graph};
|
||
%
|
||
\scoped[on background layer]
|
||
\node[draw=BackLine,inner xsep=4mm,inner ysep=6mm,yshift=2mm,
|
||
fill=BackColor!80,fit=(B1)(B2)(B3),line width=0.75pt](BB1){};
|
||
\node[below=2pt of BB1.north,anchor=north]{Definition Phase};
|
||
%
|
||
\node[Box,node distance=1.5,fill=BrownL,draw=BrownLine,right=of B3](B4){Load Data};
|
||
\node[Box,fill=BrownL,draw=BrownLine,right=of B4](B5){Run Graph};
|
||
\node[Box,fill=BrownL,draw=BrownLine,right=of B5](B6){Get Results};
|
||
%
|
||
\scoped[on background layer]
|
||
\node[draw=GreenLine,inner xsep=4mm,inner ysep=6mm,yshift=2mm,
|
||
fill=GreenL!20,fit=(B4)(B5)(B6),line width=0.75pt](BB2){};
|
||
\node[below=2pt of BB2.north,anchor=north]{Execution Phase};
|
||
%
|
||
\foreach \x/\y in{1/2,2/3,3/4,4/5,5/6}
|
||
\draw[-latex,Line](B\x)--(B\y);
|
||
\end{tikzpicture}
|
||
```
|
||
:::
|
||
|
||
The key difference from eager execution is that during construction, `x`, `y`, and `z` are not tensors containing values but rather symbolic nodes in a graph. Operations like `*` and `+` add nodes to the graph definition without performing any arithmetic. The `print(y)` line in the code example would reveal this distinction---it would print tensor metadata, not a computed value. Execution is triggered explicitly through `sess.run()`, at which point the framework analyzes the complete graph, optimizes it, and executes the optimized version with the provided input data.
|
||
|
||
##### Ahead-of-Time Optimization {.unnumbered}
|
||
|
||
\index{Ahead-of-Time (AOT) Optimization}
|
||
Because the framework has the complete graph before execution, it can perform optimizations impossible in eager mode. The kernel fusion\index{Kernel Fusion!static graph optimization} opportunity introduced in @sec-ml-frameworks-execution-strategy-matters-memory-wall-1ce8 becomes actionable here: because the framework sees `y = x * 2` and `z = y + 1` together in the graph, it can fuse them into `z = x * 2 + 1`, eliminating the intermediate `y` and halving memory traffic. With the full graph visible, the compiler can also calculate exact memory requirements for all tensors before execution, pre-allocating memory in a single pass and reusing buffers where lifetimes do not overlap. Tensor layouts can be transformed globally (e.g., NCHW to NHWC) to match hardware preferences without runtime copying. Dead code elimination[^fn-dead-code]\index{Dead Code Elimination} removes operations whose results are never consumed, and constant folding\index{Constant Folding} pre-computes operations on constant values at graph construction time, so the cost is paid once rather than on every forward pass.
|
||
|
||
[^fn-dead-code]: **Dead Code Elimination (DCE)**: A compiler optimization dating to the earliest optimizing compilers of the 1960s. The term "dead" refers to code whose results are never used by any subsequent operation—it is "dead" in the sense that removing it has no observable effect on the program's output. In ML computational graphs, dead code arises from debugging operations left in production code (print nodes, assertion checks), unused branches in conditional computations, and gradient computations for parameters not being trained (frozen layers). DCE is safe to apply because the DAG structure makes it easy to verify that no downstream node consumes a candidate node's output. For large transformer models, DCE can eliminate 5--15% of graph nodes, reducing both memory allocation and kernel launch overhead.
|
||
|
||
These optimizations map directly to **Iron Law** terms: kernel fusion reduces $D_{vol}$ by eliminating intermediate memory writes, constant folding reduces $O$ by computing values once, memory pre-allocation reduces $L_{lat}$ by avoiding runtime allocation overhead, and dead code elimination reduces both $O$ and $D_{vol}$.
|
||
|
||
\index{XLA (Accelerated Linear Algebra)!definition}
|
||
Compilation frameworks like XLA (Accelerated Linear Algebra)\index{XLA (Accelerated Linear Algebra)!graph compilation}[^fn-xla] [@GoogleXLA] take this further, compiling the TensorFlow graph to optimized machine code for specific hardware. For a transformer encoder block, XLA can achieve 1.5–2× speedup over unoptimized execution through aggressive fusion and hardware-specific code generation.
|
||
|
||
##### Systems Implications {.unnumbered}
|
||
|
||
\index{Static Graphs!optimization benefits}
|
||
Static graphs achieve high performance through ahead-of-time optimization. Kernel fusion reduces memory bandwidth requirements (often the bottleneck for ML workloads), and hardware-specific compilation enables near-peak utilization.
|
||
|
||
The cost of this performance is reduced flexibility. Standard Python control flow (`if`, `for`) cannot depend on computed tensor values in static graphs. TensorFlow provides graph-level control flow primitives (`tf.cond` and `tf.while_loop`) that support data-dependent conditions, but these require special syntax that diverges from standard Python, making code harder to write and reason about. Debugging is difficult because stack traces point to graph construction code, not execution code. Error messages often reference symbolic node names rather than the actual operations that failed.
|
||
|
||
#### Hybrid Approaches: JIT Compilation {#sec-ml-frameworks-hybrid-approaches-jit-compilation-8954}
|
||
|
||
\index{JIT Compilation!fidelity vs. generality}
|
||
Can we have both eager debugging and graph optimization? JIT compilation attempts this by capturing computation at runtime. The core trade-off is *fidelity versus generality*. Tracing captures the exact execution path taken during a sample run, producing high fidelity to that specific input but missing branches not taken. Source-level compilation (scripting) analyzes the full program structure, preserving all control flow branches but requiring a restricted language subset. Both approaches produce an intermediate representation (IR) that enables the same ahead-of-time optimizations available to static graphs: operator fusion, constant folding, dead code elimination, and buffer reuse.
|
||
|
||
This trade-off has a direct **Iron Law** consequence. JIT compilation amortizes the $L_{lat}$ (dispatch overhead) across the compiled region. Longer compiled regions mean more overhead amortized per operation, which explains why graph breaks are performance-critical: each break forces a return to eager dispatch, resetting the amortization.
|
||
|
||
PyTorch's TorchScript exemplifies both strategies. Tracing\index{JIT Compilation!tracing} executes a function once with example inputs and records every tensor operation into a static computation graph. @lst-torchscript-trace demonstrates the approach: the traced module becomes a compiled artifact that can be serialized, optimized, and executed independently of the Python interpreter:
|
||
|
||
::: {#lst-torchscript-trace lst-cap="**TorchScript Tracing**: Captures tensor operations by executing a function with example inputs and recording the execution path into a static computation graph."}
|
||
```{.python}
|
||
import torch
|
||
|
||
|
||
def forward(x):
|
||
y = x * 2
|
||
z = y + 1
|
||
return z
|
||
|
||
|
||
# Trace the function by running it once
|
||
x_example = torch.tensor([1.0])
|
||
traced = torch.jit.trace(forward, x_example)
|
||
|
||
# traced is now a compiled TorchScript module
|
||
# Can serialize: torch.jit.save(traced, "model.pt")
|
||
# Can optimize: fusion, constant folding
|
||
# Can run without Python interpreter
|
||
```
|
||
:::
|
||
|
||
The critical limitation of tracing reveals the fidelity-generality trade-off concretely. Because tracing records a single execution path, it cannot handle data-dependent control flow. @lst-tracing-silent-failure illustrates a silent correctness failure.
|
||
|
||
::: {#lst-tracing-silent-failure lst-cap="**Tracing Silent Failure**: Tracing records only the execution path taken by the example input, silently ignoring all other branches of data-dependent control flow."}
|
||
```{.python}
|
||
def conditional_forward(x):
|
||
if x.sum() > 0: # Data-dependent condition
|
||
return x * 2
|
||
else:
|
||
return x * 3
|
||
|
||
|
||
traced = torch.jit.trace(conditional_forward, torch.tensor([1.0]))
|
||
# Tracing captures ONLY the x.sum() > 0 branch
|
||
# If input later has sum <= 0, traced version
|
||
# still executes x * 2 branch
|
||
```
|
||
:::
|
||
|
||
Tracing records whichever branch executed during the example input. Subsequent executions always follow the traced path regardless of input values, silently producing incorrect results for inputs that would have taken the other branch. This failure mode is particularly dangerous because it produces no error, only wrong outputs. In production, such bugs can persist for months before anyone notices that a small fraction of inputs are being misclassified---and by then, debugging is a forensic exercise.
|
||
|
||
The alternative, scripting\index{JIT Compilation!scripting}, achieves generality by analyzing Python source code directly and compiling it to TorchScript IR without executing. The scripting compiler parses the abstract syntax tree (AST), converts supported operations to IR operations, and preserves the branching structure so that both branches of a conditional exist in the compiled representation. The cost of this generality is a restricted Python subset: type annotations are required where inference fails, arbitrary Python objects and standard library modules are excluded, and dynamic metaprogramming is forbidden.
|
||
|
||
Tracing suits feed-forward models without conditionals (ResNet, VGG, Vision Transformer) and models where control flow depends only on hyperparameters fixed at trace time. Scripting suits models with data-dependent control flow (RNN variants, recursive networks, adaptive computation) and deployment to environments without a Python interpreter. The following examples demonstrate scripting syntax (@lst-torchscript-script), control flow preservation (@lst-torchscript-conditional), language restrictions (@lst-torchscript-restrictions), and IR inspection (@lst-torchscript-ir).
|
||
|
||
::: {#lst-torchscript-script lst-cap="**TorchScript Scripting**: Compiles Python source code directly to TorchScript IR by parsing the AST, preserving control flow structure without requiring example inputs."}
|
||
```{.python}
|
||
@torch.jit.script
|
||
def forward(x):
|
||
y = x * 2
|
||
z = y + 1
|
||
return z
|
||
|
||
|
||
# Compiles Python source code to TorchScript IR
|
||
# No example inputs needed
|
||
# Preserves control flow structure
|
||
```
|
||
:::
|
||
|
||
The key advantage of scripting appears when handling conditionals. Unlike tracing, which captures only one branch, scripting preserves both paths in the IR.
|
||
|
||
::: {#lst-torchscript-conditional lst-cap="**Scripted Control Flow**: Unlike tracing, scripting preserves both branches of conditionals in the IR, enabling correct execution based on runtime input values."}
|
||
```{.python}
|
||
@torch.jit.script
|
||
def conditional_forward(x: torch.Tensor) -> torch.Tensor:
|
||
if x.sum() > 0:
|
||
return x * 2
|
||
else:
|
||
return x * 3
|
||
|
||
|
||
# Both branches preserved in IR
|
||
# Correct branch executes based on runtime input values
|
||
```
|
||
:::
|
||
|
||
To understand what the compiler produces, we can inspect the generated intermediate representation directly.
|
||
|
||
::: {#lst-torchscript-ir lst-cap="**TorchScript IR Inspection**: The generated intermediate representation shows primitive operations and constants, useful for debugging and understanding compilation results."}
|
||
```{.python}
|
||
@torch.jit.script
|
||
def example(x: torch.Tensor) -> torch.Tensor:
|
||
return x * 2 + 1
|
||
|
||
|
||
# Inspect generated IR:
|
||
print(example.graph)
|
||
# graph(%x : Tensor):
|
||
# %1 : int = prim::Constant[value=2]()
|
||
# %2 : Tensor = aten::mul(%x, %1)
|
||
# %3 : int = prim::Constant[value=1]()
|
||
# %4 : Tensor = aten::add(%2, %3, %3)
|
||
# return (%4)
|
||
```
|
||
:::
|
||
|
||
However, scripting imposes constraints on what Python constructs are supported.
|
||
|
||
::: {#lst-torchscript-restrictions lst-cap="**TorchScript Restrictions**: Scripting requires a restricted Python subset. Common unsupported features include arbitrary imports, NumPy operations, and f-strings."}
|
||
```{.python}
|
||
@torch.jit.script
|
||
def invalid_script(x):
|
||
import numpy as np # ERROR: Cannot import arbitrary modules
|
||
|
||
result = np.array([1, 2, 3]) # ERROR: NumPy not supported
|
||
print(f"Debug: {x}") # ERROR: f-strings not supported
|
||
return result
|
||
|
||
|
||
# Valid alternative:
|
||
@torch.jit.script
|
||
def valid_script(x: torch.Tensor) -> torch.Tensor:
|
||
# Use TorchScript-compatible operations
|
||
result = torch.tensor([1, 2, 3], dtype=x.dtype, device=x.device)
|
||
return result
|
||
```
|
||
:::
|
||
|
||
Scripting requires a restricted Python subset because TorchScript must statically analyze code that Python normally interprets dynamically. Function signatures and variables need explicit type annotations when type inference fails, and only tensor operations, numeric types, and standard containers (lists, dicts, tuples) are permitted---no arbitrary Python objects, no standard library modules like `os` or `sys`, and no dynamic class modification or metaprogramming. These constraints are the price of compilation: every feature that makes Python flexible also makes it unpredictable for a compiler.
|
||
|
||
\index{Single Static Assignment!compiler IR}
|
||
The TorchScript IR represents operations using the `aten` namespace for core tensor operations, the `prim` namespace for primitives and control flow, static types for every value, and Single Static Assignment (SSA) form, where each variable is assigned exactly once to simplify compiler analysis. This IR enables optimizations independent of Python: operator fusion combines adjacent operations into single kernels, constant folding evaluates constant expressions at compile time, dead code elimination removes unused operations, and memory optimization reuses buffers when possible. @tbl-tracing-vs-scripting summarizes the key trade-offs between these two approaches.
|
||
|
||
| **Aspect** | **Tracing** | **Scripting** |
|
||
|:----------------------|:-----------------------------|:------------------------------|
|
||
| **Input requirement** | Example inputs needed | No inputs needed |
|
||
| **Control flow** | Cannot handle data-dependent | Supports data-dependent |
|
||
| **Conversion ease** | Simpler (just run function) | Harder (restricted Python) |
|
||
| **Type annotations** | Not required | Required when inference fails |
|
||
| **Error detection** | Runtime (wrong results) | Compile time (syntax errors) |
|
||
| **Best for** | Feed-forward models | Models with conditionals |
|
||
|
||
: **Tracing vs. Scripting Trade-offs.** The fidelity-generality trade-off manifests concretely: tracing is simpler to use but silently ignores data-dependent control flow, while scripting preserves all branches at the cost of a restricted Python subset. Choose tracing for static architectures and scripting for models with runtime conditionals. {#tbl-tracing-vs-scripting}
|
||
|
||
#### Modern Compilation: torch.compile {#sec-ml-frameworks-modern-compilation-torchcompile-d025}
|
||
|
||
\index{Graph Compilation!eager-mode capture}
|
||
The previous approaches force a choice: write flexible code (eager execution) or fast code (static graphs). Modern JIT compilation attempts to eliminate this trade-off by automatically compiling eager code into optimized graphs with minimal developer intervention.
|
||
|
||
PyTorch 2.0's `torch.compile` [@ansel2024pytorch2] represents this approach: developers write natural Python code that executes eagerly during development, but the framework automatically captures and compiles hot paths into optimized kernels for production. @lst-torch-compile-intro shows the basic usage pattern:
|
||
|
||
::: {#lst-torch-compile-intro lst-cap="**torch.compile**: PyTorch 2.0's compiler captures execution on first call, compiles an optimized kernel, then reuses compiled code for subsequent calls with matching shapes."}
|
||
```{.python}
|
||
@torch.compile
|
||
def forward(x):
|
||
return x * 2 + 1
|
||
|
||
|
||
# First call: captures execution, compiles optimized kernel (~100ms)
|
||
result1 = forward(torch.tensor([1.0]))
|
||
|
||
# Reuse compiled code
|
||
forward(torch.randn(10, 10))
|
||
```
|
||
:::
|
||
|
||
The compilation overhead in these examples (approximately 100ms to compile the first time, microseconds to reuse) illustrates why torch.compile is so effective. The deeper question is *why* compilation helps so much. The answer lies in understanding *the physics of software overhead*. Dispatch costs that seem negligible for a single operation---a few microseconds here and there---compound dramatically across the thousands of operations in a forward pass. The following analysis quantifies this phenomenon.
|
||
|
||
```{python}
|
||
#| label: fusion-speedup-calc
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ FUSION SPEEDUP CALCULATION
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: Callout "The Physics of Software Overhead" comparing eager vs compiled
|
||
# │
|
||
# │ Goal: Quantify the latency and bandwidth benefits of kernel fusion.
|
||
# │ Show: A 2× speedup from eliminating redundant kernel launches and memory traffic.
|
||
# │ How: Model dispatch and memory overhead for eager vs. fused Add+ReLU.
|
||
# │
|
||
# │ Imports: None (pure calculation)
|
||
# │ Exports: python_dispatch_us_str, kernel_launch_only_us_str, memory_access_us_str,
|
||
# │ kernel_launch_us_str, eager_overhead_str, compiled_overhead_str,
|
||
# │ overhead_speedup_str, bw_efficiency_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
|
||
# ┌── P.I.C.O. ISOLATED SCENARIO ───────────────────────────────────────────────
|
||
class FusionSpeedup:
|
||
"""
|
||
Namespace for Kernel Fusion Speedup calculation.
|
||
Scenario: Comparing Eager (2 launches) vs Fused (1 launch) overheads.
|
||
"""
|
||
|
||
# ┌── 1. PARAMETERS (Inputs) ───────────────────────────────────────────────
|
||
python_dispatch_us = 10
|
||
kernel_launch_us = 5
|
||
memory_access_us = 1
|
||
|
||
eager_ops = 2
|
||
fused_ops = 1
|
||
|
||
eager_mem_factor = 4 # 2R + 2W
|
||
fused_mem_factor = 2 # 1R + 1W (intermediate fused)
|
||
|
||
# ┌── 2. CALCULATION (The Physics) ─────────────────────────────────────────
|
||
launch_overhead = python_dispatch_us + kernel_launch_us
|
||
|
||
eager_total_overhead = eager_ops * launch_overhead
|
||
fused_total_overhead = fused_ops * launch_overhead
|
||
|
||
speedup = eager_total_overhead / fused_total_overhead
|
||
bw_efficiency = eager_mem_factor / fused_mem_factor
|
||
|
||
# ┌── 3. INVARIANTS (Guardrails) ───────────────────────────────────────────
|
||
check(speedup >= 1.5, f"Fusion speedup ({speedup:.1f}x) is too small to justify compilation complexity.")
|
||
|
||
# ┌── 4. OUTPUTS (Formatting) ──────────────────────────────────────────────
|
||
python_dispatch_us_str = f"{python_dispatch_us}"
|
||
kernel_launch_only_us_str = f"{kernel_launch_us}"
|
||
memory_access_us_str = f"{memory_access_us}"
|
||
kernel_launch_us_str = f"{launch_overhead}"
|
||
|
||
eager_overhead_str = f"{eager_total_overhead}"
|
||
compiled_overhead_str = f"{fused_total_overhead}"
|
||
|
||
overhead_speedup_str = f"{int(speedup)}"
|
||
bw_efficiency_str = f"{int(bw_efficiency)}"
|
||
|
||
# Note: Use FusionSpeedup.python_dispatch_us_str directly.
|
||
```
|
||
|
||
::: {.callout-notebook title="The Physics of Software Overhead"}
|
||
**The Iron Law Connection:**
|
||
The **Latency Term** ($\text{Latency}_{\text{fixed}}$) in the Iron Law is dominated by software overhead: dispatching instructions from Python to the GPU.
|
||
|
||
**The Constants of Latency:**
|
||
|
||
* **Python Dispatch:** ~`{python} FusionSpeedup.python_dispatch_us_str` μs per operation.
|
||
* **Kernel Launch:** ~`{python} FusionSpeedup.kernel_launch_only_us_str` μs per operation.
|
||
* **Memory Access (VRAM):** ~`{python} FusionSpeedup.memory_access_us_str` μs.
|
||
|
||
**Scenario 1: Eager Mode (The "Tiny Op" Trap)**
|
||
Consider a simple activation block: `y = relu(x + bias)`.
|
||
|
||
* **Operations:** 2 (Add, ReLU).
|
||
* **Execution:**
|
||
1. Launch `Add` Kernel: `{python} FusionSpeedup.kernel_launch_us_str` µs overhead.
|
||
2. Read/Write Memory: $2N$ bytes.
|
||
3. Launch `ReLU` Kernel: `{python} FusionSpeedup.kernel_launch_us_str` µs overhead.
|
||
4. Read/Write Memory: $2N$ bytes.
|
||
* **Total Overhead:** `{python} FusionSpeedup.eager_overhead_str` µs.
|
||
* **Total Memory Traffic:** $4N$ bytes.
|
||
|
||
**Scenario 2: Compiled Mode (Fusion)**
|
||
The compiler fuses this into one kernel: `FusedAddRelu`.
|
||
|
||
* **Execution:**
|
||
1. Launch `Fused` Kernel: `{python} FusionSpeedup.compiled_overhead_str` µs overhead.
|
||
2. Read/Write Memory: $2N$ bytes (intermediate result stays in registers).
|
||
* **Total Overhead:** `{python} FusionSpeedup.compiled_overhead_str` µs (**`{python} FusionSpeedup.overhead_speedup_str`× speedup**).
|
||
* **Total Memory Traffic:** 2N bytes (**`{python} FusionSpeedup.bw_efficiency_str`× bandwidth efficiency**).
|
||
|
||
**The Conclusion:**
|
||
Compilation is not magic; it is **overhead amortization**. For small, element-wise operations (like LayerNorm, GELU, Add), overhead often exceeds compute time by 10–100×. Fusing them is the only way to utilize the hardware.
|
||
:::
|
||
|
||
See this tax play out concretely in @fig-python-tax. Notice how eager execution (top) creates "gaps" where the GPU sits idle while Python dispatches the next kernel. The blue compute regions are short; the red dispatch regions are comparatively long. Compilation (bottom) fuses these operations into a single kernel launch, eliminating the gaps entirely so the GPU spends nearly all its time computing rather than waiting.
|
||
|
||
```{python}
|
||
#| label: fig-python-tax
|
||
#| echo: false
|
||
#| fig-cap: "**The Python Tax**: Visualizing the overhead analysis from the preceding callout. In Eager Mode (top), the GPU (blue) finishes processing each op in microseconds but must sit idle while the Python interpreter (red) dispatches the next kernel launch. Compilation (bottom) fuses these operations into a single kernel, effectively hiding the dispatch latency and maximizing GPU utilization."
|
||
#| fig-alt: "Gantt chart of execution timeline. Eager mode shows alternating red (Python) and blue (GPU) blocks with gaps. Compiled mode shows one small red block followed by one long blue block."
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ FIG-PYTHON-TAX
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: @fig-python-tax illustrating the Python dispatch overhead
|
||
# │
|
||
# │ Goal: Visualize the dispatch overhead of eager execution.
|
||
# │ Show: How the "Python Tax" creates idle gaps between GPU kernels.
|
||
# │ How: Plot a Gantt chart comparing alternating dispatch/compute vs. fused kernels.
|
||
# │
|
||
# │ Imports: mlsys.viz (viz)
|
||
# │ Exports: figure only, no prose variables
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsys import viz
|
||
|
||
fig, ax, COLORS, plt = viz.setup_plot()
|
||
|
||
# =============================================================================
|
||
# PLOT: The Python Tax
|
||
# =============================================================================
|
||
t_dispatch, t_compute, n_ops = 10, 1, 5
|
||
|
||
# Eager execution: alternating dispatch and compute
|
||
y_eager = 1
|
||
for i in range(n_ops):
|
||
start = i * (t_dispatch + t_compute)
|
||
ax.barh(y_eager, t_dispatch, left=start, height=0.4, color=COLORS['RedLine'], alpha=0.6, label='Python Overhead' if i==0 else "")
|
||
ax.barh(y_eager, t_compute, left=start+t_dispatch, height=0.4, color=COLORS['BlueLine'], alpha=0.8, label='GPU Kernel' if i==0 else "")
|
||
|
||
# Compiled execution: one dispatch, one fused kernel
|
||
y_compiled = 0
|
||
t_fused_compute = t_compute * n_ops * 0.8
|
||
t_compiled_dispatch = 5
|
||
ax.barh(y_compiled, t_compiled_dispatch, left=0, height=0.4, color=COLORS['RedLine'], alpha=0.6)
|
||
ax.barh(y_compiled, t_fused_compute, left=t_compiled_dispatch, height=0.4, color=COLORS['BlueLine'], alpha=0.8)
|
||
|
||
ax.set_yticks([0, 1])
|
||
ax.set_yticklabels(['Compiled / Fused', 'Eager Execution'])
|
||
ax.set_xlabel('Execution Time (microseconds)')
|
||
ax.legend(loc='lower center', bbox_to_anchor=(0.5, -0.25), ncol=2, fontsize=8)
|
||
ax.text(30, 0.5, "Gap = The Python Tax", color=COLORS['RedLine'], ha='center', va='center', fontsize=9, fontweight='bold', bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=0.5))
|
||
plt.show()
|
||
```
|
||
|
||
The natural question is: can this fusion happen automatically? PyTorch 2.0's `torch.compile`[^fn-torch-compile] attempts exactly this by capturing eager code and compiling it into fused kernels without requiring users to write custom CUDA.[^fn-cuda-frameworks]
|
||
|
||
[^fn-torch-compile]: **`torch.compile`**: Introduced in PyTorch 2.0 (March 2023), `torch.compile` represents PyTorch's answer to the eager-vs-graph trade-off that has defined framework design since 2015. Rather than requiring users to rewrite code for a static graph system, `torch.compile` uses Python bytecode interception (via TorchDynamo) to automatically extract computational graphs from existing eager PyTorch code, then compiles them into optimized kernels. The function signature `model = torch.compile(model)` deliberately mirrors Python's `@functools.lru_cache` decorator pattern: a single-line change that transparently optimizes without altering semantics. Typical speedups range from 1.3--2× for transformer models, with the compilation overhead amortized over subsequent executions via on-disk caching.
|
||
|
||
[^fn-cuda-frameworks]: **CUDA (Compute Unified Device Architecture)**: NVIDIA's parallel computing platform, first released in June 2007, which made GPU programming accessible through a C-like language. For ML frameworks, CUDA serves as the foundational layer between high-level Python operations and GPU silicon. When PyTorch executes `torch.matmul(A, B)`, this call traverses the framework's dispatcher, selects a CUDA kernel implementation (often from cuBLAS), and launches it on the GPU. Each kernel launch incurs 5--20 μs of overhead as the CPU assembles parameters and signals the GPU. The framework's ability to fuse multiple operations into fewer kernel launches—reducing these launch overheads and the associated memory traffic—is why compilation (via `torch.compile` or TensorRT) yields significant speedups over eager execution.
|
||
|
||
##### Architecture: Three-Stage Compilation Pipeline {.unnumbered}
|
||
|
||
torch.compile consists of three coordinated components, each handling a distinct phase of the compilation process:
|
||
|
||
1. **TorchDynamo** (graph capture): Intercepts Python bytecode execution using CPython's PEP 523 frame evaluation API. Unlike `torch.jit.trace`, which records a single execution path and silently ignores alternative branches, TorchDynamo also captures operations during execution but inserts *graph breaks* when it encounters unsupported code (print statements, arbitrary Python), ensuring correctness rather than silent failure. The current graph is finalized for compilation, unsupported code executes eagerly, and a new graph begins after.
|
||
|
||
2. **FX Graph** (intermediate representation): Operations captured by TorchDynamo are converted to FX graph format, PyTorch's node-based directed acyclic graph where each node represents an operation with explicit inputs and outputs. The FX graph serves as PyTorch's analog to LLVM IR: a standardized representation that separates frontend (Python code capture) from backend (hardware-specific code generation). This design allows different backends (TorchInductor, ONNX Runtime, TensorRT) to consume FX graphs and enables optimization passes such as dead code elimination, constant folding, and pattern matching for fusion opportunities.
|
||
|
||
3. **TorchInductor**[^fn-torchinductor] (code generation): The default backend that compiles FX graphs to optimized machine code. For CUDA GPUs, TorchInductor generates Triton\index{Triton!GPU kernel language}[^fn-triton] kernels, a Python-based GPU kernel language that compiles to PTX[^fn-ptx]. For CPUs, it generates C++ code with vectorization instructions (AVX2, AVX-512). TorchInductor applies three key optimizations: kernel fusion (combining operations to reduce memory traffic), memory layout optimization (choosing tensor layouts that minimize access overhead), and autotuning (measuring performance across implementation variants to select the fastest).
|
||
|
||
[^fn-torchinductor]: **TorchInductor**: The code generation backend for `torch.compile`, introduced in PyTorch 2.0. The name "Inductor" evokes electromagnetic induction—converting one form of energy into another—reflecting its role in converting high-level FX graph representations into low-level machine code. TorchInductor's architecture mirrors a traditional compiler backend: it receives an intermediate representation (FX graph), applies optimization passes (fusion, layout transformation, memory planning), and emits target-specific code (Triton for GPUs, C++ for CPUs). A key design decision is generating Triton code rather than raw CUDA: Triton's higher-level abstractions (block-based programming, automatic memory coalescing) make code generation simpler while achieving 80--95% of hand-tuned CUDA performance.
|
||
|
||
[^fn-ptx]: **PTX (Parallel Thread Execution)**: NVIDIA's intermediate assembly language for GPU programs, first introduced with CUDA in 2007. PTX occupies the same role for GPU computing that LLVM IR occupies for CPU computing: a portable, human-readable intermediate representation that abstracts hardware-specific details. Triton kernels and CUDA code both compile to PTX, which is then further compiled by NVIDIA's `ptxas` assembler into architecture-specific machine code (SASS) for the target GPU. The PTX layer enables forward compatibility: code compiled to PTX for one GPU generation can be re-compiled by the driver for newer architectures, though with potentially suboptimal performance. For framework developers, PTX provides a stable compilation target that shields code generation from the rapid pace of GPU architecture changes.
|
||
|
||
[^fn-triton]: **Triton**: Named after the Greek god of the sea (son of Poseidon), evoking mastery over waves of parallel threads. OpenAI released this GPU programming language in 2021 to enable writing custom kernels in Python-like syntax without low-level CUDA knowledge. Triton compiles to PTX (NVIDIA's intermediate assembly), handling memory coalescing and thread synchronization automatically. Achieves 80--95% of hand-tuned CUDA performance while reducing development time from weeks to hours.
|
||
|
||
The generated code is cached on disk: TorchInductor maintains its own compilation cache, and Triton kernels are additionally cached in `~/.triton/cache/`. Subsequent runs with the same input shapes can skip compilation and directly execute cached code.
|
||
|
||
##### Execution Flow {.unnumbered}
|
||
|
||
The first execution follows a multi-step process: TorchDynamo intercepts bytecode and records operations into FX graph, FX graph is passed to TorchInductor for compilation (5--30 seconds for transformer models), and compiled code is cached and executed. Subsequent executions with the same input shapes dispatch directly to compiled code with microseconds overhead. If input shapes change, TorchInductor must recompile for the new shapes (shape specialization). PyTorch maintains separate compiled versions for each unique shape configuration.
|
||
|
||
##### Graph Breaks: Causes and Detection {.unnumbered}
|
||
|
||
Graph breaks occur when torch.compile encounters code it cannot compile, forcing execution to fall back to eager mode. Understanding graph break causes provides the foundation for achieving good performance.
|
||
|
||
Data-dependent control flow requires tensor values unavailable at compile time, as shown in @lst-graph-break-control-flow.
|
||
|
||
::: {#lst-graph-break-control-flow lst-cap="**Graph Break from Control Flow**: Data-dependent conditionals force a graph break because tensor values are unavailable at compile time, splitting execution into separate compiled regions."}
|
||
```{.python}
|
||
@torch.compile
|
||
def conditional_compute(x):
|
||
if x.sum() > 0: # Graph break: tensor value needed
|
||
return x * 2
|
||
else:
|
||
return x * 3
|
||
|
||
|
||
# Creates two compiled regions: operations before
|
||
# and after the if statement
|
||
# The if statement itself executes eagerly
|
||
```
|
||
:::
|
||
|
||
TorchDynamo creates a graph break: operations before the if statement are compiled, the if statement executes eagerly (evaluating which branch to take), and the chosen branch is compiled as a separate region.
|
||
|
||
Unsupported operations also cause graph breaks, as @lst-graph-break-io demonstrates.
|
||
|
||
::: {#lst-graph-break-io lst-cap="**Graph Break from I/O**: Unsupported operations like `print` force a graph break, splitting compiled code into two regions with eager execution in between."}
|
||
```{.python}
|
||
@torch.compile
|
||
def debug_compute(x):
|
||
y = x * 2
|
||
print(f"y = {y}") # Graph break: I/O operation
|
||
z = y + 1
|
||
return z
|
||
|
||
|
||
# Creates two compiled regions: before and after print
|
||
```
|
||
:::
|
||
|
||
Common unsupported operations include I/O (`print`, file operations), custom Python objects, and calls to non-PyTorch libraries. Each graph break incurs overhead: tensors must be marshalled from compiled code back to Python (possibly copying from GPU to CPU), the eager operation executes, and results are marshalled into the next compiled region.
|
||
|
||
Shape changes prevent compiled code reuse, as @lst-graph-break-shapes illustrates.
|
||
|
||
::: {#lst-graph-break-shapes lst-cap="**Recompilation from Shape Changes**: Each unique input shape triggers a separate compilation, causing significant overhead when shapes vary frequently."}
|
||
```{.python}
|
||
@torch.compile
|
||
def variable_length(x, length):
|
||
return x[:, :length] # Shape changes each call
|
||
|
||
|
||
# Each unique length triggers recompilation
|
||
for i in range(10):
|
||
result = variable_length(x, i) # 10 recompilations
|
||
```
|
||
:::
|
||
|
||
Detect graph breaks using @lst-graph-break-detect.
|
||
|
||
::: {#lst-graph-break-detect lst-cap="**Detecting Graph Breaks**: Setting `TORCH_LOGS` to `graph_breaks` prints each break location and reason during execution."}
|
||
```{.bash}
|
||
TORCH_LOGS="graph_breaks" python train.py
|
||
```
|
||
:::
|
||
|
||
This prints each break location and reason: `Graph break in user code at file.py:15 / Reason: call to unsupported function print`. Minimizing graph breaks is key to performance: move unsupported operations outside compiled regions, replace data-dependent control flow with conditional execution (`torch.where`), or accept eager execution for inherently dynamic sections.
|
||
|
||
##### Compilation Modes and Backends {.unnumbered}
|
||
|
||
\index{Graph Compilation!compilation modes}
|
||
As a project matures from prototyping to production, engineers progressively increase compilation aggressiveness. The default mode (`mode='default'`) applies moderate optimization with fast compilation (5--30 seconds for transformer models), making it suitable for development and training where compilation overhead is amortized over many iterations. When deploying an inference server with fixed input shapes, `mode='reduce-overhead'` minimizes Python interpreter overhead by aggressively capturing operations and enabling CUDA graphs that batch kernel launches, improving throughput by 20--40% over the default. For production training that will run for days, `mode='max-autotune'` generates and benchmarks multiple implementation variants for each operation, increasing compilation time (minutes to hours for large models) but improving runtime performance by 10--30%. This progression---default for development, reduce-overhead for inference, max-autotune for long training runs---mirrors the Compilation Continuum principle we formalize below.
|
||
|
||
The compilation mode controls *how aggressively* to optimize; the backend controls *what target* to optimize for. TorchInductor (the default) generates Triton kernels for CUDA and C++ for CPU, providing the best general-purpose performance for both training and inference. When cross-platform deployment is required, the ONNX Runtime backend\index{ONNX Runtime!compilation backend} exports the FX graph to ONNX format, enabling execution on CPUs, GPUs, mobile, and edge devices---though limited ONNX operation coverage may cause more graph breaks. For maximum inference throughput on NVIDIA GPUs, the TensorRT backend\index{TensorRT!inference compiler} compiles to NVIDIA's inference engine with aggressive int8 quantization, layer fusion, and kernel autotuning, often achieving 1.5–2× speedup over TorchInductor. The trade-off is clear: each backend narrows the target to unlock deeper optimization, echoing the flexibility-versus-performance axis that distinguishes eager from graph execution.
|
||
|
||
##### Practical Example: Measuring Speedup {.unnumbered}
|
||
|
||
@lst-torch-compile-benchmark implements correct GPU benchmarking methodology, incorporating CUDA synchronization, warmup iterations to exclude compilation time, and sufficient iterations to amortize measurement overhead:
|
||
|
||
::: {#lst-torch-compile-benchmark lst-cap="**Benchmarking torch.compile**: Properly measuring speedup requires CUDA synchronization, warmup to exclude compilation time, and sufficient iterations to amortize measurement overhead."}
|
||
```{.python}
|
||
import torch
|
||
import time
|
||
|
||
|
||
def forward(x, w):
|
||
return torch.matmul(x, w).relu()
|
||
|
||
|
||
x = torch.randn(1024, 1024, device="cuda")
|
||
w = torch.randn(1024, 512, device="cuda")
|
||
|
||
# Eager mode benchmark
|
||
torch.cuda.synchronize() # Ensure GPU operations complete
|
||
start = time.time()
|
||
for _ in range(100):
|
||
y = forward(x, w)
|
||
torch.cuda.synchronize() # Wait for GPU kernel completion
|
||
eager_time = time.time() - start
|
||
|
||
# Compiled mode benchmark
|
||
forward_compiled = torch.compile(forward)
|
||
forward_compiled(x, w) # Warmup: trigger compilation
|
||
torch.cuda.synchronize()
|
||
|
||
start = time.time()
|
||
for _ in range(100):
|
||
y = forward_compiled(x, w)
|
||
torch.cuda.synchronize()
|
||
compiled_time = time.time() - start
|
||
|
||
print(f"Speedup: {eager_time/compiled_time:.2f}×")
|
||
# Typical: 2-5x speedup for matrix operations
|
||
```
|
||
:::
|
||
|
||
Critical benchmarking details: (1) Use `torch.cuda.synchronize()` because CUDA operations are asynchronous; without synchronization, timing measures only kernel launch time, not execution time. (2) Warmup compilation by calling once before timing to exclude compilation from measurements. (3) Run 100+ iterations to amortize measurement overhead.
|
||
|
||
##### Systems Implications {.unnumbered}
|
||
|
||
First execution includes compilation time: 5--10 s for small models, 30--60 s for BERT-base transformers, 5--10 min for GPT-3 scale models. This overhead is amortized across training (compile once, train for thousands of iterations) but impacts development iteration time. Compiled kernels are cached on disk; subsequent runs skip compilation.
|
||
|
||
Compilation adds overhead: 100--500 MB for FX graph construction, 500 MB--2 GB peak during Triton compilation, 10--100 MB per compiled graph for storage. Runtime memory usage is similar to eager mode (kernel fusion can reduce intermediate tensors but compiled code may allocate temporary buffers). Compiled models typically use 90--110% of eager mode memory.
|
||
|
||
Errors in compiled code produce stack traces pointing to generated code, not source Python code. Print statements inside compiled regions cause graph breaks (executed eagerly, not compiled). For debugging, remove `@torch.compile` to revert to eager execution, fix bugs, then re-enable compilation. Use `TORCH_COMPILE_DEBUG=1` for verbose compilation logs.
|
||
|
||
##### When to Use torch.compile {.unnumbered}
|
||
|
||
The decision follows directly from the compilation cost model. Long training runs amortize compilation overhead across hundreds of iterations, and stable architectures with fixed control flow minimize graph breaks---making training the strongest use case. Inference is equally compelling: a deployed model compiles once at startup and serves thousands of requests, where `mode='reduce-overhead'` minimizes per-request overhead. Compilation should be deferred, however, during rapid prototyping, where the overhead slows iteration time and the architecture has not yet stabilized. Models with frequent graph breaks or dynamic shape changes prevent effective compilation, and debugging is harder in compiled mode because error locations point to generated code rather than source Python. The practical strategy is to develop in eager mode, stabilize the architecture, then enable compilation for training and deployment.
|
||
|
||
##### Comparison of Execution Models {.unnumbered}
|
||
|
||
@tbl-framework-execution-models contrasts the three execution models across six dimensions, revealing that hybrid JIT compilation achieves most of static graph performance while preserving much of eager execution's flexibility:
|
||
|
||
: **Execution Model Trade-Offs.** Each execution model occupies a distinct position in the flexibility-optimization trade-off space. Eager execution maximizes debugging flexibility but sacrifices optimization potential; static graphs maximize optimization but sacrifice dynamic control flow; hybrid JIT compilation attempts both by compiling captured regions while falling back to eager for unsupported patterns. {#tbl-framework-execution-models}
|
||
|
||
| **Aspect** | **Eager + Autograd Tape** **(PyTorch default)** | **Static Graph** **(TensorFlow 1.x)** | **JIT Compilation** **(torch.compile)** |
|
||
|:-------------------------|:------------------------------------------------|:--------------------------------------|:----------------------------------------|
|
||
| **Execution Model** | Immediate | Deferred | Hybrid |
|
||
| **Graph Construction** | During forward pass | Before execution | First execution (cached) |
|
||
| **Optimization** | None (per-operation) | Ahead-of-time | JIT compilation |
|
||
| **Dynamic Control Flow** | Full support | Limited (static unroll) | Partial (graph breaks) |
|
||
| **Debugging** | Easy (standard Python) | Difficult (symbolic) | Moderate (mixed) |
|
||
| **Performance** | Baseline | High (optimized) | High (compiled regions) |
|
||
|
||
Beyond these core execution trade-offs, @tbl-mlfm-graphs highlights additional systems-level distinctions between static and dynamic approaches:
|
||
|
||
| **Aspect** | **Static Graphs** | **Dynamic Graphs** |
|
||
|:---------------------------------|:-----------------------------------------------------|:----------------------------------------------|
|
||
| **Memory Management** | Precise allocation planning, optimized memory usage | Flexible but potentially less efficient |
|
||
| **Hardware Utilization** | Can generate highly optimized hardware-specific code | May sacrifice hardware-specific optimizations |
|
||
| **Research Velocity** | Slower iteration due to define-then-run requirement | Faster prototyping and model experimentation |
|
||
| **Integration with Legacy Code** | More separation between definition and execution | Natural integration with imperative code |
|
||
|
||
: **Additional Graph Trade-Offs.** Systems-level distinctions between static and dynamic graphs that complement the execution model comparison above. These trade-offs reappear when selecting frameworks in @sec-ml-frameworks-selecting-framework-2949. {#tbl-mlfm-graphs}
|
||
|
||
These trade-offs are not binary choices. Modern frameworks offer a spectrum of options, which raises the quantitative question: where on this spectrum should a given project operate?
|
||
|
||
### Quantitative Principles of Execution {#sec-ml-frameworks-compilation-continuum-principle-c106}
|
||
|
||
\index{Compilation Continuum!principle}
|
||
|
||
These execution models present a spectrum of trade-offs, but engineers need more than intuition to navigate them. Two quantitative principles formalize the decision. The *Compilation Continuum Principle* establishes when the performance gains from compilation justify its development cost, expressed as a ratio of production executions to development iterations. The *Dispatch Overhead Law* quantifies the per-operation cost of framework flexibility, revealing why small operations in eager mode can spend more time in Python overhead than in actual computation. Together, these principles transform framework selection from subjective preference into measurable engineering analysis.
|
||
|
||
#### The Compilation Continuum Principle {#sec-ml-frameworks-compilation-continuum-principle-7122}
|
||
|
||
The Execution Problem demands a quantitative principle: **when should you compile?**
|
||
|
||
The execution models form a continuum from maximum flexibility to maximum optimization, visualized in @eq-execution-continuum:
|
||
|
||
$$
|
||
\text{Eager} \xrightarrow{\text{tracing}} \text{JIT} \xrightarrow{\text{AOT}} \text{Static Graph} \xrightarrow{\text{synthesis}} \text{Custom Hardware}
|
||
$$ {#eq-execution-continuum}
|
||
|
||
Each step rightward sacrifices flexibility for performance. The practical question is: *where* on this continuum should your project operate? The optimal compilation strategy depends on the ratio of **development iterations** to **production executions** (@eq-compilation-benefit):
|
||
|
||
$$
|
||
\text{Compilation Benefit} = \frac{N_{\text{prod}} \cdot (T_{\text{eager}} - T_{\text{compiled}})}{T_{\text{compile}} + N_{\text{dev}} \cdot T_{\text{compile}}}
|
||
$$ {#eq-compilation-benefit}
|
||
|
||
Where:
|
||
|
||
- $N_{\text{prod}}$ = number of production executions (dimensionless count: inference requests, training steps)
|
||
- $N_{\text{dev}}$ = number of development iterations requiring recompilation (dimensionless count)
|
||
- $T_{\text{eager}}$ = time per execution in eager mode (seconds)
|
||
- $T_{\text{compiled}}$ = time per execution in compiled mode (seconds)
|
||
- $T_{\text{compile}}$ = one-time compilation cost (seconds)
|
||
|
||
**Decision Rule**: Compile when $\text{Compilation Benefit} > 1$. The ratio is dimensionless.
|
||
|
||
@tbl-training-benchmark provides representative throughput data across execution modes and model architectures:
|
||
|
||
| **Model** | **Eager** **(img/sec)** | **torch.compile** **(img/sec)** | **TensorRT** **(img/sec)** | **Compile Time** **(seconds)** |
|
||
|:-----------------|------------------------:|--------------------------------:|---------------------------:|-------------------------------:|
|
||
| **ResNet-50** | 1,450 | 2,150 | 3,800 | 15--30 |
|
||
| **BERT-Base** | 380 | 520 | 890 | 30--60 |
|
||
| **ViT-B/16** | 620 | 950 | 1,650 | 25--45 |
|
||
| **GPT-2 (124M)** | 180 | 260 | 420 | 45--90 |
|
||
|
||
: **Training and Inference Throughput.** Representative throughput comparison across execution modes for common model architectures on NVIDIA A100 GPU with batch size 32. torch.compile typically provides 1.4 to 1.5× speedup over eager mode, while TensorRT provides 2 to 3× speedup but requires longer compilation and is inference only. Compile times vary based on model complexity and optimization level. {#tbl-training-benchmark}
|
||
|
||
These throughput differences across execution modes raise a practical question: which *framework execution strategy* best serves each workload archetype?
|
||
|
||
::: {.callout-lighthouse title="Framework Strategy by Archetype"}
|
||
|
||
The optimal framework execution strategy depends on which **Iron Law** term dominates your workload. @tbl-framework-archetype-strategy aligns each archetype to its recommended execution strategy:
|
||
|
||
| **Archetype** | **Dominant Iron Law Term** | **Optimal Framework Strategy** | **Rationale** |
|
||
|:----------------------|:------------------------------------------|:-----------------------------------|:-------------------------------------------|
|
||
| **ResNet-50** | $\frac{O}{R_{peak} \cdot \eta}$ (Compute) | **TensorRT** (inference) | Kernel fusion maximizes MFU; compute-bound |
|
||
| **(Compute Beast)** | | **torch.compile** (training) | workloads benefit most from optimization |
|
||
| **GPT-2** | $\frac{D_{vol}}{BW}$ (Memory Bandwidth) | **torch.compile** | Kernel fusion reduces HBM round-trips; |
|
||
| **(Bandwidth Hog)** | | | keeps data in cache to mitigate bandwidth |
|
||
| **DLRM** | $\frac{D_{vol}}{BW}$ (Random Access) + | **Eager** with specialized kernels | Embedding lookups are inherently irregular |
|
||
| **(Sparse Scatter)** | $T_{network}$ | (FBGEMM) | and dynamic; compilation gains are small |
|
||
| **DS-CNN** | $L_{lat}$ (Overhead) | **AOT compilation** (TFLite, ONNX) | Sub-ms inference; every microsecond of |
|
||
| **(Tiny Constraint)** | | | Python overhead is unacceptable |
|
||
|
||
: **Framework Execution Strategy by Workload.** Recommended execution strategy for each workload archetype, aligned to the dominant Iron Law term. Compute-bound workloads benefit most from compilation, while irregular access patterns favor eager execution. {#tbl-framework-archetype-strategy}
|
||
|
||
**Key insight**: Compilation benefits scale with how much of your workload is *optimizable*. Compute Beasts (@tbl-training-benchmark: ResNet-50 sees 2.6× speedup from TensorRT) benefit most. Sparse Scatter workloads gain little because their bottleneck (embedding lookups) is inherently irregular.
|
||
:::
|
||
|
||
This principle has concrete implications across three regimes. In *research prototyping* ($N_{\text{dev}} \gg N_{\text{prod}}$), teams should stay eager. If the architecture changes every few minutes, compilation overhead dominates. A 30-second compile time with 10 iterations/hour means 5 minutes lost to compilation per hour, often more than the runtime savings.
|
||
|
||
For *training runs* ($N_{\text{prod}} \gg N_{\text{dev}}$), compilation pays off. A typical training run executes millions of forward/backward passes, so even 60 seconds of compilation amortizes to microseconds per step. From @tbl-training-benchmark, torch.compile provides ~48% speedup on ResNet-50 (2,150 vs 1,450 img/sec); this pays off after the breakeven point in @eq-compile-breakeven:
|
||
|
||
$$
|
||
N_{\text{breakeven}} = \frac{T_{\text{compile}}}{T_{\text{eager}} - T_{\text{compiled}}} = \frac{30\text{s}}{(1/1450 - 1/2150)\text{s/img}} \approx 134{,}000 \text{ images}
|
||
$$ {#eq-compile-breakeven}
|
||
|
||
For ImageNet (1.28M training images), compilation pays off within the first epoch.
|
||
|
||
For *production inference* ($N_{\text{dev}} \approx 0$, $N_{\text{prod}} \rightarrow \infty$), teams should maximize compilation. With no development iterations and potentially millions of requests, every optimization matters. Using `mode='max-autotune'` despite hour-long compilation is worthwhile because the cost is amortized over the deployment lifetime.
|
||
|
||
These three regimes create distinct regions in the compilation decision space. @fig-compilation-continuum maps out these regions so you can identify where each strategy wins. Watch for the crossover points: the steep eager line (highest per-execution cost) eventually overtakes JIT's moderate slope, while the gentlest compiled line (lowest per-execution cost but largest upfront investment) wins only after millions of executions. The slopes reveal per-execution cost; the vertical offsets reveal compilation overhead. A project's position on the x-axis determines which line it should be on.
|
||
|
||
```{python}
|
||
#| label: fig-compilation-continuum
|
||
#| fig-cap: "**The Compilation Continuum**: Optimal execution strategy depends on development-to-production ratio. Left region (high dev iterations): eager mode dominates. Right region (high prod executions): compilation dominates. The crossover point depends on compilation cost and per-execution speedup."
|
||
#| fig-alt: "Graph with x-axis 'Production Executions' (log scale) and y-axis 'Total Time'. Three lines: Eager (steep slope), JIT (moderate slope with offset), Static (gentle slope with larger offset). Lines cross at different points showing when compilation becomes beneficial."
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ FIG-COMPILATION-CONTINUUM
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: @fig-compilation-continuum showing cost trade-offs across
|
||
# │ execution strategies (eager, JIT, AOT)
|
||
# │
|
||
# │ Goal: Visualize the total cost of ownership for different compilation strategies.
|
||
# │ Show: The crossover points where JIT and AOT beat eager execution.
|
||
# │ How: Plot total time (compile + execute) vs. production volume on a log scale.
|
||
# │
|
||
# │ Imports: numpy (np), mlsys.viz (viz)
|
||
# │ Exports: figure only, no prose variables
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
import numpy as np
|
||
from mlsys import viz
|
||
|
||
fig, ax, COLORS, plt = viz.setup_plot()
|
||
|
||
# =============================================================================
|
||
# PLOT: The Compilation Continuum
|
||
# =============================================================================
|
||
x = np.logspace(2, 7, 200) # 100 to 10,000,000 executions
|
||
|
||
# Cost models: compile_cost + per_exec_cost * x
|
||
y_eager = 0 + 10e-6 * x
|
||
y_jit = 10 + 5e-6 * x
|
||
y_aot = 30 + 2e-6 * x
|
||
|
||
ax.plot(x, y_eager, color=COLORS['BlueLine'], linewidth=2.5, label=r'Eager (no compile cost, high per-exec cost)')
|
||
ax.plot(x, y_jit, color=COLORS['OrangeLine'], linewidth=2.5, label=r'JIT (low compile cost, medium per-exec cost)')
|
||
ax.plot(x, y_aot, color=COLORS['GreenLine'], linewidth=2.5, label=r'AOT (high compile cost, low per-exec cost)')
|
||
|
||
ax.text(1000, 5, "Eager wins", color=COLORS['BlueLine'], ha='center', va='center', fontweight='bold', fontsize=9, bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=0.5))
|
||
ax.text(4e6, 34, "JIT wins", color=COLORS['OrangeLine'], ha='center', va='center', fontweight='bold', fontsize=9, bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=0.5))
|
||
ax.text(8e6, 42, "AOT wins", color=COLORS['GreenLine'], ha='center', va='top', fontweight='bold', fontsize=9, bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=0.5))
|
||
|
||
ax.set_xscale('log')
|
||
ax.set_xlabel(r'Production Executions ($N_{prod}$)')
|
||
ax.set_ylabel('Total Time (arbitrary units)')
|
||
ax.set_xlim(100, 1e7)
|
||
ax.set_ylim(0, 100)
|
||
ax.legend(loc='upper left', fontsize=8)
|
||
plt.show()
|
||
```
|
||
|
||
#### The Dispatch Overhead Law {#sec-ml-frameworks-dispatch-overhead-law-9e0a}
|
||
|
||
\index{Dispatch Overhead!law}
|
||
A second principle emerges from the Dispatch Overhead Equation (@eq-dispatch-overhead): when does framework overhead, rather than compute or memory, dominate execution time? Let $N_{\text{ops}}$ be the number of operations (count), $t_{\text{dispatch}}$ the per-operation dispatch overhead (seconds), and $T_{\text{compute}}$ and $T_{\text{memory}}$ the total compute and memory times (seconds). Framework overhead dominates when operations are small relative to dispatch cost:
|
||
|
||
$$
|
||
\text{Overhead Ratio} = \frac{N_{\text{ops}} \cdot t_{\text{dispatch}}}{T_{\text{compute}} + T_{\text{memory}}}
|
||
$$ {#eq-dispatch-overhead}
|
||
|
||
When Overhead Ratio $> 1$, the model is **overhead-bound**. Compilation provides maximum benefit for overhead-bound workloads because it eliminates per-operation dispatch.
|
||
|
||
From the case study in @sec-ml-frameworks-putting-together-anatomy-training-step-c7f1, we can quantify this effect.
|
||
|
||
```{python}
|
||
#| label: dispatch-tax-calc
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ DISPATCH TAX CALCULATION
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: Callout "The Dispatch Tax" comparing overhead-bound vs compute-bound
|
||
# │
|
||
# │ Goal: Demonstrate why compilation benefits small models disproportionately.
|
||
# │ Show: That small MLPs are 92% overhead-bound, while GPT-3 is <0.05% overhead.
|
||
# │ How: Calculate the software-to-hardware execution time ratio for both cases.
|
||
# │
|
||
# │ Imports: mlsys.formatting (fmt, md_math)
|
||
# │ Exports: dispatch_n_ops_value, dispatch_us_per_op_value, dispatch_hw_time_us_value,
|
||
# │ dispatch_sw_time_str, dispatch_ratio_small_str, dispatch_overhead_pct_str,
|
||
# │ dispatch_compilation_speedup_str, gpt3_hw_time_us_value, gpt3_sw_time_us_value,
|
||
# │ dispatch_ratio_large_str, t_sw_md
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsys.constants import KIB_TO_BYTES
|
||
from mlsys.formatting import fmt, check, md_math
|
||
|
||
# ┌── P.I.C.O. ISOLATED SCENARIO ───────────────────────────────────────────────
|
||
class DispatchTax:
|
||
"""
|
||
Namespace for Dispatch Tax Calculation.
|
||
Scenario: Comparing overhead impact on Small Ops vs Large Ops.
|
||
"""
|
||
|
||
# ┌── 1. PARAMETERS (Inputs) ───────────────────────────────────────────────
|
||
# Scenario 1: Small MLP (Overhead Bound)
|
||
small_ops_count = 6
|
||
small_dispatch_us = 5.0
|
||
small_hw_us = 2.6
|
||
|
||
# Scenario 2: GPT-3 Layer (Compute Bound)
|
||
large_hw_us = 100_000.0 # 100ms
|
||
large_dispatch_us = 50.0
|
||
|
||
# ┌── 2. CALCULATION (The Physics) ─────────────────────────────────────────
|
||
# Small Model
|
||
small_sw_total = small_ops_count * small_dispatch_us
|
||
small_total_time = small_sw_total + small_hw_us
|
||
small_overhead_ratio = small_sw_total / small_hw_us
|
||
small_overhead_pct = (small_sw_total / small_total_time) * 100
|
||
small_speedup_limit = small_total_time / small_hw_us
|
||
|
||
# Large Model
|
||
large_overhead_ratio = large_dispatch_us / large_hw_us
|
||
|
||
# ┌── 3. INVARIANTS (Guardrails) ───────────────────────────────────────────
|
||
check(small_overhead_ratio >= 1.0, f"Small model ratio ({small_overhead_ratio:.1f}) implies it is NOT overhead bound.")
|
||
check(large_overhead_ratio <= 0.01, f"Large model overhead ({large_overhead_ratio:.4f}) is too high.")
|
||
|
||
# ┌── 4. OUTPUTS (Formatting) ──────────────────────────────────────────────
|
||
dispatch_n_ops_value = small_ops_count
|
||
dispatch_us_per_op_value = small_dispatch_us
|
||
dispatch_hw_time_us_value = small_hw_us
|
||
|
||
dispatch_sw_time_str = fmt(small_sw_total, precision=0, commas=False)
|
||
dispatch_ratio_small_str = fmt(small_overhead_ratio, precision=1, commas=False)
|
||
dispatch_overhead_pct_str = fmt(small_overhead_pct, precision=0, commas=False)
|
||
dispatch_compilation_speedup_str = fmt(small_speedup_limit, precision=0, commas=False)
|
||
|
||
gpt3_hw_time_us_value = large_hw_us
|
||
gpt3_sw_time_us_value = large_dispatch_us
|
||
dispatch_ratio_large_str = fmt(large_overhead_ratio, precision=4, commas=False)
|
||
t_sw_md = md_math(f"T_{{sw}} \\approx {large_dispatch_us} \\, \\mu s")
|
||
|
||
# Note: Use DispatchTax.dispatch_sw_time_str directly.
|
||
```
|
||
|
||
This cumulative latency creates what is effectively *a dispatch tax* on execution. We define $T_{\text{hw}}$ as hardware execution time and $T_{\text{sw}}$ as software overhead time; both are measured in seconds.
|
||
|
||
::: {.callout-notebook title="The Dispatch Tax"}
|
||
**Problem**: When does Python overhead kill performance?
|
||
|
||
**Scenario 1: Small MLP (Overhead Bound)**
|
||
|
||
* **Compute**: `{python} DispatchTax.dispatch_n_ops_value` small matrix/element-wise operations.
|
||
* **Hardware Time**: T_hw ≈ `{python} DispatchTax.dispatch_hw_time_us_value` μs (mostly memory latency).
|
||
* **Software Overhead**: T_sw ≈ `{python} DispatchTax.dispatch_n_ops_value` ops × `{python} DispatchTax.dispatch_us_per_op_value` μs/op = `{python} DispatchTax.dispatch_sw_time_str` μs.
|
||
* **Ratio**: `{python} DispatchTax.dispatch_sw_time_str` / `{python} DispatchTax.dispatch_hw_time_us_value` ≈ **`{python} DispatchTax.dispatch_ratio_small_str`**.
|
||
* **Conclusion**: You spend `{python} DispatchTax.dispatch_overhead_pct_str`% of time waiting for Python. Compilation yields **`{python} DispatchTax.dispatch_compilation_speedup_str` x speedup**.
|
||
|
||
**Scenario 2: GPT-3 Layer (Compute Bound)**
|
||
|
||
* **Compute**: Huge matrix multiplications.
|
||
* **Hardware Time**: T_hw ≈ 100 ms = `{python} DispatchTax.gpt3_hw_time_us_value` μs.
|
||
* **Software Overhead**: `{python} DispatchTax.t_sw_md`.
|
||
* **Ratio**: `{python} DispatchTax.gpt3_sw_time_us_value` / `{python} DispatchTax.gpt3_hw_time_us_value` ≈ **`{python} DispatchTax.dispatch_ratio_large_str`**.
|
||
* **Conclusion**: Python overhead is negligible. Compilation helps only via kernel fusion (memory bandwidth), not dispatch elimination.
|
||
:::
|
||
|
||
```{python}
|
||
#| label: gpt3-params
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ GPT-3 PARAMETER COUNTS
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: Comparing compilation benefits across model scales
|
||
# │
|
||
# │ Goal: Contrast compilation benefits at extreme scale.
|
||
# │ Show: That dispatch overhead is negligible for 175B parameter models.
|
||
# │ How: Retrieve GPT-3 parameter count to establish the compute-bound regime.
|
||
# │
|
||
# │ Imports: mlsys.constants (GPT3_PARAMS), mlsys.formatting (fmt)
|
||
# │ Exports: gpt3_params_b_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsys import Models
|
||
from mlsys.constants import Bparam
|
||
from mlsys.formatting import fmt, check
|
||
|
||
# ┌── P.I.C.O. ISOLATED SCENARIO ───────────────────────────────────────────────
|
||
class GPT3Context:
|
||
"""
|
||
Namespace for GPT-3 Parameter Counts.
|
||
Scenario: Compilation benefits at scale.
|
||
"""
|
||
|
||
# ┌── 1. PARAMETERS (Inputs) ───────────────────────────────────────────────
|
||
model = Models.GPT3
|
||
|
||
# ┌── 2. CALCULATION (The Physics) ─────────────────────────────────────────
|
||
params_b = model.parameters.to(Bparam).magnitude
|
||
|
||
# ┌── 4. OUTPUTS (Formatting) ──────────────────────────────────────────────
|
||
gpt3_params_b_str = fmt(params_b, precision=0, commas=False)
|
||
|
||
# Note: Use GPT3Context.gpt3_params_b_str directly.
|
||
```
|
||
|
||
The principle's implication is that small models benefit *disproportionately* from compilation. A 100-parameter toy model might see 10× speedup from torch.compile, while a `{python} GPT3Context.gpt3_params_b_str` B-parameter model sees only 1.3×. This explains why compilation matters most for efficient inference on smaller, deployed models.
|
||
|
||
The dispatch tax analysis reveals that small operations are overhead-bound regardless of hardware capability. This observation matters most at the extreme edge of the deployment spectrum, where the entire Python runtime is itself an unacceptable overhead.
|
||
|
||
### Frameworks for the Edge: TinyML and Micro-Runtimes {#sec-ml-frameworks-tinyml-micro-runtimes-2a1b}
|
||
|
||
\index{TinyML!micro-runtimes}
|
||
\index{Edge Deployment!TinyML frameworks}
|
||
The compilation continuum reaches its extreme at the far edge. While cloud frameworks like PyTorch and TensorFlow 2.x prioritize flexibility through eager execution, **TinyML**[^fn-tinyml] systems operating on microcontrollers (MCUs) with kilobytes of memory cannot afford the overhead of a Python interpreter or a dynamic runtime.
|
||
|
||
[^fn-tinyml]: **TinyML**: A term popularized by Pete Warden and Daniel Situnayake's 2019 book *TinyML*, referring to machine learning inference on microcontrollers with power budgets under 1 milliwatt and memory measured in kilobytes. The "Tiny" distinguishes these deployments from edge ML (megabytes of memory, milliwatts to watts) and cloud ML (gigabytes, hundreds of watts). TinyML frameworks like TensorFlow Lite Micro and PyTorch ExecuTorch compile models into bare-metal C/C++ with no dynamic memory allocation, representing the extreme AOT endpoint of the compilation continuum. See @sec-model-compression for the model optimization techniques that make this possible.
|
||
|
||
::: {.callout-lighthouse title="Lighthouse Example: Smart Doorbell (TinyML)"}
|
||
|
||
**The Scenario**: Deploying the **Smart Doorbell's** Keyword Spotting (KWS) model to an ARM Cortex-M4 microcontroller with 256 KB of RAM and 1 MB of Flash.
|
||
|
||
**The Constraint**: A standard PyTorch runtime occupies ~500 MB. The Python interpreter itself occupies ~20 MB. Both are orders of magnitude larger than the entire device.
|
||
|
||
**The Framework Solution**:
|
||
Micro-frameworks like **TensorFlow Lite Micro (TFLM)**\index{TensorFlow Lite Micro!extreme AOT} and **PyTorch ExecuTorch** solve this through **Extreme AOT Compilation**:
|
||
|
||
1. **Static memory planning**: The framework calculates the exact memory address for every tensor *at compile time*. There is no dynamic `malloc()` or garbage collection.
|
||
2. **Kernel specialization**: Only the specific kernels used by the model (e.g., Conv2D, DepthwiseConv) are compiled into the binary. Unused code is stripped away.
|
||
3. **No-interpreter execution**: The model is converted into a flat sequence of function calls or a simple "Command Buffer" that the MCU executes directly in C/C++.
|
||
|
||
**The Silicon Contract**: On TinyML devices, the contract is strictly **Memory-Bound**. The framework's primary job is to ensure the model's intermediate activations (the "working set") fit within the MCU's tiny SRAM.
|
||
:::
|
||
|
||
These micro-runtimes represent the "Pure AOT" endpoint of the continuum. By sacrificing all dynamic flexibility, they enable machine learning to run on devices consuming milliwatts of power, fulfilling the **Energy-Movement Invariant** (formalized in @sec-data-engineering) by keeping all data movement local to the chip.
|
||
|
||
This spectrum of execution strategies, from dynamic eager execution to static graph compilation and specialized micro-runtimes, requires developers to make deliberate trade-offs. The following checkpoint summarizes the key decision points before we turn to the second core problem.
|
||
|
||
::: {.callout-checkpoint title="Execution Models" collapse="false"}
|
||
The choice of execution mode determines both developer velocity and model performance.
|
||
|
||
**Debuggability vs. Speed**
|
||
|
||
- [ ] **Eager Mode (Python-First)**: Why does executing ops one-by-one make debugging easy but optimization hard? (Hint: The compiler cannot see the "future" ops to fuse them).
|
||
- [ ] **Graph Mode (Compiler-First)**: Why does building a static graph enable **Kernel Fusion**? (Merging Conv+ReLU saves memory bandwidth).
|
||
|
||
**The Modern Compromise**
|
||
|
||
- [ ] **JIT Compilation**: How does `torch.compile` bridge the gap? (It captures the graph *just in time* to optimize, while falling back to Python for dynamic parts).
|
||
:::
|
||
|
||
The execution problem determines *when* computation happens and *what* optimizations are possible. Neural network training, however, requires a capability that no amount of clever scheduling can provide: the ability to compute gradients automatically.
|
||
|
||
Consider what training actually requires: for each of millions of parameters, compute how a tiny change would affect the loss. Doing this manually for even a simple three-layer network requires deriving and implementing dozens of partial derivatives. For a modern Transformer with billions of parameters, manual differentiation is economically impossible. A framework that executes efficiently but cannot differentiate can run inference but cannot learn.
|
||
|
||
## Differentiation Problem {#sec-ml-frameworks-differentiation-problem-8b8a}
|
||
|
||
\index{Differentiation Problem!definition}
|
||
\index{Gradient!etymology}
|
||
\index{Automatic Differentiation!definition}
|
||
The differentiation problem asks: how should frameworks compute gradients[^fn-gradient] automatically? Neural network training requires derivatives of a scalar loss $L$ with respect to millions or billions of parameters, making manual differentiation impractical. Because a single scalar loss depends on all parameters, reverse-mode automatic differentiation (AD)[^fn-auto-diff] is the optimal strategy: one backward pass computes all parameter gradients simultaneously, while forward mode would require a separate pass for each parameter. All major ML frameworks therefore implement reverse-mode AD by default [@baydin2018].
|
||
|
||
[^fn-gradient]: **Gradient**: From Latin "gradiens" (stepping/walking), related to "gradus" (step). The gradient points in the direction of steepest ascent, as if climbing steps up a hill. The term was introduced by Sylvester in 1854 for the vector of partial derivatives. In ML, we descend this slope toward lower loss, hence "gradient descent" as the algorithm that takes steps downhill.
|
||
|
||
[^fn-auto-diff]: **Automatic Differentiation**: Technique computing exact derivatives by applying chain rule to elementary operations, formalized by Wengert (1964). Reverse-mode autodiff (backpropagation) computes all gradients in O(1) passes regardless of parameter count, making billion-parameter training feasible. Modern implementations like JAX's grad and PyTorch's autograd support higher-order derivatives and custom gradient rules.
|
||
|
||
Building on the backpropagation[^fn-backprop-frameworks] algorithm introduced in @sec-neural-computation (where we established that gradients flow backward through the computation graph via the chain rule),
|
||
|
||
[^fn-backprop-frameworks]: **Backpropagation**: The algorithm for computing gradients in neural networks, independently discovered multiple times: Paul Werbos (1974 PhD thesis), David Parker (1985), and Yann LeCun (1985), before Rumelhart, Hinton, and Williams popularized it in their landmark 1986 Nature paper "Learning representations by back-propagating errors." The name describes the direction of computation: gradients are "propagated backward" from the output loss through each layer to the input. From a systems perspective, backpropagation's key cost is memory: the forward pass must store all intermediate activations (the "tape") so the backward pass can compute local gradients at each layer. For a transformer with $L$ layers and hidden dimension $d$, this requires $O(L \cdot d \cdot N)$ memory for batch size $N$—often exceeding the model weights themselves. This memory cost motivates activation checkpointing (@sec-model-training) and is the primary reason training requires 3--4× more memory than inference.
|
||
|
||
this section shifts focus from the mathematics to the systems engineering of differentiation: how frameworks represent computation graphs, manage memory for intermediate values, and orchestrate the backward pass efficiently across accelerators. The framework's role is not to perform calculus but to manage the bookkeeping at scale, which is required for the training algorithms detailed in @sec-model-training. @lst-auto_diff_intro illustrates the core idea with a simple three-operation function:
|
||
|
||
::: {#lst-auto_diff_intro lst-cap="**Automatic Differentiation**: AD decomposes complex functions into elementary operations with known derivatives, enabling gradient computation through arbitrarily deep compositions in O(n) time where n is the number of operations."}
|
||
```{.python}
|
||
def f(x):
|
||
a = x * x # Square
|
||
b = sin(x) # Sine
|
||
return a * b # Product
|
||
```
|
||
:::
|
||
|
||
Frameworks decompose this function into elementary operations, each with a known local derivative, and then combine these local derivatives via the chain rule to compute gradients through arbitrary compositions. The systems challenge is implementing this efficiently: the framework must record the computation graph during the forward pass, store intermediate values, and execute the backward pass with minimal memory overhead. The remainder of this section examines how production frameworks solve each of these problems.
|
||
|
||
### Forward and Reverse Mode Differentiation {#sec-ml-frameworks-forward-reverse-mode-differentiation-f70a}
|
||
|
||
\index{Automatic Differentiation!forward vs. reverse mode}
|
||
Two primary approaches to automatic differentiation exist, and the choice between them (forward mode versus reverse mode) determines whether gradient computation scales with the number of inputs or the number of outputs, a distinction that explains why neural network training universally uses one mode over the other.
|
||
|
||
#### Forward Mode {#sec-ml-frameworks-forward-mode-c3ff}
|
||
|
||
\index{Automatic Differentiation!forward mode}
|
||
\index{Forward Mode AD!dual numbers}
|
||
Neural network training universally uses reverse mode (covered next), but forward mode illuminates *why* reverse mode is necessary.
|
||
\index{Dual Numbers!forward mode AD}
|
||
Forward mode automatic differentiation computes derivatives alongside the original computation, tracking how changes propagate from input to output. This approach mirrors manual derivative computation, making it intuitive to understand and implement.
|
||
|
||
Forward mode's memory requirements are its strength: the method stores only the original value, a single derivative value, and temporary results. Memory usage stays constant regardless of computation depth, making forward mode particularly suitable for embedded systems, real-time applications, and memory-bandwidth-limited systems. However, this comes with a computational cost. Forward mode doubles the Ops term (in **Iron Law** terms) for each input parameter whose derivative is requested. For a model with $N$ parameters, forward mode multiplies total computation by $N$, because each parameter requires a separate forward pass. Reverse mode, by contrast, adds a constant factor of approximately 2 to 3× regardless of $N$. This asymmetry explains why forward mode is never used for training neural networks, where $N$ ranges from millions to hundreds of billions. This combination of computational scaling with input count but constant memory creates a specific niche: forward mode excels in scenarios with few inputs but many outputs, such as sensitivity analysis, feature importance computation, and online learning with single-example updates.
|
||
|
||
To see the mechanism concretely, consider computing both the value and derivative of $f(x) = x^2 \sin(x)$. @lst-forward_mode_ad shows how forward mode propagates derivative computations alongside every operation, applying the chain rule and product rule at each step:
|
||
|
||
::: {#lst-forward_mode_ad lst-cap="**Forward Mode AD**: Propagates derivatives forward through the computation graph, computing one directional derivative per forward pass with 2× computational overhead."}
|
||
```{.python}
|
||
def f(x): # Computing both value and derivative
|
||
# Step 1: x -> x²
|
||
a = x * x # Value: x²
|
||
da = 2 * x # Derivative: 2x
|
||
|
||
# Step 2: x -> sin(x)
|
||
b = sin(x) # Value: sin(x)
|
||
db = cos(x) # Derivative: cos(x)
|
||
|
||
# Step 3: Combine using product rule
|
||
result = a * b # Value: x² * sin(x)
|
||
dresult = a * db + b * da # Derivative: x²*cos(x) + sin(x)*2x
|
||
|
||
return result, dresult
|
||
```
|
||
:::
|
||
|
||
Forward mode achieves this systematic derivative computation by augmenting each number with its derivative value, creating what mathematicians call a "dual number." @lst-forward_mode_dual traces a concrete execution with x = 2.0, revealing how each intermediate result carries both its value and derivative through the computation:
|
||
|
||
::: {#lst-forward_mode_dual lst-cap="**Dual Number Computation**: Forward mode augments each value with its derivative, doubling memory per intermediate but enabling single-pass gradient computation."}
|
||
```{.python}
|
||
x = 2.0 # Initial value
|
||
dx = 1.0 # We're tracking derivative with respect to x
|
||
|
||
# Step 1: x²
|
||
a = 4.0 # (2.0)²
|
||
da = 4.0 # 2 * 2.0
|
||
|
||
# Step 2: sin(x)
|
||
b = 0.909 # sin(2.0)
|
||
db = -0.416 # cos(2.0)
|
||
|
||
# Final result
|
||
result = 3.636 # 4.0 * 0.909 = 3.636
|
||
dresult = (
|
||
1.972 # 4.0 * (-0.416) + 0.909 * 4.0 = -1.664 + 3.636 = 1.972
|
||
)
|
||
```
|
||
:::
|
||
|
||
The dual number trace demonstrates the 2× computational overhead per input: every arithmetic operation (multiply, sine, product rule combination) is performed twice, once for the value and once for the derivative. For this single-input function, the overhead is acceptable. For a neural network with $N = 100{,}000{,}000$ parameters, computing all gradients would require 100 million such passes, which is why forward mode is restricted to the few-input applications described above.
|
||
|
||
Forward mode's strength in single-input analysis becomes its fatal weakness for training. A neural network has one scalar loss but millions of parameters, and forward mode would require a separate pass for each one---an intractable $O(N)$ cost that explains why no production framework uses forward mode for training. Forward mode remains useful for targeted analyses such as sensitivity analysis (how does changing one pixel affect the prediction?) and feature importance (which input dimensions most influence the output?), where the number of inputs of interest is small.
|
||
|
||
Given forward mode's $O(N)$ scaling with parameter count, we need an entirely different approach for training. Reverse mode provides exactly this: by propagating gradients backward from output to input, it computes all $N$ parameter gradients in a single pass.
|
||
|
||
#### Reverse Mode {#sec-ml-frameworks-reverse-mode-d328}
|
||
|
||
\index{Automatic Differentiation!reverse mode}
|
||
\index{Reverse Mode AD!backpropagation}
|
||
\index{Backpropagation!computational asymmetry}
|
||
Why does every modern ML framework default to reverse mode for training? The answer is computational asymmetry, one of the most consequential design decisions in machine learning software.
|
||
|
||
A neural network has one scalar loss but millions of parameters. Forward mode computes one parameter's gradient per pass, requiring $n$ passes for $n$ parameters. Reverse mode computes all $n$ gradients in a single backward pass. For a model with 100 million parameters, that is the difference between 100 million forward passes and exactly one backward pass, a speedup proportional to the parameter count.
|
||
|
||
This asymmetry makes reverse mode the only viable option for training. Consider a function where $x$ influences the output through two distinct paths. @lst-reverse_simple defines such a function, and @lst-reverse_forward traces its forward and backward computation for a concrete input.
|
||
|
||
::: {#lst-reverse_simple lst-cap="Basic example of reverse mode automatic differentiation"}
|
||
```{.python}
|
||
def f(x):
|
||
a = x * x # First operation: square x
|
||
b = sin(x) # Second operation: sine of x
|
||
c = a * b # Third operation: multiply results
|
||
return c
|
||
```
|
||
:::
|
||
|
||
::: {#lst-reverse_forward lst-cap="**Forward and Backward Pass**: The forward pass stores intermediate values; the backward pass propagates gradients from output to input, accumulating contributions from all paths."}
|
||
```{.python}
|
||
# --- Forward pass: compute and store values ---
|
||
x = 2.0 # Input value
|
||
a = 4.0 # x * x = 2.0 * 2.0 = 4.0
|
||
b = 0.909 # sin(2.0) ≈ 0.909
|
||
c = 3.637 # a * b = 4.0 * 0.909 ≈ 3.637
|
||
|
||
# --- Backward pass: propagate gradients from output ---
|
||
dc/dc = 1.0 # Seed gradient
|
||
|
||
# Through multiplication c = a * b
|
||
dc/da = b # ∂(a*b)/∂a = b = 0.909
|
||
dc/db = a # ∂(a*b)/∂b = a = 4.0
|
||
|
||
# Combine contributions from both paths through x
|
||
# Path 1: x -> x² -> c contribution: 2x * dc/da
|
||
# Path 2: x -> sin(x) -> c contribution: cos(x) * dc/db
|
||
dc/dx = (2 * x * dc/da) + (cos(x) * dc/db)
|
||
= (2 * 2.0 * 0.909) + (cos(2.0) * 4.0)
|
||
= 3.636 + (-0.416 * 4.0)
|
||
= 1.972 # 3.636 - 1.664 = 1.972
|
||
```
|
||
:::
|
||
|
||
The critical observation is that this single backward pass computed dc/dx regardless of how many paths connect x to c. In a neural network, each weight can affect the loss through thousands of paths across layers, and reverse mode handles them all in one traversal. This is why training a `{python} GPT3Context.gpt3_params_b_str` B parameter model like GPT-3 is feasible at all: reverse mode's O(1) backward passes (relative to parameter count) keeps gradient computation tractable.
|
||
|
||
Translating this mathematical elegance into a working system requires solving a concrete engineering problem: the backward pass needs values computed during the forward pass, so the framework must decide what to store, when to store it, and when to free it. Modern frameworks accomplish this through computational graphs and automatic gradient accumulation[^fn-gradient-accumulation].
|
||
|
||
[^fn-gradient-accumulation]: **Gradient Accumulation**: A technique for simulating larger batch sizes by summing gradients over multiple mini-batches before parameter updates. Covered in detail in @sec-model-training.
|
||
|
||
@lst-reverse_simple_nn illustrates this with a two-layer network, showing both the forward computation that stores intermediate values and the backward pass that consumes them to produce gradients for every parameter simultaneously.
|
||
|
||
::: {#lst-reverse_simple_nn lst-cap="**Reverse Mode in a Neural Network**: The forward pass computes and stores intermediate values; the backward pass walks the computation in reverse to produce gradients for every parameter."}
|
||
```{.python}
|
||
def simple_network(x, w1, w2):
|
||
hidden = x * w1 # First layer
|
||
activated = max(0, hidden) # ReLU activation
|
||
output = activated * w2 # Second layer
|
||
return output
|
||
|
||
|
||
# --- Forward pass stores intermediates ---
|
||
# x=1.0, w1=2.0, w2=3.0
|
||
# hidden=2.0, activated=2.0, output=6.0
|
||
|
||
# --- Backward pass consumes them ---
|
||
d_output = 1.0 # Seed gradient
|
||
d_w2 = activated # = 2.0
|
||
d_activated = w2 # = 3.0
|
||
d_hidden = d_activated * (1 if hidden > 0 else 0) # ReLU gate: 3.0
|
||
d_w1 = x * d_hidden # = 3.0
|
||
d_x = w1 * d_hidden # = 6.0
|
||
```
|
||
:::
|
||
|
||
Three implementation requirements emerge from this example. First, the framework must track dependencies between operations to determine the correct reverse traversal order. Second, intermediate values (hidden, activated) must persist in memory until the backward pass consumes them. Third, every operation needs both a forward implementation and a corresponding backward rule. These requirements define the engineering surface of any AD system, and the second requirement, memory persistence, turns out to be the dominant cost.
|
||
|
||
```{python}
|
||
#| label: gpt3-memory-footprint
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ GPT-3 MEMORY FOOTPRINT
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: Memory Management Strategies section discussing AD memory costs
|
||
# │
|
||
# │ Goal: Shows that even storing just weights for a 175B model requires 350 GB
|
||
# │ (FP16), far exceeding any single GPU. Establishes memory as the binding
|
||
# │ constraint on what models frameworks can train.
|
||
# │
|
||
# │ Imports: mlsys.constants (GPT3_PARAMS, BYTES_FP16), mlsys.formulas (model_memory)
|
||
# │ Exports: gpt3_params_b_str, gpt3_fp16_gb_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsys.constants import GPT3_PARAMS, BYTES_FP16, GB, Bparam
|
||
from mlsys.formulas import model_memory
|
||
from mlsys.formatting import fmt, check
|
||
|
||
# --- Inputs (from model specs) ---
|
||
gpt3_params_b_value = GPT3_PARAMS.to(Bparam).magnitude # 175 billion
|
||
|
||
# --- Process ---
|
||
gpt3_fp16_gb_value = model_memory(GPT3_PARAMS, BYTES_FP16, GB) # 350 GB
|
||
|
||
# --- Outputs (formatted strings for prose) ---
|
||
gpt3_params_b_str = fmt(gpt3_params_b_value, precision=0, commas=False) # e.g. "175"
|
||
gpt3_fp16_gb_str = fmt(gpt3_fp16_gb_value, precision=0, commas=False) # e.g. "350"
|
||
```
|
||
|
||
#### Memory Management Strategies {#sec-ml-frameworks-memory-management-strategies-b008}
|
||
|
||
A `{python} GPT3Context.gpt3_params_b_str` B parameter model in FP16 requires `{python} gpt3_fp16_gb_str` GB just for weights, far exceeding any single GPU's memory. Weights, however, are only the beginning: reverse mode AD also stores every intermediate activation from the forward pass for use during the backward pass. For a 100-layer network processing a batch of 64 images, these stored activations can consume 8 to 12 GB on top of the model weights, gradients, and optimizer state. Memory, not compute, is the binding constraint on what models a framework can train.
|
||
|
||
The problem scales linearly with depth. @lst-reverse_memory shows how each layer in a deeper network adds another activation tensor that must persist until the backward pass reaches that layer.
|
||
|
||
::: {#lst-reverse_memory lst-cap="**Reverse Mode Memory Management**: Stores intermediate values for gradient computation during backpropagation."}
|
||
```{.python}
|
||
def deep_network(x, w1, w2, w3):
|
||
# Forward pass - must store intermediates
|
||
hidden1 = x * w1
|
||
activated1 = max(0, hidden1) # Store for backward
|
||
hidden2 = activated1 * w2
|
||
activated2 = max(0, hidden2) # Store for backward
|
||
output = activated2 * w3
|
||
return output
|
||
```
|
||
:::
|
||
|
||
\index{Activation Checkpointing!definition}
|
||
Frameworks attack this memory wall with two primary strategies. The first is *activation checkpointing*\index{Gradient Checkpointing!memory-compute trade-off} (also called gradient checkpointing): rather than storing every activation, the framework stores only selected checkpoints and recomputes the intermediate activations during the backward pass. @sec-model-training examines checkpointing strategies in detail, including optimal checkpoint placement algorithms. @lst-checkpoint_scheme shows the pattern: save activations at checkpoint boundaries, recompute everything between them.
|
||
|
||
::: {#lst-checkpoint_scheme lst-cap="**Checkpointing**: Reduces memory usage by selectively storing intermediate activations during forward passes. Frameworks balance storage needs with computational efficiency to optimize model training."}
|
||
```{.python}
|
||
# Conceptual representation of checkpointing
|
||
checkpoint1 = save_for_backward(activation1)
|
||
# Intermediate activations can be recomputed
|
||
checkpoint2 = save_for_backward(activation4)
|
||
# Framework balances storage vs recomputation
|
||
```
|
||
:::
|
||
|
||
\index{Kernel Fusion!eliminating intermediate allocations}
|
||
The second strategy is *operation fusion*\index{Kernel Fusion}[^fn-operation-fusion]. Rather than executing matrix multiplication, bias addition, and ReLU as three separate operations, each writing intermediate results to memory, frameworks fuse them into a single kernel. This eliminates intermediate memory allocations entirely and achieves 2 to 3× speedup on modern GPUs by keeping data in registers and caches.
|
||
|
||
[^fn-operation-fusion]: **Operation Fusion**: Compiler optimization that combines multiple sequential operations into a single kernel to reduce memory bandwidth and latency. For example, fusing matrix multiplication, bias addition, and ReLU activation can eliminate intermediate memory allocations and achieve 2–3× speedup on modern GPUs.
|
||
|
||
The backward pass itself benefits from hardware-specific optimization. Rather than directly translating the mathematical definition of a convolution gradient into code, frameworks implement specialized backward kernels that exploit memory access patterns and hardware capabilities of modern accelerators [@chetlur2014cudnn]. These optimizations, checkpointing, fusion, and specialized kernels, work together to make training practical for architectures that would otherwise exhaust GPU memory in a single forward pass.
|
||
|
||
### Framework Implementation of Automatic Differentiation {#sec-ml-frameworks-framework-implementation-automatic-differentiation-1407}
|
||
|
||
Checkpointing, fusion, and specialized kernels solve the systems problems of AD. Practitioners, however, never interact with these mechanisms directly. Instead, frameworks expose AD through high-level APIs that hide the underlying machinery behind simple method calls. A PyTorch training loop---`optimizer.zero_grad()`, forward pass, `loss.backward()`, `optimizer.step()`---appears to be four function calls. Behind each call, however, the framework tracks all operations during the forward pass, builds and maintains the computation graph, manages memory for intermediate values, schedules gradient computations efficiently, and interfaces with hardware accelerators. The same graph machinery extends to advanced scenarios: nested `torch.autograd.grad` calls compute second-order derivatives for techniques like natural gradient descent, and mixed-precision contexts (`autocast`) select reduced-precision kernels for compute-intensive operations while maintaining FP32 for numerical stability.
|
||
|
||
#### PyTorch Autograd Internals {#sec-ml-frameworks-pytorch-autograd-internals-4fa0}
|
||
|
||
\index{Autograd Engine!internals}
|
||
\index{PyTorch!autograd system}
|
||
The autograd system is the framework component that solves the differentiation problem described in @sec-ml-frameworks-three-problems-every-framework-must-solve-317d. Three systems principles govern its design: the data structure that enables efficient gradient computation, the memory cost of maintaining that data structure, and the control mechanisms that production systems require. Understanding these principles explains why training consumes 100× more memory than inference for the same model, and why frameworks provide specific mechanisms to manage that cost.
|
||
|
||
##### Principle 1: The Reverse-Linked Graph Structure { .unnumbered}
|
||
|
||
\index{Autograd!reverse-linked graph}
|
||
\index{Gradient Function Nodes!computation graph}
|
||
During the forward pass, the autograd system constructs a reverse-linked graph of `Function` nodes. Each node records the operation performed and stores references to the tensors it needs for gradient computation. This graph is the data structure that makes reverse-mode automatic differentiation possible: regardless of how many parameters a model has, a single backward pass through this graph computes all gradients. For a model with $N$ parameters, reverse-mode AD requires $O(1)$ backward passes (compared to $O(N)$ for forward-mode), which is why every major framework implements this approach.
|
||
|
||
Concretely, every tensor produced by a differentiable operation stores a `grad_fn` attribute pointing to the `Function` that created it. Each `Function` links to its inputs through `next_functions`, forming a chain from the loss back to the leaf parameters. @lst-grad-fn-chain illustrates this structure for a simple computation:
|
||
|
||
::: {#lst-grad-fn-chain lst-cap="**Reverse-Linked Graph Structure**: Each tensor's `grad_fn` links to the `Function` that created it, forming a reverse chain from output to leaf parameters that enables O(1) backward passes."}
|
||
```{.python}
|
||
import torch
|
||
|
||
x = torch.tensor([2.0], requires_grad=True)
|
||
y = x * 3
|
||
z = y.pow(2)
|
||
|
||
# Traverse the reverse-linked graph
|
||
print(z.grad_fn) # PowBackward0
|
||
print(z.grad_fn.next_functions) # -> MulBackward0
|
||
print(
|
||
z.grad_fn.next_functions[0][0].next_functions
|
||
) # -> AccumulateGrad (leaf)
|
||
```
|
||
:::
|
||
|
||
The traversal reveals the chain: `PowBackward0` (for `z = y**2`) links to `MulBackward0` (for `y = x * 3`), which terminates at `AccumulateGrad` for the leaf tensor `x`. Leaf tensors are the endpoints of the graph where gradients accumulate into the `.grad` attribute rather than propagating further. The tuple format `(Function, index)` tracks which output of a multi-output operation each connection corresponds to.
|
||
|
||
This reverse-linked structure has a critical systems implication: the entire graph must remain in memory from the time a tensor is created until the backward pass consumes it. The graph itself is lightweight (pointers and metadata), but the tensors it references are not.
|
||
|
||
The graph structure thus introduces a second implication: memory consumption scales with model depth.
|
||
|
||
##### Principle 2: The Memory-Compute Trade-off {#sec-ml-frameworks-principle-2-memorycompute-tradeoff-19f2}
|
||
|
||
Every activation saved for the backward pass persists in memory until consumed by gradient computation. This is the primary reason training memory dwarfs inference memory. Computing the gradient of most operations requires values from the forward pass: multiplication needs both inputs ($\frac{\partial}{\partial x}(x \cdot y) = y$), exponentiation needs the base ($\frac{\partial}{\partial x}(x^2) = 2x$), and softmax needs its output values. The autograd system stores these tensors in each `Function` node's `saved_tensors` attribute.
|
||
|
||
```{python}
|
||
#| label: resnet50-memory-breakdown
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ RESNET-50 MEMORY BREAKDOWN
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: Memory-Compute Trade-off principle showing training vs inference
|
||
# │
|
||
# │ Goal: Contrast memory requirements for training vs. inference.
|
||
# │ Show: That training requires 100× more memory than inference for ResNet-50.
|
||
# │ How: Sum weights, gradients, optimizer state, and activations.
|
||
# │
|
||
# │ Imports: mlsys.constants (RESNET50_PARAMS), mlsys.formulas (model_memory)
|
||
# │ Exports: resnet_params_m_str, resnet_fp32_mb_str, resnet_adam_mb_str,
|
||
# │ resnet_training_min_gb_str, resnet_training_max_gb_str, resnet_training_ratio_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsys import Models
|
||
from mlsys.constants import BYTES_FP32, BYTES_ADAM_STATE, MB, Mparam
|
||
from mlsys.formatting import fmt, check
|
||
|
||
# ┌── P.I.C.O. ISOLATED SCENARIO ───────────────────────────────────────────────
|
||
class ResNetMemory:
|
||
"""
|
||
Namespace for ResNet-50 Memory Breakdown.
|
||
Scenario: Comparing training vs inference memory costs.
|
||
"""
|
||
|
||
# ┌── 1. PARAMETERS (Inputs) ───────────────────────────────────────────────
|
||
model = Models.ResNet50
|
||
training_min_gb = 10
|
||
training_max_gb = 15
|
||
|
||
# ┌── 2. CALCULATION (The Physics) ─────────────────────────────────────────
|
||
params_m = model.parameters.to(Mparam).magnitude
|
||
|
||
fp32_mb = model.size_in_bytes(BYTES_FP32).to(MB).magnitude
|
||
adam_mb = model.size_in_bytes(BYTES_ADAM_STATE).to(MB).magnitude
|
||
|
||
training_ratio = (training_min_gb * KIB_TO_BYTES) / fp32_mb
|
||
|
||
# ┌── 4. OUTPUTS (Formatting) ──────────────────────────────────────────────
|
||
resnet_params_m_str = fmt(params_m, precision=1, commas=False)
|
||
resnet_fp32_mb_str = fmt(fp32_mb, precision=0, commas=False)
|
||
resnet_adam_mb_str = fmt(adam_mb, precision=0, commas=False)
|
||
|
||
resnet_training_min_gb_str = fmt(training_min_gb, precision=0, commas=False)
|
||
resnet_training_max_gb_str = fmt(training_max_gb, precision=0, commas=False)
|
||
resnet_training_ratio_str = fmt(training_ratio, precision=0, commas=False)
|
||
|
||
# Note: Use ResNetMemory.resnet_params_m_str directly.
|
||
```
|
||
|
||
For a network with $L$ layers, the system must save approximately $L$ activation tensors, one per layer, for the entire batch. Consider a concrete example: ResNet-50 has `{python} ResNetMemory.resnet_params_m_str` M parameters (~`{python} ResNetMemory.resnet_fp32_mb_str` MB in FP32) and processes batch size 64 with 224 × 224 images. The memory breakdown reveals the scale of this trade-off. Forward activations alone consume approximately 8--12 GB (varying by implementation and checkpointing strategy). Parameter gradients add another ~`{python} ResNetMemory.resnet_fp32_mb_str` MB (the same size as the parameters themselves), and Adam optimizer state contributes ~`{python} ResNetMemory.resnet_adam_mb_str` MB for its two momentum buffers per parameter. The total training footprint reaches `{python} ResNetMemory.resnet_training_min_gb_str`--`{python} ResNetMemory.resnet_training_max_gb_str` GB, compared to just ~`{python} ResNetMemory.resnet_fp32_mb_str` MB for inference alone.
|
||
|
||
This `{python} ResNetMemory.resnet_training_ratio_str`x ratio between training and inference memory quantifies why the Data Movement ($D_{vol}$) term dominates training latency in the **Iron Law**. During training, the framework must write all activations to memory during the forward pass and read them back during the backward pass, doubling the memory traffic compared to inference alone. For a complete derivation of the four-component training memory equation ($M_{total} = M_{weights} + M_{gradients} + M_{optimizer} + M_{activations}$) and worked examples at larger model scales, see @sec-algorithm-foundations-true-cost-training-memory-e54e.
|
||
|
||
Frameworks provide two primary mechanisms to manage this trade-off. **Gradient checkpointing**\index{Gradient Checkpointing!recomputation strategy} [@chen2016training] trades recomputation for memory: instead of saving all activations, the framework saves only a subset and recomputes the rest during the backward pass. This typically reduces activation memory by 50--90% at the cost of 20--33% additional compute (with optimal $\sqrt{n}$ checkpoint placement). In Iron Law terms, checkpointing increases the $O$ term (recomputation) to reduce the $D_{vol}$ term (memory traffic). **Tensor detachment** provides a complementary mechanism: calling `.detach()` on a tensor removes it from the computation graph entirely, preventing the framework from saving activations through that path. This is essential for transfer learning, where pretrained layers should not accumulate gradients, and reduces the $D_{vol}$ term by eliminating unnecessary activation storage.
|
||
|
||
Mixed-precision training offers a third approach, reducing activation memory by storing values in lower precision formats. The detailed trade-offs of mixed precision are examined later in this chapter.
|
||
|
||
##### Principle 3: Extensibility and Control { .unnumbered}
|
||
|
||
Production training systems require fine-grained control over gradient flow that goes beyond the default backward pass. Three categories of control arise in practice. First, **selective gradient computation**: transfer learning and fine-tuning require freezing subsets of parameters, which the framework supports through `requires_grad=False` flags and the `.detach()` mechanism described above. Second, **gradient inspection and modification**: debugging vanishing or exploding gradients, implementing per-tensor gradient clipping, and logging gradient statistics all require intercepting gradients mid-computation, which frameworks expose through hook APIs. Third, **custom differentiation rules**: operations not in the framework's built-in library (custom CUDA kernels, novel activation functions, domain-specific operations) require user-defined forward and backward implementations.
|
||
|
||
These control mechanisms share a common systems design: they are callback-based extensions that the autograd engine invokes at specific points during graph traversal, without modifying the core differentiation algorithm. This extensibility pattern allows the framework to maintain a single optimized backward pass while supporting arbitrarily complex gradient manipulation. The following examples demonstrate these mechanisms in practice, showing how to inspect and control PyTorch's autograd system.
|
||
|
||
###### Retaining the Computation Graph {.unnumbered}
|
||
|
||
\index{Graph Retention!multiple backward passes}
|
||
By default, `backward()` frees the graph after use. To run multiple backward passes (for multi-loss optimization or higher-order derivatives), use `retain_graph=True` at the cost of doubled memory, as shown in @lst-retain-graph.
|
||
|
||
::: {#lst-retain-graph lst-cap="**Retaining Computation Graph**: Use retain_graph=True to run multiple backward passes on the same graph, useful for multi-loss optimization or higher-order derivatives."}
|
||
```{.python}
|
||
x = torch.tensor([2.0], requires_grad=True)
|
||
y = x**2
|
||
|
||
# First backward pass - graph is freed by default
|
||
y.backward()
|
||
print(x.grad) # tensor([4.])
|
||
|
||
# Second backward on SAME y fails - graph was freed
|
||
# y.backward() # RuntimeError: graph already freed!
|
||
|
||
# Solution: retain_graph=True keeps graph for multiple passes
|
||
x.grad.zero_()
|
||
y = x**2
|
||
y.backward(retain_graph=True) # First pass, keep graph
|
||
y.backward() # Second pass works, graph freed after this
|
||
```
|
||
:::
|
||
|
||
###### Gradient Accumulation Behavior {.unnumbered}
|
||
|
||
Gradients accumulate across backward passes by default. As @lst-gradient-accumulation demonstrates, without calling `zero_grad()`, successive backward passes sum their gradients:
|
||
|
||
::: {#lst-gradient-accumulation lst-cap="**Gradient Accumulation Behavior**: Gradients accumulate across backward passes by default. Use zero_grad() to reset gradients before each optimization step."}
|
||
```{.python}
|
||
x = torch.tensor([1.0], requires_grad=True)
|
||
|
||
# First backward pass
|
||
y = x * 2
|
||
y.backward()
|
||
print(x.grad) # tensor([2.])
|
||
|
||
# Second backward pass (without zero_grad)
|
||
y = x * 3
|
||
y.backward()
|
||
print(x.grad) # tensor([5.]) = 2 + 3 (accumulated!)
|
||
|
||
# Reset gradients
|
||
x.grad.zero_()
|
||
y = x * 3
|
||
y.backward()
|
||
print(x.grad) # tensor([3.])
|
||
```
|
||
:::
|
||
|
||
###### Custom Autograd Functions {.unnumbered}
|
||
|
||
\index{Custom Autograd Functions!backward implementation}
|
||
When implementing custom operations, you explicitly specify what to save for the backward pass and how to compute gradients. @lst-custom-autograd-function shows the pattern:
|
||
|
||
::: {#lst-custom-autograd-function lst-cap="**Custom Autograd Function**: Implement forward and backward methods to define custom differentiable operations, explicitly specifying tensors to save for gradient computation."}
|
||
```{.python}
|
||
class MultiplyAdd(torch.autograd.Function):
|
||
@staticmethod
|
||
def forward(ctx, x, y, z):
|
||
# Save tensors needed for backward
|
||
ctx.save_for_backward(x, y)
|
||
return x * y + z
|
||
|
||
@staticmethod
|
||
def backward(ctx, grad_output):
|
||
# Retrieve saved tensors
|
||
x, y = ctx.saved_tensors
|
||
|
||
# Compute gradients using chain rule
|
||
grad_x = grad_output * y # ∂L/∂x = ∂L/∂out * ∂out/∂x
|
||
grad_y = grad_output * x # ∂L/∂y = ∂L/∂out * ∂out/∂y
|
||
grad_z = grad_output # ∂L/∂z = ∂L/∂out * 1
|
||
|
||
return grad_x, grad_y, grad_z
|
||
|
||
|
||
# Usage
|
||
x = torch.tensor([2.0], requires_grad=True)
|
||
y = torch.tensor([3.0], requires_grad=True)
|
||
z = torch.tensor([1.0], requires_grad=True)
|
||
|
||
output = MultiplyAdd.apply(x, y, z)
|
||
output.backward()
|
||
|
||
print(
|
||
x.grad, y.grad, z.grad
|
||
) # tensor([3.]), tensor([2.]), tensor([1.])
|
||
```
|
||
:::
|
||
|
||
###### Gradient Hooks {.unnumbered}
|
||
|
||
\index{Gradient Hooks!inspection and modification}
|
||
Register hooks on tensors to inspect or modify gradients during backpropagation, as shown in @lst-gradient-hooks:
|
||
|
||
::: {#lst-gradient-hooks lst-cap="**Gradient Hooks**: Register hooks on tensors to inspect or modify gradients during backpropagation, useful for debugging, gradient clipping, or custom gradient manipulation."}
|
||
```{.python}
|
||
def gradient_hook(grad):
|
||
print(f"Gradient: {grad}")
|
||
# Modify gradient (e.g., gradient clipping)
|
||
return grad.clamp(-1.0, 1.0)
|
||
|
||
|
||
x = torch.tensor([2.0], requires_grad=True)
|
||
x.register_hook(gradient_hook)
|
||
|
||
y = x * 10
|
||
y.backward()
|
||
# Prints: Gradient: tensor([10.])
|
||
# x.grad contains clamped value: tensor([1.])
|
||
```
|
||
:::
|
||
|
||
###### Detach vs. Data {.unnumbered}
|
||
|
||
\index{Gradient Flow!detaching tensors}\index{Gradient Flow!breaking with detach}
|
||
Use `.detach()` to safely break gradient flow. @lst-detach-vs-data illustrates how the legacy `.data` attribute can silently corrupt gradient computation through in-place operations:
|
||
|
||
::: {#lst-detach-vs-data lst-cap="**Safe Gradient Detachment**: Use `.detach()` to safely break gradient flow. The legacy `.data` attribute can silently corrupt gradients through in-place operations."}
|
||
```{.python}
|
||
x = torch.tensor([1.0], requires_grad=True)
|
||
y = x * 2
|
||
|
||
# SAFE: .detach() creates a new tensor that shares storage
|
||
# but is not part of the computation graph
|
||
z_safe = y.detach()
|
||
z_safe.mul_(100) # In-place op on detached tensor
|
||
# y's data IS modified (shared storage), but autograd graph is intact
|
||
|
||
# DANGEROUS: .data bypasses autograd entirely
|
||
# In-place modifications corrupt the computation graph
|
||
z_unsafe = y.data
|
||
z_unsafe.mul_(100) # This modifies y's underlying storage!
|
||
# y.backward() now computes wrong gradients
|
||
|
||
# Best practice: always use .detach() for inference
|
||
with torch.no_grad():
|
||
inference_output = model(x).detach()
|
||
```
|
||
:::
|
||
|
||
These three principles connect directly to the framework's role as a compiler for the **Silicon Contract**. The reverse-linked graph determines which operations the backward pass must execute (the $O$ term). The memory-compute trade-off governs how much data the framework must move through the memory hierarchy (the $D_{vol}$ term). And the extensibility mechanisms allow engineers to tune both terms for their specific workload. The interaction between autograd memory management and numerical precision leads naturally to mixed-precision training, which further reduces the $D_{vol}$ term.
|
||
|
||
#### Mixed-Precision Training Support {#sec-ml-frameworks-mixedprecision-training-support-d31d}
|
||
|
||
\index{Mixed Precision!FP16 vs. FP32 trade-offs}
|
||
Mixed precision exploits a hardware asymmetry to improve two Iron Law terms simultaneously: Tensor Cores execute FP16 matrix multiplications at 2× the throughput of FP32 (increasing effective $O/R_{peak}$), while FP16 activations halve the memory footprint (reducing $D_{vol}$). Improving both terms simultaneously is rare; most optimizations improve one at the expense of the other.
|
||
|
||
Frameworks exploit this through automatic mixed-precision APIs that select reduced precision for compute-intensive operations while maintaining FP32 where numerical stability demands it. Inside these APIs, frameworks automatically apply precision rules: matrix multiplications and convolutions use FP16 for bandwidth efficiency, while numerically sensitive operations like softmax and layer normalization remain in FP32. This selective precision maintains accuracy while achieving speedups on modern GPUs with specialized hardware units. Because FP16 has a narrower dynamic range than FP32, gradients can underflow to zero during backpropagation. Loss scaling addresses this by multiplying the loss by a large factor before the backward pass, then dividing gradients by the same factor afterward.
|
||
|
||
Frameworks also support multiple precision formats including FP16, BF16, and TF32, each with different trade-offs between range and precision. BF16 maintains FP32's dynamic range, simplifying training by eliminating most gradient underflow issues and removing the need for loss scaling entirely. @sec-model-training examines the mechanics of mixed-precision training in detail, including loss scaling algorithms, memory savings analysis, and numerical stability considerations. @lst-autocast-usage demonstrates PyTorch's mixed precision API: the `autocast` context manager automatically selects FP16 for compute-intensive operations while `GradScaler` prevents gradient underflow by dynamically scaling loss values.
|
||
|
||
::: {#lst-autocast-usage lst-cap="**Mixed-Precision API**: Modern frameworks provide automatic mixed-precision support through context managers that handle precision selection and numerical stability."}
|
||
```{.python}
|
||
import torch
|
||
from torch.amp import autocast, GradScaler
|
||
|
||
model = MyModel().cuda()
|
||
optimizer = torch.optim.Adam(model.parameters())
|
||
scaler = GradScaler("cuda")
|
||
|
||
for inputs, targets in dataloader:
|
||
inputs, targets = inputs.cuda(), targets.cuda()
|
||
optimizer.zero_grad()
|
||
|
||
# Framework automatically selects precision per operation
|
||
with autocast(device_type="cuda", dtype=torch.float16):
|
||
outputs = model(inputs)
|
||
loss = criterion(outputs, targets)
|
||
|
||
# GradScaler handles gradient scaling for numerical stability
|
||
scaler.scale(loss).backward()
|
||
scaler.step(optimizer)
|
||
scaler.update()
|
||
```
|
||
:::
|
||
|
||
BF16 training typically does not require loss scaling, as @lst-bf16-training demonstrates.
|
||
|
||
::: {#lst-bf16-training lst-cap="**BF16 Training**: BF16 maintains FP32's dynamic range, eliminating the need for loss scaling that FP16 requires."}
|
||
```{.python}
|
||
# BF16 training typically does not require loss scaling
|
||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||
outputs = model(inputs)
|
||
loss = criterion(outputs, targets)
|
||
loss.backward() # No GradScaler needed
|
||
optimizer.step()
|
||
```
|
||
:::
|
||
|
||
```{python}
|
||
#| label: model-7b-optimizer-state
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ 7B MODEL OPTIMIZER STATE MEMORY
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: Optimizer State and Checkpointing section on memory requirements
|
||
# │
|
||
# │ Goal: Quantify the memory dominance of optimizer state.
|
||
# │ Show: That Adam's state triples the memory footprint of a 7B model.
|
||
# │ How: Sum FP16 weights and FP32 optimizer state bytes.
|
||
# │
|
||
# │ Imports: mlsys.constants (BYTES_FP16, BYTES_ADAM_STATE), mlsys.formulas (model_memory)
|
||
# │ Exports: model_7b_fp16_gb_str, model_7b_adam_gb_str, model_7b_total_gb_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsys.constants import BYTES_FP16, BYTES_ADAM_STATE, GB, ureg
|
||
from mlsys.formulas import model_memory
|
||
from mlsys.formatting import fmt, check
|
||
|
||
# ┌── P.I.C.O. ISOLATED SCENARIO ───────────────────────────────────────────────
|
||
class Model7B:
|
||
"""
|
||
Namespace for 7B Model Memory.
|
||
Scenario: Optimizer state overhead.
|
||
"""
|
||
|
||
# ┌── 1. PARAMETERS (Inputs) ───────────────────────────────────────────────
|
||
params = 7e9 * ureg.param
|
||
|
||
# ┌── 2. CALCULATION (The Physics) ─────────────────────────────────────────
|
||
fp16_gb = model_memory(params, BYTES_FP16, GB)
|
||
adam_gb = model_memory(params, BYTES_ADAM_STATE, GB)
|
||
total_gb = fp16_gb + adam_gb
|
||
|
||
# ┌── 4. OUTPUTS (Formatting) ──────────────────────────────────────────────
|
||
model_7b_fp16_gb_str = fmt(fp16_gb, precision=0, commas=False)
|
||
model_7b_adam_gb_str = fmt(adam_gb, precision=0, commas=False)
|
||
model_7b_total_gb_str = fmt(total_gb, precision=0, commas=False)
|
||
|
||
# Note: Use Model7B.model_7b_total_gb_str directly.
|
||
```
|
||
|
||
Resuming training after interruption requires restoring not just model weights but optimizer state: momentum buffers, adaptive learning rates, and gradient statistics. For Adam, optimizer state typically quintuples the memory footprint beyond weights alone (since two FP32 states are stored for each FP16 parameter), meaning a 7B-parameter model requires approximately `{python} Model7B.model_7b_total_gb_str` GB total (`{python} Model7B.model_7b_fp16_gb_str` GB weights + `{python} Model7B.model_7b_adam_gb_str` GB optimizer state). Checkpoint size therefore bounds recovery speed after failure, connecting fault tolerance directly to the Iron Law's $D_{vol}$ term.
|
||
|
||
@sec-model-training covers optimizer memory requirements and optimization strategies for large-scale training, where checkpoint size becomes a binding constraint. Frameworks provide the `state_dict()` interface to access optimizer state for serialization (@lst-state-dict-interface), and resuming training requires loading both model parameters and optimizer state (@lst-checkpoint-save-load).
|
||
|
||
::: {#lst-state-dict-interface lst-cap="**State Dictionary Interface**: Optimizers expose internal state through state_dict(), enabling serialization of momentum buffers and adaptive learning rate estimates for checkpointing."}
|
||
```{.python}
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.optim as optim
|
||
|
||
model = nn.Linear(10, 5)
|
||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||
|
||
# After training steps, optimizer accumulates state
|
||
loss = model(torch.randn(3, 10)).sum()
|
||
loss.backward()
|
||
optimizer.step()
|
||
|
||
# Access state for checkpointing
|
||
state = optimizer.state_dict()
|
||
# Contains: {'state': {...}, 'param_groups': [{'lr': 0.001, ...}]}
|
||
```
|
||
:::
|
||
|
||
::: {#lst-checkpoint-save-load lst-cap="**Checkpoint Save and Load**: Save both model parameters and optimizer state to properly resume training with correct momentum and adaptive learning rate values."}
|
||
```{.python}
|
||
# Saving checkpoint
|
||
checkpoint = {
|
||
"epoch": epoch,
|
||
"model_state_dict": model.state_dict(),
|
||
"optimizer_state_dict": optimizer.state_dict(),
|
||
}
|
||
torch.save(checkpoint, "checkpoint.pt")
|
||
|
||
# Resuming training
|
||
checkpoint = torch.load("checkpoint.pt")
|
||
model.load_state_dict(checkpoint["model_state_dict"])
|
||
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||
```
|
||
:::
|
||
|
||
The mathematics of automatic differentiation were established decades before deep learning's resurgence. What changed was the systems engineering. Before framework automation, implementing gradient computation for a single fully connected layer meant writing separate forward and backward functions, manually tracking intermediate values, and verifying mathematical correctness across dozens of operations. A modern Transformer involves hundreds of operations with complex dependencies; manual gradient derivation for attention, layer normalization, and residual connections would require months of careful work per architecture variant.
|
||
|
||
The breakthrough was turning this manual process into software infrastructure. A single matrix multiplication requires different gradient computations depending on which inputs require gradients, tensor shapes, hardware capabilities, and memory constraints. Autograd systems handle these variations transparently, which is why the rate of architectural innovation accelerated after frameworks matured. The mathematics did not change; software engineering made the mathematics practical to apply at scale.
|
||
|
||
#### Memory Management in Gradient Computation {#sec-ml-frameworks-memory-management-gradient-computation-bb77}
|
||
|
||
The memory strategies from @sec-ml-frameworks-reverse-mode-d328 (checkpointing, gradient accumulation) exist because reverse-mode differentiation requires preserving computational history. As @lst-reverse_memory demonstrated, each layer adds an activation tensor that persists until the backward pass consumes it, creating a memory wave that peaks at the start of backpropagation and recedes as gradients are computed. Modern frameworks track the lifetime of each intermediate value automatically, freeing memory as soon as it is no longer needed. Even with precise lifetime tracking, however, a deeper problem remains: the cost of acquiring memory from the GPU in the first place.
|
||
|
||
This observation provides a critical engineering lesson: production systems require **Memory Abstraction**. In a production environment, requesting memory directly from a GPU is a high-latency operation that can synchronize the entire device, creating a massive "Allocation Bottleneck" that stalls computation. To solve this, modern frameworks implement **Caching Allocators**\index{Memory Management!caching allocators}. Instead of communicating with the hardware for every new tensor, the framework requests large blocks of memory upfront and manages its own internal pool. This abstraction is critical because it prevents **memory fragmentation**\index{Memory Management!fragmentation}, the scenario where free memory is available but scattered in pieces too small to hold a large tensor, allowing models to push the physical limits of the hardware without constant system-level overhead.
|
||
|
||
::: {.callout-perspective title="Caching Allocator and Utilization"}
|
||
The **Caching Allocator** is the framework's primary mechanism for maximizing the **Utilization** term in the Iron Law ($\frac{1}{\text{Utilization}}$). Without it, two factors degrade performance significantly:
|
||
|
||
1. **Allocation Latency**: `cudaMalloc` is a synchronous operation that costs 10--100 microseconds. In a training loop with thousands of operations per second, this latency would dominate execution time. The caching allocator pays this cost once, then serves subsequent requests in nanoseconds from its pool.
|
||
2. **Fragmentation**: A "Swiss cheese" memory pattern reduces **Effective Capacity**. If you have 10 GB free but the largest contiguous block is 1 GB, you cannot allocate a 2 GB tensor. By binning allocations into standard sizes (powers of 2), the allocator ensures that freed memory can be reused for future requests, keeping **Utilization** high.
|
||
|
||
\index{Memory Fragmentation!OOM errors}
|
||
When you see "OOM" (Out of Memory) errors despite `nvidia-smi` showing free memory, **fragmentation** is often the culprit. The allocator cannot find a contiguous block large enough for the requested tensor.
|
||
:::
|
||
|
||
#### Production System Integration Challenges {#sec-ml-frameworks-production-system-integration-challenges-cd6e}
|
||
|
||
A training iteration that takes 300 ms in profiling may take 500 ms in production because the AD system must coordinate with the memory allocator, the device manager, the operation scheduler, and the optimizer on every single step. Each gradient computation can trigger data movement between CPU and GPU, memory allocation for intermediate tensors, and kernel launches on accelerators. These system interactions dominate wall-clock time for small models and remain significant even at scale. The gap between what the programmer writes (a five-line training loop) and what the system executes (dozens of memory allocations, kernel launches, and synchronization points) is the central tension of AD system design.
|
||
|
||
Beyond sequential overhead, the AD system must also exploit *concurrency*. Modern networks frequently contain independent branches---two convolutional paths processing the same input before merging, as in Inception-style architectures. On a GPU with sufficient resources, the framework's scheduler can execute both branch backward passes on separate CUDA streams, reducing backward pass time by up to 30--40%. The AD system therefore tracks dependencies for two purposes: correctness (computing the right gradients) and performance (scheduling independent computations concurrently). Frameworks hide this complexity behind `loss.backward()`, but the scheduling, memory allocation, and data movement decisions behind that call determine whether training runs at 40% or 80% of peak hardware utilization.
|
||
|
||
The memory and system integration challenges examined above (caching allocators, activation storage, and checkpoint overhead) affect all frameworks. Yet *how* frameworks implement automatic differentiation in the first place varies significantly, with consequences for both *optimization potential* and developer experience. The distinction between *tape-based and transform-based autodiff* captures this architectural divergence.
|
||
|
||
::: {.callout-perspective title="Tape-based vs. Transform-based Autodiff"}
|
||
\index{Automatic Differentiation!tape-based vs. transform-based}
|
||
**PyTorch (Tape-based)**: Records operations on a dynamic "tape" during the forward pass. This is flexible and easy to debug but makes it hard for a compiler to see the whole graph at once for global optimization.
|
||
|
||
**JAX (Transform-based)**: Treats automatic differentiation as a high-level function transformation (`grad(f)`). Because JAX sees the mathematical function before execution, it can easily chain other transformations like `jit(grad(f))` or `vmap(grad(f))`, producing highly optimized, compiled kernels that often outperform dynamic frameworks on specialized hardware like TPUs.
|
||
:::
|
||
|
||
JAX[^fn-jax] exemplifies the transform-based approach, where composable function transformations replace imperative tape recording.
|
||
|
||
[^fn-jax]: **JAX**: Developed by Google Research, first released in December 2018. The name originally stood for "Just After eXecution," reflecting its roots in Autograd (a Python autodiff library). JAX's design philosophy differs fundamentally from PyTorch and TensorFlow: rather than providing a neural network library, JAX provides composable *function transformations*—`grad` (differentiation), `jit` (compilation), `vmap` (vectorization), and `pmap` (parallelization)—that can be freely composed. This functional approach requires "pure functions" (no side effects, no mutation), which constrains programming style but enables mathematical reasoning about program transformations. JAX compiles through XLA, achieving particularly strong performance on Google's TPUs. Its influence extends beyond direct use: PyTorch's `torch.compile` and `torch.vmap` were directly inspired by JAX's transformation model.
|
||
|
||
#### How Different Frameworks Implement AD {#sec-ml-frameworks-different-frameworks-implement-ad-2f8e}
|
||
|
||
The execution models covered in @sec-ml-frameworks-execution-problem-e1e1, namely eager, static graph, and hybrid, directly shape how each framework implements automatic differentiation:
|
||
|
||
- **PyTorch** [@paszke2019pytorch] builds its autograd tape dynamically during forward execution, providing immediate debugging at the cost of graph-level optimization. The `grad_fn` chain mechanism detailed in @sec-ml-frameworks-pytorch-autograd-internals-4fa0 enables flexible control flow but requires storing the complete graph until backward pass completion.
|
||
- **TensorFlow** (in its 1.x incarnation) performed symbolic differentiation during graph construction, enabling ahead-of-time optimization. Modern TensorFlow 2.x uses eager execution by default but provides `tf.function` for graph compilation when performance matters.
|
||
- **JAX** [@frostig2018compiling] transforms functions rather than tracking operations. The `jax.grad()` transformation returns a new function that computes gradients, enabling composition with `jax.vmap()` for vectorization and `jax.jit()` for compilation. This approach requires pure functions but enables powerful program transformations.
|
||
|
||
These implementation differences have direct practical consequences for framework selection, which @sec-ml-frameworks-major-framework-platform-analysis-fe96 examines in detail.
|
||
|
||
A recurring tension runs through every AD design decision: mathematical correctness demands storing computational history, but hardware imposes strict memory limits. Every framework resolves this tension differently, choosing which activations to checkpoint, which operations to fuse, and how aggressively to trade recomputation for memory. These choices determine which models can train on which hardware, making AD system design one of the most consequential engineering decisions in any framework.
|
||
|
||
::: {.callout-checkpoint title="The Systems Cost of Gradients" collapse="false"}
|
||
Training is *inherently more expensive* than inference because of Automatic Differentiation.
|
||
|
||
**Computational Reality**
|
||
|
||
- [ ] **Reverse Mode AD**: Why is this the only viable method for neural networks? (Because we have 1 loss scalar and $10^9$ parameters. Forward mode would require $10^9$ passes).
|
||
- [ ] **The Activation Tax**: Do you understand why training memory scales linearly with depth? (We must stash forward activations to compute backward gradients).
|
||
|
||
**Optimization Mechanics**
|
||
|
||
- [ ] **Gradient Checkpointing**: How does re-computing activations save memory? (We discard the stash and regenerate it on demand).
|
||
:::
|
||
|
||
The execution and differentiation problems together enable the training loop: the execution model determines when computation happens, while automatic differentiation computes the gradients that drive learning. Both problems, however, quietly assume something that cannot be taken for granted: that the same code can run across diverse hardware. A model trained on an NVIDIA A100 must serve inference on a mobile phone's ARM CPU, a Google TPU, or a microcontroller with kilobytes of memory. The same `torch.matmul` call must dispatch to cuBLAS on one device and a hand-tuned ARM NEON kernel on another. This hardware diversity creates the third problem.
|
||
|
||
## Abstraction Problem {#sec-ml-frameworks-abstraction-problem-37a5}
|
||
|
||
\index{Abstraction Problem!definition}
|
||
\index{Hardware Abstraction!framework design}
|
||
\index{Hardware Abstraction!two dimensions (data, execution)}
|
||
The hardware diversity described above is not merely inconvenient; it is architecturally fundamental. A GPU offers 1,000× the parallelism of a CPU but has different memory semantics. A TPU provides higher throughput but requires static shapes. A microcontroller has kilobytes where a server has gigabytes. The abstraction problem asks: how should frameworks hide this complexity behind a single programming interface while still enabling efficient utilization of each target's unique capabilities?
|
||
|
||
The problem decomposes into two interacting dimensions. The first is *data representation*: how should frameworks represent tensors, parameters, and computational state in ways that work across hardware? The second is *execution mapping*: how should high-level operations translate to hardware-specific implementations? These dimensions are not independent concerns. The way data is represented (memory layout, precision, device placement) directly affects what execution strategies are possible. A tensor stored in row-major format on a GPU requires different kernels than one in column-major format on a CPU. A model quantized to INT8 enables entirely different execution paths than FP32.
|
||
|
||
Solving the abstraction problem requires sophisticated software infrastructure: tensor representations that encode both mathematical semantics and hardware constraints, intermediate representations that enable hardware-specific compilation, and runtime systems that manage data movement across the memory hierarchy.
|
||
|
||
To make this concrete, trace what must happen when a programmer writes `model(input)`. The framework must answer five questions in rapid succession: *What is the data?* (tensor shape, memory layout, numeric precision), *Where does it live?* (device placement and the bandwidth hierarchy connecting CPU, GPU, and accelerator memory), *How does it arrive fast enough?* (data pipelines that sustain hundreds of MB/s to keep the accelerator fed), *How does it scale beyond one device?* (parameter synchronization and distributed execution contexts), and *What actually runs on the hardware?* (kernel dispatch, scheduling, and resource optimization). The following sub-sections answer these questions in order, building from the data container up to the hardware execution layer.
|
||
|
||
### Data Structures and Tensor Abstractions {#sec-ml-frameworks-data-structures-tensor-abstractions-9cbf}
|
||
|
||
```{python}
|
||
#| label: resnet50-params-intro
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ RESNET-50 PARAMETERS FOR ABSTRACTION SECTION
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: Introducing tensor abstractions with concrete model scale
|
||
# │
|
||
# │ Goal: Motivate the need for sophisticated framework data structures.
|
||
# │ Show: That 25.6M parameters must be managed without manual pointers.
|
||
# │ How: Retrieve ResNet-50 parameter count from mlsys.constants.
|
||
# │
|
||
# │ Imports: mlsys.constants (RESNET50_PARAMS), mlsys.formatting (fmt)
|
||
# │ Exports: resnet_params_m_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsys.constants import RESNET50_PARAMS, Mparam
|
||
from mlsys.formatting import fmt, check
|
||
|
||
# ┌── P.I.C.O. ISOLATED SCENARIO ───────────────────────────────────────────────
|
||
class ResNetAbstraction:
|
||
"""
|
||
Namespace for ResNet-50 parameter scale in abstraction section.
|
||
"""
|
||
params_m = RESNET50_PARAMS.to(Mparam).magnitude
|
||
check(params_m > 20, f"ResNet-50 parameter count too low: {params_m:.1f}M")
|
||
resnet_params_m_str = fmt(params_m, precision=1, commas=False)
|
||
|
||
resnet_params_m_str = ResNetAbstraction.resnet_params_m_str
|
||
```
|
||
|
||
A ResNet-50 forward pass touches `{python} ResNetMemory.resnet_params_m_str` million parameters, produces intermediate activations at every layer, and must coordinate memory across CPU and GPU address spaces. How do frameworks organize all of this data so that a single Python call like `model(input)` executes millions of operations without the programmer managing a single pointer? Answering this question requires solving four problems in sequence: defining a universal data container (tensors), placing it on the right device (memory management), feeding data fast enough (data pipelines), and dispatching the right hardware kernel (core operations). We trace this path from data representation to hardware execution.
|
||
|
||
Computational graphs specify the *logical* flow of operations, but data structures determine how those operations access and manipulate data in *physical* memory. This distinction matters because the same mathematical operation can differ by an order of magnitude in throughput depending on whether data is contiguous in cache, pinned for DMA transfer, or scattered across pages.
|
||
|
||
The first step is the data container itself. Framework data structures must sustain memory bandwidth (hundreds of GB/s on modern GPUs), accommodate architectures from 1D sequences to 5D video tensors, and hide device management behind clean APIs. Tensors are the universal answer.
|
||
|
||
#### Tensors {#sec-ml-frameworks-tensors-1cb7}
|
||
|
||
\index{Tensor!n-dimensional arrays}
|
||
\index{Tensor!definition}
|
||
\index{Tensor!etymology}
|
||
At the foundation of every framework's data representation lies a single abstraction: the **tensor**.
|
||
|
||
::: {.callout-definition title="Tensor"}
|
||
|
||
***Tensors***\index{Tensor!framework abstraction} are the fundamental unit of **Data Parallelism**. By abstracting n-dimensional arrays into a unified data structure with defined **Strides** and **Types**, they enable frameworks to map mathematical operations onto vectorized hardware instructions without exposing memory layout complexity to the user.
|
||
|
||
:::
|
||
|
||
Every computation in a neural network operates on tensors.[^fn-tensor] Training batches, activation maps, parameter gradients, and optimizer states are all tensors. This unified representation lets frameworks optimize a single data structure for hardware rather than managing separate containers for each role.
|
||
|
||
[^fn-tensor]: **Tensor**: From Latin "tendere" (to stretch), originally describing stress distributions in elastic materials. Mathematicians Ricci and Levi-Civita formalized tensor calculus in 1900 for Einstein's general relativity, where tensors describe how spacetime curves. In ML, the term emphasizes that these arrays transform predictably under coordinate changes, though practitioners primarily use them as n-dimensional arrays with hardware-optimized operations.
|
||
|
||
But how much memory does this single abstraction actually consume? The answer is far more than the model weights alone suggest. This hidden overhead catches many practitioners off guard: they estimate memory from parameter count alone, allocate accordingly, and encounter out-of-memory errors that seem inexplicable. The following notebook quantifies what we call the *administrative tax*---the shadow tensors for gradients, optimizer momentum, and stored activations that accompany every weight tensor.
|
||
|
||
```{python}
|
||
#| label: admin-tax-calc
|
||
#| echo: false
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ ADMINISTRATIVE TAX CALCULATION
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: Callout "The Administrative Tax" explaining hidden memory costs
|
||
# │
|
||
# │ Goal: Demonstrate the hidden memory costs of large-scale training.
|
||
# │ Show: That a 2 GB model requires ~19 GB of VRAM during training.
|
||
# │ How: Sum gradients, optimizer states, and activations for a 1B param model.
|
||
# │
|
||
# │ Imports: mlsys.constants (BYTES_FP16, BYTES_ADAM_STATE, GB), mlsys.formatting (fmt, md_math)
|
||
# │ Exports: admin_weights_gb_str, admin_grads_gb_str, admin_opt_gb_str, admin_act_gb_str,
|
||
# │ admin_tax_gb_str, admin_batch_str, admin_layers_str, admin_width_str,
|
||
# │ admin_act_calc_md
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsys.constants import BYTES_FP16, BYTES_ADAM_STATE, GB, BILLION
|
||
from mlsys.formatting import fmt, check, md_math
|
||
|
||
# ┌── P.I.C.O. ISOLATED SCENARIO ───────────────────────────────────────────────
|
||
class AdminTax:
|
||
"""
|
||
Namespace for Administrative Tax Calculation.
|
||
Scenario: Memory overhead for 1B parameter model.
|
||
"""
|
||
|
||
# ┌── 1. PARAMETERS (Inputs) ───────────────────────────────────────────────
|
||
params_count = 1 * BILLION
|
||
batch_size = 32
|
||
layers = 100
|
||
width = 1024
|
||
|
||
# ┌── 2. CALCULATION (The Physics) ─────────────────────────────────────────
|
||
weights_gb = (params_count * BYTES_FP16).to(GB).magnitude
|
||
grads_gb = weights_gb
|
||
opt_gb = (params_count * BYTES_ADAM_STATE).to(GB).magnitude
|
||
|
||
# Activation approximation: Batch * Layers * Width^2 * FP16
|
||
act_bytes = batch_size * layers * (width ** 2) * BYTES_FP16
|
||
act_gb = act_bytes.to(GB).magnitude
|
||
|
||
total_gb = weights_gb + grads_gb + opt_gb + act_gb
|
||
tax_gb = total_gb - weights_gb
|
||
|
||
# ┌── 3. INVARIANTS (Guardrails) ───────────────────────────────────────────
|
||
check(tax_gb > 15, f"Administrative tax ({tax_gb:.1f} GB) unexpectedly low.")
|
||
|
||
# ┌── 4. OUTPUTS (Formatting) ──────────────────────────────────────────────
|
||
admin_weights_gb_str = fmt(weights_gb, precision=0, commas=False)
|
||
admin_grads_gb_str = fmt(grads_gb, precision=0, commas=False)
|
||
admin_opt_gb_str = fmt(opt_gb, precision=0, commas=False)
|
||
admin_act_gb_str = fmt(act_gb, precision=1, commas=False)
|
||
admin_total_gb_str = fmt(total_gb, precision=0, commas=False)
|
||
admin_tax_gb_str = fmt(tax_gb, precision=0, commas=False)
|
||
admin_batch_str = fmt(batch_size, precision=0, commas=False)
|
||
admin_layers_str = fmt(layers, precision=0, commas=False)
|
||
admin_width_str = fmt(width, precision=0, commas=False)
|
||
|
||
admin_act_calc_md = md_math(f"{admin_batch_str} \\times {admin_layers_str} \\times {admin_width_str}^{{2}} \\times 2 \\approx \\mathbf{{{admin_act_gb_str} \\text{{ GB}}}}")
|
||
|
||
# ┌── EXPORTS (Bridge to Text) ─────────────────────────────────────────────────
|
||
admin_weights_gb_str = AdminTax.admin_weights_gb_str
|
||
admin_grads_gb_str = AdminTax.admin_grads_gb_str
|
||
admin_opt_gb_str = AdminTax.admin_opt_gb_str
|
||
admin_act_gb_str = AdminTax.admin_act_gb_str
|
||
admin_total_gb_str = AdminTax.admin_total_gb_str
|
||
admin_tax_gb_str = AdminTax.admin_tax_gb_str
|
||
admin_batch_str = AdminTax.admin_batch_str
|
||
admin_layers_str = AdminTax.admin_layers_str
|
||
admin_width_str = AdminTax.admin_width_str
|
||
admin_act_calc_md = AdminTax.admin_act_calc_md
|
||
```
|
||
|
||
::: {.callout-notebook title="The Administrative Tax"}
|
||
The memory breakdown for ResNet-50 in @sec-ml-frameworks-principle-2-memorycompute-tradeoff-19f2 showed a concrete ~`{python} ResNetMemory.resnet_training_ratio_str`x ratio between training and inference memory. Here we generalize that analysis to reveal the full administrative overhead at billion-parameter scale.
|
||
|
||
**Problem**: Why does your GPU utilization drop when training small models?
|
||
|
||
**The Math (The Hidden Tax)**:
|
||
|
||
1. **Model Weights**: `{python} admin_weights_gb_str` GB.
|
||
2. **Gradients**: `{python} admin_grads_gb_str` GB (same size as weights).
|
||
3. **Optimizer States (Adam)**: `{python} admin_opt_gb_str` GB ($2 \times \text{weights}$ for momentum and velocity in FP32).
|
||
4. **Activations**: For a batch size of `{python} admin_batch_str` and a `{python} admin_layers_str`-layer network, you must store every intermediate layer output for the backward pass.
|
||
|
||
$$ \text{Activations} \approx \text{Batch} \times \text{Layers} \times \text{Width}^{2} \times 2 \text{ bytes} $$
|
||
For a `{python} admin_width_str`-width model: `{python} admin_act_calc_md`. (Each layer's activation is a `Width × Width` matrix per sample---appropriate for transformer-style models where intermediate projections scale with hidden dimension squared.)
|
||
|
||
**The Systems Conclusion**: Your `{python} admin_weights_gb_str` GB model has an **"Administrative Tax"** of ~`{python} admin_tax_gb_str` GB (`{python} admin_grads_gb_str` GB gradients + `{python} admin_opt_gb_str` GB optimizer + `{python} admin_act_gb_str` GB activations) before you even process the first batch. During training, **Data Movement** includes saving and retrieving these activations, which is why training is often 3–4× slower than pure inference.
|
||
:::
|
||
|
||
#### Tensor Structure and Dimensions {#sec-ml-frameworks-tensor-structure-dimensions-4a14}
|
||
|
||
\index{Tensor!rank hierarchy}\index{Broadcasting!shape compatibility}
|
||
A tensor generalizes scalars, vectors, and matrices to arbitrary dimensions. The hierarchy is straightforward: a scalar is a rank-0 tensor (single value), a vector is rank-1 (sequence of values), and a matrix is rank-2 (rows and columns). Higher ranks extend this pattern through nesting, so a rank-3 tensor is a stack of matrices---compare all four ranks side by side in @fig-tensor-data-structure-a to see how each level adds a new axis of organization.
|
||
|
||
::: {#fig-tensor-data-structure-a fig-env="figure" fig-pos="htb" fig-cap="**Tensor Rank Hierarchy.** Four shapes illustrating tensor ranks from left to right: a single value (rank 0, scalar), a column of values (rank 1, vector), a grid of values (rank 2, matrix), and a cube of values (rank 3, three-dimensional tensor)." fig-alt="Four shapes showing tensor ranks left to right: single box labeled Rank 0, vertical column of numbers labeled Rank 1, 2D grid of numbers labeled Rank 2, and 3D cube labeled Rank 3."}
|
||
```{.tikz}
|
||
\scalebox{0.8}{%
|
||
\begin{tikzpicture}[font=\small\usefont{T1}{phv}{m}{n}]
|
||
\begin{scope}
|
||
\pgfmathsetmacro{\cubex}{2.5}
|
||
\pgfmathsetmacro{\cubey}{2.5}
|
||
\pgfmathsetmacro{\cubez}{2.5}
|
||
\draw[BrownLine,fill=BrownL!40] (0,0,0) -- ++(-\cubex,0,0) -- ++(0,-\cubey,0) -- ++(\cubex,0,0) -- cycle;
|
||
\draw[BrownLine,fill=BrownL] (0,0,0) -- ++(0,0,-\cubez)coordinate(G) -- ++(0,-\cubey,0) -- ++(0,0,\cubez) -- cycle;
|
||
\draw[BrownLine,fill=BrownL!70] (0,0,0) -- ++(-\cubex,0,0) -- ++(0,0,-\cubez) -- ++(\cubex,0,0) -- cycle;
|
||
\path[red] (-\cubex,-\cubey,0)coordinate(A) -- (0,-\cubey,0)coordinate(B);
|
||
\node[below=0.3of $(A)!0.5!(B)$]{Rank 3};
|
||
\end{scope}
|
||
|
||
\begin{scope}[shift={(-5.5,-0.77)}]
|
||
\node[draw=BrownLine,fill=BrownL!40,rectangle,%anchor=north west,
|
||
minimum width=98,minimum height=98](R){};
|
||
\node[right=2pt of $(R.north west)!0.1!(R.south west)$]{1 \ldots ~2};
|
||
\node[right=2pt of $(R.north west)!0.24!(R.south west)$]{3 \ldots ~5};
|
||
\node[right=2pt of $(R.north west)!0.39!(R.south west)$]{5 \phantom{\ldots} 3};
|
||
\node[right=2pt of $(R.north west)!0.58!(R.south west)$]{$\vdots$ \phantom{\ldots~} $\vdots$};
|
||
\node[right=2pt of $(R.north west)!0.9!(R.south west)$]{3 \phantom{\ldots} 3};
|
||
\node[below=0.3of $(R.south west)!0.5!(R.south east)$]{Rank 2};
|
||
\end{scope}
|
||
|
||
\begin{scope}[shift={(-8.75,-0.77)}]
|
||
\node[draw=BrownLine,fill=BrownL!40,rectangle,%anchor=north west,
|
||
minimum width=18,minimum height=98](R){};
|
||
\node[right=2pt of $(R.north west)!0.1!(R.south west)$]{1};
|
||
\node[right=2pt of $(R.north west)!0.24!(R.south west)$]{3};
|
||
\node[right=2pt of $(R.north west)!0.39!(R.south west)$]{5};
|
||
\node[right=2pt of $(R.north west)!0.58!(R.south west)$]{$\vdots$};
|
||
\node[right=2pt of $(R.north west)!0.9!(R.south west)$]{3};
|
||
\node[below=0.3of $(R.south west)!0.5!(R.south east)$](R1){Rank 1};
|
||
\end{scope}
|
||
|
||
\begin{scope}[shift={(-10.5,-0.77)}]
|
||
\node[draw=BrownLine,fill=BrownL!40,rectangle,%anchor=north west,
|
||
minimum width=18,minimum height=18](3R){0};
|
||
\end{scope}
|
||
\path[red](R1)-|coordinate(P)(3R);
|
||
\node[]at(P){Rank 0};
|
||
\end{tikzpicture}}
|
||
```
|
||
:::
|
||
|
||
This rank hierarchy maps directly onto ML data. A color image is a rank-3 tensor: height x width x 3 channels (red, green, blue). @fig-tensor-data-structure-b breaks this apart, stacking the three color channels so you can see how a single photograph becomes a three-layer numerical grid. Stacking a batch of $N$ images adds a fourth dimension, producing a rank-4 tensor of shape $[N, 3, H, W]$. Every convolutional layer in a vision model consumes and produces tensors of exactly this shape, which is why the tensor abstraction is so central to framework design.
|
||
|
||
::: {#fig-tensor-data-structure-b fig-env="figure" fig-pos="htb" fig-cap="**Image as RGB Tensor.** Three stacked grids representing the red, green, and blue color channels of an image, with dimension labels showing width, height, and channel depth forming a rank-3 tensor. *Credit: Niklas Lang [https://towardsdatascience.com/what-are-tensors-in-machine-learning-5671814646ff](https://towardsdatascience.com/what-are-tensors-in-machine-learning-5671814646ff)*." fig-alt="Three stacked 3 × 3 grids in red, green, and blue representing RGB color channels. Dimension labels show width 3 pixels, height 3 pixels, and 3 color channels forming a 3D tensor for image data."}
|
||
```{.tikz}
|
||
\scalebox{0.7}{%
|
||
\begin{tikzpicture}[font=\usefont{T1}{phv}{m}{n}\Large]
|
||
%
|
||
\tikzset{
|
||
Line/.style={line width=1.0pt,black!70,font=\usefont{T1}{phv}{m}{n}\footnotesize
|
||
},
|
||
Box/.style={align=flush center,
|
||
inner xsep=4pt,
|
||
node distance=0,
|
||
draw=white,
|
||
line width=0.75pt,
|
||
fill=red!80,
|
||
minimum width=10mm,
|
||
minimum height=10mm
|
||
},
|
||
}
|
||
\node[Box](B1){\textbf{6}};
|
||
\node[Box,right=of B1](B2){\textbf{2}};
|
||
\node[Box,right=of B2](B3){\textbf{5}};
|
||
\node[Box,below=of B1](B4){\textbf{32}};
|
||
\node[Box,right=of B4](B5){\textbf{15}};
|
||
\node[Box,right=of B5](B6){\textbf{4}};
|
||
\node[Box,below=of B4](B7){\textbf{1}};
|
||
\node[Box,right=of B7](B8){\textbf{8}};
|
||
\node[Box,right=of B8](B9){\textbf{3}};
|
||
%%
|
||
\node[Box,fill= OliveLine, draw= white,above=of B2](2B1){\textbf{8}};
|
||
\node[Box,fill= OliveLine, draw= white,right=of 2B1](2B2){\textbf{7}};
|
||
\node[Box,fill= OliveLine, draw= white,right=of 2B2](2B3){\textbf{5}};
|
||
\node[Box,fill= OliveLine, draw= white,below=of 2B3](2B4){\textbf{1}};
|
||
\node[Box,fill= OliveLine, draw= white,below=of 2B4](2B5){\textbf{2}};
|
||
%%
|
||
\node[Box,fill= BlueLine!80, draw= white,above=of 2B2](3B1){\textbf{2}};
|
||
\node[Box,fill= BlueLine!80, draw= white,right=of 3B1](3B2){\textbf{1}};
|
||
\node[Box,fill= BlueLine!80, draw= white,right=of 3B2](3B3){\textbf{9}};
|
||
\node[Box,fill= BlueLine!80, draw= white,below=of 3B3](3B4){\textbf{4}};
|
||
\node[Box,fill= BlueLine!80, draw= white,below=of 3B4](3B5){\textbf{3}};
|
||
%
|
||
\draw[dashed,Line,latex-latex]([yshift=-3mm]B7.south west)--
|
||
node[below=1mm]{Width: 3 Pixel}([yshift=-3mm]B9.south east);
|
||
\draw[dashed,Line,latex-latex]([xshift=-4mm]B7.south west)--
|
||
node[left]{Height: 3 Pixel}([xshift=-4mm]B1.north west);
|
||
\draw[dashed,Line,latex-latex,shorten <=2mm]([xshift=-4mm]B1.north west)--
|
||
node[left=3mm,pos=0.6]{3 Color Channels}([xshift=-4mm]3B1.north west);
|
||
\end{tikzpicture}}
|
||
```
|
||
:::
|
||
|
||
Framework tensors carry more than raw numbers. Each tensor stores metadata that the runtime uses to validate operations and select fast execution paths: a *shape* tuple (e.g., `[64, 3, 224, 224]` for a batch of images), a *dtype* (float32, float16, int8), and a *device* tag (CPU, cuda:0). A matrix multiplication, for instance, checks shape compatibility at dispatch time and uses the dtype to route to the correct hardware kernel, whether a standard FP32 GEMM or a Tensor Core FP16 path.
|
||
|
||
\index{Memory Layout!stride patterns}
|
||
Memory layout implementation introduces distinct challenges in tensor design. While tensors provide an abstraction of multi-dimensional data, physical computer memory remains linear. Stride patterns\index{Tensor!stride patterns}\index{Memory Layout!row-major vs. column-major} address this disparity by creating mappings between multi-dimensional tensor indices and linear memory addresses. These patterns significantly impact computational performance by determining memory access patterns during tensor operations. @fig-tensor-memory-layout makes this concrete with a 2×3 tensor: follow the same six values as they map into two different linear orderings---row-major and column-major---and note how the stride values change to compensate.
|
||
|
||
::: {#fig-tensor-memory-layout fig-env="figure" fig-pos="htb" fig-cap="**Tensor Memory Layout**: A 2×3 tensor can be stored in linear memory using either row-major (C-style) or column-major (Fortran-style) ordering. Strides define the number of elements to skip in each dimension when moving through memory, enabling frameworks to calculate memory addresses for tensor[i,j] as base_address + i×stride[0] + j×stride[1]. The choice of memory layout significantly impacts cache performance and computational efficiency." fig-alt="Left: 2 × 3 tensor grid with values 1-6. Right: two linear arrays showing row-major layout (1,2,3,4,5,6) and column-major layout (1,4,2,5,3,6). Below: stride calculations for row-major [3,1] and column-major [1,2]."}
|
||
```{.tikz}
|
||
\begin{tikzpicture}[font=\footnotesize\usefont{T1}{phv}{m}{n}]
|
||
% Define colors
|
||
\definecolor{col1}{RGB}{135, 206, 250}
|
||
\definecolor{col2}{RGB}{255, 182, 193}
|
||
\definecolor{col3}{RGB}{152, 251, 152}
|
||
% 2x3 tensor visualization (LEFT SIDE)
|
||
\foreach \row in {0,1} {
|
||
\foreach \col in {0,1,2} {
|
||
\pgfmathsetmacro{\val}{\row * 3 + \col + 1}
|
||
\node[draw, minimum width=15mm, minimum height=10mm,
|
||
fill=col1!50](B\row\col) at (\col*1.7, 1-\row*1.2) {\val};
|
||
}
|
||
}
|
||
\node[above=2pt of B01]{\textbf{2D Tensor (2 $\times$ 3)}};
|
||
\path[red](B02.north east)--++(1.35,0)coordinate(CR);
|
||
\path[red](B12.340)--++(1.35,0)coordinate(ZE);
|
||
% Row-major memory layout (RIGHT SIDE)
|
||
\foreach \i in {0,1,2,3,4,5} {
|
||
\pgfmathsetmacro{\val}{\i + 1}
|
||
\node[draw, minimum width=10mm, minimum height=8mm,
|
||
anchor=north west,fill=col2!50](CB\i) at ($(CR)+(\i*1.1, 0)$) {\val};
|
||
\node[below=0pt of CB\i, font=\tiny\usefont{T1}{phv}{m}{n}] {[\i]};
|
||
}
|
||
\node[above=2pt of CB2.north east]{\textbf{Row-Major Layout}};
|
||
% Column-major memory layout (RIGHT SIDE)
|
||
\foreach \i in {0,1,2,3,4,5} {
|
||
\pgfmathsetmacro{\val}{int(mod(\i,2)*3 + int(\i/2) + 1)}
|
||
\node[draw, minimum width=10mm, minimum height=8mm,
|
||
anchor=north west,fill=col3!50](ZE\i) at ($(ZE)+(\i*1.1, 0)$) {\val.0};
|
||
\node[below=0pt of ZE\i, font=\tiny\usefont{T1}{phv}{m}{n}] {[\i]};
|
||
}
|
||
\node[above=2pt of ZE2.north east]{\textbf{Column-Major Layout}};
|
||
% Strides explanation (BOTTOM)
|
||
\node[anchor=north west,align=left,inner sep=0pt] at ($(B10.south west)+(0,-0.2)$) {%
|
||
\textbf{Stride Calculation:}\\
|
||
Row-major strides: [3, 1]\\
|
||
Column-major strides: [1, 2]\\
|
||
Element [i,j] offset = i $\times$ stride[0] + j $\times$ stride[1]
|
||
};
|
||
\end{tikzpicture}
|
||
```
|
||
:::
|
||
|
||
These memory layout patterns are crucial for framework performance optimization. Row-major layout (used by NumPy, PyTorch) stores elements row by row, making row-wise operations more cache-friendly. Column-major layout (used by some BLAS libraries) stores elements column by column, optimizing column-wise access patterns. The stride values encode this layout information: in row-major layout for a 2×3 tensor, moving to the next row requires skipping 3 elements (stride[0]=3), while moving to the next column requires skipping 1 element (stride[1]=1).
|
||
|
||
These memory layout details have direct performance implications. When a convolution kernel accesses weight values, row-major layout means consecutive weights along the output channel dimension are contiguous in memory---enabling efficient vectorized loads. Column-major layout would scatter those same weights across memory, forcing slower gather operations. Careful alignment of stride patterns with hardware memory hierarchies maximizes cache efficiency and memory throughput, with optimal layouts achieving 80--90% of theoretical memory bandwidth (1.5--3.0 TB/s on modern data-center GPUs like the A100 and H100) compared to suboptimal patterns that may achieve only 20--30% utilization.
|
||
|
||
\index{Tensor!data types (dtypes)}
|
||
Tensor implementations use type systems to control numerical precision and memory consumption. The standard choice in machine learning has been 32-bit floating-point numbers (`float32`), offering a balance of precision and efficiency. Modern frameworks extend this with multiple numeric types for different needs. Integer types support indexing and embedding operations. Reduced-precision types like 16-bit floating-point numbers enable efficient mobile deployment. 8-bit integers allow fast inference on specialized hardware.
|
||
|
||
The choice of numeric type affects both model behavior and computational efficiency.
|
||
\index{Quantization!inference precision}
|
||
Neural network training typically requires float32 precision to maintain stable gradient computations. Inference tasks can often use lower precision (`int8` or even `int4`), reducing memory usage and increasing processing speed. Mixed-precision training approaches combine these benefits by using float32 for critical accumulations while performing most computations at lower precision.
|
||
|
||
Type conversions between different numeric representations require careful management. Operating on tensors with different types demands explicit conversion rules to preserve numerical correctness. These conversions introduce computational costs and risk precision loss. Frameworks provide type casting capabilities but rely on developers to maintain numerical precision across operations.
|
||
|
||
Tensors answer the first question---*what is the data?*---by encoding shape, layout, and precision into a single abstraction. But a perfectly shaped tensor on the wrong device, or one that must cross a 60× bandwidth gap to reach the GPU, can erase every layout optimization. The next question is *where* data lives and *how* it moves.
|
||
|
||
#### Device and Memory Management {#sec-ml-frameworks-device-memory-management-9404}
|
||
|
||
```{python}
|
||
#| label: device-bandwidth-hierarchy
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ DEVICE BANDWIDTH HIERARCHY
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: Device and Memory Management section introducing bandwidth costs
|
||
# │
|
||
# │ Goal: Quantify the bandwidth hierarchy across system interconnects.
|
||
# │ Show: The 60× bandwidth gap between PCIe (32 GB/s) and HBM (2 TB/s).
|
||
# │ How: Compare transfer rates for standard interfaces and on-chip memory.
|
||
# │
|
||
# │ Imports: mlsys.constants (PCIE_GEN4_BW, NVLINK_A100_BW, A100_MEM_BW, A100_MEM_CAPACITY,
|
||
# │ A100_FLOPS_FP16_TENSOR), mlsys.formatting (fmt)
|
||
# │ Exports: pcie4_gbs_str, pcie4_bidir_gbs_str, nvlink_a100_gbs_str, a100_bw_gbs_str,
|
||
# │ a100_bw_tbs_str, a100_tflops_fp16_str, pcie4_4mb_ms_str, nvlink_4mb_ms_str,
|
||
# │ hbm_4mb_ms_str, pcie4_1gb_ms_str, pcie4_1gb_equiv_ops_str, a100_mem_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsys.constants import (PCIE_GEN4_BW, NVLINK_A100_BW, A100_MEM_BW, A100_MEM_CAPACITY,
|
||
A100_FLOPS_FP16_TENSOR, GB, TB, GiB, TFLOPs, flop, byte, second,
|
||
BILLION, MILLION, THOUSAND)
|
||
from mlsys.formatting import fmt, check
|
||
|
||
# ┌── P.I.C.O. ISOLATED SCENARIO ───────────────────────────────────────────────
|
||
class DeviceBandwidthHierarchy:
|
||
"""
|
||
Namespace for Device Bandwidth Hierarchy.
|
||
Scenario: Comparing PCIe vs NVLink vs HBM speeds.
|
||
"""
|
||
|
||
# ┌── 1. PARAMETERS (Inputs) ───────────────────────────────────────────────
|
||
tensor_4mb = 4 * MILLION
|
||
|
||
# ┌── 2. CALCULATION (The Physics) ─────────────────────────────────────────
|
||
pcie4_gbs = PCIE_GEN4_BW.to(GB/second).magnitude
|
||
nvlink_a100_gbs = NVLINK_A100_BW.to(GB/second).magnitude
|
||
a100_bw_gbs = A100_MEM_BW.to(GB/second).magnitude
|
||
a100_bw_tbs = A100_MEM_BW.to(TB/second).magnitude
|
||
a100_mem = A100_MEM_CAPACITY.to(GiB).magnitude
|
||
a100_tflops_fp16 = A100_FLOPS_FP16_TENSOR.to(TFLOPs/second).magnitude
|
||
|
||
# Transfer times for 4 MB tensor
|
||
pcie4_4mb_ms = (tensor_4mb / (PCIE_GEN4_BW.to(byte/second).magnitude)) * THOUSAND
|
||
nvlink_4mb_ms = (tensor_4mb / (NVLINK_A100_BW.to(byte/second).magnitude)) * THOUSAND
|
||
hbm_4mb_ms = (tensor_4mb / (A100_MEM_BW.to(byte/second).magnitude)) * THOUSAND
|
||
|
||
# 1 GB transfer cost analysis
|
||
# 1 GB / BW (B/s) * 1000 ms
|
||
pcie4_1gb_ms = (BILLION / (PCIE_GEN4_BW.to(byte/second).magnitude)) * THOUSAND
|
||
# Equiv Ops: (ms / 1000) * FLOPS
|
||
pcie4_1gb_equiv_ops = (pcie4_1gb_ms / THOUSAND) * A100_FLOPS_FP16_TENSOR.to(flop/second).magnitude
|
||
|
||
# ┌── 4. OUTPUTS (Formatting) ──────────────────────────────────────────────
|
||
pcie4_gbs_str = fmt(pcie4_gbs, precision=0, commas=False)
|
||
pcie4_bidir_gbs_str = fmt(pcie4_gbs * 2, precision=0, commas=False)
|
||
nvlink_a100_gbs_str = fmt(nvlink_a100_gbs, precision=0, commas=False)
|
||
a100_bw_gbs_str = fmt(a100_bw_gbs, precision=0, commas=False)
|
||
a100_bw_tbs_str = fmt(a100_bw_tbs, precision=1, commas=False)
|
||
a100_tflops_fp16_str = fmt(a100_tflops_fp16, precision=0, commas=False)
|
||
a100_mem_str = fmt(a100_mem, precision=0, commas=False)
|
||
pcie4_4mb_ms_str = fmt(pcie4_4mb_ms, precision=3, commas=False)
|
||
nvlink_4mb_ms_str = fmt(nvlink_4mb_ms, precision=3, commas=False)
|
||
hbm_4mb_ms_str = fmt(hbm_4mb_ms, precision=3, commas=False)
|
||
pcie4_1gb_ms_str = fmt(pcie4_1gb_ms, precision=0, commas=False)
|
||
pcie4_1gb_equiv_ops_str = fmt((pcie4_1gb_equiv_ops * flop).to(TFLOPs).magnitude, precision=1, commas=False)
|
||
|
||
# ┌── EXPORTS (Bridge to Text) ─────────────────────────────────────────────────
|
||
pcie4_gbs_str = DeviceBandwidthHierarchy.pcie4_gbs_str
|
||
pcie4_bidir_gbs_str = DeviceBandwidthHierarchy.pcie4_bidir_gbs_str
|
||
nvlink_a100_gbs_str = DeviceBandwidthHierarchy.nvlink_a100_gbs_str
|
||
a100_bw_gbs_str = DeviceBandwidthHierarchy.a100_bw_gbs_str
|
||
a100_bw_tbs_str = DeviceBandwidthHierarchy.a100_bw_tbs_str
|
||
a100_tflops_fp16_str = DeviceBandwidthHierarchy.a100_tflops_fp16_str
|
||
pcie4_4mb_ms_str = DeviceBandwidthHierarchy.pcie4_4mb_ms_str
|
||
nvlink_4mb_ms_str = DeviceBandwidthHierarchy.nvlink_4mb_ms_str
|
||
hbm_4mb_ms_str = DeviceBandwidthHierarchy.hbm_4mb_ms_str
|
||
pcie4_1gb_ms_str = DeviceBandwidthHierarchy.pcie4_1gb_ms_str
|
||
pcie4_1gb_equiv_ops_str = DeviceBandwidthHierarchy.pcie4_1gb_equiv_ops_str
|
||
a100_mem_str = DeviceBandwidthHierarchy.a100_mem_str
|
||
```
|
||
|
||
\index{Device Management!CPU-GPU transfers}
|
||
\index{Memory Management!device placement}
|
||
Tensors and their memory layouts establish *what* the framework computes with. Where that data physically resides, and how it moves between locations, determines whether computation happens at full speed or crawls.
|
||
|
||
Every tensor resides on a specific device, and cross-device operations incur transfer costs that can dominate execution time. PCIe 4.0 delivers `{python} pcie4_gbs_str` GB/s between CPU and GPU, while HBM2e provides `{python} MemoryWallSpecs.a100_bw_tbs_str` TB/s within the GPU. This bandwidth gap, exceeding 60×, means a single misplaced tensor transfer can erase the entire speedup from GPU acceleration.
|
||
|
||
Why does this matter for framework design? Because the framework must track where every tensor lives and enforce that operations only combine tensors on the same device. When data must move, the framework must decide whether to block execution or overlap the transfer with other work. These decisions, invisible to most users, determine whether a training loop achieves 30% or 80% of theoretical hardware throughput.
|
||
|
||
Three systems principles govern effective device and memory management: understanding the bandwidth hierarchy that constrains data movement, overlapping computation with communication to hide transfer latency, and using fine-grained synchronization to maintain correctness without sacrificing concurrency. The remainder of this section develops each principle, with quantitative analysis grounded in the **Iron Law's** data movement term.
|
||
|
||
##### Principle 1: The Device Bandwidth Hierarchy { .unnumbered}
|
||
|
||
The cost of moving data between devices varies by orders of magnitude depending on the interconnect.[^fn-nvlink-frameworks] Before examining optimization strategies, we need to understand these costs quantitatively. @tbl-device-transfer-overhead shows transfer times for a 1000 × 1000 float32 tensor (4 MB)---roughly the size of a typical activation tensor in a moderately sized model. The numbers reveal why careless device placement can erase any speedup from GPU acceleration:
|
||
|
||
[^fn-nvlink-frameworks]: **NVLink**: NVIDIA's high-bandwidth interconnect for GPU-to-GPU communication (see @sec-hardware-acceleration for architecture details), providing `{python} nvlink_a100_gbs_str` GB/s bidirectional bandwidth (NVLink 3.0 on A100) compared to `{python} pcie4_bidir_gbs_str` GB/s bidirectional for PCIe 4.0 x16. Critical for multi-GPU training where gradient synchronization requires moving gigabytes per iteration. NVSwitch extends NVLink to connect 8 GPUs in a fully-connected topology (DGX systems), enabling all-to-all communication without bottlenecks. The ~10× bandwidth advantage over PCIe determines whether tensor parallelism is practical for a given model size.
|
||
|
||
| **Interconnect** | **Bandwidth** | **Transfer Time** | **Relative to Compute** |
|
||
|:-----------------|--------------------------------------------------:|--------------------------------:|:----------------------------|
|
||
| **PCIe 3.0 x16** | 16 GB/s | 0.25 ms | 10× slower than GPU compute |
|
||
| **PCIe 4.0 x16** | `{python} pcie4_gbs_str` GB/s | `{python} pcie4_4mb_ms_str` ms | 5× slower than GPU compute |
|
||
| **NVLink 3.0** | `{python} nvlink_a100_gbs_str` GB/s bidirectional | `{python} nvlink_4mb_ms_str` ms | Comparable to GPU compute |
|
||
| **GPU Memory** | `{python} a100_bw_gbs_str` GB/s | `{python} hbm_4mb_ms_str` ms | Optimal |
|
||
|
||
: **Device Transfer Overhead.** Transfer time for a 4 MB tensor across different interconnects. PCIe bandwidth shown is unidirectional (typical for GPU transfers), with full-duplex operation providing 2× total bandwidth. NVLink bandwidth is bidirectional (300 GB/s per direction). Transfer times dominate for small operations, making device placement critical for performance. {#tbl-device-transfer-overhead}
|
||
|
||
These numbers connect directly to the **Iron Law** of performance. Every cross-device transfer inflates the data movement term ($D_{vol}/BW$) at a fraction of the available on-device bandwidth. A PCIe 4.0 transfer at `{python} pcie4_gbs_str` GB/s means moving a 1 GB activation tensor adds approximately `{python} pcie4_1gb_ms_str` ms to the data movement cost, equivalent to roughly `{python} pcie4_1gb_equiv_ops_str` trillion operations on a GPU delivering `{python} A100BLAS.dense_tflops_str` TFLOPS. For a model forward pass taking 0.5 ms on GPU, transferring inputs and outputs over PCIe 3.0 doubles the total latency. When batches are small or models are lightweight, transfer overhead can exceed computation time entirely.
|
||
|
||
The systems implication is clear: every tensor should reside on the device where it will be consumed, and transfers should occur only when unavoidable. Frameworks track device placement for every tensor and raise errors when operations attempt to combine tensors from different devices, enforcing this discipline at the API level.
|
||
|
||
##### Principle 2: Overlapping Computation and Communication { .unnumbered}
|
||
|
||
\index{Execution Streams!overlapping computation}
|
||
\index{Asynchronous Execution!hiding transfer latency}
|
||
When transfers are unavoidable, the next optimization is to hide their latency by executing them concurrently with computation. Modern GPUs contain independent hardware units for computation (SM clusters) and data transfer (copy engines), enabling true simultaneous execution. The framework abstraction that exposes this hardware parallelism is the *CUDA stream*\index{Execution Streams!definition}: an independent execution queue where operations execute sequentially within a stream but concurrently across streams.
|
||
|
||
Without explicit concurrency control, the GPU serializes all operations on a single default stream, leaving execution units idle while data transfers complete. By placing data transfers on one stream and computation on another, the effective latency approaches the theoretical minimum of $\max(\text{compute\_time}, \text{transfer\_time})$ rather than their sum. Stream-based overlap effectively hides the $D_{vol}/BW$ penalty when computation is the longer operation (see @lst-overlap-compute-transfer):
|
||
|
||
::: {#lst-overlap-compute-transfer lst-cap="**Overlapping Computation and Transfer**: Use separate streams for data transfer and computation to hide transfer latency. Pinned memory enables truly asynchronous non-blocking transfers."}
|
||
```{.python}
|
||
compute_stream = torch.cuda.Stream()
|
||
transfer_stream = torch.cuda.Stream()
|
||
|
||
# Transfer next batch while computing current batch
|
||
with torch.cuda.stream(transfer_stream):
|
||
next_batch = next_batch_cpu.to("cuda", non_blocking=True)
|
||
|
||
with torch.cuda.stream(compute_stream):
|
||
output = model(current_batch)
|
||
loss = criterion(output, labels)
|
||
|
||
# Pinned memory enables non_blocking transfers
|
||
x_pinned = torch.randn(1000, 1000).pin_memory()
|
||
x_gpu = x_pinned.to("cuda", non_blocking=True) # Asynchronous
|
||
|
||
# Regular memory requires blocking transfer
|
||
y_regular = torch.randn(1000, 1000)
|
||
y_gpu = y_regular.to("cuda", non_blocking=True) # Still blocks
|
||
```
|
||
:::
|
||
|
||
\index{Pinned Memory!definition}
|
||
\index{DMA Transfer!GPU copy engine}
|
||
The `non_blocking=True` flag enables asynchronous transfers that return immediately without waiting for completion. This works only when the source tensor uses *pinned memory*\index{Pinned Memory!DMA transfers} (page-locked memory that enables DMA transfers). Without pinned memory, the transfer blocks even when `non_blocking=True` is specified, because the GPU's copy engine cannot initiate a DMA transfer from pageable host memory.
|
||
|
||
This overlap principle extends naturally to pipeline parallelism within a single node. Different model stages on separate GPUs can process different microbatches concurrently, with each stage's computation overlapping the next stage's data reception (see @lst-pipeline-parallelism-streams):
|
||
|
||
::: {#lst-pipeline-parallelism-streams lst-cap="**Pipeline Parallelism with Streams**: Overlap multiple model stages across microbatches using streams and events for inter-stage synchronization."}
|
||
```{.python}
|
||
# Pipeline parallelism: overlap stages across microbatches
|
||
stages = [Stage1().cuda(), Stage2().cuda(), Stage3().cuda()]
|
||
streams = [torch.cuda.Stream() for _ in stages]
|
||
events = [
|
||
[torch.cuda.Event() for _ in range(num_microbatches)]
|
||
for _ in stages
|
||
]
|
||
|
||
for mb in range(num_microbatches):
|
||
for stage_idx, (stage, stream) in enumerate(zip(stages, streams)):
|
||
with torch.cuda.stream(stream):
|
||
if stage_idx > 0:
|
||
# Wait for previous stage to complete this microbatch
|
||
events[stage_idx - 1][mb].wait()
|
||
|
||
output = stage(inputs[stage_idx][mb])
|
||
events[stage_idx][mb].record()
|
||
```
|
||
:::
|
||
|
||
Extending this pattern across multiple machines requires distributed training techniques that constitute an advanced topic, but the single-node implementation above illustrates the core synchronization principles that underlie all pipeline-parallel systems.
|
||
|
||
With computation and communication overlapping effectively, the remaining challenge is ensuring correctness when operations complete out of order.
|
||
|
||
##### Principle 3: Synchronization and Correctness { .unnumbered}
|
||
|
||
\index{Synchronization Events!inter-stream}
|
||
\index{Synchronization!device vs. stream}
|
||
Concurrent execution introduces ordering constraints. When one stream's output becomes another stream's input, the system must enforce a happens-before relationship without unnecessarily serializing independent work. Two synchronization mechanisms exist, with dramatically different performance implications.
|
||
|
||
Full device synchronization (`torch.cuda.synchronize()`) blocks all streams and the CPU until every queued operation completes. This creates a global serialization point that eliminates all overlap benefits. CUDA events\index{Synchronization Events!fine-grained} provide the alternative: fine-grained synchronization that blocks only the dependent stream, allowing other streams and the CPU to continue execution (see @lst-cuda-events):
|
||
|
||
::: {#lst-cuda-events lst-cap="**CUDA Events for Synchronization**: Events enable fine-grained producer-consumer patterns between streams without blocking the entire device."}
|
||
```{.python}
|
||
# Create streams and event
|
||
stream1 = torch.cuda.Stream()
|
||
stream2 = torch.cuda.Stream()
|
||
event = torch.cuda.Event()
|
||
|
||
# Stream 1: producer
|
||
with torch.cuda.stream(stream1):
|
||
result1 = expensive_computation(data1)
|
||
event.record() # Mark completion point
|
||
|
||
# Stream 2: consumer (waits only for stream1's event)
|
||
with torch.cuda.stream(stream2):
|
||
event.wait() # Block stream2 until event is recorded
|
||
result2 = dependent_computation(result1) # Safe to use result1
|
||
```
|
||
:::
|
||
|
||
The performance difference between these approaches is not incremental but categorical. Full synchronization after every operation converts a concurrent pipeline into a sequential one, entirely negating the hardware parallelism that streams expose. Event-based synchronization preserves the concurrent execution model while enforcing only the dependencies that correctness requires.
|
||
|
||
###### Device Placement Discipline {.unnumbered}
|
||
|
||
Every tensor carries a `device` attribute, and frameworks enforce a strict rule: operations can only combine tensors on the *same* device. A `RuntimeError` results from mixing `cuda:0` and `cuda:1` tensors, preventing silent cross-device transfers. The `.to()` method moves tensors between devices with copy-on-write semantics---calling `.to("cuda")` on a tensor already on the GPU returns the same object without copying. Module `.to()` recursively moves all parameters and buffers, ensuring the entire model hierarchy lands on a single device. Three placement principles prevent transfer bottlenecks: (1) allocate tensors on the target device from the start rather than creating on CPU and transferring, (2) reuse GPU memory across iterations rather than re-allocating, and (3) colocate all inputs, labels, and model parameters on the same device to eliminate implicit transfers. Violating any of these principles inserts PCIe transfers into the critical path, which at `{python} pcie4_gbs_str` GB/s can dominate a training iteration that otherwise runs at `{python} MemoryWallSpecs.a100_bw_tbs_str` TB/s on-device.
|
||
|
||
###### Synchronization Patterns {.unnumbered}
|
||
|
||
As @lst-cuda-events demonstrated, event-based synchronization preserves parallelism by enforcing only the dependencies that correctness requires. A common mistake in production code is inserting `torch.cuda.synchronize()` calls for debugging and forgetting to remove them, silently converting an overlapped pipeline into a serialized one.
|
||
|
||
###### Profiling Transfer Bottlenecks {.unnumbered}
|
||
|
||
When overlap is insufficient, profiling reveals where time is lost. NVIDIA provides two complementary tools: **Nsight Systems** (`nsys profile`) captures system-wide timelines correlating CPU activity, GPU kernel execution, and memory transfers, identifying *which* kernels dominate runtime. **Nsight Compute** (`ncu`) provides kernel-level analysis with hardware counters, revealing *why* those kernels underperform. @tbl-nsight-metrics lists the key metrics to examine when optimizing ML kernels.
|
||
|
||
| **Metric** | **Meaning** | **Optimization Target** |
|
||
|:-----------------------|:-----------------------------|:-----------------------------------|
|
||
| **SM Occupancy** | Active warps / maximum warps | Increase parallelism if low |
|
||
| **Memory Throughput** | Achieved / peak bandwidth | Optimize memory access patterns |
|
||
| **Compute Throughput** | Achieved / peak FLOPS | Reduce memory bottlenecks |
|
||
| **Tensor Core Active** | Time in Tensor Core ops | Verify mixed-precision utilization |
|
||
|
||
: **Nsight Compute Metrics.** Key metrics for ML kernel optimization. Low values indicate specific optimization opportunities. Nsight Systems identifies which kernels dominate runtime, and Nsight Compute reveals why those kernels underperform. {#tbl-nsight-metrics}
|
||
|
||
#### Data Pipelines and Loading {#sec-ml-frameworks-domainspecific-data-organizations-48d9}
|
||
|
||
```{python}
|
||
#| label: dataloader-throughput-calc
|
||
#| echo: false
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ DATALOADER THROUGHPUT CALCULATIONS
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: Domain-Specific Data Organizations section on sustaining GPU throughput
|
||
# │
|
||
# │ Goal: Quantify the data ingestion requirements for high-speed training.
|
||
# │ Show: That sustaining 1000 images/s requires 150 MB/s continuous throughput.
|
||
# │ How: Calculate bandwidth from image resolution, batch size, and rate.
|
||
# │
|
||
# │ Imports: mlsys.constants (PCIE_GEN4_BW, BYTES_FP32), mlsys.formatting (fmt)
|
||
# │ Exports: img_res, dataloader_mbs_str, batch_mb_str, batch_transfer_ms_str, pcie4_gbs_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsys.constants import PCIE_GEN4_BW, BYTES_FP32, MB, GB, byte, second, MS_PER_SEC
|
||
from mlsys.formatting import fmt, check
|
||
|
||
# ┌── P.I.C.O. ISOLATED SCENARIO ───────────────────────────────────────────────
|
||
class DataloaderThroughput:
|
||
"""
|
||
Namespace for Dataloader Throughput.
|
||
Scenario: GPU data ingestion requirements.
|
||
"""
|
||
|
||
# ┌── 1. PARAMETERS (Inputs) ───────────────────────────────────────────────
|
||
img_per_sec = 1000
|
||
img_res = 224
|
||
img_channels = 3
|
||
batch_size = 64
|
||
|
||
# ┌── 2. CALCULATION (The Physics) ─────────────────────────────────────────
|
||
# Throughput requirement
|
||
throughput_bytes_sec = img_per_sec * img_res * img_res * img_channels * byte
|
||
dataloader_mbs = throughput_bytes_sec.to(MB).magnitude
|
||
|
||
# Batch transfer
|
||
batch_bytes = batch_size * img_res * img_res * img_channels * BYTES_FP32.magnitude * byte
|
||
batch_mb = batch_bytes.to(MB).magnitude
|
||
batch_transfer_ms = (batch_bytes / PCIE_GEN4_BW).to(second).magnitude * MS_PER_SEC
|
||
|
||
# PCIe Ref
|
||
pcie4_gbs = PCIE_GEN4_BW.to(GB/second).magnitude
|
||
|
||
# ┌── 3. INVARIANTS (Guardrails) ───────────────────────────────────────────
|
||
check(dataloader_mbs > 100, f"Throughput requirement ({dataloader_mbs:.1f} MB/s) too low.")
|
||
|
||
# ┌── 4. OUTPUTS (Formatting) ──────────────────────────────────────────────
|
||
dataloader_mbs_str = fmt(dataloader_mbs, precision=0, commas=False)
|
||
batch_mb_str = fmt(batch_mb, precision=0, commas=False)
|
||
batch_transfer_ms_str = fmt(batch_transfer_ms, precision=1, commas=False)
|
||
pcie4_gbs_str = fmt(pcie4_gbs, precision=0, commas=False)
|
||
|
||
# ┌── EXPORTS (Bridge to Text) ─────────────────────────────────────────────────
|
||
img_res = DataloaderThroughput.img_res
|
||
dataloader_mbs_str = DataloaderThroughput.dataloader_mbs_str
|
||
batch_mb_str = DataloaderThroughput.batch_mb_str
|
||
batch_transfer_ms_str = DataloaderThroughput.batch_transfer_ms_str
|
||
pcie4_gbs_str = DataloaderThroughput.pcie4_gbs_str
|
||
```
|
||
|
||
\index{Data Pipeline!throughput optimization}
|
||
Streams and events answer the second question---*where does data live, and how does it move?*---by overlapping transfers with computation so that the GPU rarely stalls on a single tensor. But scheduling alone cannot help if data arrives too slowly in the first place. The third question is *how does data arrive fast enough?* The core systems principle is straightforward: the data pipeline must sustain the accelerator's consumption rate. A GPU processing 1,000 images per second at `{python} img_res`×`{python} img_res` resolution requires approximately `{python} dataloader_mbs_str` MB/s of sustained data throughput. If the pipeline cannot maintain this rate, the accelerator idles and the effective utilization term in the **Iron Law** drops below 1.
|
||
|
||
Frameworks address this throughput requirement through three mechanisms. The first is *parallel worker processes*: the DataLoader spawns multiple CPU processes, each independently loading and preprocessing samples. Because data loading involves disk I/O and CPU-bound transformations (decoding, augmentation, normalization), a single process cannot saturate a modern GPU. Multiple workers overlap I/O wait times with preprocessing computation, collectively sustaining throughput that no single process could achieve. When `num_workers > 0`, the DataLoader distributes sample indices across workers through a shared queue, and workers push completed samples to a data queue that the main process assembles into batches.
|
||
|
||
The second mechanism is *prefetching*. The `prefetch_factor` parameter (default 2) controls how many batches each worker prepares in advance. With 4 workers and `prefetch_factor=2`, the pipeline maintains 8 batches in flight, ensuring the GPU never stalls waiting for data. While the model processes batch $N$ on the GPU, workers simultaneously load and preprocess batch $N+1$ through $N+8$ on CPUs, effectively hiding data loading latency behind computation. The cost is memory consumption proportional to batch size times prefetch depth.
|
||
|
||
The third mechanism is *pinned memory for DMA transfers*. The `pin_memory=True` option allocates batch data in page-locked (pinned) host memory rather than pageable memory. Pageable memory can be swapped to disk by the operating system, forcing the CUDA runtime to first copy data to a temporary pinned buffer before initiating the GPU transfer. Pinned memory bypasses this intermediate copy, enabling direct memory access (DMA) transfers where the GPU's memory controller reads directly from host memory while the CPU continues other work. For a batch of 64 images at 224×224×3 in FP32 (`{python} batch_mb_str` MB), pinned memory transfer takes approximately `{python} batch_transfer_ms_str` ms over PCIe 4.0 x16 (`{python} pcie4_gbs_str` GB/s) compared to ~3.0 ms with pageable memory, a 2–3× speedup. The cost is reduced available system memory, as pinned pages cannot be swapped.
|
||
|
||
These three mechanisms appear together in the DataLoader configuration. Understanding how each parameter connects to the underlying systems principle helps practitioners diagnose data pipeline bottlenecks. @lst-dataloader-throughput shows a typical setup where `num_workers` enables parallel loading, `prefetch_factor` controls pipeline depth, and `pin_memory` enables DMA transfers:
|
||
|
||
::: {#lst-dataloader-throughput lst-cap="**DataLoader Throughput Configuration**: Each parameter addresses a specific throughput bottleneck. num_workers parallelizes I/O and preprocessing across CPU cores, prefetch_factor controls pipeline depth, and pin_memory enables DMA transfers to the GPU."}
|
||
```{.python}
|
||
from torch.utils.data import DataLoader
|
||
|
||
loader = DataLoader(
|
||
dataset,
|
||
batch_size=64,
|
||
shuffle=True,
|
||
num_workers=4, # Parallel worker processes (mechanism 1)
|
||
prefetch_factor=2, # Batches prepared ahead per worker (mechanism 2)
|
||
pin_memory=True, # Page-locked memory for DMA (mechanism 3)
|
||
worker_init_fn=seed_worker, # Reproducible augmentation per worker
|
||
)
|
||
|
||
# Pipeline effect: while GPU processes batch N,
|
||
# 4 workers load batches N+1..N+8 into pinned memory,
|
||
# ready for DMA transfer when the GPU finishes.
|
||
```
|
||
:::
|
||
|
||
A practical starting point is setting `num_workers` equal to the number of available CPU cores. The optimal value depends on whether loading is I/O-bound or CPU-bound. For I/O-bound workloads such as reading images from network storage, more workers overlap disk latency and improve throughput. For CPU-bound workloads involving heavy augmentation, the benefit saturates once all cores are utilized. Too many workers waste memory, since each maintains a copy of the Dataset object.
|
||
|
||
Worker process management introduces several subtle issues. Because workers are separate processes, random number generators used in data augmentation must be explicitly seeded per worker via `worker_init_fn` to ensure reproducibility. Without proper seeding, workers may produce identical augmentation sequences, reducing effective data diversity. Shared state between workers presents a separate challenge: each worker has its own memory space, so modifications to global variables in one worker do not propagate to others or to the main process. For large datasets where caching matters, memory-mapped files or shared memory regions that persist across processes are the standard solution.
|
||
|
||
The DataLoader wraps a Dataset object that defines how individual samples are accessed. PyTorch supports two dataset paradigms. *Map-style* datasets implement `__len__` and `__getitem__`, enabling random access to samples by index---this pattern works well for datasets that fit in memory or support efficient random access on disk. *Iterable-style* datasets implement `__iter__` instead, yielding samples sequentially for streaming data sources where random access is impractical. The choice between paradigms determines whether the DataLoader can shuffle samples (map-style only) or must process them in arrival order (iterable-style).
|
||
|
||
A final detail is *collation*: the `collate_fn` parameter determines how individual samples are combined into batches. The default collation stacks tensors along a new batch dimension, which works when all samples have identical shapes. For variable-length data such as text sequences, custom collation handles padding, sorting by length, or creating attention masks---directly affecting both memory usage and training throughput.
|
||
|
||
```{python}
|
||
#| label: gpt3-parameter-structures
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ GPT-3 PARAMETER STRUCTURES
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: Parameter Structures section discussing parameter management scale
|
||
# │
|
||
# │ Goal: Demonstrate the scale challenge of modern LLM parameter management.
|
||
# │ Show: That 175B parameters require 350 GB of storage, necessitating distributed systems.
|
||
# │ How: Calculate total weight bytes in FP16 for the GPT-3 architecture.
|
||
# │
|
||
# │ Imports: mlsys.constants (GPT3_PARAMS, BYTES_FP16), mlsys.formulas (model_memory)
|
||
# │ Exports: gpt3_params_b_str, gpt3_fp16_gb_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsys.constants import GPT3_PARAMS, BYTES_FP16, GB, Bparam
|
||
from mlsys.formulas import model_memory
|
||
from mlsys.formatting import fmt, check
|
||
|
||
# --- Inputs (from model specs) ---
|
||
gpt3_params_b_value = GPT3_PARAMS.to(Bparam).magnitude # 175 billion
|
||
|
||
# --- Process ---
|
||
gpt3_fp16_gb_value = model_memory(GPT3_PARAMS, BYTES_FP16, GB) # 350 GB
|
||
|
||
# --- Outputs (formatted strings for prose) ---
|
||
gpt3_params_b_str = fmt(gpt3_params_b_value, precision=0, commas=False) # e.g. "175" # Note: also defined in gpt3-params; produces same value
|
||
gpt3_fp16_gb_str = fmt(gpt3_fp16_gb_value, precision=0, commas=False) # e.g. "350"
|
||
```
|
||
|
||
DataLoaders, Datasets, and collation functions answer the third question---*how does data arrive fast enough?*---by sustaining accelerator-rate throughput through parallelism, prefetching, and DMA. But these structures handle *ephemeral* data: samples flow through the pipeline once per epoch and are discarded. The fourth question asks how frameworks manage data that *persists*---the model's own weights---especially when those weights exceed the memory of any single device.
|
||
|
||
##### Parameter Structures {.unnumbered}
|
||
|
||
A GPT-3 scale model stores `{python} GPT3Context.gpt3_params_b_str` billion parameters, occupying `{python} gpt3_fp16_gb_str` GB in FP16. Managing these parameters across devices, keeping gradients synchronized, and maintaining optimizer state (which can triple the memory footprint, as the Administrative Tax notebook showed) is a core framework responsibility.
|
||
|
||
Because parameters persist throughout training and inference, frameworks organize them into compact structures that minimize memory while enabling fast read and write access [@li2014communication]. During multi-GPU training, frameworks may replicate parameters across devices for parallel computation while keeping a synchronized master copy. Synchronizing multi-billion parameter models can require transferring tens of GB of gradients per step, which is why frameworks implement gradient compression and efficient communication patterns like ring all-reduce.
|
||
|
||
Parameter structures must also adapt to varying precision requirements. Training typically uses FP32 for gradient stability, but inference and large-scale training increasingly use FP16 or INT8. Frameworks implement type casting and mixed-precision management to enable these optimizations without compromising numerical accuracy.
|
||
|
||
##### Distributed Execution Contexts {.unnumbered}
|
||
|
||
\index{Distributed Training!execution contexts}
|
||
The computational graph defines *what* to compute, but *where* and *how* that computation runs across devices is the job of execution contexts. On a single node, execution contexts manage CUDA streams and events (discussed earlier in this chapter) to overlap computation and data transfer across GPUs.
|
||
|
||
When training scales beyond a single machine, these same abstractions extend to manage process groups and communication primitives. Frameworks use constructs like `ProcessGroup` (PyTorch) or `Mesh` (JAX) to define how devices communicate, maintaining state for collective operations such as AllReduce that synchronize gradients across thousands of GPUs. This includes partitioning computational graphs, synchronizing gradients, and redistributing data as needed.
|
||
|
||
We introduce these concepts here because they shape framework API design even for single-node code. The implementation details of distributed training---including gradient compression, communication topologies, and fault tolerance---constitute advanced topics that build on these single-node foundations.
|
||
|
||
\index{Data Parallelism!definition}
|
||
\index{Pipeline Parallelism!definition}
|
||
When models exceed single-device memory, frameworks combine multiple parallelism strategies simultaneously. A GPT-3 scale model, for instance, cannot fit on a single GPU---its `{python} GPT3Context.gpt3_params_b_str` B parameters alone require `{python} gpt3_fp16_gb_str` GB in FP16, far exceeding any GPU's memory. How do practitioners train such models? By distributing computation across multiple devices using three complementary strategies. @fig-3d-parallelism lays out how large-scale training distributes computation across three orthogonal dimensions to overcome this constraint. In the figure, look for how each dimension addresses a different scaling need: **Data Parallelism**\index{Data Parallelism!model replication} (replicating the model across columns) scales throughput by processing different batches in parallel; **Pipeline Parallelism**\index{Pipeline Parallelism!layer splitting} (splitting layers across rows) distributes a single model's depth across devices; and **Model Parallelism**\index{Model Parallelism!tensor sharding} (sharding tensors within each cluster) partitions individual layers that are too large for one device. This "3D" approach allows frameworks to scale beyond the memory limits of any single device [@mcmahan2017federated]. @sec-model-training examines these parallelism strategies in depth, including their implementation trade-offs and communication patterns.
|
||
|
||
::: {#fig-3d-parallelism fig-env="figure" fig-pos="htb" fig-cap="**3D Parallelism.** A grid of eight accelerator clusters arranged in two rows and four columns, each containing stacked computational units. Distinct colors encode the three parallelism dimensions: data parallelism across columns, pipeline parallelism across rows, and model parallelism within each cluster." fig-alt="Grid of 8 GPU clusters in 2 rows and 4 columns. Each cluster contains 4 stacked cubes. Colors vary: blue, red, green, orange in bottom row; olive, yellow, brown, pink in top row."}
|
||
```{.tikz}
|
||
\resizebox{0.70\textwidth}{!}{
|
||
\begin{tikzpicture}[line cap=round,line join=round,font=\small\usefont{T1}{phv}{m}{n}]
|
||
\tikzset{
|
||
pics/square/.style = {
|
||
code = {
|
||
\pgfkeys{/channel/.cd, #1}
|
||
\begin{scope}[local bounding box=SQUARE,scale=\scalefac,every node/.append style={transform shape}]
|
||
% Right Face
|
||
\draw[fill=\channelcolor!70,line width=\Linewidth]
|
||
(\Depth,0,0)coordinate(\picname-ZDD)--(\Depth,\Width,0)--(\Depth,\Width,\Height)--(\Depth,0,\Height)--cycle;
|
||
% Front Face
|
||
\draw[fill=\channelcolor!40,line width=\Linewidth]
|
||
(0,0,\Height)coordinate(\picname-DL)--(0,\Width,\Height)coordinate(\picname-GL)--
|
||
(\Depth,\Width,\Height)coordinate(\picname-GD)--(\Depth,0,\Height)coordinate(\picname-DD)--(0,0,\Height);
|
||
% Top Face
|
||
\draw[fill=\channelcolor!20,line width=\Linewidth]
|
||
(0,\Width,0)coordinate(\picname-ZGL)--(0,\Width,\Height)coordinate(\picname-ZGL)--
|
||
(\Depth,\Width,\Height)--(\Depth,\Width,0)coordinate(\picname-ZGD)--cycle;
|
||
\end{scope}
|
||
}
|
||
}
|
||
}
|
||
\pgfkeys{
|
||
/channel/.cd,
|
||
Depth/.store in=\Depth,
|
||
Height/.store in=\Height,
|
||
Width/.store in=\Width,
|
||
channelcolor/.store in=\channelcolor,
|
||
drawchannelcolor/.store in=\drawchannelcolor,
|
||
scalefac/.store in=\scalefac,
|
||
Linewidth/.store in=\Linewidth,
|
||
picname/.store in=\picname,
|
||
Depth=1.6,
|
||
Height=1.1,
|
||
Width=1.4,
|
||
channelcolor=BrownLine,
|
||
drawchannelcolor=BrownLine,
|
||
scalefac=1,
|
||
Linewidth=1.0pt,
|
||
picname=C
|
||
}
|
||
\def\ras{0.95}
|
||
\def\dis{2.2}
|
||
\begin{scope}[local bounding box=BELOW,shift={($(0,0)+(0,0)$)},scale=1,every node/.append style={transform shape}]
|
||
\begin{scope}[local bounding box=GPU0,shift={($(0,0)+(0,0)$)},scale=1,every node/.append style={transform shape}]
|
||
\foreach \i in {1,...,4} {
|
||
\pic[shift={(0,0)}] at ({-\i*\ras}, {-\ras*\i}) {square={scalefac=1,picname=4,channelcolor=BlueLine,Linewidth=0.7pt}};
|
||
}
|
||
\end{scope}
|
||
\begin{scope}[local bounding box=GPU8,shift={($(0,0)+(\dis,0)$)},scale=1,every node/.append style={transform shape}]
|
||
\foreach \i in {1,...,4} {
|
||
\pic[shift={(0,0)}] at ({-\i*\ras}, {-\ras*\i}) {square={scalefac=1,picname=12,channelcolor=RedLine,Linewidth=0.7pt}};
|
||
}
|
||
\end{scope}
|
||
\begin{scope}[local bounding box=GPU16,shift={($(0,0)+(2*\dis,0)$)},scale=1,every node/.append style={transform shape}]
|
||
\foreach \i in {1,...,4} {
|
||
\pic[shift={(0,0)}] at ({-\i*\ras}, {-\ras*\i}) {square={scalefac=1,picname=20,channelcolor=GreenLine,Linewidth=0.7pt}};
|
||
}
|
||
\end{scope}
|
||
\begin{scope}[local bounding box=GPU16,shift={($(0,0)+(3*\dis,0)$)},scale=1,every node/.append style={transform shape}]
|
||
\foreach \i in {1,...,4} {
|
||
\pic[shift={(0,0)}] at ({-\i*\ras}, {-\ras*\i}) {square={scalefac=1,picname=28-\i,channelcolor=OrangeLine,Linewidth=0.7pt}};
|
||
}
|
||
\end{scope}
|
||
\end{scope}
|
||
%%%%ABOVE
|
||
\begin{scope}[local bounding box=ABOVE,shift={($(0,0)+(0,2.2)$)},scale=1,every node/.append style={transform shape}]
|
||
\begin{scope}[local bounding box=GPU0,shift={($(0,0)+(0,0)$)},scale=1,every node/.append style={transform shape}]
|
||
\foreach \i in {1,...,4} {
|
||
\pic[shift={(0,0)}] at ({-\i*\ras}, {-\ras*\i}) {square={scalefac=1,picname=0,channelcolor=OliveLine,Linewidth=0.7pt}};
|
||
}
|
||
\end{scope}
|
||
\begin{scope}[local bounding box=GPU8,shift={($(0,0)+(1*\dis,0)$)},scale=1,every node/.append style={transform shape}]
|
||
\foreach \i in {1,...,4} {
|
||
\pic[shift={(0,0)}] at ({-\i*\ras}, {-\ras*\i}) {square={scalefac=1,picname=8,channelcolor=pink,Linewidth=0.7pt}};
|
||
}
|
||
\end{scope}
|
||
\begin{scope}[local bounding box=GPU16,shift={($(0,0)+(2*\dis,0)$)},scale=1,every node/.append style={transform shape}]
|
||
\foreach \i in {1,...,4} {
|
||
\pic[shift={(0,0)}] at ({-\i*\ras}, {-\ras*\i}) {square={scalefac=1,picname=16,channelcolor=green!70!,Linewidth=0.7pt}};
|
||
}
|
||
\end{scope}
|
||
\begin{scope}[local bounding box=GPU16,shift={($(0,0)+(3*\dis,0)$)},scale=1,every node/.append style={transform shape}]
|
||
\foreach \i in {1,...,4} {
|
||
\pic[shift={(0,0)}] at ({-\i*\ras}, {-\ras*\i}) {square={scalefac=1,picname=24,channelcolor=red,Linewidth=0.7pt}};
|
||
}
|
||
\end{scope}
|
||
\end{scope}
|
||
\node[]at($(28-4-GL)!0.5!(28-4-DD)$){GPU 28};
|
||
%
|
||
\foreach \i in {0,8,16,24,4,12,20} {
|
||
\node[]at($(\i-GL)!0.5!(\i-DD)$){GPU \i};
|
||
}
|
||
\draw[thick,decoration={brace,amplitude=5pt,mirror},decorate]([yshift=-2mm]4-DL)--
|
||
([yshift=-2mm]28-4-DD) node [midway,below=2mm] {Pipeline Parallel};
|
||
\draw[thick,decoration={brace,amplitude=5pt},decorate]([xshift=-2mm]4-DL)--
|
||
([xshift=-2mm]0-GL) node [midway,above=5mm, sloped,pos=0.9,anchor=east] {Zero Data Parallel};
|
||
\draw[thick,decoration={brace,amplitude=5pt,mirror},decorate]([xshift=2mm]28-4-DD)--
|
||
([xshift=2mm]28-1-ZDD)node[midway, below=4mm, anchor=west, sloped,pos=0.25] {Model Parallel};
|
||
\end{tikzpicture}}
|
||
```
|
||
:::
|
||
|
||
The data structures examined so far---tensors, device managers, data pipelines, parameter structures, and distributed execution contexts---define *what* data a framework manages and *where* it lives. What remains is the final question: *what actually runs on the hardware?*
|
||
|
||
### Core Operations {#sec-ml-frameworks-core-operations-914f}
|
||
|
||
\index{Operator Kernels!dispatch hierarchy}
|
||
When you write `y = torch.matmul(x, w)`, what actually happens between Python and the GPU? The gap between a single line of Python and thousands of parallel GPU threads is bridged by three distinct layers working in coordination. @fig-mlfm-core-ops breaks this bridge into three distinct layers---read from bottom to top to follow the path from hardware to application: hardware abstraction operations manage computing platform complexity, basic numerical operations implement mathematical computations, and system-level operations coordinate resources and execution.
|
||
|
||
::: {#fig-mlfm-core-ops fig-env="figure" fig-pos="htb" fig-cap="**Core Operations Stack.** Three grouped layers showing how frameworks bridge Python code to hardware. The top layer contains system-level operations (scheduling, memory management, resource optimization), the middle layer holds numerical operations (GEMM, BLAS, element-wise), and the bottom layer provides hardware abstraction (kernel management, memory abstraction, execution control)." fig-alt="Three grouped boxes connected by arrows. System-Level: Scheduling, Memory Management, Resource Optimization. Numerical: GEMM, BLAS, Element-wise Operations. Hardware: Kernel Management, Memory Abstraction, Execution Control."}
|
||
```{.tikz}
|
||
\begin{tikzpicture}[font=\usefont{T1}{phv}{m}{n}\small]
|
||
%
|
||
\tikzset{Line/.style={line width=1.0pt,black!50
|
||
},
|
||
Box/.style={align=flush center,
|
||
inner xsep=2pt,
|
||
node distance=0.3,
|
||
draw=BlueLine,
|
||
line width=0.75pt,
|
||
fill=BlueL,
|
||
text width=34mm,
|
||
minimum width=30mm,
|
||
minimum height=10mm
|
||
},
|
||
}
|
||
\begin{scope}[local bounding box=box1]
|
||
\node[Box,](B1){Scheduling};
|
||
\node[Box,below=of B1](B2){Memory Management};
|
||
\node[Box,below=of B2](B3){Resource Optimization};
|
||
%
|
||
\scoped[on background layer]
|
||
\node[draw=BackLine,inner xsep=4mm,inner ysep=5mm,yshift=3mm,
|
||
fill=BackColor,fit=(B1)(B2)(B3),line width=0.75pt](BB1){};
|
||
\node[below=2pt of BB1.north,anchor=north]{System-Level Operations};
|
||
\end{scope}
|
||
|
||
\begin{scope}[local bounding box=box2,shift={(5.5,0)}]
|
||
\node[Box,fill=BrownL,draw=BrownLine,](B1){GEMM Operations};
|
||
\node[Box,fill=BrownL,draw=BrownLine,below=of B1](B2){BLAS Operations};
|
||
\node[Box,fill=BrownL,draw=BrownLine,below=of B2](B3){Element-wise Operations};
|
||
%
|
||
\scoped[on background layer]
|
||
\node[draw=BackLine,inner xsep=4mm,inner ysep=5mm,yshift=3mm,
|
||
fill=BackColor,fit=(B1)(B2)(B3),line width=0.75pt](BB2){};
|
||
\node[below=2pt of BB2.north,anchor=north]{Basic Numerical Operations};
|
||
\end{scope}
|
||
|
||
\begin{scope}[local bounding box=box3,shift={(11,0)}]
|
||
\node[Box,fill=OrangeL,draw=OrangeLine,](B1){Compute Kernel Management};
|
||
\node[Box,fill=OrangeL,draw=OrangeLine,below=of B1](B2){Memory Abstraction};
|
||
\node[Box,fill=OrangeL,draw=OrangeLine,below=of B2](B3){Execution Control};
|
||
%
|
||
\scoped[on background layer]
|
||
\node[draw=BackLine,inner xsep=4mm,inner ysep=5mm,yshift=3mm,
|
||
fill=BackColor,fit=(B1)(B2)(B3),line width=0.75pt](BB3){};
|
||
\node[below=2pt of BB3.north,anchor=north]{Hardware Operations};
|
||
\end{scope}
|
||
|
||
\foreach \x/\y in{1/2,2/3}
|
||
\draw[-latex,Line](box\x)--(box\y);
|
||
\end{tikzpicture}
|
||
```
|
||
:::
|
||
|
||
#### Hardware Abstraction Operations {#sec-ml-frameworks-hardware-abstraction-operations-1204}
|
||
|
||
The hardware abstraction layer isolates framework code from platform-specific details. It solves three concrete problems: selecting the right compute kernel, moving data through the memory hierarchy, and coordinating execution across processing units.
|
||
|
||
##### Compute Kernel Management {.unnumbered}
|
||
|
||
\index{Kernel Dispatch!hardware selection}
|
||
The kernel manager dispatches each operation to the fastest available implementation for the current hardware. When a framework encounters a matrix multiplication, it selects among AVX-512 vector instructions on modern CPUs, [cuBLAS](https://developer.nvidia.com/cublas) on NVIDIA GPUs, or dedicated tensor processing instructions on AI accelerators. The dispatch decision depends on input dimensions, data layout, and hardware capabilities. A 4096 × 4096 GEMM on an A100 GPU routes to cuBLAS Tensor Core kernels that sustain up to `{python} A100BLAS.dense_tflops_str` TFLOPS in FP16, while the same operation on a CPU falls back to an AVX-512 path at roughly 2 TFLOPS. When no specialized kernel exists, the manager falls back to a generic implementation rather than failing.
|
||
|
||
##### Memory System Abstraction {.unnumbered}
|
||
|
||
\index{Memory Layout!NCHW vs. NHWC}
|
||
The memory abstraction layer moves tensors between device types (CPU registered memory, GPU pinned memory, unified memory) and transforms data layouts to match hardware preferences. A convolutional layer, for example, may store activations in NCHW format (batch, channels, height, width) on NVIDIA GPUs but convert to NHWC for Apple's Metal backend. Alignment requirements vary from 4 bytes on CPUs to 128 bytes on some accelerators, and misaligned access can halve effective memory bandwidth. The layer also enforces cache coherency when multiple execution units read and write the same tensor, preventing silent data corruption during concurrent operations.
|
||
|
||
##### Execution Control {.unnumbered}
|
||
|
||
The execution controller coordinates work across multiple processing units and memory spaces. On a modern GPU, this means managing dozens of concurrent CUDA streams: when two independent convolutions are both ready to execute, the controller launches them on separate streams so they overlap on the GPU's streaming multiprocessors, improving utilization from as low as 40% (sequential) to over 80% (concurrent). The controller inserts synchronization barriers only where true data dependencies exist, tracks event completions to trigger dependent operations, and routes hardware errors (ECC failures, timeout watchdogs) to the framework's error handling path.
|
||
|
||
```{python}
|
||
#| label: resnet50-gflops
|
||
#| echo: false
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ RESNET-50 GFLOPS FOR GEMM DISCUSSION
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: Basic Numerical Operations section discussing GEMM dominance
|
||
# │
|
||
# │ Goal: Demonstrate the computational dominance of GEMM in vision models.
|
||
# │ Show: That ResNet-50 performs 8.2 GFLOPs per forward pass.
|
||
# │ How: Retrieve ResNet-50 GFLOPs from mlsys.constants.
|
||
# │
|
||
# │ Imports: mlsys.constants (RESNET50_FLOPs), mlsys.formatting (fmt)
|
||
# │ Exports: resnet_gflops_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsys.constants import RESNET50_FLOPs, GFLOPs
|
||
from mlsys.formatting import fmt, check
|
||
|
||
# ┌── P.I.C.O. ISOLATED SCENARIO ───────────────────────────────────────────────
|
||
class ResNetGFLOPS:
|
||
"""
|
||
Namespace for ResNet GFLOPS.
|
||
Scenario: Compute intensity check.
|
||
"""
|
||
|
||
# ┌── 1. PARAMETERS (Inputs) ───────────────────────────────────────────────
|
||
flops = RESNET50_FLOPs
|
||
|
||
# ┌── 2. CALCULATION (The Physics) ─────────────────────────────────────────
|
||
gflops = flops.to(GFLOPs).magnitude
|
||
|
||
# ┌── 4. OUTPUTS (Formatting) ──────────────────────────────────────────────
|
||
resnet_gflops_str = fmt(gflops, precision=1, commas=False)
|
||
|
||
# ┌── EXPORTS (Bridge to Text) ─────────────────────────────────────────────────
|
||
resnet_gflops_str = ResNetGFLOPS.resnet_gflops_str
|
||
```
|
||
|
||
#### Basic Numerical Operations {#sec-ml-frameworks-basic-numerical-operations-5516}
|
||
|
||
\index{GEMM!arithmetic intensity}
|
||
With hardware abstraction managing the platform-specific details, frameworks build a layer of mathematical operations on top. General Matrix Multiply (GEMM)\index{GEMM!matrix multiplication}\index{Tensor Operations!GEMM} dominates ML computation (see @sec-algorithm-foundations-general-matrix-multiply-gemm-b55d for arithmetic intensity analysis and the roofline implications). The operation C = $\alpha$AB + $\beta$C accounts for the vast majority of arithmetic in neural networks: a single ResNet-50 forward pass performs approximately `{python} resnet_gflops_str` billion floating-point operations, nearly all of which reduce to GEMM. Frameworks optimize GEMM through cache-aware tiling (splitting matrices into blocks that fit in L1/L2 cache), loop unrolling for instruction-level parallelism, and shape-specific kernels. Fully connected layers use standard dense GEMM, while convolutional layers use im2col transformations that reshape input patches into matrix columns, converting convolution into GEMM.
|
||
|
||
Beyond GEMM, frameworks implement BLAS operations\index{BLAS!vector and matrix operations} (AXPY for vector addition, GEMV for matrix-vector products) and element-wise operations\index{Element-wise Operations!memory bandwidth} (activation functions, normalization). Element-wise operations are individually cheap but collectively expensive due to memory bandwidth. Each operation reads and writes the full tensor, so a sequence of five element-wise operations on a 100 MB tensor moves 1 GB of data. Fusing those five operations into a single kernel reduces memory traffic to 200 MB, a 5× bandwidth savings that directly translates to faster execution.
|
||
|
||
Numerical precision adds another dimension. Training in FP32 uses 4 bytes per parameter; quantizing to INT8 reduces this to 1 byte, cutting memory by 4× and enabling 2–4× throughput improvements on hardware with INT8 acceleration. Training typically requires FP32 for gradient stability, while inference runs at FP16 or INT8 with minimal accuracy loss. Frameworks maintain separate kernel implementations for each precision format and handle mixed-precision workflows where different layers operate at different bit widths within a single forward pass.
|
||
|
||
#### System-Level Operations {#sec-ml-frameworks-systemlevel-operations-6a1c}
|
||
|
||
Hardware abstraction and numerical operations provide the building blocks; system-level operations orchestrate them. The system layer ties scheduling, memory management, and resource optimization into a coherent execution engine.
|
||
|
||
The operation scheduler analyzes the computational graph to find parallelism while respecting data dependencies. In a static graph, the scheduler sees the full dependency structure before execution begins and can plan an optimal ordering. In a dynamic graph, dependencies emerge at runtime, forcing the scheduler to make greedy decisions. Concretely, when a ResNet block produces two independent branch outputs, the scheduler launches both branches simultaneously rather than serializing them, reducing idle cycles on the GPU's streaming multiprocessors.
|
||
|
||
The memory manager allocates and reclaims GPU memory across the computational graph's lifetime. Model parameters (a 7B-parameter model consumes approximately `{python} Model7B.model_7b_fp16_gb_str` GB in FP16) persist for the entire training run, while activation tensors live only until the backward pass consumes them. PyTorch's caching allocator maintains a memory pool, subdividing and reusing freed blocks without returning them to CUDA, which avoids the 1 ms overhead of `cudaMalloc` calls. For models that exceed GPU memory, the manager applies gradient checkpointing: discarding selected activations during the forward pass and recomputing them during the backward pass, trading roughly 20--33% additional compute for 60% or more memory savings (with optimal checkpoint placement).
|
||
|
||
The resource optimizer integrates these scheduling and memory decisions. When two matrix multiplications with different shapes are ready to execute, it selects the algorithm variant (Winograd, Strassen, or standard tiled GEMM) that best fits each shape and the current memory pressure. A poorly scheduled graph wastes compute; a poorly managed memory pool triggers out-of-memory errors on hardware that theoretically has capacity to spare.
|
||
|
||
The preceding sections examined what happens *below* the API surface: tensors manage data layout, streams overlap computation with communication, and kernel dispatch routes operations to hardware. These mechanisms operate at the level of individual tensors and operations---the raw materials of machine learning computation. But practitioners rarely write code at this level. A ResNet-50 has `{python} ResNetMemory.resnet_params_m_str` million parameters organized into dozens of layers; manually tracking each tensor, registering it with an optimizer, and handling device placement would be error-prone and tedious. The abstraction problem is not fully solved by hardware-level mechanisms alone; it also requires a *programming model* that organizes these low-level primitives into the clean APIs that practitioners actually use.
|
||
|
||
:::: {.callout-checkpoint title="Hardware Abstraction" collapse="false"}
|
||
The abstraction problem is the bridge between *portable code* and *efficient execution*.
|
||
|
||
- [ ] **Two dimensions**: Can you distinguish **data representation** (layout, dtype, placement) from **execution mapping** (kernel selection, scheduling), and explain how they constrain each other?
|
||
- [ ] **Kernel dispatch**: Can you explain why the same high-level operation (e.g., GEMM) needs multiple implementations (CPU vector path vs GPU Tensor Core path) and how shapes/dtypes affect the choice?
|
||
- [ ] **Memory abstraction**: Can you explain why frameworks use caching allocators and layout transforms (NCHW↔NHWC) rather than calling device allocators on every tensor?
|
||
- [ ] **Execution control**: Can you describe what the runtime is doing when it overlaps independent work (streams) and inserts synchronization only where dependencies require it?
|
||
::::
|
||
|
||
Individual operations---matrix multiplications, activations, normalizations---are the atoms of deep learning computation. But building models from individual operations would be like building a house from individual atoms. Frameworks need an organizational abstraction that lets engineers compose operations into reusable, nestable building blocks. That abstraction is the *module*.
|
||
|
||
## nn.Module Abstraction {#sec-ml-frameworks-nnmodule-abstraction-2622}
|
||
|
||
\index{Module Abstraction!PyTorch nn.Module}
|
||
\index{Module Abstraction!hierarchical composition}
|
||
The hardware-facing half of the abstraction problem---tensors, kernels, streams, and memory managers---makes individual operations fast on diverse silicon. But a ResNet-50 contains fifty layers, each with multiple parameter tensors, buffers, and mode-dependent behaviors. Manually wiring each tensor to the correct device, registering it with an optimizer, toggling dropout behavior between training and inference, and serializing state for checkpointing---for every layer---would drown practitioners in bookkeeping that has nothing to do with model design. The upper layer of the abstraction problem is organizational: how do frameworks compose thousands of low-level primitives into the clean, composable APIs that practitioners actually use?
|
||
|
||
Every major framework answers this question through a *module abstraction* that bundles parameters, forward computation, and state management into a single reusable unit. PyTorch's `nn.Module`[^fn-nn-module] provides an instructive case study because its design patterns recur across frameworks: Keras uses similar layer abstractions, JAX's Flax employs analogous module structures, and TensorFlow's functional API shares conceptual parallels. Rather than catalog its API, we extract three enduring design principles that every framework must address regardless of its syntax or programming paradigm.
|
||
|
||
[^fn-nn-module]: **nn.Module**: PyTorch's base class for all neural network components, introduced in PyTorch's first release (2016). The `nn` namespace abbreviates "neural network," and "Module" borrows from software engineering's module pattern—a self-contained unit with defined inputs, outputs, and internal state. The class provides three mechanisms automatically: parameter registration (via Python's `__setattr__` metaclass interception), recursive tree traversal (for device placement, serialization, and optimizer access), and mode propagation (training vs. evaluation flags). Every layer, block, and full model inherits from `nn.Module`, creating a uniform interface that enables the three principles described in this section.
|
||
|
||
### Principle 1: Automatic Parameter Discovery { .unnumbered}
|
||
|
||
\index{nn.Module!parameter discovery}
|
||
\index{Parameter Registration!automatic tracking}
|
||
A modern neural network may contain millions of trainable parameters spread across dozens of layers. Without automation, a programmer would need to enumerate every parameter tensor and pass it to the optimizer manually, an error-prone process that scales poorly with model complexity. Frameworks solve this through *automatic parameter discovery*: the system walks the module tree, collecting every parameter tensor so the optimizer can update them in a single call.
|
||
|
||
This is a graph traversal problem at its core. When a developer assigns an `nn.Parameter` as a class attribute, the framework's metaclass machinery intercepts the assignment and registers the tensor in an internal dictionary. A call to `.parameters()` then performs a recursive depth-first traversal of the module tree, yielding every registered parameter. The same pattern appears in every major framework: Keras layers maintain a `trainable_weights` list, JAX's Flax modules use `init()` to return a nested parameter dictionary, and TensorFlow's `tf.Module` provides `trainable_variables`. The mechanism differs but the principle is universal.
|
||
|
||
The systems consequence is significant. Automatic parameter discovery enables `optimizer.step()` to update millions of parameters in a single vectorized operation, keeping the operations-per-parameter term efficient by avoiding per-parameter Python dispatch. Without this abstraction, each parameter update would require a separate Python function call, and the interpreter overhead alone would dominate training time for large models. @lst-parameter_registration demonstrates the core mechanism: attribute assignment triggers registration, and `.parameters()` returns all discovered tensors.
|
||
|
||
::: {#lst-parameter_registration lst-cap="**Parameter Registration**: Automatic parameter tracking through attribute assignment enables optimizer access to all trainable weights without manual enumeration."}
|
||
```{.python}
|
||
import torch
|
||
import torch.nn as nn
|
||
|
||
|
||
class CustomLayer(nn.Module):
|
||
def __init__(self, input_size, output_size):
|
||
super().__init__()
|
||
self.weight = nn.Parameter(
|
||
torch.randn(output_size, input_size)
|
||
)
|
||
self.bias = nn.Parameter(torch.randn(output_size))
|
||
self.register_buffer("running_mean", torch.zeros(output_size))
|
||
|
||
def forward(self, x):
|
||
return torch.matmul(x, self.weight.t()) + self.bias
|
||
|
||
|
||
layer = CustomLayer(10, 20)
|
||
# Framework discovers both parameters automatically:
|
||
for name, param in layer.named_parameters():
|
||
print(f"{name}: shape {param.shape}")
|
||
```
|
||
:::
|
||
|
||
The distinction between *parameters* and *buffers* illustrates a subtlety of the discovery mechanism. Parameters carry `requires_grad=True` and participate in gradient computation. Buffers, registered through `register_buffer()`, travel with the model during device transfers but remain excluded from gradient updates. This separation is essential for normalization layers, where running statistics must persist across batches but must not receive gradients. The same dual-track design appears in Keras (via `non_trainable_weights`) and Flax (via `state` versus `params`).
|
||
|
||
::: {.callout-perspective title="Cross-Framework Parameter Discovery" collapse="true"}
|
||
The same principle manifests differently across frameworks:
|
||
|
||
| **Framework** | **Parameter Access** | **Non-Trainable State** |
|
||
|:---------------|:------------------------------|:---------------------------------|
|
||
| **PyTorch** | `model.parameters()` | `register_buffer()` |
|
||
| **Keras** | `layer.trainable_weights` | `layer.non_trainable_weights` |
|
||
| **JAX/Flax** | `params = model.init(key, x)` | Separate `state` dict |
|
||
| **TensorFlow** | `module.trainable_variables` | `module.non_trainable_variables` |
|
||
|
||
Despite syntactic differences, all frameworks solve the same problem: enabling optimizers to discover and update trainable parameters while preserving non-trainable state across forward passes.
|
||
:::
|
||
|
||
### Principle 2: Mode-Dependent Behavior { .unnumbered}
|
||
|
||
\index{nn.Module!train vs. eval modes}
|
||
\index{Dropout!training vs. inference}
|
||
\index{BatchNormalization!mode-dependent behavior}
|
||
Training and inference require different computational behavior from the same model graph. During training, dropout layers randomly zero elements with probability $p$ to regularize the network, while during inference those same layers must perform identity mapping to produce deterministic outputs. Batch normalization uses per-batch statistics during training but switches to accumulated running statistics during inference. If these behavioral changes are left to the programmer, forgetting a single mode switch produces silently incorrect predictions in production.
|
||
|
||
Frameworks solve this with a *state flag* that propagates through the module hierarchy. A single call to `.eval()` on the root module recursively sets `self.training = False` on every descendant, and each layer queries this flag to select its behavior. This is an instance of a broader systems principle: the same computation graph must produce different execution behavior depending on context. Compilers face the same challenge when the same source code must produce debug builds (with bounds checking and symbol tables) versus release builds (with aggressive optimization). The flag-propagation pattern ensures correctness by centralizing the mode decision at the root rather than requiring per-layer coordination.
|
||
|
||
This principle extends to parameter freezing for transfer learning. Setting `requires_grad=False` on specific parameters excludes them from gradient computation, effectively creating a third behavioral mode where some parameters train while others remain fixed. Selective freezing achieves computational savings by pruning the backward pass graph: frozen parameters need no gradient storage, reducing memory consumption proportionally.
|
||
|
||
### Principle 3: Hierarchical Composition and Serialization { .unnumbered}
|
||
|
||
\index{nn.Module!hierarchical composition}
|
||
\index{Model Serialization!state dictionary pattern}
|
||
Complex models compose from reusable submodules, creating a tree structure. A ResNet is not implemented as a monolithic block of operations but as a hierarchy: the root module contains a sequence of residual blocks, each block contains convolution layers and normalization layers, and each layer contains parameter tensors. This hierarchical composition must support two critical operations: *recursive parameter collection* for training and *state serialization* for checkpointing and deployment.
|
||
|
||
Hierarchical composition mirrors the hardware memory hierarchy in a systems-relevant way: each submodule's parameters can be loaded independently, enabling model parallelism across devices. When a model is too large for a single GPU, the framework can assign different subtrees of the module hierarchy to different devices, with the tree structure providing natural partition boundaries.
|
||
|
||
The state dictionary mechanism provides the serialization half of this principle. The `state_dict()` method produces a flat key-value mapping of the full module tree, where dotted path names (e.g., `blocks.0.conv1.weight`) encode the hierarchy. This flat structure enables efficient serialization: a 7B-parameter model's approximately `{python} Model7B.model_7b_fp16_gb_str` GB FP16 checkpoint can be written as a sequential byte stream, maximizing storage bandwidth utilization. The inverse operation, `load_state_dict()`, reconstructs the hierarchy from the flat mapping, enabling checkpoint recovery and cross-framework model exchange via formats like ONNX. @lst-nested_modules demonstrates how the module tree enables both recursive parameter access and hierarchical state serialization.
|
||
|
||
::: {#lst-nested_modules lst-cap="**Nested Module Composition**: Hierarchical module composition enables recursive parameter collection and flat state serialization across the module tree."}
|
||
```{.python}
|
||
import torch
|
||
import torch.nn as nn
|
||
|
||
|
||
class ResidualBlock(nn.Module):
|
||
def __init__(self, channels):
|
||
super().__init__()
|
||
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
|
||
self.bn1 = nn.BatchNorm2d(channels)
|
||
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
|
||
self.bn2 = nn.BatchNorm2d(channels)
|
||
|
||
def forward(self, x):
|
||
residual = x
|
||
x = torch.relu(self.bn1(self.conv1(x)))
|
||
x = self.bn2(self.conv2(x))
|
||
return torch.relu(x + residual)
|
||
|
||
|
||
class ResNet(nn.Module):
|
||
def __init__(self, num_blocks, channels=64):
|
||
super().__init__()
|
||
self.conv_in = nn.Conv2d(3, channels, 7, padding=3)
|
||
self.blocks = nn.ModuleList(
|
||
[ResidualBlock(channels) for _ in range(num_blocks)]
|
||
)
|
||
self.fc = nn.Linear(channels, 10)
|
||
|
||
def forward(self, x):
|
||
x = self.conv_in(x)
|
||
for block in self.blocks:
|
||
x = block(x)
|
||
x = x.mean(dim=[2, 3]) # Global average pooling
|
||
return self.fc(x)
|
||
|
||
|
||
model = ResNet(num_blocks=4)
|
||
total = sum(p.numel() for p in model.parameters())
|
||
print(f"Total parameters: {total}")
|
||
# state_dict() flattens the tree: 'blocks.0.conv1.weight', etc.
|
||
print(list(model.state_dict().keys())[:4])
|
||
```
|
||
:::
|
||
|
||
The hierarchical structure also enables module-level traversal for systematic operations. Methods like `.named_modules()` iterate the entire tree, supporting bulk transformations such as replacing all BatchNorm layers with GroupNorm or applying Xavier initialization to every Linear layer. These traversal operations depend on the same tree structure that enables parameter discovery, illustrating how a single design decision propagates benefits across multiple use cases.
|
||
|
||
These three principles, automatic parameter discovery, mode-dependent behavior, and hierarchical composition with serialization, are not PyTorch-specific. Every framework must solve them. Keras layers, JAX's Flax modules, and even functional approaches all address the same problems of parameter management, state tracking, and compositional design. The differences lie not in *what* problems they solve but in *how* they prioritize among competing solutions. Two practical patterns built on these principles deserve attention: *selective parameter freezing* for transfer learning (@lst-parameter_freezing) and *module hooks* for non-invasive inspection (@lst-module_hooks).
|
||
|
||
::: {#lst-parameter_freezing lst-cap="**Parameter Freezing**: Demonstrates selective parameter freezing for transfer learning, where pretrained layers remain fixed while new layers train."}
|
||
```{.python}
|
||
# Freeze all parameters in a pretrained model
|
||
pretrained_model = torch.hub.load(
|
||
"pytorch/vision", "resnet18", pretrained=True
|
||
)
|
||
|
||
for param in pretrained_model.parameters():
|
||
param.requires_grad = False
|
||
|
||
# Replace final layer with trainable parameters
|
||
pretrained_model.fc = nn.Linear(512, 10) # New layer is trainable
|
||
|
||
# Only fc.parameters() will receive gradients during training
|
||
optimizer = torch.optim.Adam(
|
||
filter(lambda p: p.requires_grad, pretrained_model.parameters()),
|
||
lr=0.001,
|
||
)
|
||
```
|
||
:::
|
||
|
||
\index{Module Hooks!forward and backward}
|
||
Forward and backward hooks intercept intermediate computations without modifying model code, enabling gradient flow diagnosis and activation monitoring. @lst-module_hooks illustrates both hook types.
|
||
|
||
::: {#lst-module_hooks lst-cap="**Module Hooks**: Shows forward and backward hooks for inspecting activations and gradients during training."}
|
||
```{.python}
|
||
import torch
|
||
import torch.nn as nn
|
||
|
||
model = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5))
|
||
|
||
|
||
# Forward hook to inspect activations
|
||
def forward_hook(module, input, output):
|
||
print(
|
||
f"Layer: {module.__class__.__name__}, "
|
||
f"Output shape: {output.shape}, "
|
||
f"mean={output.mean():.3f}, "
|
||
f"std={output.std():.3f}"
|
||
)
|
||
|
||
|
||
# Backward hook to inspect gradients
|
||
def backward_hook(module, grad_input, grad_output):
|
||
print(f"Gradient norm: {grad_output[0].norm():.3f}")
|
||
|
||
|
||
# Register hooks on specific layer
|
||
handle_fwd = model[0].register_forward_hook(forward_hook)
|
||
handle_bwd = model[0].register_full_backward_hook(backward_hook)
|
||
|
||
# Execute forward and backward pass
|
||
x = torch.randn(32, 10)
|
||
y = model(x)
|
||
loss = y.sum()
|
||
loss.backward()
|
||
|
||
# Remove hooks when done
|
||
handle_fwd.remove()
|
||
handle_bwd.remove()
|
||
```
|
||
:::
|
||
|
||
Together, these patterns---parameter discovery, freezing, and hooks---demonstrate how the three principles translate into practical APIs. The `nn.Module` patterns above illustrate PyTorch's approach to the abstraction problem. PyTorch, however, is only one of several major frameworks, and its choices (mutable state, class inheritance, eager execution by default) are not the only valid design points. TensorFlow centralizes state differently, and JAX avoids mutable state entirely. These are not superficial API differences; they reflect deeply different answers to the three problems we examined at the chapter's start.
|
||
|
||
## Framework Platform Analysis {#sec-ml-frameworks-major-framework-platform-analysis-fe96}
|
||
|
||
\index{ML Framework!platform comparison}
|
||
Each major framework represents a distinct point in the design space defined by the three core problems: TensorFlow prioritizes the **Abstraction Problem** through its comprehensive deployment ecosystem, PyTorch prioritizes the **Execution Problem** through its dynamic graph approach, and JAX reframes the **Differentiation Problem** through composable function transformations. These are not just API differences but fundamental capability trade-offs that determine what each framework can and cannot do well.
|
||
|
||
### TensorFlow: The Graph-First Production Machine {#sec-ml-frameworks-tensorflow-ecosystem-063c}
|
||
|
||
\index{TensorFlow!graph-first architecture}
|
||
\index{TensorFlow!production deployment}
|
||
TensorFlow's architecture reflects a comprehensive solution to the **Abstraction Problem**: how do you target diverse hardware---from cloud TPUs to microcontrollers---using a single interface? Its design philosophy is built on the **Static Graph** (or "Define-and-Run") principle. By requiring the model to be represented as a complete computational graph before execution, TensorFlow enables ahead-of-time (AOT) compilation and optimization.
|
||
|
||
This approach prioritizes the **Deployment Spectrum**. Because the framework sees the entire graph, it can perform aggressive optimizations like constant folding, operator fusion, and memory layout optimization before the first byte of data is processed. This is why TensorFlow remains the standard for complex production ecosystems. @fig-tensorflow-architecture maps the full training-to-deployment pipeline---trace how a model flows from data preprocessing through distributed training on the left, then fans out to serving, mobile (TF Lite), browser (TF.js), and language bindings on the right.
|
||
|
||
::: {#fig-tensorflow-architecture fig-env="figure" fig-pos="htb" fig-cap="**TensorFlow Training-to-Deployment Pipeline.** Two-column diagram showing the training path (left) from data preprocessing through tf.keras and distribution strategy across CPU, GPU, and TPU, and the deployment path (right) from SavedModel export to TensorFlow Serving, Lite, JS, and language bindings. Source: [TensorFlow.](https://blog.tensorflow.org/2019/01/whats-coming-in-tensorflow-2-0.html)." fig-alt="Two-column diagram. Training: data preprocessing, tf.keras, TensorFlow Hub, Premade Estimators, Distribution Strategy across CPU/GPU/TPU. Deployment via SavedModel to TensorFlow Serving, Lite, JS, and language bindings."}
|
||
```{.tikz}
|
||
\begin{tikzpicture}[font=\usefont{T1}{phv}{m}{n}\small]
|
||
%
|
||
\tikzset{Line/.style={line width=1.0pt,black!50
|
||
},
|
||
Box/.style={align=flush center,
|
||
inner xsep=4pt,
|
||
node distance=0.8,
|
||
draw=BlueLine,
|
||
line width=0.75pt,
|
||
fill=BlueL,,
|
||
minimum height=11mm
|
||
},
|
||
}
|
||
|
||
\node[Box,text width=70mm,fill= BrownL,
|
||
draw= BrownLine](B1){\textbf{Read \& Preprocess Data}\\ tf.data, feature columns};
|
||
\node[Box,fill= BrownL,draw= BrownLine,below=of B1.south west,minimum width=20mm,
|
||
anchor=north west](B2){\textbf{tf.keras}};
|
||
\node[Box,fill= BrownL,draw= BrownLine,below=of B1.south east,,minimum width=20mm,
|
||
anchor=north east](B3){\textbf{Premade}\\\textbf{Estimators}};
|
||
\node[Box,fill= BrownL,draw= BrownLine,
|
||
minimum width=20mm](B4)at($(B2.east)!0.5!(B3.west)$){\textbf{TensorFlow}\\\textbf{Hub}};
|
||
%
|
||
\node[Box,text width=70mm,fill= BrownL,below=of B4,
|
||
draw= BrownLine](B5){\textbf{Distribution Strategy}};
|
||
\node[Box,fill= BrownL,draw= BrownLine,below=of B5.south west,minimum width=18mm,
|
||
anchor=north west](B6){\textbf{CPU}};
|
||
\node[Box,fill= BrownL,draw= BrownLine,below=of B5.south east,minimum width=18mm,
|
||
anchor=north east](B7){\textbf{TPU}};
|
||
\node[Box,fill= BrownL,draw= BrownLine,minimum width=18mm](B8)at($(B6.east)!0.5!(B7.west)$){\textbf{GPU}};
|
||
%
|
||
\node[Box,fill= BlueL,draw= BlueLine,right=1.0 of $(B1.east)!0.5!(B7.east)$](B9){\textbf{SavedModel}};
|
||
%
|
||
\def\di{4.35}
|
||
\node[Box,text width=50mm,fill= RedL,right=\di of B1,
|
||
draw= RedLine](L1){\textbf{TensorFlow Serving}\\ Cloud, on-prem};
|
||
\node[Box,text width=50mm,fill= RedL,right=\di of B3,
|
||
draw= RedLine](L2){\textbf{TensorFlow Lite}\\ Android, iOS, Raspberry Pi};
|
||
\node[Box,text width=50mm,fill= RedL,right=\di of B5,
|
||
draw= RedLine](L3){\textbf{TensorFlow.js}\\ Browser and Node Server};
|
||
\node[Box,text width=50mm,fill= RedL,right=\di of B7,
|
||
draw= RedLine](L4){\textbf{Other Language Bindings}\\ C, Java, Go, C\#, Rust, R,\ldots};
|
||
%
|
||
\node[above=2mm of B1]{\textbf{TRAINING}};
|
||
\node[above=2mm of L1]{\textbf{DEPLOYMENT}};
|
||
%
|
||
\draw[latex-,Line](B2)--(B1.south-|B2);
|
||
\draw[latex-,Line](B3)--(B1.south-|B3);
|
||
\draw[-latex,Line](B4)--(B2);
|
||
\draw[-latex,Line](B4)--(B3);
|
||
\draw[-latex,Line](B2)--(B5.north-|B2);
|
||
\draw[-latex,Line](B3)--(B5.north-|B3);
|
||
\draw[latex-,Line](B6)--(B5.south-|B6);
|
||
\draw[latex-,Line](B7)--(B5.south-|B7);
|
||
\draw[latex-,Line](B8)--(B5.south-|B8);
|
||
\draw[Line](B6)--++(270:1)-|(B7);
|
||
\draw[-latex,Line](B8)-++(270:1.35)-|(B9);
|
||
\foreach \x in {1,2,3,4}
|
||
\draw[-latex,Line](B9.east)--(L\x.west);
|
||
\end{tikzpicture}
|
||
```
|
||
:::
|
||
|
||
While TensorFlow 2.0 introduced eager execution to bridge the gap between research and production, its core strength remains the robust, compiled path from research to global-scale deployment. @sec-model-training examines how TensorFlow's distribution strategies enable large-scale training, while @sec-model-serving covers its production serving infrastructure.
|
||
|
||
### PyTorch: The Eager Research Standard {#sec-ml-frameworks-pytorch-85cd}
|
||
|
||
\index{PyTorch!eager execution default}
|
||
\index{PyTorch!research-first design}
|
||
Where TensorFlow's graph-first approach prioritizes production optimization, PyTorch makes the opposite trade-off: it prioritizes *developer experience*. PyTorch's architecture represents a sharply different answer to the **Execution Problem**, built on **Dynamic Graphs** (or "Define-by-Run"). Instead of building a blueprint before execution, PyTorch builds the computational graph on-the-fly as the code runs.
|
||
|
||
This approach won the research community for a simple reason: it treats deep learning as standard Python programming. You can use Python loops, conditionals, and debuggers (like `pdb`) directly within your model's forward pass, with no special syntax, no separate compilation step, and no waiting to see if your code works. This "Eager Execution" model enables rapid iteration and intuitive model design, which is essential for the trial-and-error nature of frontier AI research.
|
||
|
||
PyTorch's answer to the **Differentiation Problem** is the tape-based autograd system examined in @sec-ml-frameworks-pytorch-autograd-internals-4fa0: flexible and debuggable, but harder to optimize globally because the tape is rebuilt each iteration. Its answer to the **Abstraction Problem** is more pragmatic than comprehensive: strong GPU support through cuBLAS and cuDNN, but deployment to mobile, edge, and browser environments requires exporting through ONNX or specialized runtimes rather than a native path.
|
||
|
||
The trade-off is therefore a more fragmented deployment path. Because the graph is dynamic, the framework cannot easily perform global optimizations before execution. A model that works perfectly in development may hit performance walls in production when dispatch overhead dominates small operations. To bridge this research-to-production gap, PyTorch introduced **TorchScript** and **PyTorch 2.0 (with `torch.compile`)**, which allow developers to capture a dynamic model and turn it into an optimized, static representation for deployment. This evolution shows PyTorch moving toward the production end of the compilation continuum while preserving the eager experience that made it dominant in research.
|
||
|
||
### JAX: The Functional Transformation Engine {#sec-ml-frameworks-jax-functional-transformation-engine-242e-functional-transformation-engine-242e}
|
||
|
||
\index{JAX!functional transformations}
|
||
\index{JAX!composable program transformations}
|
||
\index{JAX!pure functions requirement}
|
||
\index{JAX!transformation composition}
|
||
PyTorch's eager execution and TensorFlow's graph compilation represent two points on a spectrum, yet both share an imperative programming heritage where computation proceeds as a sequence of stateful operations. JAX represents a radically different approach, one built on functional programming principles and composable program transformations rather than computational graphs [@jax2018github]. Developed by Google Research, JAX has gained significant traction in research settings, particularly for work requiring custom differentiation, advanced optimization research, and large-scale distributed training.
|
||
|
||
JAX's architecture reframes the **Differentiation Problem** entirely. Rather than implementing automatic differentiation as a tape-based system (PyTorch) or a graph transformation pass (TensorFlow), JAX treats differentiation as one of several *composable function transformations*. The `jax.grad` function does not compute gradients directly; it returns a *new function* that computes gradients. This subtle distinction enables powerful compositions: you can differentiate a differentiated function (higher-order derivatives), vectorize a gradient computation (`vmap(grad(f))`), or compile a vectorized gradient to XLA (`jit(vmap(grad(f)))`).
|
||
|
||
JAX's functional paradigm requires a genuine mental shift from "tracking state through objects" to "transforming pure functions." The conceptual introduction here covers JAX's core design; transformation composition, pytree handling, and XLA tracing mechanics each warrant dedicated study for production use.
|
||
|
||
#### Transformations over State {#sec-ml-frameworks-transformations-state-3b46}
|
||
|
||
While PyTorch and TensorFlow build computational graphs (dynamically or statically), JAX transforms functions. The core insight is that automatic differentiation, vectorization, and JIT compilation are all *program transformations* that can compose. @lst-jax-transformations demonstrates this composable approach.
|
||
|
||
::: {#lst-jax-transformations lst-cap="**JAX Function Transformations**: JAX treats differentiation, vectorization, and compilation as composable function transformations rather than graph operations."}
|
||
```{.python}
|
||
import jax
|
||
import jax.numpy as jnp
|
||
|
||
|
||
def loss_fn(params, x, y):
|
||
pred = jnp.dot(x, params["w"]) + params["b"]
|
||
return jnp.mean((pred - y) ** 2)
|
||
|
||
|
||
# Transform: compute gradients
|
||
grad_fn = jax.grad(loss_fn)
|
||
|
||
# Transform: vectorize over batch dimension
|
||
batched_grad = jax.vmap(grad_fn, in_axes=(None, 0, 0))
|
||
|
||
# Transform: compile to XLA
|
||
fast_batched_grad = jax.jit(batched_grad)
|
||
|
||
# Compose all three: fast, batched gradient computation
|
||
```
|
||
:::
|
||
|
||
This functional approach requires **pure functions** (no side effects) and **immutable data** (arrays cannot be modified in place). These constraints may seem restrictive coming from PyTorch's mutable object model, but they enable powerful guarantees: the compiler can safely reorder, fuse, and parallelize operations because function outputs depend only on inputs. The restriction is the feature; purity is what makes transformation composition possible.
|
||
|
||
#### Key Transformations {#sec-ml-frameworks-key-transformations-4105}
|
||
|
||
JAX's power emerges from composition. Start with a loss function `f` and apply `jax.grad`\index{JAX!gradient transformation} to obtain a new function that computes gradients---unlike PyTorch's tape-based autograd, `grad` returns a *function*, not a value, supporting both forward-mode (`jacfwd`) and reverse-mode (`jacrev`) differentiation. Wrap that gradient function in `jax.jit`\index{JAX!JIT compilation} and JAX traces it once, compiles to optimized XLA machine code, caches the result, and eliminates Python overhead on subsequent calls. Apply `jax.vmap`\index{JAX!automatic vectorization}\index{Vectorization!automatic batching} to the compiled gradient function and it automatically vectorizes across a batch dimension, transforming single-example code into batched code without manual reshaping. Finally, `jax.pmap`\index{JAX!device-parallel mapping} maps the vectorized, compiled gradient function across multiple GPUs or TPUs, automatically handling inter-device communication. The result---`pmap(jit(vmap(grad(f))))`---expresses distributed, compiled, batched gradient computation as a single composed expression. No other framework offers this level of compositional power.
|
||
|
||
#### Ecosystem and Libraries {#sec-ml-frameworks-ecosystem-libraries-5dae}
|
||
|
||
JAX's minimalist core delegates neural network abstractions to companion libraries (Flax, Haiku, Equinox) and optimization to Optax. This separation reflects the functional philosophy: the core provides transformations, while libraries build conventional abstractions on top. The ecosystem is younger and smaller than PyTorch's or TensorFlow's, which affects the availability of pre-built components for production use.
|
||
|
||
#### Trade-offs and Use Cases {#sec-ml-frameworks-tradeoffs-use-cases-6453}
|
||
|
||
\index{JAX!XLA compilation}
|
||
The functional constraints that JAX imposes become advantages in specific domains. Custom differentiation---higher-order gradients, custom VJP/JVP rules---composes cleanly because pure functions make differentiation rules predictable. Research on optimization algorithms benefits from transformations that let researchers manipulate gradient computation as naturally as they manipulate data. Large-scale distributed training, particularly on TPUs, leverages XLA compilation to extract maximum hardware utilization. Scientific computing with AD requirements benefits from functional purity that enables mathematical reasoning about code. JAX requires more upfront investment than PyTorch: the functional paradigm has a learning curve, state management requires explicit patterns, and debugging compiled code is harder than eager execution. Teams should choose JAX when its strengths align with project requirements, not as a default.
|
||
|
||
### Quantitative Platform Performance Analysis {#sec-ml-frameworks-quantitative-platform-performance-analysis-816d}
|
||
|
||
The preceding sections described each framework's design philosophy in qualitative terms: graph-first versus eager-first, stateful versus functional. Design philosophy claims, however, are only meaningful when backed by measurement. @tbl-mlfm-comparison quantifies how the architectural choices of TensorFlow, PyTorch, and JAX translate to system characteristics. When examining this comparison, note particularly the differences in execution mode, compilation optimization potential, and distributed scalability---these dimensions most directly impact production deployment decisions.
|
||
|
||
| **Aspect** | **TensorFlow** | **PyTorch** | **JAX** |
|
||
|:------------------------------|:---------------------------------|:-----------------|:---------------------------|
|
||
| **Graph Type** | Static (1.x), Dynamic (2.x) | Dynamic | Functional transformations |
|
||
| **Programming Model** | Imperative (2.x), Symbolic (1.x) | Imperative | Functional |
|
||
| **Core Data Structure** | Tensor (mutable) | Tensor (mutable) | Array (immutable) |
|
||
| **Execution Mode** | Eager (2.x default), Graph | Eager | Just-in-time compilation |
|
||
| **Automatic Differentiation** | Reverse mode | Reverse mode | Forward and Reverse mode |
|
||
| **Hardware Acceleration** | CPU, GPU, TPU | CPU, GPU | CPU, GPU, TPU |
|
||
| **Compilation Optimization** | XLA: 3–10× speedup | TorchScript: 2× | XLA: 3–10× speedup |
|
||
| **Memory Efficiency** | 70--90% (workload dependent) | 70--90% (varies) | 75--95% (with XLA fusion) |
|
||
| **Distributed Scalability** | High (1024+ GPUs) | High | Very High (1024+ GPUs) |
|
||
|
||
: **Framework Characteristics.** Each column reflects a distinct answer to the three core problems. TensorFlow's static graph roots enable XLA compilation but constrain dynamic control flow; PyTorch's eager default maximizes debugging flexibility but limits ahead-of-time optimization; JAX's functional model enables composable transformations and achieves the highest memory efficiency through XLA fusion. {#tbl-mlfm-comparison}
|
||
|
||
An important caveat applies to these numbers: GPU utilization and compilation speedups vary significantly by model architecture, batch size, and operation mix. JAX/XLA achieves higher utilization for TPU workloads through aggressive fusion, while PyTorch and TensorFlow perform similarly for most deep learning workloads. These framework-level generalizations provide useful orientation but cannot substitute for profiling specific workloads on target hardware.
|
||
|
||
How do these architectural differences look in practice? @lst-framework-hello-world implements the same neural network (a single linear layer mapping 10 inputs to 1 output) across all three frameworks, revealing how design philosophy shapes even the simplest code:
|
||
|
||
::: {#lst-framework-hello-world lst-cap="**Framework Comparison: Hello World**: The same simple neural network implemented in PyTorch (object-oriented), TensorFlow/Keras (declarative), and JAX (functional), illustrating each framework's distinct design philosophy."}
|
||
```{.python}
|
||
# PyTorch - Dynamic, Pythonic
|
||
import torch.nn as nn
|
||
|
||
|
||
class SimpleNet(nn.Module):
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.fc = nn.Linear(10, 1)
|
||
|
||
def forward(self, x):
|
||
return self.fc(x)
|
||
|
||
|
||
# TensorFlow/Keras - High-level API
|
||
import tensorflow as tf
|
||
|
||
model = tf.keras.Sequential(
|
||
[tf.keras.layers.Dense(1, input_shape=(10,))]
|
||
)
|
||
|
||
# JAX - Functional approach
|
||
import jax.numpy as jnp
|
||
from jax import random
|
||
|
||
|
||
def simple_net(params, x):
|
||
return jnp.dot(x, params["w"]) + params["b"]
|
||
|
||
|
||
key = random.PRNGKey(0)
|
||
params = {
|
||
"w": random.normal(key, (10, 1)),
|
||
"b": random.normal(key, (1,)),
|
||
}
|
||
```
|
||
:::
|
||
|
||
These three implementations solve the same mathematical problem but reveal distinct answers to the Three Problems. The differences are not cosmetic; they shape debugging workflows, deployment options, and optimization potential.
|
||
|
||
**PyTorch** binds state and computation together through class inheritance (`nn.Module`), solving the Execution Problem through eager evaluation: the graph builds as Python runs, making standard debuggers and control flow work naturally. The cost is that no optimizer sees the full computation before execution begins.
|
||
|
||
**TensorFlow/Keras** inverts this priority through the `Sequential` API, which declares structure without executing it, solving the Abstraction Problem first: the same declaration compiles to server GPUs, mobile NPUs, or browser WebGL backends. Eager mode (default in TensorFlow 2.x) recovers some of PyTorch's debugging flexibility, but production deployment still relies on graph capture for optimization.
|
||
|
||
**JAX** makes the most radical trade-off by treating the model as a pure function[^fn-pure-function] with immutable data[^fn-immutable-data] and no internal state[^fn-stateless-function]. This functional purity solves the Differentiation Problem most elegantly: `grad`, `vmap` (automatic vectorization[^fn-vectorization]), and `jit` (just-in-time compilation[^fn-jit-compilation]) are composable transformations on stateless functions, not infrastructure bolted onto an object system. The cost is explicit parameter management and a programming model unfamiliar to most engineers.
|
||
|
||
[^fn-immutable-data]: **Immutable Data Structures**: Cannot be modified after creation. Any operation that appears to change the data actually creates a new copy, ensuring that the original data remains unchanged. This prevents accidental modifications and enables safe parallel processing.
|
||
|
||
[^fn-stateless-function]: **Stateless Function**: Produces the same output for the same inputs every time, without relying on or modifying any external state. This predictability enables mathematical optimization and parallel execution.
|
||
|
||
[^fn-vectorization]: **Automatic Vectorization**: Transforms operations on single data points into operations on entire arrays or batches, improving computational efficiency by using SIMD (Single Instruction, Multiple Data) processor capabilities.
|
||
|
||
[^fn-jit-compilation]: **Just-in-Time (JIT) Compilation**: Translates high-level code into optimized machine code at runtime, enabling performance optimizations based on actual data shapes and hardware characteristics.
|
||
|
||
[^fn-pure-function]: **Pure Function**: Has no side effects and always returns the same output for the same inputs. Pure functions enable mathematical reasoning about code behavior and safe program transformations.
|
||
|
||
[^fn-xla]: **XLA (Accelerated Linear Algebra)**: Google's domain-specific compiler released in March 2017, optimizing tensor operations across CPUs, GPUs, and TPUs. The name emphasizes that linear algebra operations (matrix multiplies, convolutions) dominate ML computation. Achieves 3–10× speedups through operation fusion and hardware-specific codegen. Now part of OpenXLA (2022), a cross-industry effort including Google, Meta, NVIDIA, and Apple.
|
||
|
||
[^fn-onnx]: **ONNX (Open Neural Network Exchange)** [@bai2019onnx]: Launched September 2017 by Facebook and Microsoft to solve framework fragmentation. Originally named "Toffee" internally at Facebook, the name change emphasized its role as an exchange format. Became a Linux Foundation project in 2019. Enables training in PyTorch and deploying via TensorFlow Lite or TensorRT without manual conversion.
|
||
|
||
No framework optimizes all three problems simultaneously; each makes deliberate trade-offs that shape everything from API design to performance characteristics. PyTorch prioritizes the **Execution Problem** (eager debugging, dynamic graphs) at the cost of optimization potential. TensorFlow prioritizes the **Abstraction Problem** (unified deployment from cloud to microcontroller) at the cost of development flexibility. JAX reframes the **Differentiation Problem** (composable function transformations) at the cost of a steeper learning curve. These are the same design tensions examined in the subsections above, now visible even in a ten-line program. Exploratory research favors PyTorch's debugging immediacy, production deployment favors TensorFlow's optimization depth, and algorithmic research favors JAX's composable transformations. Each philosophy shapes not just code syntax but team workflows, debugging practices, and deployment pipelines, which is why framework migration costs are measured in engineer-months rather than engineer-days.
|
||
|
||
### Quantitative Framework Efficiency Comparison {#sec-ml-frameworks-quantitative-framework-efficiency-comparison-3b77}
|
||
|
||
How large are these differences in practice? @tbl-framework-efficiency-matrix compares major frameworks across efficiency dimensions using benchmark workloads representative of production deployment scenarios.
|
||
|
||
| **Framework** | **Inference** **Latency (ms)** | **Memory** **Usage (MB)** | **Energy** **(mJ/inference)** | **Model Size** **Reduction** | **Hardware** **Utilization (%)** |
|
||
|:--------------------------|-------------------------------:|--------------------------:|------------------------------:|-----------------------------:|---------------------------------:|
|
||
| **TensorFlow** | 45 | 2,100 | 850 | None | 35 |
|
||
| **TensorFlow Lite** | 12 | 180 | 120 | 4× (quantized) | 65 |
|
||
| **TensorFlow Lite Micro** | 8 | 32 | 45 | 8× (pruned+quant) | 75 |
|
||
| **PyTorch** | 52 | 1,800 | 920 | None | 32 |
|
||
| **PyTorch Mobile** | 18 | 220 | 180 | 3× (quantized) | 58 |
|
||
| **ONNX Runtime** | 15 | 340 | 210 | 2× (optimized) | 72 |
|
||
| **TensorRT** | 3 | 450 | 65 | 2× (precision opt) | 88 |
|
||
| **Apache TVM** | 6 | 280 | 95 | 3× (compiled) | 82 |
|
||
|
||
: **Framework Efficiency Comparison.** Quantitative comparison of major ML frameworks across efficiency dimensions using ResNet-50 inference on representative hardware (NVIDIA A100 GPU for server frameworks, ARM Cortex-A78 for mobile). Metrics reflect production workloads with accuracy maintained within 1% of baseline. Hardware utilization represents percentage of theoretical peak performance on typical operations. {#tbl-framework-efficiency-matrix}
|
||
|
||
\index{TensorRT!hardware utilization}
|
||
\index{Apache TVM!ML compiler}
|
||
The efficiency data reveals several important patterns. First, specialized inference frameworks (TensorRT, Apache TVM) achieve 10–15× lower latency than general-purpose training frameworks (PyTorch, TensorFlow) on identical hardware, demonstrating that framework selection has quantitative performance implications beyond qualitative design preferences. Second, mobile-optimized variants (TF Lite, PyTorch Mobile) reduce memory requirements by 10× compared to their full counterparts while maintaining accuracy within 1% through quantization and graph optimization. Third, hardware utilization varies dramatically: TensorRT achieves 88% GPU utilization through aggressive kernel fusion while vanilla PyTorch achieves only 32%, a 2.75× efficiency gap that directly translates to cost differences in production deployment.
|
||
|
||
These efficiency gaps, significant in the data center, become existential as we move beyond the server room. A 17× latency difference between PyTorch and TensorRT is an optimization opportunity on a cloud GPU; on a microcontroller with 256 KB of RAM, a framework that requires 1.8 GB of memory simply cannot run at all. The question shifts from "which framework is fastest?" to "which framework fits?"
|
||
|
||
## Deployment Targets {#sec-ml-frameworks-deployment-targets-13f1}
|
||
|
||
\index{Deployment Targets!cloud to edge spectrum}
|
||
\index{Edge Deployment!framework selection}
|
||
\index{Edge Deployment!three problems reweighting}
|
||
As ML models move from cloud servers to edge devices, the efficiency gaps measured above transform from optimization opportunities into hard deployment constraints. The three core problems reweight dramatically at the edge. The *execution problem* shifts from "eager vs. graph" to "can we execute at all within 10ms and 50KB?" The *differentiation problem* often disappears entirely, since edge devices run inference only. The *abstraction problem* intensifies: targeting ARM vs. x86, mobile NPUs vs. edge TPUs, microcontrollers with kilobytes of memory.
|
||
|
||
@tbl-deployment-frameworks summarizes framework choices by deployment target:
|
||
|
||
| **Environment** | **Primary Frameworks** | **Key Optimizations** | **Typical Constraints** |
|
||
|:--------------------|:---------------------------------|:--------------------------------------|:----------------------------------|
|
||
| **Cloud/Server** | PyTorch, TensorFlow, JAX | Distributed training, mixed precision | Throughput, cost |
|
||
| **Edge** | TensorFlow Lite, ONNX Runtime | Quantization (INT8), static graphs | Latency <10ms, limited memory |
|
||
| **Mobile** | TF Lite, Core ML, PyTorch Mobile | NPU acceleration, model compression | Battery, thermal, app size limits |
|
||
| **Microcontroller** | TF Lite Micro, | 4-bit quantization, | <256KB RAM, |
|
||
| **(TinyML)** | uTensor | static allocation | no dynamic memory |
|
||
|
||
: **Framework Selection by Deployment Target.** Recommended frameworks, optimization techniques, and key constraints for each deployment tier, from cloud servers to microcontrollers. {#tbl-deployment-frameworks}
|
||
|
||
@tbl-deployment-frameworks reveals a fragmented landscape: different deployment targets favor different frameworks. The Smart Doorbell KWS model from @sec-ml-frameworks-tinyml-micro-runtimes-2a1b exemplifies the Microcontroller tier, where TF Lite Micro's extreme AOT compilation is the only viable path. This fragmentation creates a practical problem when organizations train in one framework but deploy on a target best served by another.
|
||
|
||
\index{ONNX!definition}
|
||
\index{ONNX!hub-and-spoke interoperability}
|
||
The Open Neural Network Exchange (ONNX)\index{ONNX!cross-framework portability}[^fn-onnx] format addresses this fragmentation by enabling model portability across frameworks: train in PyTorch, deploy via TensorFlow Lite or ONNX Runtime. This standardization eliminates manual conversion when moving between development and production environments. @fig-onnx captures this hub-and-spoke interoperability model---notice how ONNX sits at the center, accepting models from any training framework on the left and dispatching them to specialized runtimes on the right. Detailed deployment optimization (quantization, pruning, hardware-specific compilation) appears in @sec-model-compression and @sec-model-serving.
|
||
|
||
{#fig-onnx fig-pos="htb" width="70%" fig-alt="Hub diagram with ONNX logo at center. Left side: PyTorch, TensorFlow, Keras with arrows pointing inward. Right side: TF Lite, ONNX Runtime with arrows outward."}
|
||
|
||
ONNX reduces the cost of framework fragmentation, but it does not eliminate the initial selection decision. With the deployment landscape mapped and interoperability options understood, we can now address the practical question: given a specific project's requirements, how should an engineer select a framework?
|
||
|
||
## Framework Selection {#sec-ml-frameworks-selecting-framework-2949}
|
||
|
||
\index{Framework Selection!trade-off analysis}
|
||
\index{Framework Selection!constrained optimization}
|
||
Framework selection is a constrained optimization problem across technical capabilities, operational requirements, and organizational factors. No single framework dominates across all criteria, which means the goal is not to find the "best" framework but to find the one whose trade-offs align with your project's constraints.
|
||
|
||
### The Framework Selection Trade-off Space {#sec-ml-frameworks-framework-selection-tradeoff-space-c76e}
|
||
|
||
Framework selection involves three interconnected tensions. The first is between development velocity and production performance: eager execution (PyTorch) prioritizes iteration speed, while graph compilation (TensorFlow/XLA, JAX/JIT) prioritizes runtime optimization. Research teams that need to test ten architecture variants per day cannot afford minutes of compilation between experiments; production teams that deploy a single model for months cannot afford the throughput penalty of eager dispatch. The optimal point shifts as a project moves through its lifecycle.
|
||
|
||
This velocity-performance tension leads directly to the second: flexibility versus optimization depth. Dynamic graphs enable the arbitrary control flow that makes eager development fast, but they limit the compiler's scope. Static graphs constrain expressiveness but enable aggressive fusion and hardware-specific code generation. As @tbl-mlfm-graphs demonstrated, this trade-off cascades through memory management, utilization, and debugging workflows---it is not a single design decision but a system-wide constraint.
|
||
|
||
The flexibility-optimization tension, in turn, exposes a third: ecosystem breadth versus specialization. General-purpose frameworks cover broad operation sets but underperform specialized runtimes. TensorRT achieves 88% GPU utilization versus PyTorch's 32% (@tbl-framework-efficiency-matrix) precisely because it optimizes for a narrower problem. ONNX bridges this gap through standardized interchange, but the underlying trade-off remains: the more a runtime specializes, the faster it runs and the less it supports.
|
||
|
||
::: {.callout-perspective title="Framework Selection Constraints"}
|
||
Rather than seeking the "best" framework, effective selection identifies the framework that satisfies hard constraints (deployment target, required operations, team expertise) while optimizing soft preferences (performance, development speed, ecosystem). Hard constraints eliminate options; soft preferences rank remaining candidates.
|
||
:::
|
||
|
||
The TensorFlow ecosystem illustrates how these axes interact concretely. Its three variants (TensorFlow\index{TensorFlow!full framework}, TensorFlow Lite\index{TensorFlow Lite!mobile deployment}, TensorFlow Lite Micro\index{TensorFlow Lite Micro!microcontroller deployment}) trace a single design philosophy across progressively tighter constraints, a pattern that generalizes to any framework family. @tbl-tf-comparison quantifies the trade-offs.
|
||
|
||
| | **TensorFlow** | **TensorFlow Lite** | **TensorFlow Lite for Microcontrollers** |
|
||
|:--------------------------------|:--------------------------------|:----------------------|:-----------------------------------------|
|
||
| **Training** | Yes | No | No |
|
||
| **Inference** | Yes (*but inefficient on edge*) | Yes (*and efficient*) | Yes (*and even more efficient*) |
|
||
| **How Many Ops** | ~1400 | ~130 | ~50 |
|
||
| **Native Quantization Tooling** | No | Yes | Yes |
|
||
|
||
: **TensorFlow Variant Software Comparison.** Design trade-offs across TensorFlow, TensorFlow Lite, and TensorFlow Lite Micro, balancing model expressiveness, binary size, and resource constraints. Supported operations decrease from approximately 1,400 in full TensorFlow to approximately 50 in TensorFlow Lite Micro, reflecting a shift from training capability to efficient edge inference. Native quantization tooling enables further optimization for constrained environments. {#tbl-tf-comparison}
|
||
|
||
The principle is progressive constraint leading to progressive optimization: fewer supported operations enable smaller binaries, tighter memory budgets, and native quantization. Three dimensions structure this analysis: model requirements (what operations must the framework support?), software dependencies (what runtime environment is available?), and hardware constraints (what are the physical limits?).
|
||
|
||
### Framework Selection Criteria {#sec-ml-frameworks-model-requirements-2e01}
|
||
|
||
Three dimensions structure systematic framework evaluation: what the model requires (supported operations and graph semantics), what the software environment provides (OS, memory management, accelerator delegation), and what the hardware physically permits (compute, memory, power). Each dimension acts as a filter—hard constraints eliminate candidates, and soft preferences rank the survivors.
|
||
|
||
#### Model Requirements {#sec-ml-frameworks-model-requirements-ca8b}
|
||
|
||
The first question is whether a framework can express the models your project requires. As @tbl-tf-comparison shows, operator count drops from approximately $10^3$ (full TensorFlow) to $10^2$ (TensorFlow Lite) to $10^1$ (TensorFlow Lite Micro). Each reduction eliminates training capability and general-purpose operations while adding native quantization tooling. The engineering principle is that expressiveness and efficiency trade against each other: fewer supported operations enable tighter code generation, smaller binaries, and hardware-specific optimization paths. This progressive constraint model applies to any framework family, not just TensorFlow. The choice between *dynamic and static computational graphs* further shapes which optimizations each constraint level permits.
|
||
|
||
::: {.callout-perspective title="Dynamic vs Static Computational Graphs"}
|
||
The static-versus-dynamic graph distinction (examined in @sec-ml-frameworks-execution-problem-e1e1) has direct implications for model requirements analysis. Static graphs constrain which operations are expressible but enable ahead-of-time optimization for deployment. Dynamic graphs support arbitrary Python control flow but require explicit compilation steps (e.g., `torch.compile`, `tf.function`) to recover optimization potential.
|
||
:::
|
||
|
||
#### Software Dependencies {#sec-ml-frameworks-software-dependencies-5245}
|
||
|
||
Once model requirements are satisfied, the framework must integrate with the target software environment. @tbl-tf-sw-comparison reveals how operating system requirements, memory management, and accelerator support vary across TensorFlow variants.
|
||
|
||
| | **TensorFlow** | **TensorFlow Lite** | **TensorFlow Lite for Microcontrollers** |
|
||
|:-------------------------------|:---------------|:--------------------|:-----------------------------------------|
|
||
| **Needs an OS** | Yes | Yes | No |
|
||
| **Memory Mapping of Models** | No | Yes | Yes |
|
||
| **Delegation to accelerators** | Yes | Yes | No |
|
||
|
||
: **TensorFlow Variant Capability Comparison.** Capabilities of TensorFlow, TensorFlow Lite, and TensorFlow Lite Micro regarding operating system dependence, memory management, and hardware acceleration. Progressive constraint across variants enables selection by deployment context, from full-scale servers to resource-constrained edge devices. {#tbl-tf-sw-comparison}
|
||
|
||
The key distinctions follow the same progressive constraint pattern. TensorFlow Lite Micro eliminates the OS requirement entirely, enabling bare-metal execution on microcontrollers (though it integrates with RTOSes like FreeRTOS and Zephyr when available). Both Lite variants support memory-mapped model access from flash storage, avoiding the RAM overhead of loading full models. Accelerator delegation drops out at the microcontroller tier, where specialized hardware is rarely available. Each software dependency removed is a deployment target gained.
|
||
|
||
#### Hardware Constraints {#sec-ml-frameworks-hardware-constraints-e449}
|
||
|
||
Software compatibility alone does not guarantee deployment; the framework must fit within physical hardware limits. @tbl-tf-hw-comparison quantifies this final constraint dimension.
|
||
|
||
| | **TensorFlow** | **TensorFlow Lite** | **TensorFlow Lite for Microcontrollers** |
|
||
|:----------------------------|:------------------------------------------------------|:-----------------------|:-----------------------------------------|
|
||
| **Base Binary Size** | A few MB (varies by platform and build configuration) | Tens to hundreds of KB | On the order of 10 KB |
|
||
| **Base Memory Footprint** | Several MB (minimum runtime overhead) | Hundreds of KB | Tens of KB |
|
||
| **Optimized Architectures** | X86, TPUs, GPUs | Arm Cortex A, x86 | Arm Cortex M, DSPs, MCUs |
|
||
|
||
: **TensorFlow Hardware Optimization.** Resource requirements (binary size and memory footprint) decrease across TensorFlow variants as they target increasingly constrained hardware, from servers to microcontrollers. Optimized architectures shift from general-purpose CPUs and GPUs to ARM Cortex-M processors and digital signal processors for resource-limited environments. {#tbl-tf-hw-comparison}
|
||
|
||
Binary size spans three orders of magnitude: from MB (full TensorFlow) to tens of KB (TensorFlow Lite Micro). Memory footprint follows the same pattern. Processor architecture support shifts correspondingly from x86/GPU/TPU (data center) through Arm Cortex-A (mobile/edge) to Arm Cortex-M, DSPs, and MCUs (embedded). These are not arbitrary engineering tiers---they mirror the physical constraints (Light Barrier, Power Wall, Memory Wall) that carve the deployment spectrum into distinct paradigms (@sec-ml-systems-deployment-spectrum-71be). The engineering lesson generalizes beyond TensorFlow: every framework family that spans deployment tiers makes analogous trade-offs between capability and resource footprint, and the framework's job is to make those trade-offs navigable rather than invisible.
|
||
|
||
#### Production-Ready Evaluation Factors {#sec-ml-frameworks-productionready-evaluation-factors-e088}
|
||
|
||
The engineering principle underlying production evaluation is that expressiveness and efficiency trade against each other: fewer supported operations enable tighter code generation, smaller binaries, and hardware-specific optimization paths. Technical specifications establish necessary but not sufficient conditions for selection. Production deployments also require evaluating operational factors: migration cost (typically 3--6 engineer-months for production systems), maintenance burden, and deployment success rates.
|
||
|
||
These hardware constraints cascade into performance trade-offs that are tightly coupled. Inference latency (tens of milliseconds for mobile image classification, sub-millisecond for industrial control), memory footprint (MB for full TensorFlow down to tens of KB for TF Lite Micro), power consumption (INT8 inference consuming several-fold less energy than FP32), and hardware utilization (operator fusion improving FLOPS utilization from 10--20% to 60--80% of peak) are not independent dimensions. Quantization simultaneously reduces memory, latency, and energy at the cost of precision, and framework selection determines which of these optimization levers are available in the first place. Scalability introduces a further concern: consistent deployment from microcontrollers to servers, smooth prototype-to-production transitions, and version management across deployed fleets all depend on the framework's deployment toolchain. The three-dimension methodology illustrated here---model requirements, software dependencies, hardware constraints---applies to any framework ecosystem, not just TensorFlow.
|
||
|
||
### Development Support and Long-term Viability Assessment {#sec-ml-frameworks-development-support-longterm-viability-assessment-d1d7}
|
||
|
||
What determines whether a framework remains viable five years into a production deployment? Technical capabilities are necessary but not sufficient. Community composition shapes framework evolution in measurable ways: PyTorch's academic community drives research-oriented features and reproducibility tools, though production tooling (PyTorch Lightning, TorchServe) has historically lagged; TensorFlow's enterprise community emphasizes production reliability through TFX pipelines, TensorBoard visualization, and TensorFlow Model Analysis; JAX's smaller community concentrates on mathematical rigor, producing powerful research tools but with a steeper onboarding curve.
|
||
|
||
A framework's practical utility, however, often depends more on its surrounding ecosystem than on its core capabilities. Hugging Face provides consistent model APIs across all three major frameworks, making pretrained model availability a near-commodity. Cross-framework tools (Weights & Biases, MLflow for experiment tracking; ONNX Runtime for serving) reduce lock-in, while framework-native tools (XLA, TorchScript, TensorFlow Serving) offer deeper optimization at the cost of portability. Cloud ML services (SageMaker, Google AI Platform, Azure ML) provide native integration for specific frameworks, creating operational advantages that compound over time.
|
||
|
||
These compounding effects make framework migration progressively harder. Integration with existing CI/CD pipelines, monitoring infrastructure, and cloud providers creates operational inertia that resists change. The measurable indicators of viability---contributor diversity (single-company dependence is a risk), backward compatibility track record, and hiring pool alignment with organizational needs---should be evaluated before commitment, not after. The mitigation strategy is defensive: use standardized formats (ONNX), maintain framework-agnostic data pipelines, and document framework-specific customizations to preserve future flexibility.
|
||
|
||
We have now examined the three core problems individually, compared how major frameworks resolve them, and established criteria for choosing among alternatives. What remains is to see all three problems interact inside a single execution---to watch the machinery we have studied operate as one integrated system.
|
||
|
||
## Anatomy of a Training Step {#sec-ml-frameworks-putting-together-anatomy-training-step-c7f1}
|
||
|
||
\index{Training Step!anatomy}
|
||
\index{Forward Pass!kernel dispatch}
|
||
\index{Backward Pass!gradient computation}
|
||
The concepts developed throughout this chapter---eager versus graph execution, reverse-mode autodiff, tensor abstractions, kernel dispatch---remain abstract until we see them interact inside a real execution. To solidify understanding, we trace a single training step through the PyTorch stack, revealing how eight lines of Python trigger the execution, differentiation, and abstraction machinery simultaneously.
|
||
|
||
@lst-training-step-anatomy presents a minimal training iteration for a two-layer multilayer perceptron. Though only eight lines of Python, this code exercises the entire framework stack: tensor allocation, kernel dispatch, autograd recording, gradient computation, and parameter updates. Tracing each phase reveals the three problems in action and connects the quantitative principles developed earlier to concrete execution.
|
||
|
||
::: {#lst-training-step-anatomy lst-cap="**Training Step Anatomy**: A minimal training iteration for a two-layer MLP, exercising tensor allocation, kernel dispatch, autograd recording, gradient computation, and parameter updates."}
|
||
```{.python}
|
||
# Single training step for a 2-layer MLP
|
||
x = torch.randn(32, 784, device="cuda") # Input batch
|
||
y = torch.randint(0, 10, (32,), device="cuda") # Labels
|
||
|
||
# Forward pass
|
||
h = torch.relu(x @ W1 + b1) # Hidden layer
|
||
logits = h @ W2 + b2 # Output layer
|
||
loss = F.cross_entropy(logits, y)
|
||
|
||
# Backward pass
|
||
loss.backward()
|
||
|
||
# Parameter update
|
||
optimizer.step()
|
||
```
|
||
:::
|
||
|
||
```{python}
|
||
#| echo: false
|
||
#| label: training-step-dims
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ TRAINING STEP DIMENSIONS
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: Phase 1 prose references matrix dimensions inline
|
||
# │
|
||
# │ Goal: Define model dimensions for the training step example.
|
||
# │ Show: The matrix sizes for a 2-layer MLP (784 -> 1024 -> 10).
|
||
# │ How: Set constants for input batch, hidden layers, and output classes.
|
||
# │
|
||
# │ Imports: (none)
|
||
# │ Exports: train_batch, train_input, train_hidden, train_output
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
train_batch = 32 # batch size
|
||
train_input = 784 # MNIST input (28 × 28)
|
||
train_hidden = 256 # hidden layer
|
||
train_output = 10 # 10 classes
|
||
```
|
||
|
||
### Phase 1: Forward Pass (Solving the Execution Problem) {.unnumbered}
|
||
|
||
When `h = torch.relu(x @ W1 + b1)` executes, PyTorch's eager execution triggers immediate computation:
|
||
|
||
1. **Python Dispatch**\index{Python Dispatch!overhead} (~1μs): Python interpreter calls `torch.matmul`, which routes through PyTorch's dispatcher to select the CUDA backend.
|
||
|
||
2. **Kernel Selection**\index{BLAS Library!kernel selection} (~0.5μs): cuBLAS selects an optimized GEMM kernel based on matrix dimensions (`{python} train_batch`×`{python} train_input` × `{python} train_input`×`{python} train_hidden`). For these dimensions, it might choose a tiled algorithm optimized for L2 cache.
|
||
|
||
3. **Kernel Launch** (~5μs): The selected kernel is queued to the GPU's command buffer. The CPU continues immediately (asynchronous execution).
|
||
|
||
4. **GPU Execution** (~15μs):
|
||
- Load W1 from HBM[^fn-hbm-frameworks] to L2 cache (~200GB/s effective bandwidth)
|
||
- Perform matrix multiply in tensor cores (if available)
|
||
- Write result to HBM
|
||
|
||
[^fn-hbm-frameworks]: HBM (introduced in @sec-network-architectures) provides 2--3 TB/s bandwidth on modern GPUs. For framework execution, HBM bandwidth determines whether operations are memory-bound or compute-bound. The `{python} a100_mem_str` GB capacity on an A100 sets practical limits on model size, as weights, activations, and gradients must all fit in HBM during execution.
|
||
|
||
5. **Autograd Recording**: Simultaneously, PyTorch's autograd engine records a `MmBackward` node on the tape, storing references to `x` and `W1` for gradient computation.
|
||
|
||
The bias addition and ReLU follow similar patterns, each adding a node to the autograd tape.
|
||
|
||
### Phase 2: Backward Pass (Solving the Differentiation Problem) {.unnumbered}
|
||
|
||
When `loss.backward()`\index{Backward Pass!triggering gradient computation} executes:
|
||
|
||
1. **Tape Traversal**\index{Autograd Tape!reverse traversal}: The autograd engine traverses the recorded graph in reverse topological order.
|
||
|
||
2. **Gradient Computation**: For each node, it calls the registered backward function:
|
||
- `CrossEntropyBackward`: Computes $\frac{\partial L}{\partial \text{logits}}$ using softmax derivative
|
||
- `MmBackward` (W2): Computes $\frac{\partial L}{\partial W_2} = h^T \cdot \frac{\partial L}{\partial \text{logits}}$ and $\frac{\partial L}{\partial h}$
|
||
- `ReluBackward`: Applies ReLU derivative mask (zero where h ≤ 0)
|
||
- `MmBackward` (W1): Computes $\frac{\partial L}{\partial W_1}$ and $\frac{\partial L}{\partial x}$
|
||
|
||
3. **Gradient Accumulation**: Gradients are accumulated into `.grad` attributes of leaf tensors.
|
||
|
||
4. **Memory Management**: After each backward node completes, its saved tensors are freed, allowing memory reuse.
|
||
|
||
### Phase 3: Memory Traffic Analysis (The Physics at Work) {.unnumbered}
|
||
|
||
Applying the Dispatch Overhead Equation (@eq-dispatch-overhead) to this step, @tbl-training-step-roofline breaks down the FLOPs, memory traffic, and arithmetic intensity for each operation:
|
||
|
||
```{python}
|
||
#| label: training-step-calc
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ TRAINING STEP FLOPS AND MEMORY ANALYSIS
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: Training Step Anatomy table showing FLOPs and memory per operation
|
||
# │
|
||
# │ Goal: Quantify the arithmetic intensity of each operation in a forward pass.
|
||
# │ MatMul has high AI (~15), ReLU has near-zero AI (0.125)—demonstrating
|
||
# │ why element-wise ops are memory-bound and benefit most from fusion.
|
||
# │
|
||
# │ Imports: mlsys.constants (byte, KB, MB, flop, KFLOPs, MFLOPs), mlsys.formatting (fmt)
|
||
# │ Exports: train_mm1_flops_str, train_mm1_mem_str, train_mm1_ai_str,
|
||
# │ train_relu_flops_str, train_relu_mem_str, train_relu_ai_str,
|
||
# │ train_mm2_flops_str, train_mm2_mem_str, train_mm2_ai_str, train_ce_flops_str,
|
||
# │ train_ce_mem_str, train_ce_ai_str, train_bwd_flops_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsys.formatting import fmt, check
|
||
from mlsys.constants import byte, KB, MB, flop, KFLOPs, MFLOPs
|
||
|
||
# --- Inputs (matching training-step-dims) ---
|
||
train_batch = 32 # batch size
|
||
train_input = 784 # MNIST input (28 × 28)
|
||
train_hidden = 256 # hidden layer
|
||
train_output = 10 # 10 classes
|
||
|
||
# --- Process: MatMul 1 [32, 784] @ [784, 256] ---
|
||
train_mm1_flops = 2 * train_batch * train_input * train_hidden # ~12.8M FLOPs
|
||
train_mm1_mem_bytes = (train_batch * train_input + train_input * train_hidden + train_batch * train_hidden) * 4 # FP32
|
||
|
||
# --- Process: ReLU [32, 256] ---
|
||
train_relu_flops = train_batch * train_hidden # ~8K FLOPs
|
||
train_relu_mem_bytes = (train_batch * train_hidden) * 2 * 4 # read + write
|
||
|
||
# --- Process: MatMul 2 [32, 256] @ [256, 10] ---
|
||
train_mm2_flops = 2 * train_batch * train_hidden * train_output # ~164K FLOPs
|
||
train_mm2_mem_bytes = (train_batch * train_hidden + train_hidden * train_output + train_batch * train_output) * 4
|
||
|
||
# --- Process: Cross Entropy [32, 10] ---
|
||
train_ce_flops = train_batch * train_output * 3 # ~1K FLOPs (exp + sum + div)
|
||
train_ce_mem_bytes = (train_batch * train_output) * 2 * 4
|
||
|
||
# --- Process: Backward pass (2x forward FLOPs) ---
|
||
train_bwd_flops = 2 * (train_mm1_flops + train_relu_flops + train_mm2_flops + train_ce_flops)
|
||
train_bwd_mem_str = "~3.2 MB" # approximation
|
||
train_bwd_ai_str = "~8.0" # approximation
|
||
|
||
# --- Outputs (formatted strings for table) ---
|
||
train_mm1_flops_str = fmt((train_mm1_flops * flop).to(MFLOPs).magnitude, precision=1, commas=False) + "M"
|
||
train_mm1_mem_str = fmt((train_mm1_mem_bytes * byte).to(MB).magnitude, precision=1, commas=False) + " MB"
|
||
train_mm1_ai_str = fmt(train_mm1_flops / train_mm1_mem_bytes, precision=1, commas=False)
|
||
|
||
train_relu_flops_str = fmt((train_relu_flops * flop).to(KFLOPs).magnitude, precision=0, commas=False) + "K"
|
||
train_relu_mem_str = fmt((train_relu_mem_bytes * byte).to(KB).magnitude, precision=0, commas=False) + " KB"
|
||
train_relu_ai_str = fmt(train_relu_flops / train_relu_mem_bytes, precision=3, commas=False)
|
||
|
||
train_mm2_flops_str = fmt((train_mm2_flops * flop).to(KFLOPs).magnitude, precision=0, commas=False) + "K"
|
||
train_mm2_mem_str = fmt((train_mm2_mem_bytes * byte).to(KB).magnitude, precision=0, commas=False) + " KB"
|
||
train_mm2_ai_str = fmt(train_mm2_flops / train_mm2_mem_bytes, precision=1, commas=False)
|
||
|
||
train_ce_flops_str = fmt((train_ce_flops * flop).to(KFLOPs).magnitude, precision=0, commas=False) + "K"
|
||
train_ce_mem_str = fmt((train_ce_mem_bytes * byte).to(KB).magnitude, precision=0, commas=False) + " KB"
|
||
train_ce_ai_str = fmt(train_ce_flops / train_ce_mem_bytes, precision=1, commas=False)
|
||
|
||
train_bwd_flops_str = fmt((train_bwd_flops * flop).to(MFLOPs).magnitude, precision=0, commas=False) + "M"
|
||
```
|
||
|
||
| **Component** | **FLOPs** | **Memory Traffic** | **Arithmetic Intensity** |
|
||
|:--------------------------|----------------------------------------------:|------------------------------:|-----------------------------:|
|
||
| **MatMul (x @ W1)** | 2×32×784×256 = `{python} train_mm1_flops_str` | `{python} train_mm1_mem_str` | `{python} train_mm1_ai_str` |
|
||
| **ReLU** | 32×256 = `{python} train_relu_flops_str` | `{python} train_relu_mem_str` | `{python} train_relu_ai_str` |
|
||
| **MatMul (h @ W2)** | 2×32×256×10 = `{python} train_mm2_flops_str` | `{python} train_mm2_mem_str` | `{python} train_mm2_ai_str` |
|
||
| **Cross-entropy** | ~`{python} train_ce_flops_str` | `{python} train_ce_mem_str` | `{python} train_ce_ai_str` |
|
||
| **Backward (2× forward)** | ~`{python} train_bwd_flops_str` | `{python} train_bwd_mem_str` | `{python} train_bwd_ai_str` |
|
||
|
||
: **Per-Operation Roofline Analysis.** FLOPs, memory traffic, and arithmetic intensity for each operation in a two-layer MLP training step. MatMul operations achieve arithmetic intensity above 1.0 (compute-bound on most hardware), while ReLU and cross-entropy are far below 1.0 (memory-bound), quantifying why kernel fusion targets element-wise operations. {#tbl-training-step-roofline}
|
||
|
||
```{python}
|
||
#| label: mnist-training-step-calc
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ MNIST TRAINING STEP ROOFLINE ANALYSIS
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: Showing that small model training is overhead-bound, not compute-bound
|
||
# │
|
||
# │ Goal: On an A100, 40M FLOPs takes 0.1 μs and 5 MB memory takes 2.5 μs, but
|
||
# │ 6 ops × 5 μs dispatch = 30 μs overhead. The step is overhead-bound!
|
||
# │ This explains why torch.compile provides 2-3x speedup for small models.
|
||
# │
|
||
# │ Imports: mlsys.constants (A100_FLOPS_FP16_TENSOR, A100_MEM_BW), mlsys.formatting (fmt)
|
||
# │ Exports: mnist_total_flops_str, mnist_mem_traffic_str, a100_peak_tflops_str,
|
||
# │ a100_mem_bw_tbs_str, mnist_t_compute_us_str, mnist_t_memory_us_str,
|
||
# │ mnist_n_ops_str, mnist_us_per_op_str, mnist_t_overhead_us_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsys.constants import A100_FLOPS_FP16_TENSOR, A100_MEM_BW, TRILLION, flop, byte, second
|
||
from mlsys.formatting import fmt, check
|
||
|
||
# --- Inputs: workload characteristics ---
|
||
mnist_total_flops_value = 40e6 # ~40M FLOPs total
|
||
mnist_mem_traffic_bytes_value = 5e6 # ~5 MB traffic
|
||
mnist_n_ops_value = 6 # 6 kernel launches
|
||
mnist_us_per_op_value = 5 # 5 μs dispatch/op
|
||
|
||
# --- Inputs: A100 hardware specs ---
|
||
a100_peak_tflops_value = A100_FLOPS_FP16_TENSOR.to(flop/second).magnitude # 312e12 FLOPS
|
||
a100_mem_bw_bytes_value = A100_MEM_BW.to(byte/second).magnitude # ~2e12 B/s
|
||
|
||
# --- Process: timing breakdown ---
|
||
mnist_t_compute_us_value = mnist_total_flops_value / a100_peak_tflops_value * MILLION # ~0.1 μs
|
||
mnist_t_memory_us_value = mnist_mem_traffic_bytes_value / a100_mem_bw_bytes_value * MILLION # ~2.5 μs
|
||
mnist_t_overhead_us_value = mnist_n_ops_value * mnist_us_per_op_value # 30 μs (dominant!)
|
||
|
||
# --- Outputs (formatted strings for prose) ---
|
||
mnist_t_compute_us_str = fmt(mnist_t_compute_us_value, precision=1, commas=False)
|
||
mnist_t_memory_us_str = fmt(mnist_t_memory_us_value, precision=1, commas=False)
|
||
mnist_t_overhead_us_str = fmt(mnist_t_overhead_us_value, precision=0, commas=False)
|
||
mnist_total_flops_str = f"{mnist_total_flops_value/MILLION:.0f}M"
|
||
mnist_mem_traffic_str = f"{mnist_mem_traffic_bytes_value/MILLION:.0f}MB"
|
||
mnist_n_ops_str = fmt(mnist_n_ops_value, precision=0, commas=False)
|
||
mnist_us_per_op_str = fmt(mnist_us_per_op_value, precision=0, commas=False)
|
||
a100_peak_tflops_str = fmt(a100_peak_tflops_value/TRILLION, precision=0, commas=False)
|
||
a100_mem_bw_tbs_str = fmt(a100_mem_bw_bytes_value/TRILLION, precision=0, commas=False)
|
||
```
|
||
|
||
Total: ~`{python} mnist_total_flops_str` FLOPs, ~`{python} mnist_mem_traffic_str` memory traffic. On an A100:
|
||
|
||
- Tcompute ≈ `{python} mnist_total_flops_str` / `{python} a100_peak_tflops_str` TFLOPS ≈ `{python} mnist_t_compute_us_str`µs
|
||
- Tmemory ≈ `{python} mnist_mem_traffic_str` / `{python} a100_mem_bw_tbs_str` TB/s ≈ `{python} mnist_t_memory_us_str`µs
|
||
- Toverhead ≈ `{python} mnist_n_ops_str` ops × `{python} mnist_us_per_op_str` μs ≈ `{python} mnist_t_overhead_us_str` μs
|
||
|
||
The training step is overhead-bound.\index{Overhead-bound!small model training} For small models, Python dispatch and kernel launch dominate. This explains why:
|
||
|
||
- `torch.compile` provides 2--3× speedup by fusing operations and reducing kernel launches
|
||
- Batch size increases help amortize per-batch overhead
|
||
- Production training uses much larger models where compute dominates
|
||
|
||
### Phase 4: Hardware Abstraction (Solving the Abstraction Problem) {.unnumbered}
|
||
|
||
The same Python code runs on different hardware through abstraction layers:
|
||
|
||
- **CUDA GPU**\index{GPU Backend!CUDA implementation}: cuBLAS GEMM kernels, CUDA streams for async execution
|
||
- **CPU**\index{CPU Backend!optimized BLAS libraries}: Intel MKL or OpenBLAS, OpenMP for parallelism
|
||
- **TPU**\index{TPU!XLA compilation}: XLA compilation to TPU-specific HLO operations
|
||
- **Apple Silicon**: Metal Performance Shaders via MPS backend
|
||
|
||
Each backend implements the same tensor operations with hardware-specific optimizations. The framework's abstraction layer (@sec-ml-frameworks-abstraction-problem-37a5) ensures identical numerical results (within floating-point tolerance) across platforms. This is the abstraction problem solved: a single `loss.backward()` call triggers completely different code paths depending on hardware, yet produces mathematically equivalent gradients.
|
||
|
||
::: {.callout-perspective title="The Three Problems in Action"}
|
||
This trace reveals the three problems in concrete terms:
|
||
|
||
- **Execution**: Eager mode enables line-by-line debugging but incurs dispatch overhead
|
||
- **Differentiation**: Autograd tape records operations during forward, replays in reverse during backward
|
||
- **Abstraction**: Same code runs on GPU/CPU/TPU through backend-specific kernel implementations
|
||
|
||
Understanding this flow enables informed optimization: fuse operations to reduce overhead, use appropriate batch sizes, and match model scale to hardware capabilities.
|
||
:::
|
||
|
||
This detailed trace through a single training step demonstrates how deeply the three core problems interact. Even simple code exercises the full framework stack, and seemingly minor decisions---device placement, batch size, compilation mode---cascade through execution, differentiation, and abstraction layers in ways that are difficult to predict without systems-level understanding. The following section catalogs the most common misconceptions that arise when engineers lack this understanding.
|
||
|
||
## Fallacies and Pitfalls {#sec-ml-frameworks-fallacies-pitfalls-61ef}
|
||
|
||
Framework selection involves subtle trade-offs where intuitions from conventional software engineering fail. The memory wall, kernel fusion constraints, and deployment target diversity create pitfalls that waste months of engineering effort and cause production systems to miss latency targets by 10× or more.
|
||
|
||
```{python}
|
||
#| label: framework-gaps-calc
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ FRAMEWORK PERFORMANCE GAPS
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: Fallacies section demonstrating framework non-equivalence
|
||
# │
|
||
# │ Goal: Shows 17x performance gap (52ms vs 3ms) between PyTorch and TensorRT
|
||
# │ on same model, and 7040x memory gap between PyTorch Mobile (220MB) and
|
||
# │ TFLite Micro (32KB). Frameworks are NOT interchangeable.
|
||
# │
|
||
# │ Imports: mlsys.formatting (fmt)
|
||
# │ Exports: pytorch_ms_str, tensorrt_ms_str, perf_gap_str, pytorch_mobile_mb_str,
|
||
# │ tflite_micro_kb_str, memory_ratio_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsys.formatting import fmt, check
|
||
|
||
# --- Inputs: framework performance (from @tbl-framework-efficiency-matrix) ---
|
||
pytorch_ms_value = 52 # 52 ms inference
|
||
tensorrt_ms_value = 3 # 3 ms inference
|
||
|
||
# --- Inputs: deployment memory spectrum ---
|
||
pytorch_mobile_mb_value = 220 # 220 MB runtime
|
||
tflite_micro_kb_value = 32 # 32 KB runtime
|
||
|
||
# --- Process ---
|
||
perf_gap_value = pytorch_ms_value / tensorrt_ms_value # ~17x gap
|
||
memory_ratio_value = pytorch_mobile_mb_value * KIB_TO_BYTES / tflite_micro_kb_value # ~7040x gap
|
||
|
||
# --- Outputs (formatted strings for prose) ---
|
||
pytorch_ms_str = fmt(pytorch_ms_value, precision=0, commas=False) # e.g. "52"
|
||
tensorrt_ms_str = fmt(tensorrt_ms_value, precision=0, commas=False) # e.g. "3"
|
||
pytorch_mobile_mb_str = fmt(pytorch_mobile_mb_value, precision=0, commas=False) # e.g. "220"
|
||
tflite_micro_kb_str = fmt(tflite_micro_kb_value, precision=0, commas=False) # e.g. "32"
|
||
perf_gap_str = fmt(perf_gap_value, precision=0, commas=False) # e.g. "17"
|
||
memory_ratio_str = fmt(memory_ratio_value, precision=0, commas=False) # e.g. "7040"
|
||
```
|
||
|
||
**Fallacy:** *"All frameworks provide equivalent performance for the same model architecture."*
|
||
|
||
Engineers assume that ResNet-50 yields identical performance across frameworks since the mathematics is the same. In production, framework implementation matters enormously. @tbl-framework-efficiency-matrix shows PyTorch achieves `{python} pytorch_ms_str` ms inference at 32% hardware utilization while TensorRT delivers `{python} tensorrt_ms_str` ms at 88% utilization---a **`{python} perf_gap_str`x performance gap** on identical hardware. The difference arises from kernel fusion depth, graph optimization strategies, and memory access patterns that vary dramatically between frameworks. Organizations that assume equivalence miss latency SLAs and require costly last-minute framework migrations.
|
||
|
||
**Pitfall:** *Choosing frameworks based on popularity rather than project requirements.*
|
||
|
||
Engineers assume the most popular framework works for any project. In reality, deployment constraints dominate. @tbl-framework-efficiency-matrix shows PyTorch Mobile requires `{python} pytorch_mobile_mb_str` MB memory while TensorFlow Lite Micro runs in `{python} tflite_micro_kb_str` KB---a **`{python} memory_ratio_str`x difference**. Teams that prototype edge applications with PyTorch face either memory bloat that exceeds device capacity or 2--3 month framework migrations after development completes. Evaluate deployment targets per @sec-ml-frameworks-deployment-targets-13f1 *before* selecting a training framework.
|
||
|
||
**Fallacy:** *"Framework abstractions eliminate the need for systems knowledge."*
|
||
|
||
Engineers assume high-level APIs handle all optimization automatically. The **Roofline Model** (@sec-machine-foundations-roofline-model-2529) proves otherwise: element-wise operations like ReLU achieve arithmetic intensity of 0.125 FLOPs/byte, utilizing **under 0.1%** of an A100's peak compute regardless of framework sophistication. @sec-ml-frameworks-execution-strategy-matters-memory-wall-1ce8 explains why: memory bandwidth, not compute, is the bottleneck for most operations. Engineers who lack this understanding leave 80--90% of hardware capacity unused, directly translating to 5–10× higher inference costs at production scale.
|
||
|
||
**Pitfall:** *Ignoring vendor lock-in from framework-specific formats.*
|
||
|
||
Engineers assume framework migration is straightforward since models are "just math." Converting TensorFlow SavedModel to PyTorch requires rewriting custom operations, validating numerical equivalence across 10,000+ test cases, and retraining when operations lack exact equivalents---typically **3--6 engineer-months** for production systems. ONNX (@sec-ml-frameworks-deployment-targets-13f1) provides portability but supports only 80--85% of operations. Organizations that ignore this during initial framework selection face costly migrations when deployment requirements change or better frameworks emerge.
|
||
|
||
**Pitfall:** *Selecting development frameworks without evaluating production infrastructure.*
|
||
|
||
Engineers assume training framework choice is independent of deployment infrastructure. In practice, framework-infrastructure mismatches impose substantial operational overhead. TensorFlow Serving provides atomic model swaps with zero downtime; PyTorch deployments often require container restarts imposing 30--60 second outages. TensorFlow integrates natively with monitoring tools; PyTorch requires custom instrumentation adding 2--4 weeks of development. Per @sec-ml-frameworks-major-framework-platform-analysis-fe96, evaluate the *complete deployment stack* during framework selection, not just training APIs.
|
||
|
||
```{python}
|
||
#| label: model-7b-memory
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ 7B MODEL MEMORY FOOTPRINT
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: Fallacy about batch size being free optimization
|
||
# │
|
||
# │ Goal: Dispel the myth that "small" models leave infinite room for batching.
|
||
# │ Show: That 14 GB of weights consumes 18% of an A100 before training starts.
|
||
# │ How: Calculate weight memory for 7B parameters in FP16.
|
||
# │
|
||
# │ Imports: mlsys.constants (A100_MEM_CAPACITY, BYTES_FP16), mlsys.formulas (model_memory)
|
||
# │ Exports: model_7b_fp16_gb_str, a100_remaining_7b_gb_str, a100_mem_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsys.constants import A100_MEM_CAPACITY, BYTES_FP16, GB, GiB
|
||
from mlsys.formulas import model_memory
|
||
from mlsys.formatting import fmt, check
|
||
|
||
# --- Inputs (from scenario) ---
|
||
model_7b_params = 7e9 # 7 billion parameters
|
||
a100_mem_value = A100_MEM_CAPACITY.to(GiB).magnitude # 80 GB
|
||
|
||
# --- Process ---
|
||
model_7b_fp16_gb_value = model_memory(model_7b_params, BYTES_FP16, GB) # 14 GB
|
||
|
||
# --- Outputs (formatted strings for prose) ---
|
||
model_7b_fp16_gb_str = fmt(model_7b_fp16_gb_value, precision=0, commas=False) # e.g. "14"
|
||
a100_remaining_7b_gb_str = fmt(a100_mem_value - model_7b_fp16_gb_value, precision=0, commas=False) # e.g. "66"
|
||
a100_mem_str = fmt(a100_mem_value, precision=0, commas=False) # e.g. "80"
|
||
```
|
||
|
||
**Fallacy:** *"Increasing batch size is a free throughput optimization within framework memory limits."*
|
||
|
||
Engineers assume that if memory is available, larger batches always improve throughput. The Dispatch Overhead Equation (@eq-dispatch-overhead) reveals hidden costs. A 7B parameter model in FP16 consumes `{python} Model7B.model_7b_fp16_gb_str` GB, leaving `{python} a100_remaining_7b_gb_str` GB on an A100-`{python} a100_mem_str` GB. Increasing batch size from 8 to 32 quadruples activation memory for transformers due to attention's $O(S^2)$ scaling, potentially triggering recomputation strategies that **reduce throughput by 20--30%** despite the larger batch. Teams that blindly maximize batch size often achieve *lower* throughput than smaller batches that avoid these memory management pathways.
|
||
|
||
**Pitfall:** *Treating compilation overhead as negligible.*
|
||
|
||
```{python}
|
||
#| label: compilation-overhead-calc
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ COMPILATION OVERHEAD BREAK-EVEN ANALYSIS
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: Pitfall about treating compilation overhead as negligible
|
||
# │
|
||
# │ Goal: Demonstrate that compilation costs can outweigh execution benefits.
|
||
# │ Show: That frequent recompilation makes eager mode 40× faster than compiled.
|
||
# │ How: Model total time = (compile_time * recompiles) + (exec_time * samples).
|
||
# │
|
||
# │ Imports: mlsys.formatting (fmt)
|
||
# │ Exports: n_images_str, n_recompilations_str, eager_throughput_str,
|
||
# │ compiled_throughput_str, compilation_time_s_str, eager_total_str,
|
||
# │ compiled_total_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsys.formatting import fmt, check
|
||
|
||
# --- Inputs: prototyping scenario ---
|
||
n_images_value = 10_000 # small experiment
|
||
eager_throughput_value = 1_450 # images/sec (eager)
|
||
compiled_throughput_value = 2_150 # images/sec (compiled)
|
||
n_recompilations_value = 10 # code changes
|
||
compilation_time_s_value = 30 # seconds per compile
|
||
|
||
# --- Process: total execution time ---
|
||
eager_total_value = n_images_value / eager_throughput_value # ~6.9 s
|
||
compiled_total_value = (n_images_value / compiled_throughput_value +
|
||
n_recompilations_value * compilation_time_s_value) # ~304.7 s
|
||
|
||
# --- Outputs (formatted strings for prose) ---
|
||
n_images_str = f"{n_images_value:,}" # e.g. "10,000"
|
||
n_recompilations_str = fmt(n_recompilations_value, precision=0, commas=False) # e.g. "10"
|
||
eager_throughput_str = f"{eager_throughput_value:,}" # e.g. "1,450"
|
||
compiled_throughput_str = f"{compiled_throughput_value:,}" # e.g. "2,150"
|
||
compilation_time_s_str = fmt(compilation_time_s_value, precision=0, commas=False) # e.g. "30"
|
||
eager_total_str = fmt(eager_total_value, precision=1, commas=False) # e.g. "6.9"
|
||
compiled_total_str = fmt(compiled_total_value, precision=1, commas=False) # e.g. "304.7"
|
||
```
|
||
|
||
Engineers assume compilation overhead is a one-time cost that pays off quickly. @tbl-training-benchmark shows torch.compile achieves 48% higher ResNet-50 throughput but incurs 15--60 seconds compilation overhead per graph change. For a `{python} n_images_str`-image experiment with `{python} n_recompilations_str` code changes: Eager completes in `{python} eager_total_str` seconds while Compiled requires `{python} compiled_total_str` seconds (including `{python} n_recompilations_str` x `{python} compilation_time_s_str` s recompilation overhead). Teams that enable compilation during rapid prototyping waste hours waiting for recompilations that negate any throughput gains.
|
||
|
||
## Summary {#sec-ml-frameworks-summary-07f0}
|
||
|
||
Machine learning frameworks exist to solve three fundamental problems that would otherwise make deep learning impractical:
|
||
|
||
1. **The Execution Problem**: When and how should computation happen? Frameworks navigate the trade-off between eager execution (immediate, debuggable, flexible) and graph execution (deferred, optimizable, deployable). Modern hybrid approaches like `torch.compile` attempt to provide both flexibility during development and optimization for production.
|
||
|
||
2. **The Differentiation Problem**: How do we compute gradients automatically? Frameworks implement reverse-mode automatic differentiation that computes exact gradients for arbitrary operation compositions. This transforms the mathematical chain rule into a software primitive, enabling training on billions of parameters with a single `loss.backward()` call.
|
||
|
||
3. **The Abstraction Problem**: How do we target diverse hardware from a single interface? Frameworks provide tensor abstractions, intermediate representations, and runtime systems that hide hardware complexity while enabling efficient utilization across CPUs, GPUs, TPUs, and specialized accelerators.
|
||
|
||
These problems are interconnected and constrained by the **Iron Law** of performance (@sec-introduction-iron-law-ml-systems-c32a): execution strategy determines dispatch overhead ($L_{lat}$), differentiation determines memory traffic ($D_{vol}$), and abstraction determines hardware utilization ($\eta$). The memory wall makes data movement often more expensive than computation, explaining why frameworks invest in kernel fusion, activation checkpointing, mixed-precision training, and compilation pipelines.
|
||
|
||
::: {.callout-takeaways title="The Layer Between Math and Hardware"}
|
||
|
||
* **Three problems define every framework**: Execution (how to run), differentiation (how to train), and abstraction (how to express). TensorFlow prioritizes abstraction for deployment breadth, PyTorch prioritizes execution for research velocity, and JAX reframes differentiation through composable function transformations. These are infrastructure commitments, not tooling preferences.
|
||
* **The memory wall drives optimization**: Compute has grown approximately 1000× faster than memory bandwidth. Kernel fusion, activation checkpointing, mixed-precision training, and data layout optimizations all target the data movement term ($D_{vol}$) in the Iron Law, not the compute term.
|
||
* **Compilation pays off only at scale**: The Compilation Continuum principle (@eq-compilation-benefit) quantifies when compilation benefits exceed costs. Research prototyping favors eager mode; production training and inference favor progressive compilation from JIT to AOT. The Dispatch Overhead Law (@eq-dispatch-overhead) explains why small models benefit disproportionately.
|
||
* **The nn.Module pattern is widely adopted**: Automatic parameter discovery, mode-dependent behavior, and hierarchical composition with serialization appear across major frameworks, enabling million-parameter optimization in a single `optimizer.step()` call regardless of API syntax.
|
||
* **Framework choice constrains deployment by orders of magnitude**: A 17× latency gap (PyTorch vs. TensorRT) and 7,040× memory gap (PyTorch Mobile vs. TFLite Micro) on identical models demonstrate that frameworks are not interchangeable. Deployment target must be evaluated before framework selection.
|
||
|
||
:::
|
||
|
||
Understanding framework internals transforms how practitioners approach performance debugging and optimization. When a training job runs slower than expected, engineers who understand execution graphs can identify whether the bottleneck lies in eager-mode overhead, insufficient kernel fusion, or suboptimal memory layout. When deployment fails on target hardware, the compilation pipeline reveals whether the issue is operator support, quantization compatibility, or runtime configuration. This knowledge is essential for diagnosing and resolving performance issues in production systems.
|
||
|
||
::: {.callout-chapter-connection title="From Control Room to Power Plant"}
|
||
|
||
We have established the software substrate of ML: the frameworks that translate abstract architectures into executable kernels. The computational graphs, autograd tapes, and kernel dispatch pipelines examined here are the control room instruments---they give engineers visibility into and control over the training process. A control room without a source of energy, however, is just a room with glowing lights. We turn next to @sec-model-training, where the concepts introduced here---mixed-precision training, gradient checkpointing, compilation pipelines, and distributed execution contexts---scale from single-device examples to the massive multi-GPU and multi-node orchestration that powers modern AI.
|
||
|
||
:::
|
||
|
||
::: { .quiz-end }
|
||
:::
|