mirror of
https://github.com/harvard-edge/cs249r_book.git
synced 2026-03-09 07:15:51 -05:00
4936 lines
397 KiB
Plaintext
4936 lines
397 KiB
Plaintext
---
|
||
quiz: frameworks_quizzes.json
|
||
concepts: frameworks_concepts.yml
|
||
glossary: frameworks_glossary.json
|
||
engine: jupyter
|
||
---
|
||
|
||
# ML Frameworks {#sec-ml-frameworks}
|
||
|
||
::: {layout-narrow}
|
||
::: {.column-margin}
|
||
|
||
\chapterminitoc
|
||
|
||
:::
|
||
|
||
\noindent
|
||
{fig-alt="Colorful illustration showing machine learning framework concepts with icons representing TensorFlow, Keras, PyTorch, and other tools, featuring geometric shapes, charts, and training and inference workflow labels."}
|
||
|
||
:::
|
||
|
||
## Purpose {.unnumbered}
|
||
|
||
\begin{marginfigure}
|
||
\mlsysstack{35}{90}{40}{45}{15}{0}{0}{10}
|
||
\end{marginfigure}
|
||
|
||
_Why does the framework---the layer between your math and your hardware---silently constrain every decision that follows?_
|
||
|
||
Neural networks are defined by mathematics (matrix multiplications, gradient computations, activation functions), but mathematics does not execute itself. Between the equations and the silicon lies a translation layer that decides how operations are scheduled on hardware, how memory is allocated across the compute hierarchy, and how gradients flow backward through the computational graph. The framework *is* this translation layer, and the translation is not neutral. An eager-mode framework that prioritizes debugging flexibility sacrifices the graph-level optimizations that can halve inference latency. A framework lacking support for the target accelerator renders the hardware investment useless. A framework with a rich training API but no export path to edge devices means the model cannot reach the deployment target it was designed for. Architecture choices are at least visible: engineers debate model size, layer count, and attention mechanisms explicitly. Framework choices are more insidious because they operate below the level of daily attention, silently determining which optimizations are possible, which hardware is reachable, and which deployment paths exist. In the AI Triad (@sec-introduction), the framework is the invisible mediator between Algorithm and Machine, and its design choices---baked into its compilation stack, memory management, and operator libraries---are difficult to reverse. Migrating between frameworks invalidates data pipelines, serving infrastructure, model checkpoints, and team expertise, typically requiring months of engineering effort for production systems. Framework selection is therefore an infrastructure commitment that determines what the system *can* do on the hardware it *must* run on.
|
||
|
||
::: {.content-visible when-format="pdf"}
|
||
|
||
\newpage
|
||
|
||
:::
|
||
|
||
::: {.callout-tip title="Learning Objectives"}
|
||
|
||
- Explain how ML frameworks solve three core problems: execution (**computational graphs**), differentiation (**automatic differentiation**), and abstraction (hardware-optimized operations)
|
||
- Compare eager, static, and hybrid (**JIT**) execution strategies using the **Compilation Continuum** and **Dispatch Overhead** principles to determine when compilation benefits outweigh costs
|
||
- Describe the **nn.Module** abstraction pattern for hierarchical composition, automatic parameter discovery, and mode-dependent behavior
|
||
- Analyze how the **Memory Wall** drives framework optimization strategies including **kernel fusion**, **mixed-precision training**, and **activation checkpointing**
|
||
- Evaluate major framework architectures (**TensorFlow**, **PyTorch**, **JAX**) based on their execution models, differentiation approaches, and deployment trade-offs
|
||
- Evaluate framework selection trade-offs by matching model requirements, hardware constraints, and deployment targets across the cloud-to-edge spectrum
|
||
|
||
:::
|
||
|
||
```{python}
|
||
#| label: a100-specs-blas
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ A100 SPECS FOR BLAS FOOTNOTE
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: fn-blas-performance footnote + fn-gemm-utilization footnote in Ladder of Abstraction;
|
||
# │ also @sec-ml-frameworks-execution-strategy-matters-memory-wall-1ce8
|
||
# │ and @sec-ml-frameworks-device-memory-management-9404 prose
|
||
# │
|
||
# │ Goal: Quantify A100 peak FP16 TFLOPS (dense ~312 and sparse ~624) to ground
|
||
# │ the BLAS-hardware connection and Memory Wall discussion.
|
||
# │ Show: "up to 312 TFLOPS (FP16/BF16/TF32) … or 624 TFLOPS with structured
|
||
# │ sparsity" — in fn-blas-performance footnote and Device Bandwidth Hierarchy prose.
|
||
# │ How: Hardware.Cloud.A100.peak_flops → m_as(TFLOPs/second); sparse = 2×.
|
||
# │
|
||
# │ Note: PERSISTENT — A100BLAS.dense_tflops_str used again at line ~284
|
||
# │ (Memory Wall), line ~2627 (Device Bandwidth Hierarchy), line ~3083
|
||
# │ (Kernel Manager subsection). Also feeds GraphOptimizationStats.
|
||
# │
|
||
# │ Imports: mlsysim.core.constants (TFLOPs, second), mlsysim.book (fmt, check)
|
||
# │ Exports: graph_flop_reduction_str, A100BLAS.dense_tflops_str,
|
||
# │ A100BLAS.sparse_tflops_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsysim import Hardware
|
||
from mlsysim.fmt import fmt, check
|
||
from mlsysim.core.constants import TFLOPs, second
|
||
|
||
# ┌── LEGO ───────────────────────────────────────────────
|
||
class GraphOptimizationStats:
|
||
"""
|
||
Namespace for graph-level optimization statistics.
|
||
"""
|
||
flop_reduction_min_pct = 5
|
||
flop_reduction_max_pct = 10
|
||
|
||
flop_reduction_range_str = f"{flop_reduction_min_pct}-{flop_reduction_max_pct}%"
|
||
|
||
# ┌── LEGO ───────────────────────────────────────────────
|
||
class A100BLAS:
|
||
"""
|
||
Namespace for A100 BLAS Specs.
|
||
Scenario: Dense vs Sparse Tensor Core throughput.
|
||
"""
|
||
|
||
# ┌── 1. LOAD (Constants) ───────────────────────────────────────────────
|
||
gpu = Hardware.Cloud.A100
|
||
|
||
# ┌── 2. EXECUTE (The Compute) ─────────────────────────────────────────
|
||
dense_flops = gpu.peak_flops.m_as(TFLOPs/second)
|
||
sparse_flops = dense_flops * 2
|
||
|
||
# ┌── 4. OUTPUT (Formatting) ──────────────────────────────────────────────
|
||
dense_tflops_str = fmt(dense_flops, precision=0, commas=False)
|
||
sparse_tflops_str = fmt(sparse_flops, precision=0, commas=False)
|
||
|
||
# Note: Use A100BLAS.dense_tflops_str directly.
|
||
```
|
||
|
||
## Three Framework Problems {#sec-ml-frameworks-three-problems-every-framework-must-solve-317d}
|
||
|
||
Two lines of code: `model = Transformer(...)` followed by `loss.backward()`. Between them, invisible to the programmer, the framework orchestrates billions of floating-point operations across memory hierarchies, computes exact gradients through millions of parameters using **automatic differentiation** (the systematic application of the chain rule to compute derivatives), schedules thousands of GPU kernel launches, and manages gigabytes of intermediate state. The simplicity is an illusion. Those two lines trigger machinery as complex as a compiler, because that is exactly what a modern ML framework is.
|
||
|
||
The architectures defined in @sec-network-architectures specify *what* computations neural networks perform, but knowing *what* to compute is entirely different from knowing *how* to compute it efficiently. A Transformer's attention mechanism (introduced in @sec-network-architectures) requires coordinating computation across memory hierarchies and accelerator cores in patterns that naive implementations would execute 100$\times$ slower than optimized ones. Implementing these operations from scratch for every model would make deep learning economically infeasible. ML frameworks exist to bridge this gap by translating high-level model definitions into hardware-specific execution plans that extract maximum performance from silicon.
|
||
|
||
\index{ML Framework!compiler analogy}
|
||
A framework is to machine learning *what* a compiler is to traditional programming. A C compiler translates human-readable code into optimized machine instructions, managing register allocation, instruction scheduling, and memory layout. An ML framework translates high-level model definitions into hardware-specific execution plans, managing operator fusion, memory reuse, and device placement. This analogy is more than metaphor: modern frameworks literally include compilers, as we will see throughout this chapter.
|
||
|
||
\index{Deferred Execution!graph construction}
|
||
Every ML framework, regardless of API or design philosophy, must solve three core problems\index{Framework!core problems (execution, differentiation, abstraction)}. First, the *execution problem*: when and how should computation happen? Should operations execute immediately as written (**eager execution**[^fn-eager-execution-tradeoff]), or should the framework build a complete description first---a **computational graph**[^fn-computational-graph-frameworks] (a structured representation of operations and their dependencies)---and optimize before executing (graph execution)? This choice shapes debugging capability, optimization potential, and deployment flexibility. Second, the *differentiation problem*: how should the framework compute gradients automatically? As established in @sec-neural-computation, training (the complex orchestration detailed in @sec-model-training) requires derivatives of a loss function with respect to millions or billions of parameters, and manual differentiation is error-prone at this scale. Frameworks must implement **automatic differentiation** systems that compute exact gradients for arbitrary compositions of operations while managing the memory overhead of storing intermediate values. Third, the *hardware abstraction problem*: how should the framework target diverse hardware from a single interface? The same model definition should run on CPUs, GPUs, TPUs, and mobile devices, each with different memory constraints and optimal execution patterns.
|
||
|
||
[^fn-computational-graph-frameworks]: **Computational Graph**: The "optimize before executing" distinction in the triggering sentence is the key design choice. By capturing the full program as a data structure (pioneered by Theano in 2010), the framework can fuse multiple operations into a single GPU kernel before any code runs, reducing overhead by over 10$\times$. The engineering cost of this visibility is that the executed program differs from the source code, making debugging significantly harder, a trade-off every graph-based framework must justify against the performance gain. \index{Computational Graph!framework optimization}
|
||
|
||
[^fn-eager-execution-tradeoff]: **Eager Execution**: This mode executes each operation sequentially and immediately, which enables direct debugging with standard tools but sacrifices the global view needed for graph-level optimizations. Without seeing the full sequence of computations, the framework cannot fuse operations or pre-plan memory, forfeiting potential speedups of over 30% that compilers like `torch.compile` can provide. \index{Eager Execution!optimization trade-off}
|
||
|
||
These three problems are deeply interconnected. The execution model determines *when* differentiation occurs and *what* optimizations are possible. The abstraction layer must support both execution styles across all hardware targets. Solving any one problem in isolation leads to frameworks that excel in narrow contexts but fail in broader deployment. Because these problems are ultimately about translating mathematics into efficient hardware execution, a useful perspective is to view frameworks not as libraries but as compilers.
|
||
|
||
::: {.callout-perspective title="The ML Compiler"}
|
||
|
||
In the context of the **Iron Law** (@sec-introduction-iron-law-ml-systems-c32a), a framework is a **Compiler for the Silicon Contract**\index{Framework!compiler for the Silicon Contract}.
|
||
|
||
Your "Source Code" is the model architecture (the **$O$** term). The framework's job is to take this high-level math and compile it into a series of hardware-specific kernel launches that:
|
||
|
||
1. Minimize **Data Movement ($D_{\text{vol}}$)** through techniques like kernel fusion.
|
||
2. Maximize **Utilization ($\eta$)** by matching operations to specialized hardware units like Tensor Cores.
|
||
3. Minimize **Overhead ($L_{\text{lat}}$)** through efficient asynchronous dispatch and graph capture.
|
||
|
||
Choosing a framework means choosing the compiler that determines *how* efficiently a model uses hardware.
|
||
|
||
:::
|
||
|
||
With these three problems in mind, we can now define *what* a machine learning framework fundamentally is.
|
||
|
||
::: {.callout-definition title="Machine Learning Frameworks"}
|
||
|
||
***Machine Learning Frameworks***\index{ML Framework!definition} are software systems that translate high-level mathematical model definitions into hardware-optimized execution plans by managing the computational graph, automatic differentiation, kernel dispatch, and memory allocation across the hardware hierarchy.
|
||
|
||
1. **Significance (Quantitative):** Frameworks directly determine the system efficiency ($\eta$) term in the Iron Law. XLA's operator fusion, for example, eliminates intermediate memory writes between consecutive elementwise operations: fusing a matrix multiplication, bias add, and ReLU into a single kernel reduces the total data movement ($D_{\text{vol}}$) by 2–3$\times$ versus three separate kernel launches, yielding observed end-to-end speedups of 1.5–2$\times$ on Transformer training without any model changes.
|
||
2. **Distinction (Durable):** Unlike a numerical library such as NumPy, which executes each operation immediately (eager evaluation), an ML framework can defer execution to analyze the full computational graph and apply global optimizations: operator fusion, memory layout transformations, and parallel scheduling. These optimizations are impossible when operations are evaluated one at a time.
|
||
3. **Common Pitfall:** A frequent misconception is that frameworks are interchangeable API wrappers. Framework choice determines which hardware optimizations are available: a PyTorch model using the default eager execution mode cannot benefit from XLA's graph-level fusion until explicitly compiled with `torch.compile()`, and the resulting throughput difference can exceed 2$\times$ on the same hardware.
|
||
|
||
:::
|
||
|
||
The compiler metaphor is not decorative. An ML framework translates logical intent into physical execution under the constraints of the Iron Law, deciding how to partition computation across memory hierarchies, when to trade numerical precision for throughput, and how to schedule operations so that the dominant term (data movement, computation, or overhead) is minimized. The framework is where the governing physics developed throughout this book becomes executable code.
|
||
|
||
The scale of this translation is not obvious from the API surface. A single call to `loss.backward()` triggers operation recording, memory allocation for gradients, reverse-order graph traversal, and hardware-optimized kernel dispatch---machinery that would require hundreds of lines of manual calculus for even a three-layer network. For a contemporary language model, the framework additionally orchestrates billions of floating-point operations across distributed hardware, coordinating memory hierarchies, communication protocols, and numerical precision. Building this from scratch would be economically prohibitive for most organizations, which is why the history of ML frameworks is a history of progressively automating these layers.
|
||
|
||
The three problems---execution, differentiation, and abstraction---did not emerge simultaneously. Each arose as a response to scaling limitations in the previous generation of tools. Tracing this evolution reveals why modern frameworks are designed as they are and why the particular trade-offs they embody were, in hindsight, inevitable.
|
||
|
||
## The Ladder of Abstraction {#sec-ml-frameworks-frameworks-evolved-ac68}
|
||
|
||
\index{ML Framework!historical evolution}
|
||
\index{NumPy!framework foundation}
|
||
In 1979, writing a matrix multiplication in Fortran that saturated the hardware required deep knowledge of cache lines, register scheduling, and vector units. By 2016, a single line of Python (`torch.matmul(A, B)`) achieved the same peak throughput without the programmer knowing anything about the silicon. That compression of effort did not happen in one step; it accumulated across four decades of abstraction, each layer solving a bottleneck that made the previous generation impractical for scaling. The result is a **Ladder of Abstraction** where each rung automates what the rung below exposed.
|
||
|
||
1. **Solving Performance (1979–1992)**: The **Basic Linear Algebra Subprograms (BLAS)**\index{BLAS!historical foundation}[^fn-blas-performance] and **LAPACK**[^fn-lapack-algebra] solved the problem of *Hardware Primitives*. They provided standardized, highly optimized implementations of matrix operations (like GEMM[^fn-gemm-utilization]). This layer ensures that `C = A @ B` runs at near-peak silicon speed, regardless of the language calling it.
|
||
|
||
[^fn-gemm-utilization]: **GEMM (General Matrix Multiply)**: The single operation that the "near-peak silicon speed" claim rests on. Hardware vendors hand-tune GEMM for their specific chips because every layer in a neural network reduces to matrix multiplication, making this one routine the performance floor for all frameworks above it on the ladder. The catch: GEMM achieves peak throughput only when matrix dimensions satisfy strict alignment constraints (multiples of 8 for NVIDIA Tensor Cores), and violating these rules drops a framework from over 90% to roughly 30% of $R_{\text{peak}}$. \index{GEMM!hardware utilization}
|
||
|
||
[^fn-lapack-algebra]: **LAPACK (Linear Algebra PACKage)**: Extends BLAS by providing a standard API for higher-level routines (SVD, eigendecomposition, least-squares) that vendors implement with chip-specific code layered on top of fast GEMM kernels. This layered design is the architectural pattern every ML framework inherits: high-level operations delegate downward to hand-tuned primitives, so a vendor-optimized LAPACK call can execute over 10$\times$ faster than a naive implementation without the framework author writing a single line of hardware-specific code. \index{LAPACK!ML initialization}
|
||
|
||
[^fn-blas-performance]: **BLAS (Basic Linear Algebra Subprograms)**: The 1979 API specification that forms the bottom rung of the ladder described here. By decoupling `C = A @ B` from its hardware implementation, BLAS forced vendors to compete on optimized libraries (NVIDIA cuBLAS, Intel MKL) for a fixed set of primitives. Every framework above it on the ladder inherits this bargain: a single BLAS call from any language can saturate an A100, achieving over `{python} A100BLAS.dense_tflops_str` TFLOPS for GEMM alone, without the framework knowing anything about the silicon. \index{BLAS!performance foundation}
|
||
|
||
2. **Solving Usability (2006)**: **NumPy**[^fn-numpy-abstraction] solved the problem of *Developer Velocity*. By wrapping low-level BLAS routines in high-level Python, it allowed scientists to write code in a friendly language while executing it in optimized C/Fortran. This "Vectorization" pattern, where the slow language handles logic and the fast language handles loops, became the standard contract for scientific computing.
|
||
|
||
[^fn-numpy-abstraction]: **NumPy (Numerical Python)**: In 2005, Travis Oliphant unified two competing Python array libraries (Numeric and Numarray) into a single package, giving the scientific computing community one BLAS-backed array standard at the moment it needed to scale. The "vectorization" contract this created (write logic in Python, execute loops in C/Fortran via BLAS) became the design template for every ML framework that followed: PyTorch tensors and TensorFlow arrays are direct descendants, extending the same n-dimensional array abstraction to GPUs. Python's dominance in ML is a direct inheritance from this consolidation decision. \index{NumPy!framework lineage}
|
||
|
||
3. **Solving Differentiation (2015–present)**: **Deep Learning Frameworks** (Theano[^fn-theano-origin], TensorFlow, PyTorch) solved the problem of *Gradient Computation*. While NumPy required manual derivation of backpropagation gradients (error-prone and slow), these frameworks introduced **Automatic Differentiation** via the computational graph [@rumelhart1986learning]. This turned the chain rule into a software primitive, allowing researchers to define *forward* passes and get *backward* passes for free.
|
||
|
||
[^fn-theano-origin]: **Theano**: Developed at the Montreal Institute for Learning Algorithms (MILA) under Yoshua Bengio starting in 2007, Theano was the first framework to compile symbolic mathematical expressions into optimized CPU and GPU code via computational graphs [@bergstra2010theano]. Its key insight -- that a Python-defined computation graph could be compiled to CUDA without the researcher writing GPU code -- became the architectural template for TensorFlow (2015) and influenced PyTorch's autograd design. Theano was retired in 2017, but every modern framework inherits its core abstraction. \index{Theano!computational graph origin}
|
||
|
||
As @fig-mlfm-timeline illustrates, this progression reveals a critical insight: frameworks exist to bridge the gap between mathematical intent and silicon reality. As we move up the ladder, we gain productivity but lose transparency—a trade-off we explore in the Execution Problem (@sec-ml-frameworks-execution-problem-e1e1).
|
||
|
||
::: {#fig-mlfm-timeline fig-env="figure" fig-pos="htb" fig-cap="**Computational Library Evolution**: Modern machine learning frameworks build on decades of numerical computing advancements, transitioning from low-level routines like BLAS and LAPACK to high-level abstractions in NumPy and SciPy, and finally to deep learning frameworks such as Theano [@bergstra2010theano], TensorFlow, and PyTorch. SciPy was first released in 2001; the 2007 entry marks the period when both SciPy's maturing ecosystem and Theano's introduction of computational graphs jointly established Python as the dominant language for scientific and machine learning computing." fig-alt="Horizontal timeline from 1979 to 2018 with colored boxes marking key years. Dashed arrows connect to milestones below: 1979 BLAS introduced, 1992 LAPACK extends BLAS, 2006 NumPy becomes Python's numerical backbone, 2007 SciPy and Theano introduce computational graphs, 2015 TensorFlow revolutionizes distributed ML, 2016 PyTorch introduces dynamic graphs, 2018 JAX introduces functional paradigms."}
|
||
|
||
```{.tikz}
|
||
\begin{tikzpicture}[node distance=1mm,outer sep=0pt,font=\small\usefont{T1}{phv}{m}{n}]
|
||
\tikzset{%
|
||
Line/.style={line width=1.0pt,black!50
|
||
},
|
||
Box/.style={inner xsep=1pt,
|
||
draw=none,
|
||
fill=#1,
|
||
anchor=west,
|
||
text width=27mm,align=flush center,
|
||
minimum width=28mm, minimum height=13mm
|
||
},
|
||
Box/.default=red
|
||
}
|
||
\definecolor{col1}{RGB}{128, 179, 255}
|
||
\definecolor{col2}{RGB}{255, 255, 128}
|
||
\definecolor{col3}{RGB}{204, 255, 204}
|
||
\definecolor{col4}{RGB}{230, 179, 255}
|
||
\definecolor{col5}{RGB}{255, 153, 204}
|
||
\definecolor{col6}{RGB}{245, 82, 102}
|
||
\definecolor{col7}{RGB}{255, 102, 102}
|
||
|
||
\node[Box={col1}](B1){1979};
|
||
\node[Box={col2!},right=of B1](B2){1992};
|
||
\node[Box={col3},right=of B2](B3){2006};
|
||
\node[Box={col4},right=of B3](B4){2007};
|
||
\node[Box={col5},right=of B4](B5){2015};
|
||
\node[Box={col6},right=of B5](B6){2016};
|
||
\node[Box={col7},right=of B6](B7){2018};
|
||
%%
|
||
\foreach \x in{1,2,...,7}
|
||
\draw[dashed,thick,-latex](B\x)--++(270:6);
|
||
|
||
\path[red]([yshift=-8mm]B1.south west)coordinate(P)-|coordinate(K)(B7.south east);
|
||
|
||
\draw[line width=2pt,-latex](P)--(K)--++(0:3mm);
|
||
|
||
\node[Box={col1!50},below=2 of B1](BB1){BLAS introduced};
|
||
\node[Box={col2!50},below=2 of B2](BB2){LAPACK extends BLAS};
|
||
\node[Box={col3!50},below=2 of B3](BB3){NumPy becomes Python's numerical backbone};
|
||
\node[Box={col4!50},below=2 of B4](BB4){SciPy adds advanced computations};
|
||
\node[Box={col4!50},below= 2mm of BB4](BBB4){Theano introduces computational graphs};
|
||
\node[Box={col5!50},below=2 of B5](BB5){TensorFlow revolutionizes distributed ML};
|
||
\node[Box={col6!50},below=2 of B6](BB6){PyTorch introduces dynamic graphs};
|
||
\node[Box={col7!50},below=2 of B7](BB7){JAX introduces functional paradigms};
|
||
\end{tikzpicture}
|
||
```
|
||
|
||
:::
|
||
|
||
Each generation abstracted away details that consumed engineering effort in the previous one, yet each abstraction introduced new trade-offs. BLAS hid assembly-level optimization but fixed the interface. NumPy hid memory management but required manual differentiation. Modern frameworks hide gradient computation but introduce the execution model choice we examine next.
|
||
|
||
All modern frameworks converge on the same three core problems: *how* to execute computation, *how* to differentiate it, and *how* to abstract across hardware. We begin with the most visible of these: the execution problem, because its resolution determines what optimizations the other two problems can exploit.
|
||
|
||
## Execution Problem {#sec-ml-frameworks-execution-problem-e1e1}
|
||
|
||
\index{Execution Problem!definition}
|
||
Consider two engineers writing the same neural network. The first debugs interactively, printing tensor shapes after each operation, inspecting intermediate values, and stepping through code with `pdb`. The second waits 30 seconds for compilation, then watches the model run 3$\times$ faster with no ability to inspect any intermediate state. Both are correct; they have simply made different choices about the execution problem, the question of whether operations should execute immediately as written or be recorded for later execution. This choice creates a cascade of engineering trade-offs that shape every aspect of framework behavior, from debugging workflows to deployment options to peak hardware utilization.
|
||
|
||
### Why Execution Strategy Matters: The Memory Wall {#sec-ml-frameworks-execution-strategy-matters-memory-wall-1ce8}
|
||
|
||
```{python}
|
||
#| label: a100-memory-wall
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ A100 SPECS FOR MEMORY WALL DISCUSSION
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: @sec-ml-frameworks-execution-strategy-matters-memory-wall-1ce8 —
|
||
# │ opening paragraph of "Why Execution Strategy Matters: The Memory Wall"
|
||
# │
|
||
# │ Goal: Quantify the compute-vs-bandwidth gap on the A100 (312 TFLOPS vs 2 TB/s
|
||
# │ = 156 ops/byte ridge point) to show why element-wise ops hit <1% utilization.
|
||
# │ Show: "312 TFLOPS … 2.0 TB/s" — inline in Memory Wall opening paragraph;
|
||
# │ a100_bw_tbs_str also reused in Device Bandwidth Hierarchy (line ~2615).
|
||
# │ How: Hardware.Cloud.A100 attributes → m_as(TFLOPs/second) and m_as(TB/second);
|
||
# │ ridge_point() → m_as(flop/byte) for the invariant check.
|
||
# │
|
||
# │ Note: PERSISTENT — MemoryWallSpecs.a100_bw_tbs_str reused at line ~2615
|
||
# │ (Device and Memory Management section prose).
|
||
# │
|
||
# │ Imports: mlsysim.core.constants (TFLOPs, TB, second, flop, byte), mlsysim.book (fmt, check)
|
||
# │ Exports: MemoryWallSpecs.a100_tflops_fp16_str, MemoryWallSpecs.a100_bw_tbs_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsysim import Hardware
|
||
from mlsysim.core.constants import TFLOPs, TB, second, flop, byte
|
||
from mlsysim.fmt import fmt, check
|
||
|
||
# ┌── LEGO ───────────────────────────────────────────────
|
||
class MemoryWallSpecs:
|
||
"""
|
||
Namespace for A100 Memory Wall Specs.
|
||
Scenario: Demonstrating the 150x gap between compute and bandwidth.
|
||
"""
|
||
|
||
# ┌── 1. LOAD (Constants) ───────────────────────────────────────────────
|
||
gpu = Hardware.Cloud.A100
|
||
|
||
flops_fp16 = gpu.peak_flops.m_as(TFLOPs/second)
|
||
bw_tbs = gpu.memory_bw.m_as(TB/second)
|
||
|
||
# ┌── 2. EXECUTE (The Compute) ─────────────────────────────────────────
|
||
# Step 1: Arithmetic Intensity "Ridge Point" (Ops / Byte)
|
||
ridge_point = gpu.ridge_point().m_as(flop/byte)
|
||
|
||
# ┌── 3. GUARD (Invariants) ───────────────────────────────────────────
|
||
check(ridge_point >= 100, f"A100 ridge point ({ridge_point:.1f}) is too low to claim a 'Memory Wall'.")
|
||
|
||
# ┌── 4. OUTPUT (Formatting) ──────────────────────────────────────────────
|
||
a100_tflops_fp16_str = fmt(flops_fp16, precision=0, commas=False)
|
||
a100_bw_tbs_str = fmt(bw_tbs, precision=1, commas=False)
|
||
|
||
# Note: Use MemoryWallSpecs.a100_tflops_fp16_str directly.
|
||
```
|
||
|
||
\index{Memory Wall!execution strategy impact}
|
||
To understand why execution strategy matters so much, consider the **Memory Wall**\index{Memory Wall!definition} (first introduced in @sec-neural-computation), the growing gap between processor computational speed and memory bandwidth. Modern GPUs can perform arithmetic far faster than they can fetch data. On an A100 GPU with `{python} A100BLAS.dense_tflops_str` TFLOPS of compute and `{python} MemoryWallSpecs.a100_bw_tbs_str` TB/s of memory bandwidth, element-wise operations like ReLU achieve less than 1% of peak compute capacity, not because the hardware is slow, but because they spend nearly all their time waiting for data. The Roofline Model (@sec-machine-foundations-roofline-model-2529) formalizes this trade-off, showing exactly when operations are memory-bound versus compute-bound.
|
||
|
||
The memory wall creates a critical classification: operations are either **compute-bound**\index{Compute-bound operations} (limited by arithmetic throughput, like large matrix multiplications) or **memory-bound**\index{Memory-bound operations} (limited by data movement, like activation functions and normalization). Most individual neural network operation types (activations, normalizations, element-wise operations) are memory-bound, though the large matrix multiplications that dominate total compute time can be compute-bound.
|
||
|
||
The key optimization for memory-bound operations is **kernel fusion**\index{Kernel Fusion!optimizing memory-bound ops}, combining multiple operations into a single GPU function (called a *kernel*)[^fn-kernel-gpu-dispatch] to avoid intermediate memory traffic. Fusing a sequence of LayerNorm, Dropout, and ReLU into one kernel can yield 5$\times$ speedup by eliminating intermediate writes between operations. FlashAttention[^fn-flashattention-fusion-fw] fuses the entire attention computation, reducing HBM traffic by 10--20$\times$ and achieving 2--4$\times$ wall-clock speedup.
|
||
|
||
[^fn-kernel-gpu-dispatch]: **Kernel (GPU)**: In GPU programming, a kernel is the function dispatched to execute in parallel across thousands of threads. Each kernel launch incurs 5--20 $\mu$s of CPU-side overhead for parameter assembly and GPU signaling, which means that small, unfused operations spend more time on launch overhead ($L_{\text{lat}}$) than on useful arithmetic. Reducing kernel count through fusion is therefore a direct attack on the overhead term of the Iron Law. \index{Kernel!GPU dispatch overhead}
|
||
|
||
[^fn-flashattention-fusion-fw]: **FlashAttention**: Kernel fusion taken to its logical extreme, fusing the entire attention computation (Q, K, V projections, softmax, output) into a single kernel that tiles data to fit in SRAM (introduced in @sec-network-architectures). By reducing HBM traffic 10--20$\times$, FlashAttention transforms a memory-bound operation into a compute-bound one, demonstrating that framework-level fusion can shift an operation's position on the Roofline Model from bandwidth-limited to throughput-limited. \index{FlashAttention!kernel fusion}
|
||
|
||
A framework can only fuse operations it can see together. If operations execute immediately one at a time (eager execution)\index{Eager Execution!optimization limitations}, the framework cannot fuse them. If operations are recorded first into a graph (deferred execution)\index{Graph Execution!optimization advantages}, the framework can analyze and optimize the entire computation. This is why execution strategy matters so much: it determines *what* optimizations are even possible.
|
||
|
||
### The Computational Graph {#sec-ml-frameworks-computational-graph-00f7}
|
||
|
||
Kernel fusion is the key optimization for memory-bound operations, but fusion requires seeing multiple operations together. How do frameworks represent computation in a way that makes this visibility possible? The answer is the **computational graph**\index{Computational Graph!definition (DAG)}, a directed acyclic graph (DAG) where nodes represent operations and edges represent data dependencies. This graph is the framework's internal model of the computation.
|
||
|
||
To ground this abstraction, examine @fig-comp-graph: computing $z = x \times y$ maps onto two input nodes ($x$ and $y$), one operation node (multiplication), and one output node ($z$). The execution problem asks: *when* is this graph constructed, and *when* is it executed?
|
||
|
||
::: {#fig-comp-graph fig-env="figure" fig-pos="htb" fig-cap="**Simple Computational Graph.** The computation $z = x \\times y$ represented as a graph, where nodes define operations and edges specify data flow." fig-alt="Simple directed graph with nodes x and y flowing into function f(x,y) which outputs z."}
|
||
|
||
```{.tikz}
|
||
\begin{tikzpicture}[font=\usefont{T1}{phv}{m}{n}\small]
|
||
%
|
||
\tikzset{%
|
||
Line/.style={line width=1.0pt,black!50,rounded corners
|
||
},
|
||
Box/.style={align=flush center,
|
||
shape=circle,
|
||
inner xsep=1pt,
|
||
node distance=1.4,
|
||
draw=BlueLine,
|
||
line width=0.75pt,
|
||
fill=BlueL,
|
||
minimum width=8mm,
|
||
},
|
||
}
|
||
\node[Box,fill=GreenFill,draw=GreenLine,minimum width=13mm, ](B1){$f(x,y)$};
|
||
\node[Box,right=of B1,fill=GreenFill,draw=GreenLine](B2){$z$};
|
||
\node[Box,above left=0.1 and 2 of B1,fill=GreenFill,draw=GreenLine](B3){$x$};
|
||
\node[Box,below left=0.1 and 2 of B1,fill=GreenFill,draw=GreenLine](B4){y};
|
||
\draw[-latex,Line](B1)--(B2);
|
||
\draw[-latex,Line](B3)to[bend left=25](B1);
|
||
\draw[-latex,Line](B4)to[bend right=25](B1);
|
||
\end{tikzpicture}
|
||
```
|
||
|
||
:::
|
||
|
||
Real machine learning models require much more complex graph structures. @fig-mlfm-comp-graph extends this representation to show a neural network computation graph alongside the system components that reason about it. In the left panel, notice how data flows through six operation nodes in a directed acyclic graph---each node's output becomes the next node's input. The right panel reveals what the framework gains by having this graph: it can query the structure to plan memory allocation for each tensor's lifetime, and it can assign operations to devices based on data dependencies rather than execution order. The critical insight is that the graph exists independently of execution, enabling the framework to optimize *before* any arithmetic occurs.
|
||
|
||
::: {#fig-mlfm-comp-graph fig-env="figure" fig-pos="htb" fig-cap="**Computation Graph with System Interactions.** A neural network computation graph (left) alongside system components including memory management and device placement (right) that interact with the graph to optimize resource allocation before execution." fig-alt="Left side shows computational graph with 6 operation nodes connected by data flow edges. Right side shows system components box with Memory Management and Device Placement nodes that interact with the computational graph."}
|
||
|
||
```{.tikz}
|
||
\begin{tikzpicture}[font=\usefont{T1}{phv}{m}{n}\small]
|
||
|
||
\tikzset{
|
||
Box/.style={ inner xsep=2pt,
|
||
node distance=1.4,
|
||
draw=none,
|
||
line width=0.5pt,,
|
||
fill=none,
|
||
minimum width=27mm, minimum height=15mm
|
||
},
|
||
LineA/.style={violet!40,line width=5pt,{-{Triangle[width=1.0*11pt,length=1.0*8pt]}},shorten <=1pt,shorten >=1pt},
|
||
graphpanel/.style={
|
||
draw=olive!60!black,
|
||
fill=olive!02,
|
||
line width=0.9pt,
|
||
rounded corners=4pt,
|
||
inner sep=14pt,yshift=9pt
|
||
},
|
||
syspanel/.style={
|
||
draw=orange!70!black,
|
||
fill=orange!03,
|
||
line width=0.9pt,
|
||
rounded corners=4pt,
|
||
inner sep=10pt
|
||
},
|
||
opnode/.style={
|
||
circle,node distance=7mm,
|
||
draw=BlueLine,,
|
||
fill=cyan!15,
|
||
minimum size=7mm,
|
||
line width=0.9pt
|
||
},
|
||
compbox/.style={
|
||
draw=orange!80!black,
|
||
fill=orange!12,
|
||
rounded corners=2pt,
|
||
minimum width=2.7cm,
|
||
minimum height=0.9cm,
|
||
align=center,
|
||
line width=0.8pt
|
||
},
|
||
flow/.style={
|
||
-{Latex[length=2.2mm]},
|
||
draw=BrownLine!75,
|
||
line width=0.9pt
|
||
},
|
||
}
|
||
%CPU style
|
||
\tikzset{
|
||
pics/cpu/.style = {
|
||
code = {
|
||
\pgfkeys{/channel/.cd, #1}
|
||
\begin{scope}[local bounding box = CPU,shift={($(0,0)+(0,0)$)},scale=\scalefac,every node/.append style={transform shape}]
|
||
\node[fill=\filllcolor,minimum width=66, minimum height=66,
|
||
rounded corners=2,outer sep=2pt] (C1) {};
|
||
\node[fill=white,minimum width=54, minimum height=54] (C2) {};
|
||
\node[fill=\filllcolor!50,minimum width=44, minimum height=44] (C3) {\Large\bfseries GPU};
|
||
|
||
\foreach \x/\y in {0.11/1,0.26/2,0.41/3,0.56/4,0.71/5,0.85/6}{
|
||
\node[fill=\filllcolor,minimum width=4, minimum height=15,
|
||
inner sep=0pt,anchor=south](GO\y)at($(C1.north west)!\x!(C1.north east)$){};
|
||
}
|
||
\foreach \x/\y in {0.11/1,0.26/2,0.41/3,0.56/4,0.71/5,0.85/6}{
|
||
\node[fill=\filllcolor,minimum width=4, minimum height=15,
|
||
inner sep=0pt,anchor=north](DO\y)at($(C1.south west)!\x!(C1.south east)$){};
|
||
}
|
||
\foreach \x/\y in {0.11/1,0.26/2,0.41/3,0.56/4,0.71/5,0.85/6}{
|
||
\node[fill=\filllcolor,minimum width=15, minimum height=4,
|
||
inner sep=0pt,anchor=east](LE\y)at($(C1.north west)!\x!(C1.south west)$){};
|
||
}
|
||
\foreach \x/\y in {0.11/1,0.26/2,0.41/3,0.56/4,0.71/5,0.85/6}{
|
||
\node[fill=\filllcolor,minimum width=15, minimum height=4,
|
||
inner sep=0pt,anchor=west](DE\y)at($(C1.north east)!\x!(C1.south east)$){};
|
||
}
|
||
\end{scope}
|
||
}
|
||
}
|
||
}
|
||
\tikzset{
|
||
pics/dram/.style = {
|
||
code = {
|
||
\pgfkeys{/channel/.cd, #1}
|
||
\begin{scope}[shift={($(0,0)+(0,0)$)},scale=\scalefac,every node/.append style={transform shape}]
|
||
\node[draw=\drawcolor,fill=\filllcolor!70,line width=1.5*\Linewidth,inner sep=0pt,outer sep=0pt,
|
||
minimum width=56mm,minimum height=14mm](DRAM\picname)at(0,0){};
|
||
\node[draw=\drawcolor,fill=\filllcolor!30,line width=1.5*\Linewidth,inner sep=0pt,outer sep=0pt,anchor=north,
|
||
minimum width=52mm,minimum height=6mm](MDRAM\picname)at(DRAM\picname.south){};
|
||
%
|
||
\pgfmathsetmacro{\spacing}{56/(6+1)}
|
||
\foreach \i in {1,...,6} {
|
||
\pgfmathsetmacro{\x}{\i * \spacing}
|
||
\node[draw=\drawcolor,fill=\filllcolor!20,line width=\Linewidth, inner sep=0pt, outer sep=0pt,
|
||
minimum width=6mm, minimum height=8mm]
|
||
at ([xshift=\x mm]DRAM\picname.west) {};
|
||
}
|
||
%
|
||
\foreach \i in {1,...,19} {
|
||
\pgfmathsetmacro{\x}{\i*(52/20)}
|
||
\draw[draw=\drawcolor, line width=3*\Linewidth]
|
||
([xshift=\x mm,yshift=1pt]MDRAM\picname.south west) -- ++(0,2mm);
|
||
}
|
||
\end{scope}
|
||
}
|
||
}
|
||
}
|
||
\pgfkeys{
|
||
/channel/.cd,
|
||
Depth/.store in=\Depth,
|
||
Height/.store in=\Height,
|
||
Width/.store in=\Width,
|
||
filllcirclecolor/.store in=\filllcirclecolor,
|
||
filllcolor/.store in=\filllcolor,
|
||
drawcolor/.store in=\drawcolor,
|
||
drawcircle/.store in=\drawcircle,
|
||
scalefac/.store in=\scalefac,
|
||
Linewidth/.store in=\Linewidth,
|
||
picname/.store in=\picname,
|
||
filllcolor=BrownLine,
|
||
filllcirclecolor=cyan!40,
|
||
drawcolor=black,
|
||
drawcircle=violet,
|
||
scalefac=1,
|
||
Linewidth=0.5pt,
|
||
Depth=1.3,
|
||
Height=0.8,
|
||
Width=1.1,
|
||
picname=C
|
||
}
|
||
|
||
% graph nodes
|
||
\node[opnode] (a) {};
|
||
\node[opnode,below left=of a] (b) {};
|
||
\node[opnode,below right =of a] (c) {};
|
||
\node[opnode,below=of b] (d) {};
|
||
\node[opnode,below =of c] (e) {};
|
||
\node[opnode,below right=of d] (f) {};
|
||
% edges
|
||
\draw[flow] (a) -- (b);
|
||
\draw[flow] (a) -- (c);
|
||
\draw[flow] (b) -- (d);
|
||
\draw[flow] (c) -- (e);
|
||
\draw[flow] (d) -- (f);
|
||
\draw[flow] (e) -- (f);
|
||
% optional cross-edge
|
||
\draw[flow] (b) -- (e);
|
||
|
||
\node[opnode,minimum size=3mm,below left= 0.33 and 1.65of f] (l) {};
|
||
\node[right=0pt of l,font=\usefont{T1}{phv}{m}{n}\footnotesize](O){Operations};
|
||
\draw[flow]($(O.east)+(1mm,0)$)--++(0:7mm)coordinate(A);
|
||
\node[right=0pt of A,font=\usefont{T1}{phv}{m}{n}\footnotesize](AA){Data flow};
|
||
\scoped[on background layer]
|
||
\node[graphpanel,fit=(a)(l)(O)(AA),inner xsep=6pt](FF){};
|
||
\node[below=1pt of FF.north,font=\usefont{T1}{phv}{b}{n}\footnotesize]{Computational Graph};
|
||
%
|
||
\node[right=29mm of FF.27,Box](MM){};
|
||
\node[below=-6pt of MM](T1){Memory Management};
|
||
\pic[shift={(0,0.1)}] at (MM){dram={scalefac=0.43,picname=1,
|
||
drawcolor=black,filllcolor=OrangeLine!50!,Linewidth=0.5pt}};
|
||
\node[right=29mm of FF.338,Box](DP){};
|
||
\node[below=1pt of DP](T2){Device Placement};
|
||
\pic[shift={(0,0)}] at (DP) {cpu={scalefac=0.45,picname=1,
|
||
drawcolor=RedLine,filllcolor=BlueLine!80!,Linewidth=0.5pt}};
|
||
|
||
\scoped[on background layer]
|
||
\node[fit=(MM)(T1)(T2),syspanel,yshift=1mm,inner xsep=6pt](DD){};
|
||
\node[below=1pt of DD.north,font=\usefont{T1}{phv}{b}{n}\footnotesize]{System Components};
|
||
|
||
\draw[LineA](FF)--
|
||
node[above,pos=0.45,text=black!70,font=\usefont{T1}{phv}{m}{n}\footnotesize]{Interacts with}
|
||
(DD);
|
||
\end{tikzpicture}
|
||
```
|
||
|
||
:::
|
||
|
||
This graph representation is more than a visualization; it is the data structure that enables both efficient execution and automatic differentiation. The answer to *when* this graph is constructed creates a design choice with cascading implications:
|
||
|
||
- **For debugging**: Can you print intermediate values? Step through code with a debugger?
|
||
- **For optimization**: Can the framework see multiple operations at once to fuse them?
|
||
- **For deployment**: Can the model run without a Python interpreter?
|
||
- **For flexibility**: Can control flow depend on computed tensor values?
|
||
|
||
No single execution model optimizes all these dimensions. Frameworks must choose their position in this trade-off space, and practitioners must understand these trade-offs to select appropriate tools and write efficient code. The following sections examine *how* different execution strategies navigate these constraints.
|
||
|
||
### Three Execution Strategies {#sec-ml-frameworks-three-execution-strategies-5934}
|
||
|
||
The computational graph representation enables global optimization, but it raises a critical design question: *when* should the framework build this graph? Consider a simple operation like `y = x * 2`. Two distinct approaches exist:
|
||
|
||
1. **Immediate execution**: Perform the multiplication right now, storing the result in `y`. Natural and debuggable, but the framework sees only one operation at a time.
|
||
|
||
2. **Deferred execution**: Record the intention to multiply, building a graph of operations. Execute later when explicitly requested. Less intuitive, but the framework sees the complete computation, enabling optimization.
|
||
|
||
Neither approach dominates; each embodies different trade-offs between **flexibility** and **optimization potential**. Modern frameworks have explored three primary execution strategies: eager execution with dynamic graphs, static computation graphs, and hybrid approaches that combine JIT compilation with eager development. We examine each through its systems implications.
|
||
|
||
#### Eager Execution with Dynamic Graphs {#sec-ml-frameworks-eager-execution-dynamic-graphs-29c2}
|
||
|
||
\index{Eager Execution!define-by-run model}
|
||
\index{Dynamic Graphs!eager execution}
|
||
|
||
::: {.callout-example title="Eager vs. Graph Execution Code Comparison"}
|
||
|
||
**PyTorch (Eager Execution)**:
|
||
```{.python}
|
||
import torch
|
||
|
||
x = torch.tensor([1.0, 2.0])
|
||
y = x * 2
|
||
print(f"Intermediate value: {y}") # Works immediately
|
||
z = y.sum()
|
||
```
|
||
|
||
**TensorFlow 1.x (Static Graph)**:
|
||
|
||
```{.python}
|
||
import tensorflow as tf
|
||
|
||
x = tf.placeholder(tf.float32)
|
||
y = x * 2
|
||
# print(y) -> Prints Tensor("mul:0"...), not value!
|
||
z = tf.reduce_sum(y)
|
||
|
||
with tf.Session() as sess:
|
||
result = sess.run(z, feed_dict={x: [1.0, 2.0]})
|
||
```
|
||
|
||
:::
|
||
|
||
Eager execution runs operations immediately as encountered, building the computation graph dynamically during execution. When a programmer writes `y = x * 2`, the multiplication happens instantly and the result is available for immediate use.
|
||
|
||
This provides the flexibility of normal programming: developers can print intermediate values, use conditionals based on computed results, and debug with standard tools. The framework records operations as they happen, constructing a **dynamic graph**\index{Dynamic Graphs!runtime construction} that reflects the actual execution path taken.
|
||
|
||
\index{Autograd Tape!definition}
|
||
For gradient computation, the framework records a history of operations in what is called an **autograd tape**\index{Autograd Tape!dynamic graph construction}[^fn-autograd-tape-memory], a transient data structure built during execution. Each tensor operation creates a node that records: the operation performed, references to input tensors, and how to compute gradients. These nodes form a directed acyclic graph (DAG) of operations built **during** forward pass execution, not before.
|
||
|
||
[^fn-autograd-tape-memory]: **Autograd Tape**: A transient data structure built during forward execution, where each node records the operation type, input tensor references, saved intermediate values, and the backward function for chain rule application. The tape's memory cost scales linearly with network depth and is destroyed after the backward pass. For deep models, this transient graph can consume more memory than the model weights themselves, which is why activation checkpointing (trading recomputation for memory) becomes necessary for training models that would otherwise exhaust accelerator memory. \index{Autograd Tape!memory cost}
|
||
|
||
Consider this example using PyTorch, which implements eager execution as its default mode. @lst-autograd-tape-example shows *how* operations are recorded as they execute.
|
||
|
||
::: {#lst-autograd-tape-example lst-cap="**Autograd Tape Construction**: Each operation executes immediately while recording a backward node to the autograd tape for later gradient computation."}
|
||
|
||
```{.python}
|
||
import torch
|
||
|
||
x = torch.tensor([1.0], requires_grad=True)
|
||
y = x * 2 # Executes immediately; records MulBackward node
|
||
z = y + 1 # Executes immediately; records AddBackward node
|
||
# The autograd tape exists NOW, built during execution
|
||
```
|
||
|
||
:::
|
||
|
||
After these two operations, the framework has constructed an autograd tape with two nodes: one for the multiplication and one for the addition. The tape records that `z` depends on `y`, and `y` depends on `x`.
|
||
|
||
Calling `z.backward()` traverses this tape in reverse topological order, applying the chain rule at each node:
|
||
|
||
1. Compute $\frac{\partial z}{\partial z} = 1$ (seed gradient)
|
||
2. Call `AddBackward0.backward()` $\rightarrow \frac{\partial z}{\partial y} = 1$
|
||
3. Call `MulBackward0.backward()` $\rightarrow \frac{\partial z}{\partial x} = 2$
|
||
4. Accumulate gradient in `x.grad`
|
||
|
||
After `backward()` completes, the autograd tape is **destroyed** to free memory. The next forward pass builds a completely new tape. This design enables memory-efficient training: the system pays for gradient computation storage only during the backward pass.
|
||
|
||
::: {.callout-war-story title="The Silent Gradient Killer"}
|
||
|
||
**The Context**: An ML engineer at Facebook AI Research (FAIR) implemented a custom activation function in PyTorch. To save memory, they used the in-place operation `x += 1` instead of `x = x + 1`.
|
||
|
||
**The Failure**: In-place operations modify the data directly in memory. However, the autograd tape (the computational graph) often needs the *original* value of `x` to compute gradients for previous layers. By overwriting `x`, the engineer destroyed the history needed for the chain rule.
|
||
|
||
**The Consequence**: The framework did not crash. Instead, it computed gradients using the *modified* value of `x`, resulting in mathematically incorrect updates. The model trained, but its loss plateaued at a high value. The team spent weeks debugging hyperparameters, never suspecting that a "memory optimization" had silently corrupted the calculus.
|
||
|
||
**The Systems Lesson**: Frameworks are graph construction engines, and in-place operations violate the immutability required for automatic differentiation. Writing `x += 1` does not merely add a number: it sabotages the graph's history [@paszke2019pytorch].
|
||
|
||
:::
|
||
|
||
Follow this "define-by-run" execution model step by step in @fig-mlfm-dynamic-graph-flow. Notice the alternating pattern: define, execute, define, execute. Each operation completes entirely before the next begins, which is why standard Python debuggers work---a developer can set a breakpoint between any two operations and inspect the actual tensor values. This contrasts sharply with static graphs, where all operations must be defined before any execution occurs.
|
||
|
||
::: {#fig-mlfm-dynamic-graph-flow fig-env="figure" fig-pos="htb" fig-cap="**Dynamic Graph Execution Flow**: In eager execution, each operation is defined and immediately executed before the next operation begins. This define-by-run model enables natural debugging and data-dependent control flow at the cost of optimization opportunities." fig-alt="Flow diagram showing Start to Operation 1 to Operation 1 Executed to Operation 2 to Operation 2 Executed to End. Above arrows show Define Operation, Execute Operation, Define Next Operation, Execute Operation, Repeat Until Done."}
|
||
|
||
```{.tikz}
|
||
\begin{tikzpicture}[line join=round,font=\usefont{T1}{phv}{m}{n}\small]
|
||
\tikzset{%
|
||
Line/.style={line width=0.75pt,black!50,text=black},
|
||
Box/.style={align=flush center,
|
||
inner xsep=2pt,
|
||
node distance=0.75,
|
||
draw=BlueLine,
|
||
line width=0.75pt,
|
||
fill=BlueL!30,
|
||
%text width=35mm,
|
||
minimum width=35mm, minimum height=11mm
|
||
},
|
||
decision/.style = {Box,diamond,text width=35mm,aspect=1.95, inner xsep=7pt,inner ysep=-2ex, fill=VioletL2!70,
|
||
draw=VioletLine},
|
||
startstop/.style = {Box,minimum width=25mm, rounded corners=10pt, fill=red!10, draw=RedLine},
|
||
LineA/.style={black!50,line width=1.5pt,{-{Triangle[width=1.0*5pt,length=1.0*5pt]}},shorten <=0pt,shorten >=0pt},
|
||
}
|
||
|
||
\tikzset{
|
||
pics/repeat/.style = {
|
||
code = {
|
||
\pgfkeys{/channel/.cd, #1}
|
||
\begin{scope}[shift={($(0,0)+(0,0)$)},scale=\scalefac,every node/.append style={transform shape}]
|
||
\def\w{4cm}
|
||
\def\h{15mm}
|
||
\def\r{6mm} % radius
|
||
\def\gap{4mm} % break lengths
|
||
\draw[\filllcirclecolor, -{Latex[length=10pt,width=15pt]},line width=\Linewidth]
|
||
(\w,\h-\r) -- (\w,\r)
|
||
arc[start angle=0, end angle=-90, radius=\r]
|
||
-- (\gap,0);
|
||
%
|
||
\draw[\filllcolor, -{Latex[length=10pt,width=15pt]},line width=\Linewidth]
|
||
(0,\r) -- (0,\h-\r)
|
||
arc[start angle=180, end angle=90, radius=\r]
|
||
-- ({\w-\gap},\h);
|
||
\end{scope}
|
||
}
|
||
}
|
||
}
|
||
\tikzset{
|
||
pics/interpreter/.style = {
|
||
code = {
|
||
\pgfkeys{/channel/.cd, #1}
|
||
\def\ra{}
|
||
\begin{scope}[shift={($(0,0)+(0,0)$)},scale=\scalefac,every node/.append style={transform shape}]
|
||
\node[red,font=\Large\bfseries]at(-0.75,0.6){\textless\,/\,\textgreater};
|
||
\draw[line cap=round,line join=round,yellow,line width=\Linewidth](-1.15,-0.1)--(-0.95,-0.1);
|
||
\draw[line cap=round,line join=round,red,line width=\Linewidth](-0.7,-0.1)--(0.1,-0.1);
|
||
|
||
\draw[line cap=round,line join=round,yellow,line width=\Linewidth](-1.15,-0.5)--(0,-0.5);
|
||
|
||
\draw[line cap=round,line join=round,yellow,line width=\Linewidth](-1.15,-0.9)--(-0.75,-0.9);
|
||
\draw[line cap=round,line join=round,red,line width=\Linewidth](-0.45,-0.9)--(0.45,-0.9);
|
||
\draw[line cap=round,line join=round,cyan,line width=\Linewidth](0.75,-0.9)--(1.1,-0.9);
|
||
|
||
\draw[line cap=round,line join=round,yellow,line width=\Linewidth](-1.15,-1.3)--(-1,-1.3);
|
||
\draw[line cap=round,line join=round,red,line width=\Linewidth](-0.65,-1.3)--(-0.10,-1.3);
|
||
\draw[line cap=round,line join=round,cyan,line width=\Linewidth](0.2,-1.3)--(1.1,-1.3);
|
||
\end{scope}
|
||
}
|
||
}
|
||
}
|
||
%CPU
|
||
\tikzset{%
|
||
pics/cpu/.style = {
|
||
code = {
|
||
\pgfkeys{/channel/.cd, #1}
|
||
\begin{scope}[local bounding box=FUNNEL,scale=\scalefac, every node/.append style={transform shape}]
|
||
\node[fill=\filllcolor,minimum width=66, minimum height=66,
|
||
rounded corners=2,outer sep=2pt] (C1) {};
|
||
\node[fill=white,minimum width=54, minimum height=54] (C2) {};
|
||
\node[fill=\filllcolor!40,minimum width=44, minimum height=44] (C3) {\Large\bfseries GPU};
|
||
|
||
\foreach \x/\y in {0.11/1,0.26/2,0.41/3,0.56/4,0.71/5,0.85/6}{
|
||
\node[fill=\filllcolor,minimum width=3, minimum height=15,
|
||
inner sep=0pt,anchor=south](GO\y)at($(C1.north west)!\x!(C1.north east)$){};
|
||
}
|
||
\foreach \x/\y in {0.11/1,0.26/2,0.41/3,0.56/4,0.71/5,0.85/6}{
|
||
\node[fill=\filllcolor,minimum width=3, minimum height=15,
|
||
inner sep=0pt,anchor=north](DO\y)at($(C1.south west)!\x!(C1.south east)$){};
|
||
}
|
||
\foreach \x/\y in {0.11/1,0.26/2,0.41/3,0.56/4,0.71/5,0.85/6}{
|
||
\node[fill=\filllcolor,minimum width=15, minimum height=3,
|
||
inner sep=0pt,anchor=east](LE\y)at($(C1.north west)!\x!(C1.south west)$){};
|
||
}
|
||
\foreach \x/\y in {0.11/1,0.26/2,0.41/3,0.56/4,0.71/5,0.85/6}{
|
||
\node[fill=\filllcolor,minimum width=15, minimum height=3,
|
||
inner sep=0pt,anchor=west](DE\y)at($(C1.north east)!\x!(C1.south east)$){};
|
||
}
|
||
\end{scope}
|
||
}
|
||
}
|
||
}
|
||
%start
|
||
\tikzset{%
|
||
pics/start/.style = {
|
||
code = {
|
||
\pgfkeys{/channel/.cd, #1}
|
||
\begin{scope}[local bounding box=START,scale=\scalefac, every node/.append style={transform shape}]
|
||
\node[fill=\filllcolor,minimum width=6mm, minimum height=6mm,
|
||
outer sep=2pt] (C1) {};
|
||
\node[isosceles triangle,isosceles triangle apex angle=45,xshift=-2.5pt,
|
||
inner sep=1pt, fill=white,minimum size =3.5mm] (T1)at (C1){};
|
||
\end{scope}
|
||
}
|
||
}
|
||
}
|
||
%check
|
||
\tikzset{pics/.cd,
|
||
checkmark/.style={code={
|
||
\pgfkeys{/channel/.cd, #1}
|
||
\pgfgettransformentries{\tmpxx}{\tmp}{\tmp}{\tmp}{\tmp}{\tmp}
|
||
\draw[line width=\tmpxx*1pt,draw=none,fill=\filllcirclecolor,line join=bevel] (0,.35) -- (.25,0) to[bend left=5] (0.8,.6) to[bend
|
||
right=5] (.25,.18) -- cycle;}}}
|
||
\tikzset{%
|
||
pics/checkI/.style = {
|
||
code = {
|
||
\pgfkeys{/channel/.cd, #1}
|
||
\begin{scope}[local bounding box=CHECK,scale=\scalefac, every node/.append style={transform shape}]
|
||
\node[fill=\filllcolor,minimum width=6mm, minimum height=6mm,
|
||
outer sep=2pt] (C1) {};
|
||
\pic[shift={(-0.27,-0.19)},scale=0.7]{checkmark};
|
||
\end{scope}
|
||
}
|
||
}
|
||
}
|
||
\pgfkeys{
|
||
/channel/.cd,
|
||
Depth/.store in=\Depth,
|
||
Height/.store in=\Height,
|
||
Width/.store in=\Width,
|
||
filllcirclecolor/.store in=\filllcirclecolor,
|
||
filllcolor/.store in=\filllcolor,
|
||
drawcolor/.store in=\drawcolor,
|
||
drawcircle/.store in=\drawcircle,
|
||
scalefac/.store in=\scalefac,
|
||
Linewidth/.store in=\Linewidth,
|
||
picname/.store in=\picname,
|
||
filllcolor=BrownLine,
|
||
filllcirclecolor=violet!20,
|
||
drawcolor=red,
|
||
drawcircle=violet,
|
||
scalefac=1,
|
||
Linewidth=0.5pt,
|
||
Depth=1.3,
|
||
Height=0.8,
|
||
Width=1.1,
|
||
picname=C
|
||
}
|
||
|
||
\node[startstop](B1){};
|
||
\coordinate(GO1)at($(B1.north west)!0.33!(B1.north east)$);
|
||
\coordinate(T1)at($(GO1)!0.5!(B1.south east)$);
|
||
\node[align=center]at(T1){Start};
|
||
\begin{scope}
|
||
\coordinate(I1)at($(GO1)!0.5!(B1.south west)$);
|
||
\clip (B1.south west) rectangle (GO1);
|
||
%\fill[RedLine!30,rounded corners=10pt] (B1.south west) rectangle (B1.north east);
|
||
\end{scope}
|
||
\draw[dashed,RedLine, line width=0.75pt,](GO1)--(GO1|-B1.south west);
|
||
\node[startstop,fill=none]{};
|
||
\pic[shift={(0,0)}] at (I1){start={scalefac=1,picname=1,filllcolor=GreenLine, Linewidth=0.7pt}};
|
||
%Define
|
||
\node[Box,right=of B1](B2){};
|
||
\node[above=1pt of B2,text=black!70]{Python Dispatch};
|
||
\coordinate(GO2)at($(B2.north west)!0.3!(B2.north east)$);
|
||
%\fill[fill=BlueL!90](B2.south west)rectangle(GO2);
|
||
\coordinate(T2)at($(GO2)!0.5!(B2.south east)$);
|
||
\coordinate(I2)at($(GO2)!0.5!(B2.south west)$);
|
||
\node[align=center]at(T2){Define\\Operation};
|
||
\draw[dashed,BlueLine, line width=0.75pt,](GO2)--(GO2|-B2.south west);
|
||
\node[Box,right=of B1,fill=none]{};
|
||
\pic[shift={(0.05,0.11)}] at (I2){interpreter={scalefac=0.35,picname=1,
|
||
filllcolor=cyan!30!, Linewidth=1.5pt,filllcirclecolor=orange}};
|
||
%Execute
|
||
\node[Box,right=of B2](B3){};
|
||
\node[above=1pt of B3,text=black!70]{GPU Kernel};
|
||
\coordinate(GO3)at($(B3.north west)!0.3!(B3.north east)$);
|
||
%\fill[fill=BlueL!90](B3.south west)rectangle(GO3);
|
||
\coordinate(T3)at($(GO3)!0.5!(B3.south east)$);
|
||
\coordinate(I3)at($(GO3)!0.5!(B3.south west)$);
|
||
\node[align=center]at(T3){Execute\\Operation};
|
||
\draw[dashed,BlueLine, line width=0.75pt,](GO3)--(GO3|-B3.south west);
|
||
\node[Box,right=of B2,fill=none](B3){};
|
||
\pic[shift={(0,0)}] at (I3){cpu={scalefac=0.23,picname=1,filllcolor=BrownLine, Linewidth=0.7pt}};
|
||
%More operations
|
||
\node[decision,right=of B3](B4){More\\operations?};
|
||
\path[red](B4.west)|-coordinate(SR4)(B4.south);
|
||
\pic[shift={(-0.30,-0.14)}] at (SR4){repeat={scalefac=0.3,picname=1,filllcolor=RedLine,
|
||
Linewidth=4pt,filllcirclecolor=GreenLine}};
|
||
%End
|
||
\node[startstop,right=of B4](B5){};
|
||
\coordinate(GO5)at($(B5.north west)!0.33!(B5.north east)$);
|
||
\coordinate(T5)at($(GO5)!0.5!(B5.south east)$);
|
||
\coordinate(I5)at($(GO5)!0.5!(B5.south west)$);
|
||
\node[align=center]at(T5){End};
|
||
\begin{scope}
|
||
\clip (B5.south west) rectangle (GO5);
|
||
%\fill[RedLine!30,rounded corners=10pt] (B5.south west) rectangle (B5.north east);
|
||
\end{scope}
|
||
\draw[dashed,RedLine, line width=0.75pt,](GO5)--(GO5|-B5.south west);
|
||
\node[startstop,right=of B4,fill=none](B5){};
|
||
\pic[shift={(0,0)}] at (I5){checkI={scalefac=1,picname=1,filllcolor=GreenLine, filllcirclecolor=white,Linewidth=0.7pt}};
|
||
%arrows
|
||
\foreach \i in {1,2,3,4}{
|
||
\pgfmathtruncatemacro{\x}{\i + 1}
|
||
\draw[LineA](B\i)--coordinate[pos=0.3](SR\i)(B\x);
|
||
}
|
||
\node[above=0pt of SR4]{No};
|
||
\draw[LineA](B4.south)--node[right,pos=0.5]{Yes}++(270:0.55)-|(B2);
|
||
\end{tikzpicture}
|
||
```
|
||
|
||
:::
|
||
|
||
##### Systems Implications: Flexibility {.unnumbered}
|
||
|
||
The dynamic autograd tape enables capabilities impossible with static graphs. Conditionals and loops can depend on tensor values computed during execution, enabling algorithms like beam search, dynamic RNN lengths, or adaptive computation that adjust their behavior based on intermediate results. Different iterations can process tensors of different sizes without redefining the computation---essential for natural language processing where sentence lengths vary. Because operations execute immediately in standard Python, developers can print tensors, inspect values, and use standard debuggers (`pdb`, breakpoints) to diagnose errors in the same way they would debug any Python program.
|
||
|
||
##### Systems Implications: Overhead {.unnumbered}
|
||
|
||
This flexibility comes with performance costs that map directly to the Iron Law (@sec-introduction-iron-law-ml-systems-c32a). Each forward pass rebuilds the autograd tape from scratch, adding Python object creation, reference counting, and node linking overhead to $L_{\text{lat}}$ on every iteration. Every operation goes through Python dispatch---function lookup, argument parsing, type checking---costing ~10μs per operation, which becomes significant for models with thousands of operations. Because the graph is built during execution, the framework cannot see across operations to fuse kernels, so each operation launches its own GPU kernel, inflating both $O$ and $D_{\text{vol}}$. The autograd tape itself stores references to all intermediate tensors and `Function` nodes, increasing memory consumption by 2--3$\times$ compared to forward-only execution and adding pressure to $D_{\text{vol}}$. Together, these costs create a performance ceiling that becomes visible as models grow smaller and dispatch overhead dominates computation.
|
||
|
||
For a typical ResNet-50 forward pass, eager execution overhead adds approximately 5--10 ms compared to an optimized compiled version, with the majority spent in Python dispatch and tape construction rather than actual computation.
|
||
|
||
```{python}
|
||
#| label: dispatch-tax-calc
|
||
#| echo: false
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ THE DISPATCH TAX CALCULATION
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: "The Dispatch Tax" section
|
||
# │
|
||
# │ Goal: Quantify the overhead of Python dispatch vs. GPU execution.
|
||
# │ Show: Why small models are overhead-bound (>90% tax) while large models are compute-bound (<10%).
|
||
# │ How: Compare a fixed 10μs Python dispatch latency to varying kernel durations.
|
||
# │
|
||
# │ Imports: mlsysim.book (fmt)
|
||
# │ Exports: python_overhead_str, small_kernel_str, large_kernel_str,
|
||
# │ small_tax_pct_str, large_tax_pct_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsysim.fmt import fmt, check
|
||
|
||
# ┌── LEGO ───────────────────────────────────────────────
|
||
class DispatchTax:
|
||
"""
|
||
Namespace for The Dispatch Tax calculation.
|
||
Scenario: Comparing overhead for small vs large operations.
|
||
"""
|
||
|
||
# ┌── 1. LOAD (Constants) ───────────────────────────────────────────────
|
||
python_overhead_us = 10.0 # Standard Python dispatch (μs)
|
||
|
||
# Kernel Durations (μs)
|
||
small_kernel_us = 1.0 # e.g. ReLU on 1024 elements
|
||
large_kernel_us = 100.0 # e.g. MatMul 1024x1024
|
||
|
||
# ┌── 2. EXECUTE (The Compute) ─────────────────────────────────────────
|
||
# Step 1: Dispatch Tax = Overhead / (Overhead + Execution)
|
||
small_tax_pct = (python_overhead_us / (python_overhead_us + small_kernel_us)) * 100
|
||
large_tax_pct = (python_overhead_us / (python_overhead_us + large_kernel_us)) * 100
|
||
|
||
# ┌── 3. GUARD (Invariants) ───────────────────────────────────────────
|
||
check(small_tax_pct > 90, f"Small op tax ({small_tax_pct:.1f}%) should be dominant.")
|
||
check(large_tax_pct < 15, f"Large op tax ({large_tax_pct:.1f}%) should be negligible.")
|
||
|
||
# ┌── 4. OUTPUT (Formatting) ──────────────────────────────────────────────
|
||
python_overhead_str = f"{int(python_overhead_us)}"
|
||
small_kernel_str = f"{int(small_kernel_us)}"
|
||
large_kernel_str = f"{int(large_kernel_us)}"
|
||
small_tax_pct_str = f"{int(small_tax_pct)}"
|
||
large_tax_pct_str = f"{int(large_tax_pct)}"
|
||
```
|
||
|
||
### The Dispatch Tax: Python Overhead vs. GPU Reality {#sec-ml-frameworks-dispatch-tax}
|
||
|
||
Eager execution's performance ceiling is driven by a fundamental systems mismatch: the speed of the host-side interpreter versus the speed of the device-side silicon. We quantify this using **The Dispatch Tax**\index{Dispatch Tax!framework overhead}, defined as the fraction of time spent in the host-side orchestration (Python) versus actual device execution (GPU).
|
||
|
||
Every operation in an eager framework (like standard PyTorch) must pay a fixed "Tax" of approximately **`{python} DispatchTax.python_overhead_str` $\mu$s** for Python to lookup the function, check tensor types, and launch the kernel.
|
||
|
||
* **For small operations** (e.g., a ReLU on a small vector), the kernel might execute in only **`{python} DispatchTax.small_kernel_str` $\mu$s**. The dispatch tax is **`{python} DispatchTax.small_tax_pct_str`%**, meaning the GPU spends the vast majority of its time waiting for the next command.
|
||
* **For large operations** (e.g., a massive $4096\times4096$ matrix multiply), the kernel executes for **`{python} DispatchTax.large_kernel_str` $\mu$s**. The dispatch tax drops to **`{python} DispatchTax.large_tax_pct_str`%**, and the system becomes compute-bound.
|
||
|
||
The dispatch tax explains why models with many small layers run significantly slower than their raw FLOP count predicts. To reach the "Titan" standard of efficiency, frameworks must move from **Kernel-by-Kernel Dispatch** to **Graph-Level Execution**, where the dispatch tax is paid once for the entire graph rather than per operation. The hybrid JIT and compilation strategies in @sec-ml-frameworks-hybrid-approaches-jit-compilation-8954 exist precisely to address this overhead.
|
||
|
||
The overhead costs of eager execution raise a natural question: what if we could see the entire computation *before* executing any of it? This is precisely what static computation graphs provide.
|
||
|
||
#### Static Computation Graphs {#sec-ml-frameworks-static-computation-graphs-e100}
|
||
|
||
\index{Static Graphs!define-then-run model}
|
||
\index{Computational Graph!static}
|
||
Static graph execution defines the complete computational graph as a symbolic representation first, then executes it separately. This "define-then-run" execution model means the graph exists **before** any computation occurs, enabling aggressive ahead-of-time optimization. The key insight is that if the framework sees the entire computation before running it, the framework can analyze, transform, and optimize the graph globally---a visibility impossible when operations execute immediately one at a time.
|
||
|
||
##### Two-Phase Execution {.unnumbered}
|
||
|
||
Static graphs implement a clear separation between graph construction and execution. @lst-tf-static-graph illustrates the two phases using TensorFlow 1.x, which pioneered this approach: symbolic definition creates placeholders and operations without computation, while explicit execution triggers actual arithmetic:
|
||
|
||
::: {#lst-tf-static-graph lst-cap="**Static Graph Two-Phase Execution**: Graph construction (symbolic definition) is separated from execution (actual computation), enabling ahead-of-time optimization."}
|
||
|
||
```{.python}
|
||
# Phase 1: Graph Construction (symbolic, no computation)
|
||
import tensorflow.compat.v1 as tf
|
||
|
||
tf.disable_v2_behavior()
|
||
|
||
# Define graph symbolically
|
||
x = tf.placeholder(tf.float32, shape=[1]) # Just a placeholder
|
||
y = x * 2 # Not executed, just recorded
|
||
z = y + 1 # Still no execution
|
||
# At this point, nothing has been computed
|
||
|
||
# Phase 2: Graph Execution (actual computation)
|
||
with tf.Session() as sess:
|
||
result = sess.run(z, feed_dict={x: [1.0]})
|
||
# Now computation happens: result = [3.0]
|
||
```
|
||
|
||
:::
|
||
|
||
Compare this with the dynamic model by examining @fig-mlfm-static-graph. Notice the clear boundary between phases: in the definition phase (left), the framework builds a complete blueprint without touching any data; in the execution phase (right), data flows through an already-optimized graph. This separation enables the framework to answer questions during the definition phase that are impossible to answer operation-by-operation: "Which intermediate tensors can share memory?" "Which operations can fuse into a single kernel?" "What is the total memory footprint?" By the time execution begins, these optimizations are already baked in.
|
||
|
||
::: {#fig-mlfm-static-graph fig-env="figure" fig-pos="htb" fig-cap="**Static Graph: Define then Execute.** The two phases of static graph execution. The definition phase (left) declares operations and builds the graph. The execution phase (right) loads data, runs the optimized graph, and produces results." fig-alt="Flow diagram showing two phases. Definition Phase: Define Operations, Declare Variables, Build Graph. Execution Phase: Load Data, Run Graph, Get Results. Arrows connect boxes left to right."}
|
||
|
||
```{.tikz}
|
||
\begin{tikzpicture}[line join=round,font=\usefont{T1}{phv}{m}{n}\small]
|
||
\tikzset{%
|
||
Box/.style={align=flush center,
|
||
inner xsep=2pt,
|
||
node distance=0.65,
|
||
draw=BlueLine,
|
||
line width=0.75pt,
|
||
fill=BlueL!30,
|
||
%text width=35mm,
|
||
minimum width=30mm, minimum height=14mm
|
||
},
|
||
Box2/.style={Box, draw=BrownLine, fill=BrownL!30,
|
||
},
|
||
LineA/.style={black!40,line width=6.5pt,{-{Triangle[width=1.0*12pt,length=1.0*5pt]}},shorten <=0pt,shorten >=0pt},
|
||
}
|
||
|
||
\tikzset{
|
||
pics/interpreter/.style = {
|
||
code = {
|
||
\pgfkeys{/channel/.cd, #1}
|
||
\def\ra{}
|
||
\begin{scope}[shift={($(0,0)+(0,0)$)},scale=\scalefac,every node/.append style={transform shape}]
|
||
\node[red,font=\Large\bfseries]at(-0.75,0.6){\textless\,/\,\textgreater};
|
||
\draw[line cap=round,line join=round,violet,line width=\Linewidth](-1.15,-0.1)--(-0.95,-0.1);
|
||
\draw[line cap=round,line join=round,red,line width=\Linewidth](-0.7,-0.1)--(0.1,-0.1);
|
||
|
||
\draw[line cap=round,line join=round,violet,line width=\Linewidth](-1.15,-0.5)--(0,-0.5);
|
||
|
||
\draw[line cap=round,line join=round,violet,line width=\Linewidth](-1.15,-0.9)--(-0.75,-0.9);
|
||
\draw[line cap=round,line join=round,red,line width=\Linewidth](-0.45,-0.9)--(0.45,-0.9);
|
||
\draw[line cap=round,line join=round,cyan,line width=\Linewidth](0.75,-0.9)--(1.1,-0.9);
|
||
|
||
\draw[line cap=round,line join=round,violet,line width=\Linewidth](-1.15,-1.3)--(-1,-1.3);
|
||
\draw[line cap=round,line join=round,red,line width=\Linewidth](-0.65,-1.3)--(-0.10,-1.3);
|
||
\draw[line cap=round,line join=round,cyan,line width=\Linewidth](0.2,-1.3)--(1.1,-1.3);
|
||
\end{scope}
|
||
}
|
||
}
|
||
}
|
||
%CPU
|
||
\tikzset{%
|
||
pics/cpu/.style = {
|
||
code = {
|
||
\pgfkeys{/channel/.cd, #1}
|
||
\begin{scope}[local bounding box=FUNNEL,scale=\scalefac, every node/.append style={transform shape}]
|
||
\node[fill=\filllcolor,minimum width=66, minimum height=66,
|
||
rounded corners=2,outer sep=2pt] (C1) {};
|
||
\node[fill=white,minimum width=54, minimum height=54] (C2) {};
|
||
\node[fill=\filllcolor!40,minimum width=44, minimum height=44] (C3) {\Large\bfseries GPU};
|
||
|
||
\foreach \x/\y in {0.11/1,0.26/2,0.41/3,0.56/4,0.71/5,0.85/6}{
|
||
\node[fill=\filllcolor,minimum width=3, minimum height=15,
|
||
inner sep=0pt,anchor=south](GO\y)at($(C1.north west)!\x!(C1.north east)$){};
|
||
}
|
||
\foreach \x/\y in {0.11/1,0.26/2,0.41/3,0.56/4,0.71/5,0.85/6}{
|
||
\node[fill=\filllcolor,minimum width=3, minimum height=15,
|
||
inner sep=0pt,anchor=north](DO\y)at($(C1.south west)!\x!(C1.south east)$){};
|
||
}
|
||
\foreach \x/\y in {0.11/1,0.26/2,0.41/3,0.56/4,0.71/5,0.85/6}{
|
||
\node[fill=\filllcolor,minimum width=15, minimum height=3,
|
||
inner sep=0pt,anchor=east](LE\y)at($(C1.north west)!\x!(C1.south west)$){};
|
||
}
|
||
\foreach \x/\y in {0.11/1,0.26/2,0.41/3,0.56/4,0.71/5,0.85/6}{
|
||
\node[fill=\filllcolor,minimum width=15, minimum height=3,
|
||
inner sep=0pt,anchor=west](DE\y)at($(C1.north east)!\x!(C1.south east)$){};
|
||
}
|
||
\end{scope}
|
||
}
|
||
}
|
||
}
|
||
%graph style
|
||
\tikzset{
|
||
pics/graph3D/.style = {
|
||
code = {
|
||
\pgfkeys{/channel/.cd, #1}
|
||
\begin{scope}[local bounding box=GRAPH,scale=1, every node/.append style={transform shape}]
|
||
\def\dx{\Width}
|
||
\def\dy{\Height}
|
||
\def\dz{\Depth}
|
||
% koordinata donjeg levog ugla (početak bara)
|
||
\def\x{0}
|
||
\def\y{0.15}
|
||
\def\z{0}
|
||
% boje
|
||
\draw[draw=\filllcirclecolor,line width=1pt](-0.2,0)--(1.3,0);
|
||
\draw[draw=\filllcirclecolor,line width=1pt](-0.2,0)--(-0.2,1.2);
|
||
\filldraw[fill=\filllcolor!10, draw=\drawcolor] (\x,\y+\dy,\z) -- (\x,\y+\dy,\z+\dz) -- (\x+\dx,\y+\dy,\z+\dz) -- (\x+\dx,\y+\dy,\z) -- cycle; % gornja strana
|
||
\filldraw[fill=\filllcolor!50, draw=\drawcolor] (\x+\dx,\y,\z) -- (\x+\dx,\y,\z+\dz) -- (\x+\dx,\y+\dy,\z+\dz) -- (\x+\dx,\y+\dy,\z) -- cycle; % desna strana
|
||
\filldraw[fill=\filllcolor!60, draw=\drawcolor] (\x,\y,\z+\dz) -- (\x+\dx,\y,\z+\dz) -- (\x+\dx,\y+\dy,\z+\dz) -- (\x,\y+\dy,\z+\dz) -- cycle; % prednja strana
|
||
\end{scope}
|
||
}
|
||
}
|
||
}
|
||
\tikzset{mycylinder/.style={cylinder, shape border rotate=90, aspect=1.3, draw, fill=white,
|
||
minimum width=25mm,minimum height=11mm,line width=\Linewidth,node distance=-0.15},
|
||
pics/data/.style = {
|
||
code = {
|
||
\pgfkeys{/channel/.cd, #1}
|
||
\begin{scope}[local bounding box=STREAMING,scale=\scalefac, every node/.append style={transform shape}]
|
||
\node[mycylinder,fill=\filllcolor!50] (A) {};
|
||
\node[mycylinder, above=of A,fill=\filllcolor!30] (B) {};
|
||
\node[mycylinder, above=of B,fill=\filllcolor!10] (C) {};
|
||
\fill[\filllcolor!50!black]($(C.west)!0.12!(C.east)$)circle(3pt);
|
||
\fill[\filllcolor!50!black]($(B.west)!0.12!(B.east)$)circle(3pt);
|
||
\fill[\filllcolor!50!black]($(A.west)!0.12!(A.east)$)circle(3pt);
|
||
\end{scope}
|
||
}
|
||
}
|
||
}
|
||
\tikzset{pics/brain/.style = {
|
||
code = {
|
||
\pgfkeys{/channel/.cd, #1}
|
||
\begin{scope}[local bounding box=BRAIN,scale=\scalefac, every node/.append style={transform shape}]
|
||
\fill[fill=\filllcolor!50](0.1,-0.5)to[out=0,in=180](0.33,-0.5)
|
||
to[out=0,in=270](0.45,-0.38)to(0.45,-0.18)
|
||
to[out=40,in=240](0.57,-0.13)to[out=110,in=310](0.52,-0.05)
|
||
to[out=130,in=290](0.44,0.15)to[out=90,in=340,distance=8](0.08,0.69)
|
||
to[out=160,in=80](-0.42,-0.15)to (-0.48,-0.7)to(0.07,-0.7)to(0.1,-0.5)
|
||
(-0.10,-0.42)to[out=310,in=180](0.1,-0.5);
|
||
\draw[draw=\drawcolor,line width=\Linewidth](0.1,-0.5)to[out=0,in=180](0.33,-0.5)
|
||
to[out=0,in=270](0.45,-0.38)to(0.45,-0.18)
|
||
to[out=40,in=240](0.57,-0.13)to[out=110,in=310](0.52,-0.05)
|
||
to[out=130,in=290](0.44,0.15)to[out=90,in=340,distance=8](0.08,0.69)
|
||
(-0.42,-0.15)to (-0.48,-0.7)
|
||
(0.07,-0.7)to(0.1,-0.5)
|
||
(-0.10,-0.42)to[out=310,in=180](0.1,-0.5);
|
||
\draw[fill=\filllcolor,line width=\Linewidth](-0.3,-0.10)to(0.08,0.60)
|
||
to[out=60,in=50,distance=3](-0.1,0.69)to[out=160,in=80](-0.26,0.59)to[out=170,in=90](-0.46,0.42)
|
||
to[out=170,in=110](-0.54,0.25)to[out=210,in=150](-0.54,0.04)
|
||
to[out=240,in=130](-0.52,-0.1)to[out=300,in=240]cycle;
|
||
\draw[fill=\filllcolor,line width=\Linewidth]
|
||
(-0.04,0.64)to[out=120,in=0](-0.1,0.69)(-0.19,0.52)to[out=120,in=330](-0.26,0.59)
|
||
(-0.4,0.33)to[out=150,in=280](-0.46,0.42)
|
||
%
|
||
(-0.44,-0.03)to[bend left=30](-0.34,-0.04)
|
||
(-0.33,0.08)to[bend left=40](-0.37,0.2) (-0.37,0.12)to[bend left=40](-0.45,0.14)
|
||
(-0.26,0.2)to[bend left=30](-0.24,0.13)
|
||
(-0.16,0.32)to[bend right=30](-0.27,0.3)to[bend right=30](-0.29,0.38)
|
||
(-0.13,0.49)to[bend left=30](-0.04,0.51);
|
||
|
||
\draw[rounded corners=0.8pt,\drawcircle,-{Circle[fill=\filllcirclecolor,length=2.5pt]}](-0.23,0.03)--(-0.15,-0.03)--(-0.19,-0.18)--(-0.04,-0.28);
|
||
\draw[rounded corners=0.8pt,\drawcircle,-{Circle[fill=\filllcirclecolor,length=2.5pt]}](-0.17,0.13)--(-0.04,0.05)--(-0.06,-0.06)--(0.14,-0.11);
|
||
\draw[rounded corners=0.8pt,\drawcircle,-{Circle[fill=\filllcirclecolor,length=2.5pt]}](-0.12,0.23)--(0.31,0.0);
|
||
\draw[rounded corners=0.8pt,\drawcircle,-{Circle[fill=\filllcirclecolor,length=2.5pt]}](-0.07,0.32)--(0.06,0.26)--(0.16,0.33)--(0.34,0.2);
|
||
\draw[rounded corners=0.8pt,\drawcircle,-{Circle[fill=\filllcirclecolor,length=2.5pt]}](-0.01,0.43)--(0.06,0.39)--(0.18,0.51)--(0.31,0.4);
|
||
\end{scope}
|
||
}
|
||
}
|
||
}
|
||
\tikzset{pics/brainMEM/.style = {
|
||
code = {
|
||
\pgfkeys{/channel/.cd, #1}
|
||
\begin{scope}[local bounding box=BRAIN,scale=\scalefac, every node/.append style={transform shape}]
|
||
\fill[fill=\filllcolor!50](0.1,-0.5)to[out=0,in=180](0.33,-0.5)%
|
||
to[out=0,in=270](0.45,-0.38)to(0.45,-0.18)
|
||
to[out=40,in=240](0.57,-0.13)to[out=110,in=310](0.52,-0.05)
|
||
to[out=130,in=290](0.44,0.15)to[out=90,in=340,distance=8](0.08,0.69)
|
||
to[out=160,in=80](-0.42,-0.15)to (-0.48,-0.7)to(0.07,-0.7)to(0.1,-0.5)
|
||
(-0.10,-0.42)to[out=310,in=180](0.1,-0.5);
|
||
\draw[draw=\drawcolor,line width=\Linewidth](0.1,-0.5)to[out=0,in=180](0.33,-0.5)
|
||
to[out=0,in=270](0.45,-0.38)to(0.45,-0.18)
|
||
to[out=40,in=240](0.57,-0.13)to[out=110,in=310](0.52,-0.05)
|
||
to[out=130,in=290](0.44,0.15)to[out=90,in=340,distance=8](0.08,0.69)
|
||
(-0.42,-0.15)to (-0.48,-0.7)
|
||
(0.07,-0.7)to(0.1,-0.5)
|
||
(-0.10,-0.42)to[out=310,in=180](0.1,-0.5);
|
||
%
|
||
\node[draw=\drawcolor,fill=\filllcirclecolor,line width=\Linewidth,rounded corners=3pt,minimum size=8mm](MR)at(-0.350,0.31){};
|
||
\node[draw=\drawcolor,fill=white,line width=1.5*\Linewidth,circle,inner sep=1pt,minimum size=2mm]
|
||
at($(MR.south)+(0,0.25)$){};
|
||
\node[fill=\filllcolor!50!blue!50,draw=\drawcolor,line width=\Linewidth,rectangle,anchor=north,
|
||
minimum width=5mm](MMR)at(MR.north){};
|
||
\draw[draw=\drawcolor,line width=\Linewidth](MMR.120)--(MMR.240);
|
||
\node[fill=black,minimum size=0.9mm,inner sep=1pt]at($(MMR.west)!0.75!(MMR.east)$){};
|
||
\end{scope}
|
||
}
|
||
}
|
||
}
|
||
\pgfkeys{
|
||
/channel/.cd,
|
||
Depth/.store in=\Depth,
|
||
Height/.store in=\Height,
|
||
Width/.store in=\Width,
|
||
filllcirclecolor/.store in=\filllcirclecolor,
|
||
filllcolor/.store in=\filllcolor,
|
||
drawcolor/.store in=\drawcolor,
|
||
drawcircle/.store in=\drawcircle,
|
||
scalefac/.store in=\scalefac,
|
||
Linewidth/.store in=\Linewidth,
|
||
picname/.store in=\picname,
|
||
filllcolor=BrownLine,
|
||
filllcirclecolor=violet!20,
|
||
drawcolor=red,
|
||
drawcircle=violet,
|
||
scalefac=1,
|
||
Linewidth=0.5pt,
|
||
Depth=0.2,
|
||
Height=0.5,
|
||
Width=0.25,
|
||
picname=C
|
||
}
|
||
|
||
\node[Box](B1){};
|
||
\coordinate(GO1)at($(B1.north west)!0.38!(B1.north east)$);
|
||
\coordinate(T1)at($(GO1)!0.5!(B1.south east)$);
|
||
\coordinate(I1)at($(B1.west)!0.21!(B1.east)$);
|
||
\node[align=center]at(T1){Define\\ Operations};
|
||
%\draw[dashed,RedLine, line width=0.75pt,](GO1)--(GO1|-B1.south west);
|
||
\node[Box,fill=none]{};
|
||
\pic[shift={(0.05,0.11)}] at (I1){interpreter={scalefac=0.38,picname=1,
|
||
filllcolor=cyan!30!, Linewidth=1.5pt,filllcirclecolor=orange}};
|
||
%Declare Variables
|
||
\node[Box,right=of B1](B2){};
|
||
\coordinate(GO2)at($(B2.north west)!0.4!(B2.north east)$);
|
||
%\fill[fill=BlueL!90](B2.south west)rectangle(GO2);
|
||
\coordinate(T2)at($(GO2)!0.5!(B2.south east)$);
|
||
\coordinate(I2)at($(B2.west)!0.24!(B2.east)$);
|
||
\node[align=center]at(T2){Declare\\ Variables};
|
||
%\draw[dashed,BlueLine, line width=0.75pt,](GO2)--(GO2|-B2.south west);
|
||
\node[Box,right=of B1,fill=none]{};
|
||
\pic[shift={(0,0)}] at (I2){brainMEM={scalefac=0.68,drawcolor=black,
|
||
filllcirclecolor=green!70!black!30,picname=1,filllcolor=orange!30!, Linewidth=0.75pt}};
|
||
%Build Graph
|
||
\node[Box,right=of B2](B3){};
|
||
\coordinate(GO3)at($(B3.north west)!0.4!(B3.north east)$);
|
||
%\fill[fill=BlueL!90](B3.south west)rectangle(GO3);
|
||
\coordinate(T3)at($(GO3)!0.5!(B3.south east)$);
|
||
\coordinate(I3)at($(B3.west)!0.22!(B3.east)$);
|
||
\node[align=center]at(T3){Build\\ Graph};
|
||
%\draw[dashed,BlueLine, line width=0.75pt,](GO3)--(GO3|-B3.south west);
|
||
\node[Box,right=of B2,fill=none](B3){};
|
||
\pic[shift={(0,0)}] at (I3){brain={scalefac=0.65,picname=1,drawcolor=black,filllcolor=orange!30!, Linewidth=0.65pt}};
|
||
%Load Data
|
||
\node[Box2,right=1.5 of B3](B4){};
|
||
\coordinate(GO4)at($(B4.north west)!0.34!(B4.north east)$);
|
||
%\fill[fill=BlueL!90](B3.south west)rectangle(GO3);
|
||
\coordinate(T4)at($(GO4)!0.5!(B4.south east)$);
|
||
\coordinate(I4)at($(B4.west)!0.22!(B4.east)$);
|
||
\node[align=center]at(T4){Load\\ Data};
|
||
%\draw[dashed,BlueLine, line width=0.75pt,](GO4)--(GO4|-B4.south west);
|
||
\node[Box2,right=1.5 of B3,fill=none](B4){};
|
||
\pic[shift={(0,-0.33)}] at (I4){data={scalefac=0.3,picname=1,filllcolor=red, Linewidth=0.6pt}};
|
||
%Run Graph
|
||
\node[Box2,right=of B4](B5){};
|
||
\coordinate(GO5)at($(B5.north west)!0.34!(B5.north east)$);
|
||
\coordinate(T5)at($(GO5)!0.5!(B5.south east)$);
|
||
\coordinate(I5)at($(B5.west)!0.24!(B5.east)$);
|
||
\node[align=center]at(T5){Run\\ Graph};
|
||
%\draw[dashed,RedLine, line width=0.75pt,](GO5)--(GO5|-B5.south west);
|
||
\node[Box2,right=of B4,fill=none](B5){};
|
||
\pic[shift={(0,0)}] at (I5){cpu={scalefac=0.3,picname=1,filllcolor=VioletLine, Linewidth=0.7pt}};
|
||
%Get Results
|
||
\node[Box2,right=of B5](B6){};
|
||
\coordinate(GO6)at($(B6.north west)!0.4!(B6.north east)$);
|
||
\coordinate(T6)at($(GO6)!0.5!(B6.south east)$);
|
||
\coordinate(I6)at($(B6.west)!0.22!(B6.east)$);
|
||
\node[align=center]at(T6){Get\\ Results};
|
||
%\draw[dashed,RedLine, line width=0.75pt,](GO6)--(GO6|-B6.south west);
|
||
\node[Box2,right=of B5,fill=none](B6){};
|
||
\begin{scope}[shift={($(I6)+(-0.25,-0.4)$)},scale=0.7, every node/.append style={transform shape}]
|
||
\pic[shift={(0,0)}] at (0,0){graph3D={filllcirclecolor=black!60,scalefac=0.5,picname=1,drawcolor=black,filllcolor=red,Height=0.5,Linewidth=1.25pt}};
|
||
\pic[shift={(0.33,0)}] at (0,0){graph3D={filllcirclecolor=none,scalefac=0.5,picname=2,drawcolor=black,filllcolor=blue,Height=1,Linewidth=1.25pt}};
|
||
\pic[shift={(0.66,0)}] at (0,0){graph3D={filllcirclecolor=none,scalefac=0.5,picname=3,drawcolor=black,filllcolor=green,Height=0.25,Linewidth=1.25pt}};
|
||
\pic[shift={(0.99,0)}] at (0,0){graph3D={filllcirclecolor=none,scalefac=0.5,picname=4,drawcolor=black,filllcolor=orange,Height=0.75,Linewidth=1.25pt}};
|
||
\end{scope}
|
||
|
||
%arrows
|
||
\foreach \i in {1,2,3,4,5}{
|
||
\pgfmathtruncatemacro{\x}{\i + 1}
|
||
\draw[LineA](B\i)--coordinate[pos=0.3](SR\i)(B\x);
|
||
}
|
||
\begin{scope}[on background layer]
|
||
\node[draw=orange,fit=(B1)(B3),inner xsep=4mm,inner ysep=7mm,yshift=3mm,
|
||
fill=orange!05](F1){};
|
||
\node[font=\usefont{T1}{phv}{b}{n}\small,below=1pt of F1.north]{Definition Phase};
|
||
\node[draw=GreenLine,fit=(B4)(B6),inner xsep=4mm,inner ysep=7mm,yshift=3mm,
|
||
fill=green!03](F2){};
|
||
\node[font=\usefont{T1}{phv}{b}{n}\small,below=1pt of F2.north]{Execution Phase};
|
||
\end{scope}
|
||
\end{tikzpicture}
|
||
|
||
```
|
||
|
||
:::
|
||
|
||
The key difference from eager execution is that during construction, `x`, `y`, and `z` are not tensors containing values but rather symbolic nodes in a graph. Operations like `*` and `+` add nodes to the graph definition without performing any arithmetic. The `print(y)` line in the code example would reveal this distinction---it would print tensor metadata, not a computed value. Execution is triggered explicitly through `sess.run()`, at which point the framework analyzes the complete graph, optimizes it, and executes the optimized version with the provided input data.
|
||
|
||
##### Ahead-of-Time Optimization {.unnumbered}
|
||
|
||
\index{Ahead-of-Time (AOT) Optimization}
|
||
Because the framework has the complete graph before execution, it can perform optimizations impossible in eager mode. The kernel fusion\index{Kernel Fusion!static graph optimization} opportunity introduced in @sec-ml-frameworks-execution-strategy-matters-memory-wall-1ce8 becomes actionable here: because the framework sees `y = x * 2` and `z = y + 1` together in the graph, it can fuse them into `z = x * 2 + 1`, eliminating the intermediate `y` and halving memory traffic. With the full graph visible, the compiler can also calculate exact memory requirements for all tensors before execution, pre-allocating memory in a single pass and reusing buffers where lifetimes do not overlap. Tensor layouts can be transformed globally (e.g., NCHW to NHWC) to match hardware preferences without runtime copying. Dead code elimination[^fn-dce-graph-optimization]\index{Dead Code Elimination} removes operations whose results are never consumed, and constant folding\index{Constant Folding} pre-computes operations on constant values at graph construction time, so the cost is paid once rather than on every forward pass.
|
||
|
||
[^fn-dce-graph-optimization]: **Dead Code Elimination (DCE)**: Removes graph nodes whose results are never consumed by any downstream operation. In ML graphs, dead code arises from debugging operations left in production (print nodes, assertions), unused conditional branches, and gradient computations for frozen layers. For large transformer models, DCE eliminates 5--15% of graph nodes, reducing both $O$ (fewer operations) and $L_{\text{lat}}$ (fewer kernel launches). The DAG structure makes this safe: the framework verifies no downstream node depends on a candidate before removing it. \index{Dead Code Elimination!graph optimization}
|
||
|
||
These optimizations map directly to **Iron Law** terms: kernel fusion reduces $D_{\text{vol}}$ by eliminating intermediate memory writes, constant folding reduces $O$ by computing values once, memory pre-allocation reduces $L_{\text{lat}}$ by avoiding runtime allocation overhead, and dead code elimination reduces both $O$ and $D_{\text{vol}}$. Concretely, in large Transformer models, constant folding and dead code elimination can reduce total FLOPs by `{python} GraphOptimizationStats.flop_reduction_range_str` before the first batch even arrives.
|
||
|
||
\index{XLA (Accelerated Linear Algebra)!definition}
|
||
Compilation frameworks like XLA (Accelerated Linear Algebra)\index{XLA (Accelerated Linear Algebra)!graph compilation}[^fn-xla-compiler] [@GoogleXLA] take this further, compiling the TensorFlow graph to optimized machine code for specific hardware. For a transformer encoder block, XLA can achieve 1.5--2$\times$ speedup over unoptimized execution through aggressive fusion and hardware-specific code generation.
|
||
|
||
##### Systems Implications {.unnumbered}
|
||
|
||
\index{Static Graphs!optimization benefits}
|
||
Static graphs achieve high performance through ahead-of-time optimization. Kernel fusion reduces memory bandwidth requirements (often the bottleneck for ML workloads), and hardware-specific compilation enables near-peak utilization.
|
||
|
||
The cost of this performance is reduced flexibility. Standard Python control flow (`if`, `for`) cannot depend on computed tensor values in static graphs. TensorFlow provides graph-level control flow primitives (`tf.cond` and `tf.while_loop`) that support data-dependent conditions, but these require special syntax that diverges from standard Python, making code harder to write and reason about. Debugging is difficult because stack traces point to graph construction code, not execution code. Error messages often reference symbolic node names rather than the actual operations that failed.
|
||
|
||
#### Hybrid Approaches: JIT Compilation {#sec-ml-frameworks-hybrid-approaches-jit-compilation-8954}
|
||
|
||
\index{JIT Compilation!fidelity vs. generality}
|
||
Can we have both eager debugging and graph optimization? JIT compilation attempts this by capturing computation at runtime. The core trade-off is *fidelity versus generality*. Tracing captures the exact execution path taken during a sample run, producing high fidelity to that specific input but missing branches not taken. Source-level compilation (scripting) analyzes the full program structure, preserving all control flow branches but requiring a restricted language subset. Both approaches produce an intermediate representation (IR)[^fn-ir-compilation] that enables the same ahead-of-time optimizations available to static graphs: operator fusion, constant folding, dead code elimination, and buffer reuse.
|
||
|
||
[^fn-ir-compilation]: **Intermediate Representation (IR)**: The "intermediate" captures this format's architectural role: a language-independent layer that decouples the frontend (Python capture) from the backend (hardware code generation), exactly as LLVM IR decouples C/Rust/Swift frontends from x86/ARM backends. ML frameworks adopted this compiler pattern because it reduces the $O(M \times N)$ cost of supporting $M$ frontends and $N$ backends to $O(M + N)$: a single graph capture mechanism (TorchDynamo, tf2xla) can target multiple hardware backends without rewriting the capture logic. \index{Intermediate Representation!compiler pattern}
|
||
|
||
The eager-versus-compiled trade-off has a direct **Iron Law** consequence. JIT compilation amortizes the $L_{\text{lat}}$ (dispatch overhead) across the compiled region. Longer compiled regions mean more overhead amortized per operation, which explains why graph breaks are performance-critical: each break forces a return to eager dispatch, resetting the amortization.
|
||
|
||
PyTorch's TorchScript exemplifies both strategies. Tracing\index{JIT Compilation!tracing} executes a function once with example inputs and records every tensor operation into a static computation graph. @lst-torchscript-trace demonstrates the approach: the traced module becomes a compiled artifact that can be serialized, optimized, and executed independently of the Python interpreter:
|
||
|
||
::: {#lst-torchscript-trace lst-cap="**TorchScript Tracing**: Captures tensor operations by executing a function with example inputs and recording the execution path into a static computation graph."}
|
||
|
||
```{.python}
|
||
import torch
|
||
|
||
|
||
def forward(x):
|
||
y = x * 2
|
||
z = y + 1
|
||
return z
|
||
|
||
|
||
# Trace the function by running it once
|
||
x_example = torch.tensor([1.0])
|
||
traced = torch.jit.trace(forward, x_example)
|
||
|
||
# traced is now a compiled TorchScript module
|
||
# Can serialize: torch.jit.save(traced, "model.pt")
|
||
# Can optimize: fusion, constant folding
|
||
# Can run without Python interpreter
|
||
```
|
||
|
||
:::
|
||
|
||
The critical limitation of tracing reveals the fidelity-generality trade-off concretely. Because tracing records a single execution path, it cannot handle data-dependent control flow. @lst-tracing-silent-failure illustrates a silent correctness failure.
|
||
|
||
::: {#lst-tracing-silent-failure lst-cap="**Tracing Silent Failure**: Tracing records only the execution path taken by the example input, silently ignoring all other branches of data-dependent control flow."}
|
||
|
||
```{.python}
|
||
def conditional_forward(x):
|
||
if x.sum() > 0: # Data-dependent condition
|
||
return x * 2
|
||
else:
|
||
return x * 3
|
||
|
||
|
||
traced = torch.jit.trace(conditional_forward, torch.tensor([1.0]))
|
||
# Tracing captures ONLY the x.sum() > 0 branch
|
||
# If input later has sum <= 0, traced version
|
||
# still executes x * 2 branch
|
||
```
|
||
|
||
:::
|
||
|
||
Tracing records whichever branch executed during the example input. Subsequent executions always follow the traced path regardless of input values, silently producing incorrect results for inputs that would have taken the other branch. This failure mode is particularly dangerous because it produces no error, only wrong outputs. In production, such bugs can persist for months before anyone notices that a small fraction of inputs are being misclassified---and by then, debugging is a forensic exercise.
|
||
|
||
The alternative, scripting\index{JIT Compilation!scripting}, achieves generality by analyzing Python source code directly and compiling it to TorchScript IR without executing. The scripting compiler parses the abstract syntax tree (AST), converts supported operations to IR operations, and preserves the branching structure so that both branches of a conditional exist in the compiled representation. The cost of this generality is a restricted Python subset: type annotations are required where inference fails, arbitrary Python objects and standard library modules are excluded, and dynamic metaprogramming is forbidden.
|
||
|
||
Tracing suits feed-forward models without conditionals (ResNet, VGG, Vision Transformer) and models where control flow depends only on hyperparameters fixed at trace time. Scripting suits models with data-dependent control flow (RNN variants, recursive networks, adaptive computation) and deployment to environments without a Python interpreter. The following examples demonstrate scripting syntax (@lst-torchscript-script), control flow preservation (@lst-torchscript-conditional), language restrictions (@lst-torchscript-restrictions), and IR inspection (@lst-torchscript-ir).
|
||
|
||
::: {#lst-torchscript-script lst-cap="**TorchScript Scripting**: Compiles Python source code directly to TorchScript IR by parsing the AST, preserving control flow structure without requiring example inputs."}
|
||
|
||
```{.python}
|
||
@torch.jit.script
|
||
def forward(x):
|
||
y = x * 2
|
||
z = y + 1
|
||
return z
|
||
|
||
|
||
# Compiles Python source code to TorchScript IR
|
||
# No example inputs needed
|
||
# Preserves control flow structure
|
||
```
|
||
|
||
:::
|
||
|
||
The key advantage of scripting appears when handling conditionals. Unlike tracing, which captures only one branch, scripting preserves both paths in the IR.
|
||
|
||
::: {#lst-torchscript-conditional lst-cap="**Scripted Control Flow**: Unlike tracing, scripting preserves both branches of conditionals in the IR, enabling correct execution based on runtime input values."}
|
||
|
||
```{.python}
|
||
@torch.jit.script
|
||
def conditional_forward(x: torch.Tensor) -> torch.Tensor:
|
||
if x.sum() > 0:
|
||
return x * 2
|
||
else:
|
||
return x * 3
|
||
|
||
|
||
# Both branches preserved in IR
|
||
# Correct branch executes based on runtime input values
|
||
```
|
||
|
||
:::
|
||
|
||
To understand what the compiler produces, we can inspect the generated intermediate representation directly.
|
||
|
||
::: {#lst-torchscript-ir lst-cap="**TorchScript IR Inspection**: The generated intermediate representation shows primitive operations and constants, useful for debugging and understanding compilation results."}
|
||
|
||
```{.python}
|
||
@torch.jit.script
|
||
def example(x: torch.Tensor) -> torch.Tensor:
|
||
return x * 2 + 1
|
||
|
||
|
||
# Inspect generated IR:
|
||
print(example.graph)
|
||
# graph(%x : Tensor):
|
||
# %1 : int = prim::Constant[value=2]()
|
||
# %2 : Tensor = aten::mul(%x, %1)
|
||
# %3 : int = prim::Constant[value=1]()
|
||
# %4 : Tensor = aten::add(%2, %3, %3)
|
||
# return (%4)
|
||
```
|
||
|
||
:::
|
||
|
||
However, scripting imposes constraints on what Python constructs are supported.
|
||
|
||
::: {#lst-torchscript-restrictions lst-cap="**TorchScript Restrictions**: Scripting requires a restricted Python subset. Common unsupported features include arbitrary imports, NumPy operations, and f-strings."}
|
||
|
||
```{.python}
|
||
@torch.jit.script
|
||
def invalid_script(x):
|
||
import numpy as np # ERROR: Cannot import arbitrary modules
|
||
|
||
result = np.array([1, 2, 3]) # ERROR: NumPy not supported
|
||
print(f"Debug: {x}") # ERROR: f-strings not supported
|
||
return result
|
||
|
||
|
||
# Valid alternative:
|
||
@torch.jit.script
|
||
def valid_script(x: torch.Tensor) -> torch.Tensor:
|
||
# Use TorchScript-compatible operations
|
||
result = torch.tensor([1, 2, 3], dtype=x.dtype, device=x.device)
|
||
return result
|
||
```
|
||
|
||
:::
|
||
|
||
Scripting requires a restricted Python subset because TorchScript must statically analyze code that Python normally interprets dynamically. Function signatures and variables need explicit type annotations when type inference fails, and only tensor operations, numeric types, and standard containers (lists, dicts, tuples) are permitted---no arbitrary Python objects, no standard library modules like `os` or `sys`, and no dynamic class modification or metaprogramming. These constraints are the price of compilation: every feature that makes Python flexible also makes it unpredictable for a compiler.
|
||
|
||
\index{Single Static Assignment!compiler IR}
|
||
The TorchScript IR represents operations using the `aten` namespace for core tensor operations, the `prim` namespace for primitives and control flow, static types for every value, and Single Static Assignment (SSA) form, where each variable is assigned exactly once to simplify compiler analysis. This IR enables optimizations independent of Python: operator fusion combines adjacent operations into single kernels, constant folding evaluates constant expressions at compile time, dead code elimination removes unused operations, and memory optimization reuses buffers when possible. @tbl-tracing-vs-scripting summarizes the key trade-offs between these two approaches.
|
||
|
||
| **Aspect** | **Tracing** | **Scripting** |
|
||
|:----------------------|:-----------------------------|:------------------------------|
|
||
| **Input requirement** | Example inputs needed | No inputs needed |
|
||
| **Control flow** | Cannot handle data-dependent | Supports data-dependent |
|
||
| **Conversion ease** | Simpler (just run function) | Harder (restricted Python) |
|
||
| **Type annotations** | Not required | Required when inference fails |
|
||
| **Error detection** | Runtime (wrong results) | Compile time (syntax errors) |
|
||
| **Best for** | Feed-forward models | Models with conditionals |
|
||
|
||
: **Tracing vs. Scripting Trade-offs.** The fidelity-generality trade-off manifests concretely: tracing is simpler to use but silently ignores data-dependent control flow, while scripting preserves all branches at the cost of a restricted Python subset. Choose tracing for static architectures and scripting for models with runtime conditionals. {#tbl-tracing-vs-scripting}
|
||
|
||
#### Modern Compilation: torch.compile {#sec-ml-frameworks-modern-compilation-torchcompile-d025}
|
||
|
||
\index{Graph Compilation!eager-mode capture}
|
||
The previous approaches force a choice: write flexible code (eager execution) or fast code (static graphs). Modern JIT compilation attempts to eliminate this trade-off by automatically compiling eager code into optimized graphs with minimal developer intervention.
|
||
|
||
PyTorch 2.0's `torch.compile` [@ansel2024pytorch2] represents this approach: developers write natural Python code that executes eagerly during development, but the framework automatically captures and compiles hot paths into optimized kernels for production. @lst-torch-compile-intro shows the basic usage pattern:
|
||
|
||
::: {#lst-torch-compile-intro lst-cap="**torch.compile**: PyTorch 2.0's compiler captures execution on first call, compiles an optimized kernel, then reuses compiled code for subsequent calls with matching shapes."}
|
||
|
||
```{.python}
|
||
@torch.compile
|
||
def forward(x):
|
||
return x * 2 + 1
|
||
|
||
|
||
# First call: captures execution, compiles optimized kernel (~100ms)
|
||
result1 = forward(torch.tensor([1.0]))
|
||
|
||
# Reuse compiled code
|
||
forward(torch.randn(10, 10))
|
||
```
|
||
|
||
:::
|
||
|
||
The compilation overhead in these examples (approximately 100 ms to compile the first time, microseconds to reuse) illustrates why torch.compile is so effective. The deeper question is *why* compilation helps so much. The answer lies in understanding *the physics of software overhead*. Dispatch costs that seem negligible for a single operation---a few microseconds here and there---compound dramatically across the thousands of operations in a forward pass. The following analysis quantifies this phenomenon.
|
||
|
||
```{python}
|
||
#| label: fusion-speedup-calc
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ FUSION SPEEDUP CALCULATION
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: Callout "The Physics of Software Overhead" comparing eager vs compiled
|
||
# │
|
||
# │ Goal: Quantify the latency and bandwidth benefits of kernel fusion.
|
||
# │ Show: A 2× speedup from eliminating redundant kernel launches and memory traffic.
|
||
# │ How: Model dispatch and memory overhead for eager vs. fused Add+ReLU.
|
||
# │
|
||
# │ Imports: None (pure calculation)
|
||
# │ Exports: python_dispatch_us_str, kernel_launch_only_us_str, memory_access_us_str,
|
||
# │ kernel_launch_us_str, eager_overhead_str, compiled_overhead_str,
|
||
# │ overhead_speedup_str, bw_efficiency_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
|
||
# ┌── LEGO ───────────────────────────────────────────────
|
||
class FusionSpeedup:
|
||
"""
|
||
Namespace for Kernel Fusion Speedup calculation.
|
||
Scenario: Comparing Eager (2 launches) vs Fused (1 launch) overheads.
|
||
"""
|
||
|
||
# ┌── 1. LOAD (Constants) ───────────────────────────────────────────────
|
||
python_dispatch_us = 10
|
||
kernel_launch_us = 5
|
||
memory_access_us = 1
|
||
|
||
eager_ops = 2
|
||
fused_ops = 1
|
||
|
||
eager_mem_factor = 4 # 2R + 2W
|
||
fused_mem_factor = 2 # 1R + 1W (intermediate fused)
|
||
|
||
# ┌── 2. EXECUTE (The Compute) ─────────────────────────────────────────
|
||
launch_overhead = python_dispatch_us + kernel_launch_us
|
||
|
||
eager_total_overhead = eager_ops * launch_overhead
|
||
fused_total_overhead = fused_ops * launch_overhead
|
||
|
||
speedup = eager_total_overhead / fused_total_overhead
|
||
bw_efficiency = eager_mem_factor / fused_mem_factor
|
||
|
||
# ┌── 3. GUARD (Invariants) ───────────────────────────────────────────
|
||
check(speedup >= 1.5, f"Fusion speedup ({speedup:.1f}x) is too small to justify compilation complexity.")
|
||
|
||
# ┌── 4. OUTPUT (Formatting) ──────────────────────────────────────────────
|
||
python_dispatch_us_str = f"{python_dispatch_us}"
|
||
kernel_launch_only_us_str = f"{kernel_launch_us}"
|
||
memory_access_us_str = f"{memory_access_us}"
|
||
kernel_launch_us_str = f"{launch_overhead}"
|
||
|
||
eager_overhead_str = f"{eager_total_overhead}"
|
||
compiled_overhead_str = f"{fused_total_overhead}"
|
||
|
||
overhead_speedup_str = f"{int(speedup)}"
|
||
bw_efficiency_str = f"{int(bw_efficiency)}"
|
||
|
||
# Note: Use FusionSpeedup.python_dispatch_us_str directly.
|
||
```
|
||
|
||
::: {.callout-notebook title="The Physics of Software Overhead"}
|
||
|
||
**The Iron Law Connection:**
|
||
The **Latency Term** ($\text{Latency}_{\text{fixed}}$) in the Iron Law is dominated by software overhead: dispatching instructions from Python to the GPU.
|
||
|
||
**The Constants of Latency:**
|
||
|
||
* **Python Dispatch:** ~`{python} FusionSpeedup.python_dispatch_us_str` μs per operation.
|
||
* **Kernel Launch:** ~`{python} FusionSpeedup.kernel_launch_only_us_str` μs per operation.
|
||
* **Memory Access (VRAM):** ~`{python} FusionSpeedup.memory_access_us_str` μs.
|
||
|
||
**Scenario 1: Eager Mode (The "Tiny Op" Trap)**
|
||
Consider a simple activation block: `y = relu(x + bias)`.
|
||
|
||
* **Operations:** 2 (Add, ReLU).
|
||
* **Execution:**
|
||
|
||
1. Launch `Add` Kernel: `{python} FusionSpeedup.kernel_launch_us_str` µs overhead.
|
||
2. Read/Write Memory: $2N$ bytes.
|
||
3. Launch `ReLU` Kernel: `{python} FusionSpeedup.kernel_launch_us_str` µs overhead.
|
||
4. Read/Write Memory: $2N$ bytes.
|
||
* **Total Overhead:** `{python} FusionSpeedup.eager_overhead_str` µs.
|
||
* **Total Memory Traffic:** $4N$ bytes.
|
||
|
||
**Scenario 2: Compiled Mode (Fusion)**
|
||
The compiler fuses this into one kernel: `FusedAddRelu`.
|
||
|
||
* **Execution:**
|
||
|
||
1. Launch `Fused` Kernel: `{python} FusionSpeedup.compiled_overhead_str` µs overhead.
|
||
2. Read/Write Memory: $2N$ bytes (intermediate result stays in registers).
|
||
* **Total Overhead:** `{python} FusionSpeedup.compiled_overhead_str` µs (**`{python} FusionSpeedup.overhead_speedup_str`$\times$ speedup**).
|
||
* **Total Memory Traffic:** 2N bytes (**`{python} FusionSpeedup.bw_efficiency_str`$\times$ bandwidth efficiency**).
|
||
|
||
**The Conclusion:**
|
||
Compilation is not magic; it is **overhead amortization**. For small, element-wise operations (like LayerNorm, GELU, Add), overhead often exceeds compute time by 10--100$\times$. Fusing them is the only way to use the hardware effectively.
|
||
|
||
:::
|
||
|
||
See this tax play out concretely in @fig-python-tax. Notice how eager execution (top) creates "gaps" where the GPU sits idle while Python dispatches the next kernel. The blue compute regions are short; the red dispatch regions are comparatively long. Compilation (bottom) fuses these operations into a single kernel launch, eliminating the gaps entirely so the GPU spends nearly all its time computing rather than waiting.
|
||
|
||
::: {#fig-python-tax fig-env="figure" fig-pos="htb" fig-cap="**The Python Tax**: Visualizing the overhead analysis from the preceding callout. In Eager Mode (top), the GPU (blue) finishes processing each op in microseconds but must sit idle while the Python interpreter (red) dispatches the next kernel launch. Compilation (bottom) fuses these operations into a single kernel, effectively hiding the dispatch latency and maximizing GPU utilization." fig-alt="Gantt chart of execution timeline. Eager mode shows alternating red (Python) and blue (GPU) blocks with gaps. Compiled mode shows one small red block followed by one long blue block."}
|
||
|
||
```{python}
|
||
#| echo: false
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ FIG-PYTHON-TAX
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: @fig-python-tax illustrating the Python dispatch overhead
|
||
# │
|
||
# │ Goal: Visualize the dispatch overhead of eager execution.
|
||
# │ Show: How the "Python Tax" creates idle gaps between GPU kernels.
|
||
# │ How: Plot a Gantt chart comparing alternating dispatch/compute vs. fused kernels.
|
||
# │
|
||
# │ Imports: mlsysim.core.viz (viz)
|
||
# │ Exports: figure only, no prose variables
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
# ┌── 1. CANVAS ────────────────────────────────────────────────────────────────
|
||
# │ Visualize the dispatch overhead of eager execution.
|
||
from mlsysim import viz
|
||
|
||
fig, ax, COLORS, plt = viz.setup_plot()
|
||
|
||
# =============================================================================
|
||
# PLOT: The Python Tax
|
||
# =============================================================================
|
||
|
||
# ┌── 2. ARRAYS ────────────────────────────────────────────────────────────────
|
||
t_dispatch, t_compute, n_ops = 10, 1, 5
|
||
|
||
# Eager execution: alternating dispatch and compute
|
||
y_eager = 1
|
||
|
||
# ┌── 3. RENDER ────────────────────────────────────────────────────────────────
|
||
for i in range(n_ops):
|
||
start = i * (t_dispatch + t_compute)
|
||
ax.barh(y_eager, t_dispatch, left=start, height=0.4, color=COLORS['RedLine'], alpha=0.6, label='Python Overhead' if i==0 else "")
|
||
ax.barh(y_eager, t_compute, left=start+t_dispatch, height=0.4, color=COLORS['BlueLine'], alpha=0.8, label='GPU Kernel' if i==0 else "")
|
||
|
||
# Compiled execution: one dispatch, one fused kernel
|
||
y_compiled = 0
|
||
t_fused_compute = t_compute * n_ops * 0.8
|
||
t_compiled_dispatch = 5
|
||
ax.barh(y_compiled, t_compiled_dispatch, left=0, height=0.4, color=COLORS['RedLine'], alpha=0.6)
|
||
ax.barh(y_compiled, t_fused_compute, left=t_compiled_dispatch, height=0.4, color=COLORS['BlueLine'], alpha=0.8)
|
||
|
||
|
||
# ┌── 4. DECORATE ──────────────────────────────────────────────────────────────
|
||
ax.set_yticks([0, 1])
|
||
ax.set_yticklabels(['Compiled / Fused', 'Eager Execution'])
|
||
ax.set_xlabel('Execution Time (microseconds)')
|
||
ax.legend(loc='lower center', bbox_to_anchor=(0.5, -0.25), ncol=2, fontsize=8)
|
||
ax.text(30, 0.5, "Gap = The Python Tax", color=COLORS['RedLine'], ha='center', va='center', fontsize=9, fontweight='bold', bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=0.5))
|
||
plt.show()
|
||
```
|
||
|
||
:::
|
||
|
||
The natural question is: can this fusion happen automatically? PyTorch 2.0's `torch.compile`[^fn-torch-compile-hybrid] attempts exactly this by capturing eager code and compiling it into fused kernels without requiring users to write custom CUDA.[^fn-cuda-dispatch-overhead]
|
||
|
||
[^fn-torch-compile-hybrid]: **`torch.compile`**: It enables this automatic fusion by intercepting Python bytecode (via TorchDynamo) to extract a computational graph from unmodified eager code. This graph is then compiled into optimized kernels, trading a one-time compilation delay for a permanent 1.3--$2\times$ throughput gain on transformer models by reducing kernel launch overhead. \index{torch.compile!hybrid execution}
|
||
|
||
[^fn-cuda-dispatch-overhead]: **CUDA (Compute Unified Device Architecture)**: NVIDIA's parallel computing platform (2007) serving as the foundational layer between high-level Python operations and GPU silicon. When PyTorch executes `torch.matmul(A, B)`, the call traverses the framework's dispatcher, selects a cuBLAS kernel, and launches it on the GPU. Each launch incurs 5--20 $\mu$s of CPU-side overhead. For small operations, this dispatch overhead ($L_{\text{lat}}$) exceeds the useful compute time, which is why compilation (fusing $N$ operations into one kernel launch) yields speedups proportional to the reduction in launch count rather than the reduction in arithmetic. \index{CUDA!dispatch overhead}
|
||
|
||
##### Architecture: Three-Stage Compilation Pipeline {.unnumbered}
|
||
|
||
torch.compile consists of three coordinated components, each handling a distinct phase of the compilation process:
|
||
|
||
1. **TorchDynamo** (graph capture): Intercepts Python bytecode execution using CPython's PEP 523 frame evaluation API. Unlike `torch.jit.trace`, which records a single execution path and silently ignores alternative branches, TorchDynamo also captures operations during execution but inserts *graph breaks* when it encounters unsupported code (print statements, arbitrary Python), ensuring correctness rather than silent failure. The current graph is finalized for compilation, unsupported code executes eagerly, and a new graph begins after.
|
||
|
||
2. **FX Graph** (intermediate representation): Operations captured by TorchDynamo are converted to FX graph format, PyTorch's node-based directed acyclic graph where each node represents an operation with explicit inputs and outputs. The FX graph serves as PyTorch's analog to LLVM IR: a standardized representation that separates frontend (Python code capture) from backend (hardware-specific code generation). This design allows different backends (TorchInductor, ONNX Runtime, TensorRT) to consume FX graphs and enables optimization passes such as dead code elimination, constant folding, and pattern matching for fusion opportunities.
|
||
|
||
3. **TorchInductor**[^fn-torchinductor-codegen] (code generation): The default backend that compiles FX graphs to optimized machine code. For CUDA GPUs, TorchInductor generates Triton\index{Triton!GPU kernel language}[^fn-triton-kernel-language] kernels, a Python-based GPU kernel language that compiles to PTX[^fn-ptx-portability]. For CPUs, it generates C++ code with vectorization instructions (AVX2, AVX-512). TorchInductor applies three key optimizations: kernel fusion (combining operations to reduce memory traffic), memory layout optimization (choosing tensor layouts that minimize access overhead), and autotuning (measuring performance across implementation variants to select the fastest).
|
||
|
||
[^fn-torchinductor-codegen]: **TorchInductor**: The use of Triton to generate GPU code is a deliberate trade-off that prioritizes fast just-in-time (JIT) compilation speed over achieving maximum hardware performance. This makes on-the-fly optimization practical for an eager-execution framework, even if the resulting kernels are 5–20% slower than highly optimized, hand-written CUDA. \index{TorchInductor!code generation}
|
||
|
||
[^fn-ptx-portability]: **PTX (Parallel Thread Execution)**: An intermediate representation (IR) from NVIDIA that serves as a stable compilation target for high-level GPU languages like Triton. This allows TorchInductor to generate portable code, as the NVIDIA driver—not the framework—is responsible for the final translation to hardware-specific machine code (SASS). This forward compatibility, however, can result in performance that is 10-15% slower than kernels hand-tuned for a specific GPU architecture. \index{PTX!compilation portability}
|
||
|
||
[^fn-triton-kernel-language]: **Triton**: TorchInductor generates Triton because its Python-like syntax provides a simpler, more stable compilation target than raw CUDA, making automated code generation tractable. This abstraction allows the compiler to handle complex GPU details like memory coalescing automatically, a requirement for performing kernel fusion. The accepted trade-off is achieving 80–95% of hand-tuned CUDA performance in exchange for enabling the compiler to effectively autotune kernels and reduce development time from weeks to hours. \index{Triton!performance trade-off}
|
||
|
||
The generated code is cached on disk: TorchInductor maintains its own compilation cache, and Triton kernels are additionally cached in `~/.triton/cache/`. Subsequent runs with the same input shapes can skip compilation and directly execute cached code.
|
||
|
||
##### Execution Flow {.unnumbered}
|
||
|
||
The first execution follows a multi-step process: TorchDynamo intercepts bytecode and records operations into FX graph, FX graph is passed to TorchInductor for compilation (5--30 seconds for transformer models), and compiled code is cached and executed. Subsequent executions with the same input shapes dispatch directly to compiled code with microseconds overhead. If input shapes change, TorchInductor must recompile for the new shapes (shape specialization). PyTorch maintains separate compiled versions for each unique shape configuration.
|
||
|
||
##### Graph Breaks: Causes and Detection {.unnumbered}
|
||
|
||
Graph breaks occur when torch.compile encounters code it cannot compile, forcing execution to fall back to eager mode. Understanding graph break causes provides the foundation for achieving good performance.
|
||
|
||
Data-dependent control flow requires tensor values unavailable at compile time, as shown in @lst-graph-break-control-flow.
|
||
|
||
::: {#lst-graph-break-control-flow lst-cap="**Graph Break from Control Flow**: Data-dependent conditionals force a graph break because tensor values are unavailable at compile time, splitting execution into separate compiled regions."}
|
||
|
||
```{.python}
|
||
@torch.compile
|
||
def conditional_compute(x):
|
||
if x.sum() > 0: # Graph break: tensor value needed
|
||
return x * 2
|
||
else:
|
||
return x * 3
|
||
|
||
|
||
# Creates two compiled regions: operations before
|
||
# and after the if statement
|
||
# The if statement itself executes eagerly
|
||
```
|
||
|
||
:::
|
||
|
||
TorchDynamo creates a graph break: operations before the if statement are compiled, the if statement executes eagerly (evaluating which branch to take), and the chosen branch is compiled as a separate region.
|
||
|
||
Unsupported operations also cause graph breaks, as @lst-graph-break-io demonstrates.
|
||
|
||
::: {#lst-graph-break-io lst-cap="**Graph Break from I/O**: Unsupported operations like `print` force a graph break, splitting compiled code into two regions with eager execution in between."}
|
||
|
||
```{.python}
|
||
@torch.compile
|
||
def debug_compute(x):
|
||
y = x * 2
|
||
print(f"y = {y}") # Graph break: I/O operation
|
||
z = y + 1
|
||
return z
|
||
|
||
|
||
# Creates two compiled regions: before and after print
|
||
```
|
||
|
||
:::
|
||
|
||
Common unsupported operations include I/O (`print`, file operations), custom Python objects, and calls to non-PyTorch libraries. Each graph break incurs overhead: tensors must be marshalled from compiled code back to Python (possibly copying from GPU to CPU), the eager operation executes, and results are marshalled into the next compiled region.
|
||
|
||
Shape changes prevent compiled code reuse, as @lst-graph-break-shapes illustrates.
|
||
|
||
::: {#lst-graph-break-shapes lst-cap="**Recompilation from Shape Changes**: Each unique input shape triggers a separate compilation, causing significant overhead when shapes vary frequently."}
|
||
|
||
```{.python}
|
||
@torch.compile
|
||
def variable_length(x, length):
|
||
return x[:, :length] # Shape changes each call
|
||
|
||
|
||
# Each unique length triggers recompilation
|
||
for i in range(10):
|
||
result = variable_length(x, i) # 10 recompilations
|
||
```
|
||
|
||
:::
|
||
|
||
Detect graph breaks using @lst-graph-break-detect.
|
||
|
||
::: {#lst-graph-break-detect lst-cap="**Detecting Graph Breaks**: Setting `TORCH_LOGS` to `graph_breaks` prints each break location and reason during execution."}
|
||
|
||
```{.bash}
|
||
TORCH_LOGS="graph_breaks" python train.py
|
||
```
|
||
|
||
:::
|
||
|
||
This prints each break location and reason: `Graph break in user code at file.py:15 / Reason: call to unsupported function print`. Minimizing graph breaks is key to performance: move unsupported operations outside compiled regions, replace data-dependent control flow with conditional execution (`torch.where`), or accept eager execution for inherently dynamic sections.
|
||
|
||
##### Compilation Modes and Backends {.unnumbered}
|
||
|
||
\index{Graph Compilation!compilation modes}
|
||
As a project matures from prototyping to production, engineers progressively increase compilation aggressiveness. The default mode (`mode='default'`) applies moderate optimization with fast compilation (5--30 seconds for transformer models), making it suitable for development and training where compilation overhead is amortized over many iterations. When deploying an inference server with fixed input shapes, `mode='reduce-overhead'` minimizes Python interpreter overhead by aggressively capturing operations and enabling CUDA graphs that batch kernel launches, improving throughput by 20--40% over the default. For production training that will run for days, `mode='max-autotune'` generates and benchmarks multiple implementation variants for each operation, increasing compilation time (minutes to hours for large models) but improving runtime performance by 10--30%. This progression---default for development, reduce-overhead for inference, max-autotune for long training runs---mirrors the Compilation Continuum principle we formalize below.
|
||
|
||
The compilation mode controls *how aggressively* to optimize; the backend controls *what target* to optimize for. TorchInductor (the default) generates Triton kernels for CUDA and C++ for CPU, providing the best general-purpose performance for both training and inference. When cross-platform deployment is required, the ONNX Runtime backend\index{ONNX Runtime!compilation backend} exports the FX graph to ONNX format, enabling execution on CPUs, GPUs, mobile, and edge devices---though limited ONNX operation coverage may cause more graph breaks. For maximum inference throughput on NVIDIA GPUs, the TensorRT backend\index{TensorRT!inference compiler} compiles to NVIDIA's inference engine with aggressive int8 quantization, layer fusion, and kernel autotuning, often achieving 1.5--2$\times$ speedup over TorchInductor. The trade-off is clear: each backend narrows the target to unlock deeper optimization, echoing the flexibility-versus-performance axis that distinguishes eager from graph execution.
|
||
|
||
##### Practical Example: Measuring Speedup {.unnumbered}
|
||
|
||
@lst-torch-compile-benchmark implements correct GPU benchmarking methodology, incorporating CUDA synchronization, warmup iterations to exclude compilation time, and sufficient iterations to amortize measurement overhead:
|
||
|
||
::: {#lst-torch-compile-benchmark lst-cap="**Benchmarking torch.compile**: Properly measuring speedup requires CUDA synchronization, warmup to exclude compilation time, and sufficient iterations to amortize measurement overhead."}
|
||
|
||
```{.python}
|
||
import torch
|
||
import time
|
||
|
||
|
||
def forward(x, w):
|
||
return torch.matmul(x, w).relu()
|
||
|
||
|
||
x = torch.randn(1024, 1024, device="cuda")
|
||
w = torch.randn(1024, 512, device="cuda")
|
||
|
||
# Eager mode benchmark
|
||
torch.cuda.synchronize() # Ensure GPU operations complete
|
||
start = time.time()
|
||
for _ in range(100):
|
||
y = forward(x, w)
|
||
torch.cuda.synchronize() # Wait for GPU kernel completion
|
||
eager_time = time.time() - start
|
||
|
||
# Compiled mode benchmark
|
||
forward_compiled = torch.compile(forward)
|
||
forward_compiled(x, w) # Warmup: trigger compilation
|
||
torch.cuda.synchronize()
|
||
|
||
start = time.time()
|
||
for _ in range(100):
|
||
y = forward_compiled(x, w)
|
||
torch.cuda.synchronize()
|
||
compiled_time = time.time() - start
|
||
|
||
print(f"Speedup: {eager_time/compiled_time:.2f}$\times$ ")
|
||
# Typical: 2-5x speedup for matrix operations
|
||
```
|
||
|
||
:::
|
||
|
||
Critical benchmarking details: (1) Use `torch.cuda.synchronize()` because CUDA operations are asynchronous; without synchronization, timing measures only kernel launch time, not execution time. (2) Warmup compilation by calling once before timing to exclude compilation from measurements. (3) Run 100+ iterations to amortize measurement overhead.
|
||
|
||
##### Systems Implications {.unnumbered}
|
||
|
||
First execution includes compilation time: 5--10 s for small models, 30--60 s for BERT-base transformers, 5--10 min for GPT-3 scale models. This overhead is amortized across training (compile once, train for thousands of iterations) but impacts development iteration time. Compiled kernels are cached on disk; subsequent runs skip compilation.
|
||
|
||
Compilation adds overhead: 100--500 MB for FX graph construction, 500 MB--2 GB peak during Triton compilation, 10--100 MB per compiled graph for storage. Runtime memory usage is similar to eager mode (kernel fusion can reduce intermediate tensors but compiled code may allocate temporary buffers). Compiled models typically use 90--110% of eager mode memory.
|
||
|
||
Errors in compiled code produce stack traces pointing to generated code, not source Python code. Print statements inside compiled regions cause graph breaks (executed eagerly, not compiled). For debugging, remove `@torch.compile` to revert to eager execution, fix bugs, then re-enable compilation. Use `TORCH_COMPILE_DEBUG=1` for verbose compilation logs.
|
||
|
||
##### When to Use torch.compile {.unnumbered}
|
||
|
||
The decision follows directly from the compilation cost model. Long training runs amortize compilation overhead across hundreds of iterations, and stable architectures with fixed control flow minimize graph breaks---making training the strongest use case. Inference is equally compelling: a deployed model compiles once at startup and serves thousands of requests, where `mode='reduce-overhead'` minimizes per-request overhead. Compilation should be deferred, however, during rapid prototyping, where the overhead slows iteration time and the architecture has not yet stabilized. Models with frequent graph breaks or dynamic shape changes prevent effective compilation, and debugging is harder in compiled mode because error locations point to generated code rather than source Python. The practical strategy is to develop in eager mode, stabilize the architecture, then enable compilation for training and deployment.
|
||
|
||
##### Comparison of Execution Models {.unnumbered}
|
||
|
||
@tbl-framework-execution-models contrasts the three execution models across six dimensions, revealing that hybrid JIT compilation achieves most of static graph performance while preserving much of eager execution's flexibility:
|
||
|
||
: **Execution Model Trade-Offs.** Each execution model occupies a distinct position in the flexibility-optimization trade-off space. Eager execution maximizes debugging flexibility but sacrifices optimization potential; static graphs maximize optimization but sacrifice dynamic control flow; hybrid JIT compilation attempts both by compiling captured regions while falling back to eager for unsupported patterns. {#tbl-framework-execution-models}
|
||
|
||
| **Aspect** | **Eager + Autograd Tape** **(PyTorch default)** | **Static Graph** **(TensorFlow 1.x)** | **JIT Compilation** **(torch.compile)** |
|
||
|:-------------------------|:------------------------------------------------|:--------------------------------------|:----------------------------------------|
|
||
| **Execution Model** | Immediate | Deferred | Hybrid |
|
||
| **Graph Construction** | During forward pass | Before execution | First execution (cached) |
|
||
| **Optimization** | None (per-operation) | Ahead-of-time | JIT compilation |
|
||
| **Dynamic Control Flow** | Full support | Limited (static unroll) | Partial (graph breaks) |
|
||
| **Debugging** | Easy (standard Python) | Difficult (symbolic) | Moderate (mixed) |
|
||
| **Performance** | Baseline | High (optimized) | High (compiled regions) |
|
||
|
||
Eager mode's primary value is in the "Workflow Iteration" loop (@sec-ml-workflow): it allows using standard Python debuggers (like PDB) to inspect variables mid-execution, whereas graph-mode debugging often requires specialized framework tools. This immediate feedback accelerates the prototyping phase of the ML lifecycle.
|
||
|
||
Beyond these core execution trade-offs, @tbl-mlfm-graphs highlights additional systems-level distinctions between static and dynamic approaches:
|
||
|
||
| **Aspect** | **Static Graphs** | **Dynamic Graphs** |
|
||
|:---------------------------------|:-----------------------------------------------------|:----------------------------------------------|
|
||
| **Memory Management** | Precise allocation planning, optimized memory usage | Flexible but potentially less efficient |
|
||
| **Hardware Utilization** | Can generate highly optimized hardware-specific code | May sacrifice hardware-specific optimizations |
|
||
| **Research Velocity** | Slower iteration due to define-then-run requirement | Faster prototyping and model experimentation |
|
||
| **Integration with Legacy Code** | More separation between definition and execution | Natural integration with imperative code |
|
||
|
||
: **Additional Graph Trade-Offs.** Systems-level distinctions between static and dynamic graphs that complement the execution model comparison above. These trade-offs reappear when selecting frameworks in @sec-ml-frameworks-selecting-framework-2949. {#tbl-mlfm-graphs}
|
||
|
||
These trade-offs are not binary choices. Modern frameworks offer a spectrum of options, which raises the quantitative question of where on this spectrum a given project should operate.
|
||
|
||
### Quantitative Principles of Execution {#sec-ml-frameworks-compilation-continuum-principle-c106}
|
||
|
||
\index{Compilation Continuum!principle}
|
||
|
||
These execution models present a spectrum of trade-offs, but engineers need more than intuition to navigate them. Two quantitative principles formalize the decision. The *Compilation Continuum Principle* establishes when the performance gains from compilation justify its development cost, expressed as a ratio of production executions to development iterations. The *Dispatch Overhead Law* quantifies the per-operation cost of framework flexibility, revealing why small operations in eager mode can spend more time in Python overhead than in actual computation. Together, these principles transform framework selection from subjective preference into measurable engineering analysis.
|
||
|
||
#### The Compilation Continuum Principle {#sec-ml-frameworks-compilation-continuum-principle-7122}
|
||
|
||
The Execution Problem demands a quantitative principle: **when should a project compile?**
|
||
|
||
The execution models form a continuum from maximum flexibility to maximum optimization, visualized in @eq-execution-continuum:
|
||
|
||
$$
|
||
\text{Eager} \xrightarrow{\text{tracing}} \text{JIT} \xrightarrow{\text{AOT}} \text{Static Graph} \xrightarrow{\text{synthesis}} \text{Custom Hardware}
|
||
$$ {#eq-execution-continuum}
|
||
|
||
Each step rightward sacrifices flexibility for performance. The practical question is *where* on this continuum a given project should operate. The optimal compilation strategy depends on the ratio of **development iterations** to **production executions** (@eq-compilation-benefit):
|
||
|
||
$$
|
||
\text{Compilation Benefit} = \frac{N_{\text{prod}} \cdot (T_{\text{eager}} - T_{\text{compiled}})}{T_{\text{compile}} + N_{\text{dev}} \cdot T_{\text{compile}}}
|
||
$$ {#eq-compilation-benefit}
|
||
|
||
Where:
|
||
|
||
- $N_{\text{prod}}$ = number of production executions (dimensionless count: inference requests, training steps)
|
||
- $N_{\text{dev}}$ = number of development iterations requiring recompilation (dimensionless count)
|
||
- $T_{\text{eager}}$ = time per execution in eager mode (seconds)
|
||
- $T_{\text{compiled}}$ = time per execution in compiled mode (seconds)
|
||
- $T_{\text{compile}}$ = one-time compilation cost (seconds)
|
||
|
||
**Decision Rule**: Compile when $\text{Compilation Benefit} > 1$. The ratio is dimensionless.
|
||
|
||
@tbl-training-benchmark provides representative throughput data across execution modes and model architectures:
|
||
|
||
| **Model** | **Eager** **(img/sec)** | **torch.compile** **(img/sec)** | **TensorRT** **(img/sec)** | **Compile Time** **(seconds)** |
|
||
|:-----------------|------------------------:|--------------------------------:|---------------------------:|-------------------------------:|
|
||
| **ResNet-50** | 1,450 | 2,150 | 3,800 | 15--30 |
|
||
| **BERT-Base** | 380 | 520 | 890 | 30--60 |
|
||
| **ViT-B/16** | 620 | 950 | 1,650 | 25--45 |
|
||
| **GPT-2 (124M)** | 180 | 260 | 420 | 45--90 |
|
||
|
||
: **Training and Inference Throughput.** Representative throughput comparison across execution modes for common model architectures on NVIDIA A100 GPU with batch size 32. torch.compile typically provides 1.4 to 1.5$\times$ speedup over eager mode, while TensorRT provides 2 to 3$\times$ speedup but requires longer compilation and is inference only. Compile times vary based on model complexity and optimization level. {#tbl-training-benchmark}
|
||
|
||
These throughput differences across execution modes raise a practical question — which *framework execution strategy* best serves each workload archetype.
|
||
|
||
::: {.callout-lighthouse title="Framework Strategy by Archetype"}
|
||
|
||
The optimal framework execution strategy depends on which **Iron Law** term dominates the workload. @tbl-framework-archetype-strategy aligns each archetype to its recommended execution strategy:
|
||
|
||
| **Archetype** | **Dominant Iron Law Term** | **Optimal Framework Strategy** | **Rationale** |
|
||
|:----------------------|:-------------------------------------------------|:-----------------------------------|:-------------------------------------------|
|
||
| **ResNet-50** | $\frac{O}{R_{\text{peak}} \cdot \eta}$ (Compute) | **TensorRT** (inference) | Kernel fusion maximizes MFU; compute-bound |
|
||
| **(Compute Beast)** | | **torch.compile** (training) | workloads benefit most from optimization |
|
||
| **GPT-2** | $\frac{D_{\text{vol}}}{BW}$ (Memory Bandwidth) | **torch.compile** | Kernel fusion reduces HBM round-trips; |
|
||
| **(Bandwidth Hog)** | | | keeps data in cache to mitigate bandwidth |
|
||
| **DLRM** | $\frac{D_{\text{vol}}}{BW}$ (Random Access) + | **Eager** with specialized kernels | Embedding lookups are inherently irregular |
|
||
| **(Sparse Scatter)** | $T_{network}$ | (FBGEMM) | and dynamic; compilation gains are small |
|
||
| **DS-CNN** | $L_{\text{lat}}$ (Overhead) | **AOT compilation** (TFLite, ONNX) | Sub-ms inference; every microsecond of |
|
||
| **(Tiny Constraint)** | | | Python overhead is unacceptable |
|
||
|
||
: **Framework Execution Strategy by Workload.** Recommended execution strategy for each workload archetype, aligned to the dominant Iron Law term. Compute-bound workloads benefit most from compilation, while irregular access patterns favor eager execution. {#tbl-framework-archetype-strategy}
|
||
|
||
**Key insight**: Compilation benefits scale with how much of the workload is *optimizable*. Compute Beasts (@tbl-training-benchmark: ResNet-50 sees 2.6$\times$ speedup from TensorRT) benefit most. Sparse Scatter workloads gain little because their bottleneck (embedding lookups) is inherently irregular.
|
||
|
||
:::
|
||
|
||
This principle has concrete implications across three regimes. In *research prototyping* ($N_{\text{dev}} \gg N_{\text{prod}}$), teams should stay eager. If the architecture changes every few minutes, compilation overhead dominates. A 30-second compile time with 10 iterations/hour means 5 minutes lost to compilation per hour, often more than the runtime savings.
|
||
|
||
For *training runs* ($N_{\text{prod}} \gg N_{\text{dev}}$), compilation pays off. A typical training run executes millions of forward/backward passes, so even 60 seconds of compilation amortizes to microseconds per step. From @tbl-training-benchmark, torch.compile provides ~48% speedup on ResNet-50 (2,150 vs 1,450 img/sec); this pays off after the breakeven point in @eq-compile-breakeven:
|
||
|
||
$$
|
||
N_{\text{breakeven}} = \frac{T_{\text{compile}}}{T_{\text{eager}} - T_{\text{compiled}}} = \frac{30\text{s}}{(1/1450 - 1/2150)\text{s/img}} \approx 134{,}000 \text{ images}
|
||
$$ {#eq-compile-breakeven}
|
||
|
||
For ImageNet (1.28M training images), compilation pays off within the first epoch.
|
||
|
||
For *production inference* ($N_{\text{dev}} \approx 0$, $N_{\text{prod}} \rightarrow \infty$), teams should maximize compilation. With no development iterations and potentially millions of requests, every optimization matters. Using `mode='max-autotune'` despite hour-long compilation is worthwhile because the cost is amortized over the deployment lifetime.
|
||
|
||
These three regimes create distinct regions in the compilation decision space. @fig-compilation-continuum maps out these regions so engineers can identify where each strategy wins. Watch for the crossover points: the steep eager line (highest per-execution cost) eventually overtakes JIT's moderate slope, while the gentlest compiled line (lowest per-execution cost but largest upfront investment) wins only after millions of executions. The slopes reveal per-execution cost; the vertical offsets reveal compilation overhead. A project's position on the x-axis determines which line it should be on.
|
||
|
||
::: {#fig-compilation-continuum fig-env="figure" fig-pos="htb" fig-cap="**The Compilation Continuum**: Optimal execution strategy depends on development-to-production ratio. Left region (high dev iterations): eager mode dominates. Right region (high prod executions): compilation dominates. The crossover point depends on compilation cost and per-execution speedup." fig-alt="Graph with x-axis 'Production Executions' (log scale) and y-axis 'Total Time'. Three lines: Eager (steep slope), JIT (moderate slope with offset), Static (gentle slope with larger offset). Lines cross at different points showing when compilation becomes beneficial."}
|
||
|
||
```{python}
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ FIG-COMPILATION-CONTINUUM
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: @fig-compilation-continuum showing cost trade-offs across
|
||
# │ execution strategies (eager, JIT, AOT)
|
||
# │
|
||
# │ Goal: Visualize the total cost of ownership for different compilation strategies.
|
||
# │ Show: The crossover points where JIT and AOT beat eager execution.
|
||
# │ How: Plot total time (compile + execute) vs. production volume on a log scale.
|
||
# │
|
||
# │ Imports: numpy (np), mlsysim.core.viz (viz)
|
||
# │ Exports: figure only, no prose variables
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
# ┌── 1. CANVAS ────────────────────────────────────────────────────────────────
|
||
# │ Visualize the total cost of ownership for different compilation strategies.
|
||
import numpy as np
|
||
from mlsysim import viz
|
||
|
||
fig, ax, COLORS, plt = viz.setup_plot()
|
||
|
||
# =============================================================================
|
||
# PLOT: The Compilation Continuum
|
||
# =============================================================================
|
||
|
||
# ┌── 2. ARRAYS ────────────────────────────────────────────────────────────────
|
||
x = np.logspace(2, 7, 200) # 100 to 10,000,000 executions
|
||
|
||
# Cost models: compile_cost + per_exec_cost * x
|
||
y_eager = 0 + 10e-6 * x
|
||
y_jit = 10 + 5e-6 * x
|
||
y_aot = 30 + 2e-6 * x
|
||
|
||
|
||
# ┌── 3. RENDER ────────────────────────────────────────────────────────────────
|
||
ax.plot(x, y_eager, color=COLORS['BlueLine'], linewidth=2.5, label=r'Eager (no compile cost, high per-exec cost)')
|
||
ax.plot(x, y_jit, color=COLORS['OrangeLine'], linewidth=2.5, label=r'JIT (low compile cost, medium per-exec cost)')
|
||
ax.plot(x, y_aot, color=COLORS['GreenLine'], linewidth=2.5, label=r'AOT (high compile cost, low per-exec cost)')
|
||
|
||
|
||
# ┌── 4. DECORATE ──────────────────────────────────────────────────────────────
|
||
ax.text(1000, 5, "Eager wins", color=COLORS['BlueLine'], ha='center', va='center', fontweight='bold', fontsize=9, bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=0.5))
|
||
ax.text(4e6, 34, "JIT wins", color=COLORS['OrangeLine'], ha='center', va='center', fontweight='bold', fontsize=9, bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=0.5))
|
||
ax.text(8e6, 42, "AOT wins", color=COLORS['GreenLine'], ha='center', va='top', fontweight='bold', fontsize=9, bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=0.5))
|
||
|
||
ax.set_xscale('log')
|
||
ax.set_xlabel(r'Production Executions ($N_{prod}$)')
|
||
ax.set_ylabel('Total Time (arbitrary units)')
|
||
ax.set_xlim(100, 1e7)
|
||
ax.set_ylim(0, 100)
|
||
ax.legend(loc='upper left', fontsize=8)
|
||
plt.show()
|
||
```
|
||
|
||
:::
|
||
|
||
#### The Dispatch Overhead Law {#sec-ml-frameworks-dispatch-overhead-law-9e0a}
|
||
|
||
\index{Dispatch Overhead!law}
|
||
A second principle emerges from the Dispatch Overhead Equation (@eq-dispatch-overhead): when does framework overhead, rather than compute or memory, dominate execution time? Let $N_{\text{ops}}$ be the number of operations (count), $t_{\text{dispatch}}$ the per-operation dispatch overhead (seconds), and $T_{\text{compute}}$ and $T_{\text{memory}}$ the total compute and memory times (seconds). Framework overhead dominates when operations are small relative to dispatch cost:
|
||
|
||
$$
|
||
\text{Overhead Ratio} = \frac{N_{\text{ops}} \cdot t_{\text{dispatch}}}{T_{\text{compute}} + T_{\text{memory}}}
|
||
$$ {#eq-dispatch-overhead}
|
||
|
||
When Overhead Ratio $> 1$, the model is **overhead-bound**. Compilation provides maximum benefit for overhead-bound workloads because it eliminates per-operation dispatch.
|
||
|
||
From the case study in @sec-ml-frameworks-putting-together-anatomy-training-step-c7f1, we can quantify this effect.
|
||
|
||
```{python}
|
||
#| label: dispatch-tax-overhead-ratio
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ DISPATCH TAX CALCULATION
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: Callout "The Dispatch Tax" comparing overhead-bound vs compute-bound
|
||
# │
|
||
# │ Goal: Demonstrate why compilation benefits small models disproportionately.
|
||
# │ Show: That small MLPs are 92% overhead-bound, while GPT-3 is <0.05% overhead.
|
||
# │ How: Calculate the software-to-hardware execution time ratio for both cases.
|
||
# │
|
||
# │ Imports: mlsysim.book (fmt, md_math)
|
||
# │ Exports: dispatch_n_ops_value, dispatch_us_per_op_value, dispatch_hw_time_us_value,
|
||
# │ dispatch_sw_time_str, dispatch_ratio_small_str, dispatch_overhead_pct_str,
|
||
# │ dispatch_compilation_speedup_str, gpt3_hw_time_us_value, gpt3_sw_time_us_value,
|
||
# │ dispatch_ratio_large_str, t_sw_md
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsysim.core.constants import KIB_TO_BYTES
|
||
from mlsysim.fmt import fmt, check, md_math
|
||
|
||
# ┌── LEGO ───────────────────────────────────────────────
|
||
class DispatchTax:
|
||
"""
|
||
Namespace for Dispatch Tax Calculation.
|
||
Scenario: Comparing overhead impact on Small Ops vs Large Ops.
|
||
"""
|
||
|
||
# ┌── 1. LOAD (Constants) ───────────────────────────────────────────────
|
||
# Scenario 1: Small MLP (Overhead Bound)
|
||
small_ops_count = 6
|
||
small_dispatch_us = 5.0
|
||
small_hw_us = 2.6
|
||
|
||
# Scenario 2: GPT-3 Layer (Compute Bound)
|
||
large_hw_us = 100_000.0 # 100ms
|
||
large_dispatch_us = 50.0
|
||
|
||
# ┌── 2. EXECUTE (The Compute) ─────────────────────────────────────────
|
||
# Step 1: Small Model
|
||
small_sw_total = small_ops_count * small_dispatch_us
|
||
small_total_time = small_sw_total + small_hw_us
|
||
small_overhead_ratio = small_sw_total / small_hw_us
|
||
small_overhead_pct = (small_sw_total / small_total_time) * 100
|
||
small_speedup_limit = small_total_time / small_hw_us
|
||
|
||
# Step 2: Large Model
|
||
large_overhead_ratio = large_dispatch_us / large_hw_us
|
||
|
||
# ┌── 3. GUARD (Invariants) ───────────────────────────────────────────
|
||
check(small_overhead_ratio >= 1.0, f"Small model ratio ({small_overhead_ratio:.1f}) implies it is NOT overhead bound.")
|
||
check(large_overhead_ratio <= 0.01, f"Large model overhead ({large_overhead_ratio:.4f}) is too high.")
|
||
|
||
# ┌── 4. OUTPUT (Formatting) ──────────────────────────────────────────────
|
||
dispatch_n_ops_value = small_ops_count
|
||
dispatch_us_per_op_value = small_dispatch_us
|
||
dispatch_hw_time_us_value = small_hw_us
|
||
|
||
dispatch_sw_time_str = fmt(small_sw_total, precision=0, commas=False)
|
||
dispatch_ratio_small_str = fmt(small_overhead_ratio, precision=1, commas=False)
|
||
dispatch_overhead_pct_str = fmt(small_overhead_pct, precision=0, commas=False)
|
||
dispatch_compilation_speedup_str = fmt(small_speedup_limit, precision=0, commas=False)
|
||
|
||
gpt3_hw_time_us_value = large_hw_us
|
||
gpt3_sw_time_us_value = large_dispatch_us
|
||
dispatch_ratio_large_str = fmt(large_overhead_ratio, precision=4, commas=False)
|
||
t_sw_md = md_math(f"T_{{sw}} \\approx {large_dispatch_us} \\, \\mu s")
|
||
|
||
# Note: Use DispatchTax.dispatch_sw_time_str directly.
|
||
```
|
||
|
||
This cumulative latency creates what is effectively *a dispatch tax* on execution. We define $T_{\text{hw}}$ as hardware execution time and $T_{\text{sw}}$ as software overhead time; both are measured in seconds.
|
||
|
||
::: {.callout-notebook title="The Dispatch Tax"}
|
||
|
||
**Problem**: When does Python overhead kill performance?
|
||
|
||
**Scenario 1: Small MLP (Overhead Bound)**
|
||
|
||
* **Compute**: `{python} DispatchTax.dispatch_n_ops_value` small matrix/element-wise operations.
|
||
* **Hardware Time**: T_hw ≈ `{python} DispatchTax.dispatch_hw_time_us_value` μs (mostly memory latency).
|
||
* **Software Overhead**: T_sw ≈ `{python} DispatchTax.dispatch_n_ops_value` ops$\times$ `{python} DispatchTax.dispatch_us_per_op_value` μs/op = `{python} DispatchTax.dispatch_sw_time_str` μs.
|
||
* **Ratio**: `{python} DispatchTax.dispatch_sw_time_str` / `{python} DispatchTax.dispatch_hw_time_us_value` ≈ **`{python} DispatchTax.dispatch_ratio_small_str`**.
|
||
* **Conclusion**: The system spends `{python} DispatchTax.dispatch_overhead_pct_str`% of time waiting for Python. Compilation yields **`{python} DispatchTax.dispatch_compilation_speedup_str`$\times$ speedup**.
|
||
|
||
**Scenario 2: GPT-3 Layer (Compute Bound)**
|
||
|
||
* **Compute**: Huge matrix multiplications.
|
||
* **Hardware Time**: T_hw ≈ 100 ms = `{python} DispatchTax.gpt3_hw_time_us_value` μs.
|
||
* **Software Overhead**: `{python} DispatchTax.t_sw_md`.
|
||
* **Ratio**: `{python} DispatchTax.gpt3_sw_time_us_value` / `{python} DispatchTax.gpt3_hw_time_us_value` ≈ **`{python} DispatchTax.dispatch_ratio_large_str`**.
|
||
* **Conclusion**: Python overhead is negligible. Compilation helps only via kernel fusion (memory bandwidth), not dispatch elimination.
|
||
|
||
:::
|
||
|
||
```{python}
|
||
#| label: gpt3-params
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ GPT-3 PARAMETER COUNTS
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: @sec-ml-frameworks-dispatch-overhead-law-9e0a — Dispatch Tax
|
||
# │ analysis comparing small MLP vs. GPT-3 Layer overhead regimes
|
||
# │
|
||
# │ Goal: Quantify GPT-3 parameter count (175 B) to establish the compute-bound
|
||
# │ regime where dispatch overhead is negligible (<0.05%).
|
||
# │ Show: "175 B-parameter model sees only 1.3× speedup" — inline in prose
|
||
# │ after The Dispatch Tax callout.
|
||
# │ How: Models.GPT3.parameters → m_as(Bparam); single scalar extraction.
|
||
# │
|
||
# │ Note: PERSISTENT — GPT3Context.gpt3_params_b_str reused at line ~1634
|
||
# │ (Reverse Mode section), line ~1703 (Memory Management Strategies),
|
||
# │ line ~2905 and ~2922 (Parameter Structures / 3D Parallelism paragraphs).
|
||
# │
|
||
# │ Imports: mlsysim.core.constants (Bparam), mlsysim.book (fmt, check)
|
||
# │ Exports: GPT3Context.gpt3_params_b_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsysim import Models
|
||
from mlsysim.core.constants import Bparam
|
||
from mlsysim.fmt import fmt, check
|
||
|
||
# ┌── LEGO ───────────────────────────────────────────────
|
||
class GPT3Context:
|
||
"""
|
||
Namespace for GPT-3 Parameter Counts.
|
||
Scenario: Compilation benefits at scale.
|
||
"""
|
||
|
||
# ┌── 1. LOAD (Constants) ───────────────────────────────────────────────
|
||
model = Models.GPT3
|
||
|
||
# ┌── 2. EXECUTE (The Compute) ─────────────────────────────────────────
|
||
params_b = model.parameters.m_as(Bparam)
|
||
|
||
# ┌── 4. OUTPUT (Formatting) ──────────────────────────────────────────────
|
||
gpt3_params_b_str = fmt(params_b, precision=0, commas=False)
|
||
|
||
# Note: Use GPT3Context.gpt3_params_b_str directly.
|
||
```
|
||
|
||
The principle's implication is that small models benefit *disproportionately* from compilation. A 100-parameter toy model might see 10$\times$ speedup from torch.compile, while a `{python} GPT3Context.gpt3_params_b_str` B-parameter model sees only 1.3$\times$. This explains why compilation matters most for efficient inference on smaller, deployed models.
|
||
|
||
The dispatch tax analysis reveals that small operations are overhead-bound regardless of hardware capability. This observation matters most at the extreme edge of the deployment spectrum, where the entire Python runtime is itself an unacceptable overhead.
|
||
|
||
### Frameworks for the Edge: TinyML and Micro-Runtimes {#sec-ml-frameworks-tinyml-micro-runtimes-2a1b}
|
||
|
||
\index{TinyML!micro-runtimes}
|
||
\index{Edge Deployment!TinyML frameworks}
|
||
The compilation continuum reaches its extreme at the far edge. While cloud frameworks like PyTorch and TensorFlow 2.x prioritize flexibility through eager execution, **TinyML**[^fn-tinyml-aot-extreme] systems operating on microcontrollers (MCUs) with kilobytes of memory cannot afford the overhead of a Python interpreter or a dynamic runtime.
|
||
|
||
[^fn-tinyml-aot-extreme]: **TinyML**: Systems designed for microcontrollers (MCUs) that cannot afford the memory or processing overhead of a Python interpreter. Instead of flexible eager execution, frameworks compile models ahead-of-time (AOT) into self-contained C/C++ executables with no dynamic memory allocation. This is a hard requirement, as a single `malloc()` failure on a device with just 256 KB of RAM is unrecoverable. \index{TinyML!AOT compilation}
|
||
|
||
::: {.callout-lighthouse title="Lighthouse Example: Smart Doorbell (TinyML)"}
|
||
|
||
**The Scenario**: Deploying the **Smart Doorbell's** Keyword Spotting (KWS) model to an ARM Cortex-M4 microcontroller with 256 KB of RAM and 1 MB of Flash.
|
||
|
||
**The Constraint**: A standard PyTorch runtime occupies ~500 MB. The Python interpreter itself occupies ~20 MB. Both are orders of magnitude larger than the entire device.
|
||
|
||
**The Framework Solution**:
|
||
Micro-frameworks like **TensorFlow Lite Micro (TFLM)**\index{TensorFlow Lite Micro!extreme AOT} and **PyTorch ExecuTorch** solve this through **Extreme AOT Compilation**:
|
||
|
||
1. **Static memory planning**: The framework calculates the exact memory address for every tensor *at compile time*. There is no dynamic `malloc()` or garbage collection.
|
||
2. **Kernel specialization**: Only the specific kernels used by the model (e.g., Conv2D, DepthwiseConv) are compiled into the binary. Unused code is stripped away.
|
||
3. **No-interpreter execution**: The model is converted into a flat sequence of function calls or a simple "Command Buffer" that the MCU executes directly in C/C++.
|
||
|
||
**The Silicon Contract**: On TinyML devices, the contract is strictly **Memory-Bound**. The framework's primary job is to ensure the model's intermediate activations (the "working set") fit within the MCU's tiny SRAM.
|
||
|
||
:::
|
||
|
||
These micro-runtimes represent the "Pure AOT" endpoint of the continuum. By sacrificing all dynamic flexibility, they enable machine learning to run on devices consuming milliwatts of power, fulfilling the **Energy-Movement Invariant** (formalized in @sec-data-engineering) by keeping all data movement local to the chip.
|
||
|
||
The spectrum of execution strategies, from dynamic eager execution to static graph compilation and specialized micro-runtimes, requires developers to make deliberate trade-offs. The following checkpoint summarizes the key decision points before we address the second core problem.
|
||
|
||
::: {.callout-checkpoint title="Execution Models"}
|
||
|
||
The choice of execution mode determines both developer velocity and model performance.
|
||
|
||
**Debuggability vs. Speed**
|
||
|
||
- [ ] **Eager Mode (Python-First)**: Why does executing ops one-by-one make debugging easy but optimization hard? (Hint: The compiler cannot see the "future" ops to fuse them).
|
||
- [ ] **Graph Mode (Compiler-First)**: Why does building a static graph enable **Kernel Fusion**? (Merging Conv+ReLU saves memory bandwidth).
|
||
|
||
**The Modern Compromise**
|
||
|
||
- [ ] **JIT Compilation**: How does `torch.compile` bridge the gap? (It captures the graph *just in time* to optimize, while falling back to Python for dynamic parts).
|
||
|
||
:::
|
||
|
||
The execution problem determines *when* computation happens and *what* optimizations are possible. Neural network training, however, requires a capability that no amount of clever scheduling can provide: the ability to compute gradients automatically.
|
||
|
||
Consider what training actually requires: for each of millions of parameters, compute how a tiny change would affect the loss. Doing this manually for even a simple three-layer network requires deriving and implementing dozens of partial derivatives. For a modern Transformer with billions of parameters, manual differentiation is economically impossible. A framework that executes efficiently but cannot differentiate can run inference but cannot learn.
|
||
|
||
## Differentiation Problem {#sec-ml-frameworks-differentiation-problem-8b8a}
|
||
|
||
\index{Differentiation Problem!definition}
|
||
\index{Gradient!etymology}
|
||
\index{Automatic Differentiation!definition}
|
||
The differentiation problem asks: how should frameworks compute gradients[^fn-gradient-descent-memory] automatically? Neural network training requires derivatives of a scalar loss $\mathcal{L}$ with respect to millions or billions of parameters, making manual differentiation impractical. Because a single scalar loss depends on all parameters, reverse-mode automatic differentiation (AD)[^fn-autodiff-reverse-mode] is the optimal strategy: one backward pass computes all parameter gradients simultaneously, while forward mode would require a separate pass for each parameter. All major ML frameworks therefore implement reverse-mode AD by default [@baydin2018].
|
||
|
||
[^fn-gradient-descent-memory]: **Automatic Differentiation (AD)**: The "automatically" in the triggering sentence is the key word: AD mechanizes the chain rule as a graph traversal, eliminating the manual derivative computation that made scaling beyond toy networks impractical. The systems trade-off that makes this feasible is the choice of reverse mode, which exploits the many-to-one topology of training (many parameters, one scalar loss) to compute all gradients in a single backward pass. Forward mode would require one pass *per parameter*, making billion-parameter training computationally impossible. \index{Gradient!memory cost}
|
||
|
||
[^fn-autodiff-reverse-mode]: **Reverse-Mode AD**: The $O(1)$-vs-$O(N)$ asymmetry mentioned above has a concrete price: reverse mode must store every intermediate value from the forward pass for use during the backward traversal. For a billion-parameter transformer, these stored activations can consume 3--4$\times$ more memory than the weights themselves. This memory cost is the reason frameworks provide activation checkpointing and gradient accumulation, techniques that trade recomputation time for the memory that reverse-mode AD demands. \index{Automatic Differentiation!reverse mode}
|
||
|
||
Building on the backpropagation algorithm introduced in @sec-neural-computation (where we established that gradients flow backward through the computation graph via the chain rule), this section shifts focus from the mathematics to the systems engineering of differentiation: how frameworks represent computation graphs, manage memory for intermediate values, and orchestrate the backward pass efficiently across accelerators. The framework's role is not to perform calculus but to manage the bookkeeping at scale, which is required for the training algorithms detailed in @sec-model-training. @lst-auto_diff_intro illustrates the core idea with a simple three-operation function:
|
||
|
||
::: {#lst-auto_diff_intro lst-cap="**Automatic Differentiation**: AD decomposes complex functions into elementary operations with known derivatives, enabling gradient computation through arbitrarily deep compositions in O(n) time where n is the number of operations."}
|
||
|
||
```{.python}
|
||
def f(x):
|
||
a = x * x # Square
|
||
b = sin(x) # Sine
|
||
return a * b # Product
|
||
```
|
||
|
||
:::
|
||
|
||
Frameworks decompose this function into elementary operations, each with a known local derivative, and then combine these local derivatives via the chain rule to compute gradients through arbitrary compositions. The systems challenge is implementing this efficiently: the framework must record the computation graph during the forward pass, store intermediate values, and execute the backward pass with minimal memory overhead. The following subsections trace how production frameworks solve each of these problems.
|
||
|
||
### Forward and Reverse Mode Differentiation {#sec-ml-frameworks-forward-reverse-mode-differentiation-f70a}
|
||
|
||
\index{Automatic Differentiation!forward vs. reverse mode}
|
||
Two primary approaches to automatic differentiation exist, and the choice between them (forward mode versus reverse mode) determines whether gradient computation scales with the number of inputs or the number of outputs, a distinction that explains why neural network training universally uses one mode over the other.
|
||
|
||
#### Forward Mode {#sec-ml-frameworks-forward-mode-c3ff}
|
||
|
||
\index{Automatic Differentiation!forward mode}
|
||
\index{Forward Mode AD!dual numbers}
|
||
Neural network training universally uses reverse mode (covered next), but forward mode illuminates *why* reverse mode is necessary.
|
||
\index{Dual Numbers!forward mode AD}
|
||
Forward mode automatic differentiation computes derivatives alongside the original computation, tracking how changes propagate from input to output. This approach mirrors manual derivative computation, making it intuitive to understand and implement.
|
||
|
||
Forward mode's memory requirements are its strength: the method stores only the original value, a single derivative value, and temporary results. Memory usage stays constant regardless of computation depth, making forward mode particularly suitable for embedded systems, real-time applications, and memory-bandwidth-limited systems. However, this comes with a computational cost. Forward mode doubles the Ops term (in **Iron Law** terms) for each input parameter whose derivative is requested. For a model with $N$ parameters, forward mode multiplies total computation by $N$, because each parameter requires a separate forward pass. Reverse mode, by contrast, adds a constant factor of approximately 2 to 3$\times$ regardless of $N$. This asymmetry explains why forward mode is never used for training neural networks, where $N$ ranges from millions to hundreds of billions. This combination of computational scaling with input count but constant memory creates a specific niche: forward mode excels in scenarios with few inputs but many outputs, such as sensitivity analysis, feature importance computation, and online learning with single-example updates.
|
||
|
||
To see the mechanism concretely, consider computing both the value and derivative of $f(x) = x^2 \sin(x)$. @lst-forward_mode_ad shows how forward mode propagates derivative computations alongside every operation, applying the chain rule and product rule at each step:
|
||
|
||
::: {#lst-forward_mode_ad lst-cap="**Forward Mode AD**: Propagates derivatives forward through the computation graph, computing one directional derivative per forward pass with 2$\times$ computational overhead."}
|
||
|
||
```{.python}
|
||
def f(x): # Computing both value and derivative
|
||
# Step 1: x -> x²
|
||
a = x * x # Value: x²
|
||
da = 2 * x # Derivative: 2x
|
||
|
||
# Step 2: x -> sin(x)
|
||
b = sin(x) # Value: sin(x)
|
||
db = cos(x) # Derivative: cos(x)
|
||
|
||
# Step 3: Combine using product rule
|
||
result = a * b # Value: x² * sin(x)
|
||
dresult = a * db + b * da # Derivative: x²*cos(x) + sin(x)*2x
|
||
|
||
return result, dresult
|
||
```
|
||
|
||
:::
|
||
|
||
Forward mode achieves this systematic derivative computation by augmenting each number with its derivative value, creating what mathematicians call a "dual number." @lst-forward_mode_dual traces a concrete execution with x = 2.0, revealing how each intermediate result carries both its value and derivative through the computation:
|
||
|
||
::: {#lst-forward_mode_dual lst-cap="**Dual Number Computation**: Forward mode augments each value with its derivative, doubling memory per intermediate but enabling single-pass gradient computation."}
|
||
|
||
```{.python}
|
||
x = 2.0 # Initial value
|
||
dx = 1.0 # We're tracking derivative with respect to x
|
||
|
||
# Step 1: x²
|
||
a = 4.0 # (2.0)²
|
||
da = 4.0 # 2 * 2.0
|
||
|
||
# Step 2: sin(x)
|
||
b = 0.909 # sin(2.0)
|
||
db = -0.416 # cos(2.0)
|
||
|
||
# Final result
|
||
result = 3.636 # 4.0 * 0.909 = 3.636
|
||
dresult = (
|
||
1.972 # 4.0 * (-0.416) + 0.909 * 4.0 = -1.664 + 3.636 = 1.972
|
||
)
|
||
```
|
||
|
||
:::
|
||
|
||
The dual number trace demonstrates the 2$\times$ computational overhead per input: every arithmetic operation (multiply, sine, product rule combination) is performed twice, once for the value and once for the derivative. For this single-input function, the overhead is acceptable. For a neural network with $N = 100{,}000{,}000$ parameters, computing all gradients would require 100 million such passes, which is why forward mode is restricted to the few-input applications described above.
|
||
|
||
Forward mode's strength in single-input analysis becomes its fatal weakness for training. A neural network has one scalar loss but millions of parameters, and forward mode would require a separate pass for each one---an intractable $O(N)$ cost that explains why no production framework uses forward mode for training. Forward mode remains useful for targeted analyses such as sensitivity analysis (how does changing one pixel affect the prediction?) and feature importance (which input dimensions most influence the output?), where the number of inputs of interest is small.
|
||
|
||
Given forward mode's $O(N)$ scaling with parameter count, we need an entirely different approach for training. Reverse mode provides exactly this: by propagating gradients backward from output to input, it computes all $N$ parameter gradients in a single pass.
|
||
|
||
#### Reverse Mode {#sec-ml-frameworks-reverse-mode-d328}
|
||
|
||
\index{Automatic Differentiation!reverse mode}
|
||
\index{Reverse Mode AD!backpropagation}
|
||
\index{Backpropagation!computational asymmetry}
|
||
Why does every modern ML framework default to reverse mode for training? The answer is computational asymmetry, one of the most consequential design decisions in machine learning software.
|
||
|
||
A neural network has one scalar loss but millions of parameters. Forward mode computes one parameter's gradient per pass, requiring $n$ passes for $n$ parameters. Reverse mode computes all $n$ gradients in a single backward pass. For a model with 100 million parameters, that is the difference between 100 million forward passes and exactly one backward pass, a speedup proportional to the parameter count.
|
||
|
||
This asymmetry makes reverse mode the only viable option for training. Consider a function where $x$ influences the output through two distinct paths. @lst-reverse_simple defines such a function, and @lst-reverse_forward traces its forward and backward computation for a concrete input.
|
||
|
||
::: {#lst-reverse_simple lst-cap="Basic example of reverse mode automatic differentiation"}
|
||
|
||
```{.python}
|
||
def f(x):
|
||
a = x * x # First operation: square x
|
||
b = sin(x) # Second operation: sine of x
|
||
c = a * b # Third operation: multiply results
|
||
return c
|
||
```
|
||
|
||
:::
|
||
|
||
::: {#lst-reverse_forward lst-cap="**Forward and Backward Pass**: The forward pass stores intermediate values; the backward pass propagates gradients from output to input, accumulating contributions from all paths."}
|
||
|
||
```{.python}
|
||
# --- Forward pass: compute and store values ---
|
||
x = 2.0 # Input value
|
||
a = 4.0 # x * x = 2.0 * 2.0 = 4.0
|
||
b = 0.909 # sin(2.0) ≈ 0.909
|
||
c = 3.637 # a * b = 4.0 * 0.909 ≈ 3.637
|
||
|
||
# --- Backward pass: propagate gradients from output ---
|
||
dc/dc = 1.0 # Seed gradient
|
||
|
||
# Through multiplication c = a * b
|
||
dc/da = b # ∂(a*b)/∂a = b = 0.909
|
||
dc/db = a # ∂(a*b)/∂b = a = 4.0
|
||
|
||
# Combine contributions from both paths through x
|
||
# Path 1: x -> x² -> c contribution: 2x * dc/da
|
||
# Path 2: x -> sin(x) -> c contribution: cos(x) * dc/db
|
||
dc/dx = (2 * x * dc/da) + (cos(x) * dc/db)
|
||
= (2 * 2.0 * 0.909) + (cos(2.0) * 4.0)
|
||
= 3.636 + (-0.416 * 4.0)
|
||
= 1.972 # 3.636 - 1.664 = 1.972
|
||
```
|
||
|
||
:::
|
||
|
||
The critical observation is that this single backward pass computed dc/dx regardless of how many paths connect x to c. In a neural network, each weight can affect the loss through thousands of paths across layers, and reverse mode handles them all in one traversal. This is why training a `{python} GPT3Context.gpt3_params_b_str` B parameter model like GPT-3 is feasible at all: reverse mode's O(1) backward passes (relative to parameter count) keeps gradient computation tractable.
|
||
|
||
Translating this mathematical elegance into a working system requires solving a concrete engineering problem: the backward pass needs values computed during the forward pass, so the framework must decide what to store, when to store it, and when to free it. Modern frameworks accomplish this through computational graphs and automatic gradient accumulation[^fn-gradient-accumulation-batch].
|
||
|
||
[^fn-gradient-accumulation-batch]: **Gradient Accumulation**: A direct answer to the "when to free it" question: the framework breaks a large logical batch into smaller mini-batches processed sequentially, freeing activation memory after each mini-batch's backward pass and accumulating only the small gradient tensors. This lets a system simulate a batch size of 4,096 using the memory footprint of a 64-sample batch, trading sequential compute time for a 60$\times$ reduction in peak activation memory. Without this technique, many production training configurations would exceed accelerator memory on the first batch. \index{Gradient Accumulation!memory trade-off}
|
||
|
||
@lst-reverse_simple_nn illustrates this with a two-layer network, showing both the forward computation that stores intermediate values and the backward pass that consumes them to produce gradients for every parameter simultaneously.
|
||
|
||
::: {#lst-reverse_simple_nn lst-cap="**Reverse Mode in a Neural Network**: The forward pass computes and stores intermediate values; the backward pass walks the computation in reverse to produce gradients for every parameter."}
|
||
|
||
```{.python}
|
||
def simple_network(x, w1, w2):
|
||
hidden = x * w1 # First layer
|
||
activated = max(0, hidden) # ReLU activation
|
||
output = activated * w2 # Second layer
|
||
return output
|
||
|
||
|
||
# --- Forward pass stores intermediates ---
|
||
# x=1.0, w1=2.0, w2=3.0
|
||
# hidden=2.0, activated=2.0, output=6.0
|
||
|
||
# --- Backward pass consumes them ---
|
||
d_output = 1.0 # Seed gradient
|
||
d_w2 = activated # = 2.0
|
||
d_activated = w2 # = 3.0
|
||
d_hidden = d_activated * (1 if hidden > 0 else 0) # ReLU gate: 3.0
|
||
d_w1 = x * d_hidden # = 3.0
|
||
d_x = w1 * d_hidden # = 6.0
|
||
```
|
||
|
||
:::
|
||
|
||
Three implementation requirements emerge from this example. First, the framework must track dependencies between operations to determine the correct reverse traversal order. Second, intermediate values (hidden, activated) must persist in memory until the backward pass consumes them. Third, every operation needs both a forward implementation and a corresponding backward rule. These requirements define the engineering surface of any AD system, and the second requirement, memory persistence, turns out to be the dominant cost.
|
||
|
||
```{python}
|
||
#| label: gpt3-memory-footprint
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ GPT-3 MEMORY FOOTPRINT
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: @sec-ml-frameworks-memory-management-strategies-b008 — opening
|
||
# │ paragraph establishing memory as the binding training constraint
|
||
# │
|
||
# │ Goal: Quantify GPT-3 FP16 weight memory (350 GB) to show that a 175B-param
|
||
# │ model far exceeds any single GPU's capacity.
|
||
# │ Show: "175 B parameters → 350 GB" — inline in opening prose paragraph.
|
||
# │ How: model_memory() formula; m_as(Bparam) for param count extraction.
|
||
# │
|
||
# │ Imports: mlsysim.core.constants (GPT3_PARAMS, BYTES_FP16, GB, Bparam),
|
||
# │ mlsysim.formulas (model_memory), mlsysim.book (fmt)
|
||
# │ Exports: gpt3_params_b_str, gpt3_fp16_gb_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsysim.core.constants import GPT3_PARAMS, BYTES_FP16, GB, Bparam
|
||
from mlsysim.core.formulas import model_memory
|
||
from mlsysim.fmt import fmt, check
|
||
|
||
# ┌── LEGO ───────────────────────────────────────────────
|
||
class GPT3MemoryFootprint:
|
||
"""GPT-3 FP16 weight memory to show single-GPU capacity is exceeded."""
|
||
|
||
# ┌── 1. LOAD (Constants) ──────────────────────────────────────────────
|
||
params_b = GPT3_PARAMS.m_as(Bparam) # 175 billion
|
||
|
||
# ┌── 2. EXECUTE (The Compute) ────────────────────────────────────────
|
||
fp16_gb = model_memory(GPT3_PARAMS, BYTES_FP16, GB) # 350 GB
|
||
|
||
# ┌── 4. OUTPUT (Formatting) ─────────────────────────────────────────────
|
||
gpt3_params_b_str = fmt(params_b, precision=0, commas=False) # e.g. "175"
|
||
gpt3_fp16_gb_str = fmt(fp16_gb, precision=0, commas=False) # e.g. "350"
|
||
```
|
||
|
||
#### Memory Management Strategies {#sec-ml-frameworks-memory-management-strategies-b008}
|
||
|
||
A `{python} GPT3Context.gpt3_params_b_str` B parameter model in FP16 requires `{python} GPT3MemoryFootprint.gpt3_fp16_gb_str` GB just for weights, far exceeding any single GPU's memory. Weights, however, are only the beginning: reverse mode AD also stores every intermediate activation from the forward pass for use during the backward pass. For a 100-layer network processing a batch of 64 images, these stored activations can consume 8 to 12 GB on top of the model weights, gradients, and optimizer state. Memory, not compute, is the binding constraint on what models a framework can train.
|
||
|
||
The problem scales linearly with depth. @lst-reverse_memory shows how each layer in a deeper network adds another activation tensor that must persist until the backward pass reaches that layer.
|
||
|
||
::: {#lst-reverse_memory lst-cap="**Reverse Mode Memory Management**: Stores intermediate values for gradient computation during backpropagation."}
|
||
|
||
```{.python}
|
||
def deep_network(x, w1, w2, w3):
|
||
# Forward pass - must store intermediates
|
||
hidden1 = x * w1
|
||
activated1 = max(0, hidden1) # Store for backward
|
||
hidden2 = activated1 * w2
|
||
activated2 = max(0, hidden2) # Store for backward
|
||
output = activated2 * w3
|
||
return output
|
||
```
|
||
|
||
:::
|
||
|
||
\index{Activation Checkpointing!definition}
|
||
Frameworks attack this memory wall with two primary strategies. The first is *activation checkpointing*\index{Gradient Checkpointing!memory-compute trade-off} (also called gradient checkpointing): rather than storing every activation, the framework stores only selected checkpoints and recomputes the intermediate activations during the backward pass. @sec-model-training examines checkpointing strategies in detail, including optimal checkpoint placement algorithms. @lst-checkpoint_scheme shows the pattern: save activations at checkpoint boundaries, recompute everything between them.
|
||
|
||
::: {#lst-checkpoint_scheme lst-cap="**Checkpointing**: Reduces memory usage by selectively storing intermediate activations during forward passes. Frameworks balance storage needs with computational efficiency to optimize model training."}
|
||
|
||
```{.python}
|
||
# Conceptual representation of checkpointing
|
||
checkpoint1 = save_for_backward(activation1)
|
||
# Intermediate activations can be recomputed
|
||
checkpoint2 = save_for_backward(activation4)
|
||
# Framework balances storage vs recomputation
|
||
```
|
||
|
||
:::
|
||
|
||
\index{Kernel Fusion!eliminating intermediate allocations}
|
||
The second strategy is *operation fusion*\index{Kernel Fusion}[^fn-operation-fusion-bandwidth]. Rather than executing matrix multiplication, bias addition, and ReLU as three separate operations, each writing intermediate results to memory, frameworks fuse them into a single kernel. This eliminates intermediate memory allocations entirely and achieves 2 to 3$\times$ speedup on modern GPUs by keeping data in registers and caches.
|
||
|
||
[^fn-operation-fusion-bandwidth]: **Operation Fusion**: The 2--3$\times$ speedup cited in the triggering sentence arises from a specific hardware fact: GPU registers and L1 cache deliver 10--100$\times$ higher bandwidth than HBM. When matmul, bias, and ReLU execute as separate kernels, each writes its output to HBM and the next reads it back, a round-trip that dominates execution time for memory-bound operations. Fusing them into one kernel keeps intermediates in registers, converting three HBM round-trips into zero. \index{Operation Fusion!bandwidth reduction}
|
||
|
||
The backward pass itself benefits from hardware-specific optimization. Rather than directly translating the mathematical definition of a convolution gradient into code, frameworks implement specialized backward kernels that exploit memory access patterns and hardware capabilities of modern accelerators [@chetlur2014cudnn]. These optimizations, checkpointing, fusion, and specialized kernels, work together to make training practical for architectures that would otherwise exhaust GPU memory in a single forward pass.
|
||
|
||
### Framework Implementation of Automatic Differentiation {#sec-ml-frameworks-framework-implementation-automatic-differentiation-1407}
|
||
|
||
Checkpointing, fusion, and specialized kernels solve the systems problems of AD. Practitioners, however, never interact with these mechanisms directly. Instead, frameworks expose AD through high-level APIs that hide the underlying machinery behind simple method calls. A PyTorch training loop---`optimizer.zero_grad()`, forward pass, `loss.backward()`, `optimizer.step()`---appears to be four function calls. Behind each call, however, the framework tracks all operations during the forward pass, builds and maintains the computation graph, manages memory for intermediate values, schedules gradient computations efficiently, and interfaces with hardware accelerators. The same graph machinery extends to advanced scenarios: nested `torch.autograd.grad` calls compute second-order derivatives for techniques like natural gradient descent, and mixed-precision contexts (`autocast`) select reduced-precision kernels for compute-intensive operations while maintaining FP32 for numerical stability.
|
||
|
||
#### PyTorch Autograd Internals {#sec-ml-frameworks-pytorch-autograd-internals-4fa0}
|
||
|
||
\index{Autograd Engine!internals}
|
||
\index{PyTorch!autograd system}
|
||
The autograd system is the framework component that solves the differentiation problem described in @sec-ml-frameworks-three-problems-every-framework-must-solve-317d. Three systems principles govern its design: the data structure that enables efficient gradient computation, the memory cost of maintaining that data structure, and the control mechanisms that production systems require. Understanding these principles explains why training consumes 100$\times$ more memory than inference for the same model, and why frameworks provide specific mechanisms to manage that cost.
|
||
|
||
##### Principle 1: The Reverse-Linked Graph Structure { .unnumbered}
|
||
|
||
\index{Autograd!reverse-linked graph}
|
||
\index{Gradient Function Nodes!computation graph}
|
||
During the forward pass, the autograd system constructs a reverse-linked graph of `Function` nodes. Each node records the operation performed and stores references to the tensors it needs for gradient computation. This graph is the data structure that makes reverse-mode automatic differentiation possible: regardless of how many parameters a model has, a single backward pass through this graph computes all gradients. For a model with $N$ parameters, reverse-mode AD requires $O(1)$ backward passes (compared to $O(N)$ for forward-mode), which is why every major framework implements this approach.
|
||
|
||
Concretely, every tensor produced by a differentiable operation stores a `grad_fn` attribute pointing to the `Function` that created it. Each `Function` links to its inputs through `next_functions`, forming a chain from the loss back to the leaf parameters. @lst-grad-fn-chain illustrates this structure for a simple computation:
|
||
|
||
::: {#lst-grad-fn-chain lst-cap="**Reverse-Linked Graph Structure**: Each tensor's `grad_fn` links to the `Function` that created it, forming a reverse chain from output to leaf parameters that enables O(1) backward passes."}
|
||
|
||
```{.python}
|
||
import torch
|
||
|
||
x = torch.tensor([2.0], requires_grad=True)
|
||
y = x * 3
|
||
z = y.pow(2)
|
||
|
||
# Traverse the reverse-linked graph
|
||
print(z.grad_fn) # PowBackward0
|
||
print(z.grad_fn.next_functions) # -> MulBackward0
|
||
print(
|
||
z.grad_fn.next_functions[0][0].next_functions
|
||
) # -> AccumulateGrad (leaf)
|
||
```
|
||
|
||
:::
|
||
|
||
The traversal reveals the chain: `PowBackward0` (for `z = y**2`) links to `MulBackward0` (for `y = x * 3`), which terminates at `AccumulateGrad` for the leaf tensor `x`. Leaf tensors are the endpoints of the graph where gradients accumulate into the `.grad` attribute rather than propagating further. The tuple format `(Function, index)` tracks which output of a multi-output operation each connection corresponds to.
|
||
|
||
This reverse-linked structure has a critical systems implication: the entire graph must remain in memory from the time a tensor is created until the backward pass consumes it. The graph itself is lightweight (pointers and metadata), but the tensors it references are not.
|
||
|
||
The graph structure thus introduces a second implication: memory consumption scales with model depth.
|
||
|
||
##### Principle 2: The Memory-Compute Trade-off {#sec-ml-frameworks-principle-2-memorycompute-tradeoff-19f2}
|
||
|
||
Every activation saved for the backward pass persists in memory until consumed by gradient computation. This is the primary reason training memory dwarfs inference memory. Computing the gradient of most operations requires values from the forward pass: multiplication needs both inputs ($\frac{\partial}{\partial x}(x \cdot y) = y$), exponentiation needs the base ($\frac{\partial}{\partial x}(x^2) = 2x$), and softmax needs its output values. The autograd system stores these tensors in each `Function` node's `saved_tensors` attribute.
|
||
|
||
```{python}
|
||
#| label: resnet50-memory-breakdown
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ RESNET-50 MEMORY BREAKDOWN
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: @sec-ml-frameworks-principle-2-memorycompute-tradeoff-19f2 —
|
||
# │ "Principle 2: The Memory-Compute Trade-off" subsection
|
||
# │
|
||
# │ Goal: Quantify ResNet-50 training vs. inference memory (25.6 M params,
|
||
# │ ~102 MB FP32 weights, 10–15 GB training footprint) to show the ~100×
|
||
# │ ratio driving the $D_{\text{vol}}$ term in the Iron Law.
|
||
# │ Show: "~102 MB inference vs. 10–15 GB training" — inline in Principle 2 prose
|
||
# │ and in "The Administrative Tax" callout (@sec-ml-frameworks-tensor-structure-dimensions-4a14).
|
||
# │ How: model.size_in_bytes() helpers; m_as(MB/Mparam) for unit extraction.
|
||
# │
|
||
# │ Note: PERSISTENT — ResNetMemory reused at line ~2224 (Abstraction section),
|
||
# │ line ~2324 (Administrative Tax callout), line ~3143 (nn.Module intro).
|
||
# │
|
||
# │ Imports: mlsysim.core.constants (BYTES_FP32, BYTES_ADAM_STATE, MB, Mparam),
|
||
# │ mlsysim.book (fmt, check)
|
||
# │ Exports: resnet_params_m_str, resnet_fp32_mb_str, resnet_adam_mb_str,
|
||
# │ resnet_training_min_gb_str, resnet_training_max_gb_str, resnet_training_ratio_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsysim import Models
|
||
from mlsysim.core.constants import BYTES_FP32, BYTES_ADAM_STATE, MB, Mparam
|
||
from mlsysim.fmt import fmt, check
|
||
|
||
# ┌── LEGO ───────────────────────────────────────────────
|
||
class ResNetMemory:
|
||
"""
|
||
Namespace for ResNet-50 Memory Breakdown.
|
||
Scenario: Comparing training vs inference memory costs.
|
||
"""
|
||
|
||
# ┌── 1. LOAD (Constants) ───────────────────────────────────────────────
|
||
model = Models.ResNet50
|
||
training_min_gb = 10
|
||
training_max_gb = 15
|
||
|
||
# ┌── 2. EXECUTE (The Compute) ─────────────────────────────────────────
|
||
params_m = model.parameters.m_as(Mparam)
|
||
|
||
fp32_mb = model.size_in_bytes(BYTES_FP32).m_as(MB)
|
||
adam_mb = model.size_in_bytes(BYTES_ADAM_STATE).m_as(MB)
|
||
|
||
training_ratio = (training_min_gb * KIB_TO_BYTES) / fp32_mb
|
||
|
||
# ┌── 4. OUTPUT (Formatting) ──────────────────────────────────────────────
|
||
resnet_params_m_str = fmt(params_m, precision=1, commas=False)
|
||
resnet_fp32_mb_str = fmt(fp32_mb, precision=0, commas=False)
|
||
resnet_adam_mb_str = fmt(adam_mb, precision=0, commas=False)
|
||
|
||
resnet_training_min_gb_str = fmt(training_min_gb, precision=0, commas=False)
|
||
resnet_training_max_gb_str = fmt(training_max_gb, precision=0, commas=False)
|
||
resnet_training_ratio_str = fmt(training_ratio, precision=0, commas=False)
|
||
|
||
# Note: Use ResNetMemory.resnet_params_m_str directly.
|
||
```
|
||
|
||
For a network with $N_L$ layers, the system must save approximately $N_L$ activation tensors, one per layer, for the entire batch. Consider a concrete example: ResNet-50 has `{python} ResNetMemory.resnet_params_m_str` M parameters (~`{python} ResNetMemory.resnet_fp32_mb_str` MB in FP32) and processes batch size 64 with $224\times224$ images. The memory breakdown reveals the scale of this trade-off. Forward activations alone consume approximately 8--12 GB (varying by implementation and checkpointing strategy). Parameter gradients add another ~`{python} ResNetMemory.resnet_fp32_mb_str` MB (the same size as the parameters themselves), and Adam optimizer state contributes ~`{python} ResNetMemory.resnet_adam_mb_str` MB for its two momentum buffers per parameter. The total training footprint reaches `{python} ResNetMemory.resnet_training_min_gb_str`--`{python} ResNetMemory.resnet_training_max_gb_str` GB, compared to just ~`{python} ResNetMemory.resnet_fp32_mb_str` MB for inference alone.
|
||
|
||
This `{python} ResNetMemory.resnet_training_ratio_str`$\times$ ratio between training and inference memory quantifies why the Data Movement ($D_{\text{vol}}$) term dominates training latency in the **Iron Law**. During training, the framework must write all activations to memory during the forward pass and read them back during the backward pass, doubling the memory traffic compared to inference alone. For a complete derivation of the four-component training memory equation ($M_{total} = M_{weights} + M_{gradients} + M_{optimizer} + M_{activations}$) and worked examples at larger model scales, see @sec-algorithm-foundations-true-cost-training-memory-e54e.
|
||
|
||
Frameworks provide two primary mechanisms to manage this trade-off. **Gradient checkpointing**\index{Gradient Checkpointing!recomputation strategy} [@chen2016training] trades recomputation for memory: instead of saving all activations, the framework saves only a subset and recomputes the rest during the backward pass. This typically reduces activation memory by 50--90% at the cost of 20--33% additional compute (with optimal $\sqrt{n}$ checkpoint placement). In Iron Law terms, checkpointing increases the $O$ term (recomputation) to reduce the $D_{\text{vol}}$ term (memory traffic). **Tensor detachment** provides a complementary mechanism: calling `.detach()` on a tensor removes it from the computation graph entirely, preventing the framework from saving activations through that path. This is essential for transfer learning, where pretrained layers should not accumulate gradients, and reduces the $D_{\text{vol}}$ term by eliminating unnecessary activation storage.
|
||
|
||
Mixed-precision training offers a third approach, reducing activation memory by storing values in lower precision formats. The detailed trade-offs of mixed precision are examined later in this chapter.
|
||
|
||
##### Principle 3: Extensibility and Control { .unnumbered}
|
||
|
||
Production training systems require fine-grained control over gradient flow that goes beyond the default backward pass. Three categories of control arise in practice. First, **selective gradient computation**: transfer learning and fine-tuning require freezing subsets of parameters, which the framework supports through `requires_grad=False` flags and the `.detach()` mechanism described above. Second, **gradient inspection and modification**: debugging vanishing or exploding gradients, implementing per-tensor gradient clipping, and logging gradient statistics all require intercepting gradients mid-computation, which frameworks expose through hook APIs. Third, **custom differentiation rules**: operations not in the framework's built-in library (custom CUDA kernels, novel activation functions, domain-specific operations) require user-defined forward and backward implementations.
|
||
|
||
These control mechanisms share a common systems design: they are callback-based extensions that the autograd engine invokes at specific points during graph traversal, without modifying the core differentiation algorithm. This extensibility pattern allows the framework to maintain a single optimized backward pass while supporting arbitrarily complex gradient manipulation. The following examples demonstrate these mechanisms in practice, showing how to inspect and control PyTorch's autograd system.
|
||
|
||
###### Retaining the Computation Graph {.unnumbered}
|
||
|
||
\index{Graph Retention!multiple backward passes}
|
||
By default, `backward()` frees the graph after use. To run multiple backward passes (for multi-loss optimization or higher-order derivatives), use `retain_graph=True` at the cost of doubled memory, as shown in @lst-retain-graph.
|
||
|
||
::: {#lst-retain-graph lst-cap="**Retaining Computation Graph**: Use retain_graph=True to run multiple backward passes on the same graph, useful for multi-loss optimization or higher-order derivatives."}
|
||
|
||
```{.python}
|
||
x = torch.tensor([2.0], requires_grad=True)
|
||
y = x**2
|
||
|
||
# First backward pass - graph is freed by default
|
||
y.backward()
|
||
print(x.grad) # tensor([4.])
|
||
|
||
# Second backward on SAME y fails - graph was freed
|
||
# y.backward() # RuntimeError: graph already freed!
|
||
|
||
# Solution: retain_graph=True keeps graph for multiple passes
|
||
x.grad.zero_()
|
||
y = x**2
|
||
y.backward(retain_graph=True) # First pass, keep graph
|
||
y.backward() # Second pass works, graph freed after this
|
||
```
|
||
|
||
:::
|
||
|
||
###### Gradient Accumulation Behavior {.unnumbered}
|
||
|
||
Gradients accumulate across backward passes by default. As @lst-gradient-accumulation demonstrates, without calling `zero_grad()`, successive backward passes sum their gradients:
|
||
|
||
::: {#lst-gradient-accumulation lst-cap="**Gradient Accumulation Behavior**: Gradients accumulate across backward passes by default. Use zero_grad() to reset gradients before each optimization step."}
|
||
|
||
```{.python}
|
||
x = torch.tensor([1.0], requires_grad=True)
|
||
|
||
# First backward pass
|
||
y = x * 2
|
||
y.backward()
|
||
print(x.grad) # tensor([2.])
|
||
|
||
# Second backward pass (without zero_grad)
|
||
y = x * 3
|
||
y.backward()
|
||
print(x.grad) # tensor([5.]) = 2 + 3 (accumulated!)
|
||
|
||
# Reset gradients
|
||
x.grad.zero_()
|
||
y = x * 3
|
||
y.backward()
|
||
print(x.grad) # tensor([3.])
|
||
```
|
||
|
||
:::
|
||
|
||
###### Custom Autograd Functions {.unnumbered}
|
||
|
||
\index{Custom Autograd Functions!backward implementation}
|
||
When implementing custom operations, the developer explicitly specifies what to save for the backward pass and how to compute gradients. @lst-custom-autograd-function shows the pattern:
|
||
|
||
::: {#lst-custom-autograd-function lst-cap="**Custom Autograd Function**: Implement forward and backward methods to define custom differentiable operations, explicitly specifying tensors to save for gradient computation."}
|
||
|
||
```{.python}
|
||
class MultiplyAdd(torch.autograd.Function):
|
||
@staticmethod
|
||
def forward(ctx, x, y, z):
|
||
# Save tensors needed for backward
|
||
ctx.save_for_backward(x, y)
|
||
return x * y + z
|
||
|
||
@staticmethod
|
||
def backward(ctx, grad_output):
|
||
# Retrieve saved tensors
|
||
x, y = ctx.saved_tensors
|
||
|
||
# Compute gradients using chain rule
|
||
grad_x = grad_output * y # ∂L/∂x = ∂L/∂out * ∂out/∂x
|
||
grad_y = grad_output * x # ∂L/∂y = ∂L/∂out * ∂out/∂y
|
||
grad_z = grad_output # ∂L/∂z = ∂L/∂out * 1
|
||
|
||
return grad_x, grad_y, grad_z
|
||
|
||
|
||
# Usage
|
||
x = torch.tensor([2.0], requires_grad=True)
|
||
y = torch.tensor([3.0], requires_grad=True)
|
||
z = torch.tensor([1.0], requires_grad=True)
|
||
|
||
output = MultiplyAdd.apply(x, y, z)
|
||
output.backward()
|
||
|
||
print(
|
||
x.grad, y.grad, z.grad
|
||
) # tensor([3.]), tensor([2.]), tensor([1.])
|
||
```
|
||
|
||
:::
|
||
|
||
###### Gradient Hooks {.unnumbered}
|
||
|
||
\index{Gradient Hooks!inspection and modification}
|
||
Register hooks on tensors to inspect or modify gradients during backpropagation, as shown in @lst-gradient-hooks:
|
||
|
||
::: {#lst-gradient-hooks lst-cap="**Gradient Hooks**: Register hooks on tensors to inspect or modify gradients during backpropagation, useful for debugging, gradient clipping, or custom gradient manipulation."}
|
||
|
||
```{.python}
|
||
def gradient_hook(grad):
|
||
print(f"Gradient: {grad}")
|
||
# Modify gradient (e.g., gradient clipping)
|
||
return grad.clamp(-1.0, 1.0)
|
||
|
||
|
||
x = torch.tensor([2.0], requires_grad=True)
|
||
x.register_hook(gradient_hook)
|
||
|
||
y = x * 10
|
||
y.backward()
|
||
# Prints: Gradient: tensor([10.])
|
||
# x.grad contains clamped value: tensor([1.])
|
||
```
|
||
|
||
:::
|
||
|
||
###### Detach vs. Data {.unnumbered}
|
||
|
||
\index{Gradient Flow!detaching tensors}\index{Gradient Flow!breaking with detach}
|
||
Use `.detach()` to safely break gradient flow. @lst-detach-vs-data illustrates how the legacy `.data` attribute can silently corrupt gradient computation through in-place operations:
|
||
|
||
::: {#lst-detach-vs-data lst-cap="**Safe Gradient Detachment**: Use `.detach()` to safely break gradient flow. The legacy `.data` attribute can silently corrupt gradients through in-place operations."}
|
||
|
||
```{.python}
|
||
x = torch.tensor([1.0], requires_grad=True)
|
||
y = x * 2
|
||
|
||
# SAFE: .detach() creates a new tensor that shares storage
|
||
# but is not part of the computation graph
|
||
z_safe = y.detach()
|
||
z_safe.mul_(100) # In-place op on detached tensor
|
||
# y's data IS modified (shared storage), but autograd graph is intact
|
||
|
||
# DANGEROUS: .data bypasses autograd entirely
|
||
# In-place modifications corrupt the computation graph
|
||
z_unsafe = y.data
|
||
z_unsafe.mul_(100) # This modifies y's underlying storage!
|
||
# y.backward() now computes wrong gradients
|
||
|
||
# Best practice: always use .detach() for inference
|
||
with torch.no_grad():
|
||
inference_output = model(x).detach()
|
||
```
|
||
|
||
:::
|
||
|
||
These three principles connect directly to the framework's role as a compiler for the **Silicon Contract**. The reverse-linked graph determines which operations the backward pass must execute (the $O$ term). The memory-compute trade-off governs how much data the framework must move through the memory hierarchy (the $D_{\text{vol}}$ term). And the extensibility mechanisms allow engineers to tune both terms for their specific workload. The interaction between autograd memory management and numerical precision leads naturally to mixed-precision training, which further reduces the $D_{\text{vol}}$ term.
|
||
|
||
#### Mixed-Precision Training Support {#sec-ml-frameworks-mixedprecision-training-support-d31d}
|
||
|
||
\index{Mixed Precision!FP16 vs. FP32 trade-offs}
|
||
Mixed precision exploits a hardware asymmetry to improve two Iron Law terms simultaneously: Tensor Cores execute FP16 matrix multiplications at 2$\times$ the throughput of FP32 (increasing effective $O/R_{\text{peak}}$), while FP16 activations halve the memory footprint (reducing $D_{\text{vol}}$). Improving both terms simultaneously is rare; most optimizations improve one at the expense of the other.
|
||
|
||
Frameworks exploit this through automatic mixed-precision APIs that select reduced precision for compute-intensive operations while maintaining FP32 where numerical stability demands it. Inside these APIs, frameworks automatically apply precision rules: matrix multiplications and convolutions use FP16 for bandwidth efficiency, while numerically sensitive operations like softmax and layer normalization remain in FP32. This selective precision maintains accuracy while achieving speedups on modern GPUs with specialized hardware units. Because FP16 has a narrower dynamic range than FP32, gradients can underflow to zero during backpropagation. Loss scaling addresses this by multiplying the loss by a large factor before the backward pass, then dividing gradients by the same factor afterward.
|
||
|
||
Frameworks also support multiple precision formats including FP16, BF16[^fn-bf16-design], and TF32, each with different trade-offs between range and precision. BF16 maintains FP32's dynamic range, simplifying training by eliminating most gradient underflow issues and removing the need for loss scaling entirely. @sec-model-training examines the mechanics of mixed-precision training in detail, including loss scaling algorithms, memory savings analysis, and numerical stability considerations. @lst-autocast-usage demonstrates PyTorch's mixed precision API: the `autocast` context manager automatically selects FP16 for compute-intensive operations while `GradScaler` prevents gradient underflow by dynamically scaling loss values.
|
||
|
||
[^fn-bf16-design]: **BFloat16 Design Rationale**: Developed by Google Brain circa 2018 specifically for TPU training stability, BF16 preserves FP32's 8-bit exponent range while halving memory footprint — an explicit trade-off of mantissa precision (7 bits vs. FP16's 10) for dynamic range. The critical consequence is loss scaling elimination: FP16's 5-bit exponent causes gradient underflow for values below $6 \times 10^{-5}$, requiring manual loss scaling to keep gradients in range. BF16's FP32-matched exponent makes this entire class of training instability impossible, which is why BF16 and FP16 are not interchangeable: BF16 is preferred when training stability matters; FP16 is preferred when numerical precision matters more than gradient stability. \index{BFloat16!design rationale}
|
||
|
||
::: {#lst-autocast-usage lst-cap="**Mixed-Precision API**: Modern frameworks provide automatic mixed-precision support through context managers that handle precision selection and numerical stability."}
|
||
|
||
```{.python}
|
||
import torch
|
||
from torch.amp import autocast, GradScaler
|
||
|
||
model = MyModel().cuda()
|
||
optimizer = torch.optim.Adam(model.parameters())
|
||
scaler = GradScaler("cuda")
|
||
|
||
for inputs, targets in dataloader:
|
||
inputs, targets = inputs.cuda(), targets.cuda()
|
||
optimizer.zero_grad()
|
||
|
||
# Framework automatically selects precision per operation
|
||
with autocast(device_type="cuda", dtype=torch.float16):
|
||
outputs = model(inputs)
|
||
loss = criterion(outputs, targets)
|
||
|
||
# GradScaler handles gradient scaling for numerical stability
|
||
scaler.scale(loss).backward()
|
||
scaler.step(optimizer)
|
||
scaler.update()
|
||
```
|
||
|
||
:::
|
||
|
||
BF16 training typically does not require loss scaling, as @lst-bf16-training demonstrates.
|
||
|
||
::: {#lst-bf16-training lst-cap="**BF16 Training**: BF16 maintains FP32's dynamic range, eliminating the need for loss scaling that FP16 requires."}
|
||
|
||
```{.python}
|
||
# BF16 training typically does not require loss scaling
|
||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||
outputs = model(inputs)
|
||
loss = criterion(outputs, targets)
|
||
loss.backward() # No GradScaler needed
|
||
optimizer.step()
|
||
```
|
||
|
||
:::
|
||
|
||
```{python}
|
||
#| label: model-7b-optimizer-state
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ 7B MODEL OPTIMIZER STATE MEMORY
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: @sec-ml-frameworks-pytorch-autograd-internals-4fa0 Optimizer State
|
||
# │ and Checkpointing subsection discussing Adam memory overhead
|
||
# │
|
||
# │ Goal: Quantify the memory overhead of Adam optimizer state for a 7B-param
|
||
# │ model: FP16 weights (14 GB) + Adam state (56 GB) = ~98 GB total,
|
||
# │ showing that optimizer state dominates training memory.
|
||
# │ Show: "~98 GB total (14 GB weights + 56 GB optimizer state)" — inline in
|
||
# │ Optimizer State prose; model_7b_fp16_gb_str reused later in nn.Module
|
||
# │ serialization and Fallacies sections.
|
||
# │ How: model_memory() for both BYTES_FP16 and BYTES_ADAM_STATE quantities.
|
||
# │
|
||
# │ Note: PERSISTENT — Model7B.model_7b_fp16_gb_str reused at line ~3170
|
||
# │ (System-Level Operations memory manager paragraph), line ~3267
|
||
# │ (nn.Module Principle 3 serialization paragraph), line ~4130
|
||
# │ (Fallacies: batch size Fallacy).
|
||
# │
|
||
# │ Imports: mlsysim.core.constants (BYTES_FP16, BYTES_ADAM_STATE, GB, ureg),
|
||
# │ mlsysim.formulas (model_memory), mlsysim.book (fmt, check)
|
||
# │ Exports: Model7B.model_7b_fp16_gb_str, Model7B.model_7b_adam_gb_str,
|
||
# │ Model7B.model_7b_total_gb_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsysim.core.constants import BYTES_FP16, BYTES_ADAM_STATE, GB, ureg
|
||
from mlsysim.core.formulas import model_memory
|
||
from mlsysim.fmt import fmt, check
|
||
|
||
# ┌── LEGO ───────────────────────────────────────────────
|
||
class Model7B:
|
||
"""
|
||
Namespace for 7B Model Memory.
|
||
Scenario: Optimizer state overhead.
|
||
"""
|
||
|
||
# ┌── 1. LOAD (Constants) ───────────────────────────────────────────────
|
||
params = 7e9 * ureg.param
|
||
|
||
# ┌── 2. EXECUTE (The Compute) ─────────────────────────────────────────
|
||
fp16_gb = model_memory(params, BYTES_FP16, GB)
|
||
adam_gb = model_memory(params, BYTES_ADAM_STATE, GB)
|
||
total_gb = fp16_gb + adam_gb
|
||
|
||
# ┌── 4. OUTPUT (Formatting) ──────────────────────────────────────────────
|
||
model_7b_fp16_gb_str = fmt(fp16_gb, precision=0, commas=False)
|
||
model_7b_adam_gb_str = fmt(adam_gb, precision=0, commas=False)
|
||
model_7b_total_gb_str = fmt(total_gb, precision=0, commas=False)
|
||
|
||
# Note: Use Model7B.model_7b_total_gb_str directly.
|
||
```
|
||
|
||
Resuming training after interruption requires restoring model weights and optimizer state together: momentum buffers, adaptive learning rates, and gradient statistics. For Adam, optimizer state typically quintuples the memory footprint beyond weights alone (since two FP32 states are stored for each FP16 parameter), meaning a 7B-parameter model requires approximately `{python} Model7B.model_7b_total_gb_str` GB total (`{python} Model7B.model_7b_fp16_gb_str` GB weights + `{python} Model7B.model_7b_adam_gb_str` GB optimizer state). Checkpoint size therefore bounds recovery speed after failure, connecting fault tolerance directly to the Iron Law's $D_{\text{vol}}$ term.
|
||
|
||
@sec-model-training covers optimizer memory requirements and optimization strategies for large-scale training, where checkpoint size becomes a binding constraint. Frameworks provide the `state_dict()` interface to access optimizer state for serialization (@lst-state-dict-interface), and resuming training requires loading both model parameters and optimizer state (@lst-checkpoint-save-load).
|
||
|
||
::: {#lst-state-dict-interface lst-cap="**State Dictionary Interface**: Optimizers expose internal state through state_dict(), enabling serialization of momentum buffers and adaptive learning rate estimates for checkpointing."}
|
||
|
||
```{.python}
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.optim as optim
|
||
|
||
model = nn.Linear(10, 5)
|
||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||
|
||
# After training steps, optimizer accumulates state
|
||
loss = model(torch.randn(3, 10)).sum()
|
||
loss.backward()
|
||
optimizer.step()
|
||
|
||
# Access state for checkpointing
|
||
state = optimizer.state_dict()
|
||
# Contains: {'state': {...}, 'param_groups': [{'lr': 0.001, ...}]}
|
||
```
|
||
|
||
:::
|
||
|
||
::: {#lst-checkpoint-save-load lst-cap="**Checkpoint Save and Load**: Save both model parameters and optimizer state to properly resume training with correct momentum and adaptive learning rate values."}
|
||
|
||
```{.python}
|
||
# Saving checkpoint
|
||
checkpoint = {
|
||
"epoch": epoch,
|
||
"model_state_dict": model.state_dict(),
|
||
"optimizer_state_dict": optimizer.state_dict(),
|
||
}
|
||
torch.save(checkpoint, "checkpoint.pt")
|
||
|
||
# Resuming training
|
||
checkpoint = torch.load("checkpoint.pt")
|
||
model.load_state_dict(checkpoint["model_state_dict"])
|
||
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||
```
|
||
|
||
:::
|
||
|
||
The mathematics of automatic differentiation were established decades before deep learning's resurgence. What changed was the systems engineering. Before framework automation, implementing gradient computation for a single fully connected layer meant writing separate forward and backward functions, manually tracking intermediate values, and verifying mathematical correctness across dozens of operations. A modern Transformer involves hundreds of operations with complex dependencies; manual gradient derivation for attention, layer normalization, and residual connections would require months of careful work per architecture variant.
|
||
|
||
The breakthrough was turning this manual process into software infrastructure. A single matrix multiplication requires different gradient computations depending on which inputs require gradients, tensor shapes, hardware capabilities, and memory constraints. Autograd systems handle these variations transparently, which is why the rate of architectural innovation accelerated after frameworks matured. The mathematics did not change; software engineering made the mathematics practical to apply at scale.
|
||
|
||
#### Memory Management in Gradient Computation {#sec-ml-frameworks-memory-management-gradient-computation-bb77}
|
||
|
||
The memory strategies from @sec-ml-frameworks-reverse-mode-d328 (checkpointing, gradient accumulation) exist because reverse-mode differentiation requires preserving computational history. As @lst-reverse_memory demonstrated, each layer adds an activation tensor that persists until the backward pass consumes it, creating a memory wave that peaks at the start of backpropagation and recedes as gradients are computed. Modern frameworks track the lifetime of each intermediate value automatically, freeing memory as soon as it is no longer needed. Even with precise lifetime tracking, however, a deeper problem remains: the cost of acquiring memory from the GPU in the first place.
|
||
|
||
The cost of raw GPU memory allocation provides a critical engineering lesson: production systems require **Memory Abstraction**. Requesting memory directly from a GPU is a high-latency operation that can synchronize the entire device, creating an allocation bottleneck that stalls computation. To solve this, modern frameworks implement **Caching Allocators**\index{Memory Management!caching allocators}. Instead of communicating with the hardware for every new tensor, the framework requests large blocks of memory upfront and manages its own internal pool. This abstraction is critical because it prevents **memory fragmentation**\index{Memory Management!fragmentation}, the scenario where free memory is available but scattered in pieces too small to hold a large tensor, allowing models to push the physical limits of the hardware without constant system-level overhead.
|
||
|
||
::: {.callout-perspective title="Caching Allocator and Utilization"}
|
||
|
||
The **Caching Allocator** is the framework's primary mechanism for maximizing the **Utilization** term in the Iron Law ($\frac{1}{\text{Utilization}}$). Without it, two factors degrade performance significantly:
|
||
|
||
1. **Allocation Latency**: `cudaMalloc` is a synchronous operation that costs 10--100 microseconds. In a training loop with thousands of operations per second, this latency would dominate execution time. The caching allocator pays this cost once, then serves subsequent requests in nanoseconds from its pool.
|
||
2. **Fragmentation**: A "Swiss cheese" memory pattern reduces **Effective Capacity**. If 10 GB is free but the largest contiguous block is 1 GB, a 2 GB tensor cannot be allocated. By binning allocations into standard sizes (powers of 2), the allocator ensures that freed memory can be reused for future requests, keeping **Utilization** high.
|
||
|
||
\index{Memory Fragmentation!OOM errors}
|
||
When "OOM" (Out of Memory) errors appear despite `nvidia-smi` showing free memory, **fragmentation** is often the culprit. The allocator cannot find a contiguous block large enough for the requested tensor.
|
||
|
||
:::
|
||
|
||
#### Production System Integration Challenges {#sec-ml-frameworks-production-system-integration-challenges-cd6e}
|
||
|
||
A training iteration that takes 300 ms in profiling may take 500 ms in production because the AD system must coordinate with the memory allocator, the device manager, the operation scheduler, and the optimizer on every single step. Each gradient computation can trigger data movement between CPU and GPU, memory allocation for intermediate tensors, and kernel launches on accelerators. These system interactions dominate wall-clock time for small models and remain significant even at scale. The gap between what the programmer writes (a five-line training loop) and what the system executes (dozens of memory allocations, kernel launches, and synchronization points) is the central tension of AD system design.
|
||
|
||
Beyond sequential overhead, the AD system must also exploit *concurrency*. Modern networks frequently contain independent branches---two convolutional paths processing the same input before merging, as in Inception-style architectures. On a GPU with sufficient resources, the framework's scheduler can execute both branch backward passes on separate CUDA streams, reducing backward pass time by up to 30--40%. The AD system therefore tracks dependencies for two purposes: correctness (computing the right gradients) and performance (scheduling independent computations concurrently). Frameworks hide this complexity behind `loss.backward()`, but the scheduling, memory allocation, and data movement decisions behind that call determine whether training runs at 40% or 80% of peak hardware utilization.
|
||
|
||
The memory and system integration challenges examined above (caching allocators, activation storage, and checkpoint overhead) affect all frameworks. Yet *how* frameworks implement automatic differentiation in the first place varies significantly, with consequences for both *optimization potential* and developer experience. The distinction between *tape-based and transform-based autodiff* captures this architectural divergence.
|
||
|
||
::: {.callout-perspective title="Tape-based vs. Transform-based Autodiff"}
|
||
|
||
\index{Automatic Differentiation!tape-based vs. transform-based}
|
||
**PyTorch (Tape-based)**: Records operations on a dynamic "tape" during the forward pass. This is flexible and easy to debug but makes it hard for a compiler to see the whole graph at once for global optimization.
|
||
|
||
**JAX (Transform-based)**: Treats automatic differentiation as a high-level function transformation (`grad(f)`). Because JAX sees the mathematical function before execution, it can easily chain other transformations like `jit(grad(f))` or `vmap(grad(f))`, producing highly optimized, compiled kernels that often outperform dynamic frameworks on specialized hardware like TPUs.
|
||
|
||
:::
|
||
|
||
JAX[^fn-jax-functional-transforms] exemplifies the transform-based approach, where composable function transformations replace imperative tape recording.
|
||
|
||
[^fn-jax-functional-transforms]: **JAX**: The "transform-based" distinction matters because JAX's `grad`, `jit`, and `vmap` are not library calls but algebraic transformations on pure functions, composable in any order. A chain like `jit(grad(vmap(f)))` compiles into a single XLA kernel because functional purity (no side effects, no mutation) lets the compiler reason about the entire program mathematically. The payoff is over 90% hardware utilization on TPUs; the cost is that any impurity (printing, mutation, unkeyed randomness) silently vanishes after the first trace. \index{JAX!composable transformations}
|
||
|
||
#### How Different Frameworks Implement AD {#sec-ml-frameworks-different-frameworks-implement-ad-2f8e}
|
||
|
||
The execution models covered in @sec-ml-frameworks-execution-problem-e1e1, namely eager, static graph, and hybrid, directly shape how each framework implements automatic differentiation:
|
||
|
||
- **PyTorch** [@paszke2019pytorch] builds its autograd tape dynamically during forward execution, providing immediate debugging at the cost of graph-level optimization. The `grad_fn` chain mechanism detailed in @sec-ml-frameworks-pytorch-autograd-internals-4fa0 enables flexible control flow but requires storing the complete graph until backward pass completion.
|
||
- **TensorFlow** (in its 1.x incarnation) performed symbolic differentiation during graph construction, enabling ahead-of-time optimization. Modern TensorFlow 2.x uses eager execution by default but provides `tf.function` for graph compilation when performance matters.
|
||
- **JAX** [@frostig2018compiling] transforms functions rather than tracking operations. The `jax.grad()` transformation returns a new function that computes gradients, enabling composition with `jax.vmap()` for vectorization and `jax.jit()` for compilation. This approach requires pure functions but enables composable program transformations that chain differentiation, vectorization, and compilation in a single expression.
|
||
|
||
These implementation differences have direct practical consequences for framework selection, which @sec-ml-frameworks-major-framework-platform-analysis-fe96 examines in detail.
|
||
|
||
A recurring tension runs through every AD design decision: mathematical correctness demands storing computational history, but hardware imposes strict memory limits. Every framework resolves this tension differently, choosing which activations to checkpoint, which operations to fuse, and how aggressively to trade recomputation for memory. These choices determine which models can train on which hardware, making AD system design one of the most consequential engineering decisions in any framework.
|
||
|
||
::: {.callout-checkpoint title="The Systems Cost of Gradients"}
|
||
|
||
Training is *inherently more expensive* than inference because of Automatic Differentiation.
|
||
|
||
**Computational Reality**
|
||
|
||
- [ ] **Reverse Mode AD**: Why is this the only viable method for neural networks? (Because we have 1 loss scalar and $10^9$ parameters. Forward mode would require $10^9$ passes).
|
||
- [ ] **The Activation Tax**: Do you understand why training memory scales linearly with depth? (We must stash forward activations to compute backward gradients).
|
||
|
||
**Optimization Mechanics**
|
||
|
||
- [ ] **Gradient Checkpointing**: How does re-computing activations save memory? (We discard the stash and regenerate it on demand).
|
||
|
||
:::
|
||
|
||
The execution and differentiation problems together enable the training loop: the execution model determines when computation happens, while automatic differentiation computes the gradients that drive learning. Both problems, however, quietly assume something that cannot be taken for granted: that the same code can run across diverse hardware. A model trained on an NVIDIA A100 must serve inference on a mobile phone's ARM CPU, a Google TPU, or a microcontroller with kilobytes of memory. The same `torch.matmul` call must dispatch to cuBLAS on one device and a hand-tuned ARM NEON kernel on another. This hardware diversity creates the third problem.
|
||
|
||
## Abstraction Problem {#sec-ml-frameworks-abstraction-problem-37a5}
|
||
|
||
\index{Abstraction Problem!definition}
|
||
\index{Hardware Abstraction!framework design}
|
||
\index{Hardware Abstraction!two dimensions (data, execution)}
|
||
The hardware diversity described above is not merely inconvenient; it is architecturally fundamental. A GPU offers 1,000$\times$ the parallelism of a CPU but has different memory semantics. A TPU provides higher throughput but requires static shapes. A microcontroller has kilobytes where a server has gigabytes. The abstraction problem is precisely this: frameworks must hide this complexity behind a single programming interface while still enabling efficient utilization of each target's unique capabilities.
|
||
|
||
The problem decomposes into two interacting dimensions. The first is *data representation*: how should frameworks represent tensors, parameters, and computational state in ways that work across hardware? The second is *execution mapping*: how should high-level operations translate to hardware-specific implementations? These dimensions are not independent concerns. The way data is represented (memory layout, precision, device placement) directly affects *what* execution strategies are possible. A tensor stored in row-major format on a GPU requires different kernels than one in column-major format on a CPU. A model quantized to INT8 enables entirely different execution paths than FP32.
|
||
|
||
Solving the abstraction problem requires sophisticated software infrastructure: tensor representations that encode both mathematical semantics and hardware constraints, intermediate representations that enable hardware-specific compilation, and runtime systems that manage data movement across the memory hierarchy.
|
||
|
||
To make this concrete, trace what must happen when a programmer writes `model(input)`. The framework must answer five questions in rapid succession: *What is the data?* (tensor shape, memory layout, numeric precision), *Where does it live?* (device placement and the bandwidth hierarchy connecting CPU, GPU, and accelerator memory), *How does it arrive fast enough?* (data pipelines that sustain hundreds of MB/s to keep the accelerator fed), *How does it scale beyond one device?* (parameter synchronization and distributed execution contexts), and *What actually runs on the hardware?* (kernel dispatch, scheduling, and resource optimization). The following sub-sections answer these questions in order, building from the data container up to the hardware execution layer.
|
||
|
||
### Data Structures and Tensor Abstractions {#sec-ml-frameworks-data-structures-tensor-abstractions-9cbf}
|
||
|
||
```{python}
|
||
#| label: resnet50-params-intro
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ RESNET-50 PARAMETERS FOR ABSTRACTION SECTION
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: Introducing tensor abstractions with concrete model scale
|
||
# │
|
||
# │ Goal: Motivate the need for sophisticated framework data structures.
|
||
# │ Show: That 25.6M parameters must be managed without manual pointers.
|
||
# │ How: Retrieve ResNet-50 parameter count from mlsysim.core.constants.
|
||
# │
|
||
# │ Imports: mlsysim.core.constants (RESNET50_PARAMS), mlsysim.book (fmt)
|
||
# │ Exports: resnet_params_m_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsysim.core.constants import RESNET50_PARAMS, Mparam
|
||
from mlsysim.fmt import fmt, check
|
||
|
||
# ┌── LEGO ───────────────────────────────────────────────
|
||
class ResNetAbstraction:
|
||
"""
|
||
Namespace for ResNet-50 parameter scale in abstraction section.
|
||
"""
|
||
params_m = RESNET50_PARAMS.m_as(Mparam)
|
||
check(params_m > 20, f"ResNet-50 parameter count too low: {params_m:.1f}M")
|
||
resnet_params_m_str = fmt(params_m, precision=1, commas=False)
|
||
```
|
||
|
||
A ResNet-50 forward pass touches `{python} ResNetMemory.resnet_params_m_str` million parameters, produces intermediate activations at every layer, and must coordinate memory across CPU and GPU address spaces. How do frameworks organize all of this data so that a single Python call like `model(input)` executes millions of operations without the programmer managing a single pointer? Answering this question requires solving four problems in sequence: defining a universal data container (tensors), placing it on the right device (memory management), feeding data fast enough (data pipelines), and dispatching the right hardware kernel (core operations). We trace this path from data representation to hardware execution.
|
||
|
||
Computational graphs specify the *logical* flow of operations, but data structures determine how those operations access and manipulate data in *physical* memory. This distinction matters because the same mathematical operation can differ by an order of magnitude in throughput depending on whether data is contiguous in cache, pinned for DMA transfer, or scattered across pages.
|
||
|
||
The first step is the data container itself. Framework data structures must sustain memory bandwidth (hundreds of GB/s on modern GPUs), accommodate architectures from 1D sequences to 5D video tensors, and hide device management behind clean APIs. Tensors are the universal answer.
|
||
|
||
#### Tensors {#sec-ml-frameworks-tensors-1cb7}
|
||
|
||
\index{Tensor!n-dimensional arrays}
|
||
\index{Tensor!definition}
|
||
\index{Tensor!etymology}
|
||
At the foundation of every framework's data representation lies a single abstraction: the **tensor**.
|
||
|
||
::: {.callout-definition title="Tensor"}
|
||
|
||
***Tensors***\index{Tensor!definition} are $n$-dimensional arrays with explicit shape, data type, and memory layout metadata that allow ML frameworks to map mathematical operations directly onto hardware vector units without intermediate data transformation.
|
||
|
||
1. **Significance (Quantitative):** Tensor memory footprint is fully deterministic from its metadata: a contiguous FP32 tensor of shape $[1024, 1024]$ occupies exactly $1024 \times 1024 \times 4 = 4$ MB. Non-contiguous layouts (e.g., from a transpose operation) require explicit `.contiguous()` calls before certain CUDA kernels can execute, adding a memory-copy overhead that can dominate the $L_{\text{lat}}$ term for tensors under 1 MB.
|
||
2. **Distinction (Durable):** Unlike a Python list or generic NumPy array, a framework tensor carries device placement metadata (CPU vs. GPU), dtype (FP32, BF16, INT8), and stride information that enables zero-copy view operations and CUDA kernel dispatch without any runtime type checking or data movement.
|
||
3. **Common Pitfall:** A frequent misconception is that tensor operations are always in-place. Framework tensor operations return new tensors by default, allocating fresh GPU memory for each intermediate result. In a long computation graph, these intermediate allocations accumulate and can exhaust GPU memory before any weights are updated.
|
||
|
||
:::
|
||
|
||
Every computation in a neural network operates on tensors.[^fn-tensor-hardware-contract] Training batches, activation maps, parameter gradients, and optimizer states are all tensors. This unified representation lets frameworks optimize a single data structure for hardware rather than managing separate containers for each role.
|
||
|
||
[^fn-tensor-hardware-contract]: **Tensor**: From Latin *tendere* ("to stretch"), coined in its mathematical sense by physicist Woldemar Voigt in 1898 for objects defined by how they transform under coordinate changes. ML inherited the term because framework tensors are similarly defined by transformation behavior: transposing changes strides but not data, reshaping changes metadata without moving bytes. This transformation-centric design is also the source of layout sensitivity: choosing NCHW when the target accelerator prefers NHWC (or vice versa) can halve computational throughput, because misaligned memory access patterns break hardware coalescing. \index{Tensor!hardware contract}
|
||
|
||
The tensor abstraction consumes far more memory than model weights alone suggest. Engineers who estimate memory from parameter count alone allocate accordingly and encounter out-of-memory errors that seem inexplicable. The following notebook quantifies what we call the *administrative tax*: the shadow tensors for gradients, optimizer momentum, and stored activations that accompany every weight tensor.
|
||
|
||
```{python}
|
||
#| label: admin-tax-calc
|
||
#| echo: false
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ ADMINISTRATIVE TAX CALCULATION
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: Callout "The Administrative Tax" explaining hidden memory costs
|
||
# │
|
||
# │ Goal: Demonstrate the hidden memory costs of large-scale training.
|
||
# │ Show: That a 2 GB model requires ~19 GB of VRAM during training.
|
||
# │ How: Sum gradients, optimizer states, and activations for a 1B param model.
|
||
# │
|
||
# │ Imports: mlsysim.core.constants (BYTES_FP16, BYTES_ADAM_STATE, GB), mlsysim.book (fmt, md_math)
|
||
# │ Exports: admin_weights_gb_str, admin_grads_gb_str, admin_opt_gb_str, admin_act_gb_str,
|
||
# │ admin_tax_gb_str, admin_batch_str, admin_layers_str, admin_width_str,
|
||
# │ admin_act_calc_md
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsysim.core.constants import BYTES_FP16, BYTES_ADAM_STATE, GB, BILLION
|
||
from mlsysim.fmt import fmt, check, md_math
|
||
|
||
# ┌── LEGO ───────────────────────────────────────────────
|
||
class AdminTax:
|
||
"""
|
||
Namespace for Administrative Tax Calculation.
|
||
Scenario: Memory overhead for 1B parameter model.
|
||
"""
|
||
|
||
# ┌── 1. LOAD (Constants) ───────────────────────────────────────────────
|
||
params_count = 1 * BILLION
|
||
batch_size = 32
|
||
layers = 100
|
||
width = 1024
|
||
|
||
# ┌── 2. EXECUTE (The Compute) ─────────────────────────────────────────
|
||
weights_gb = (params_count * BYTES_FP16).m_as(GB)
|
||
grads_gb = weights_gb
|
||
opt_gb = (params_count * BYTES_ADAM_STATE).m_as(GB)
|
||
|
||
# Step 1: Activation approximation: Batch * Layers * Width^2 * FP16
|
||
act_bytes = batch_size * layers * (width ** 2) * BYTES_FP16
|
||
act_gb = act_bytes.m_as(GB)
|
||
|
||
total_gb = weights_gb + grads_gb + opt_gb + act_gb
|
||
tax_gb = total_gb - weights_gb
|
||
|
||
# ┌── 3. GUARD (Invariants) ───────────────────────────────────────────
|
||
check(tax_gb > 15, f"Administrative tax ({tax_gb:.1f} GB) unexpectedly low.")
|
||
|
||
# ┌── 4. OUTPUT (Formatting) ──────────────────────────────────────────────
|
||
admin_weights_gb_str = fmt(weights_gb, precision=0, commas=False)
|
||
admin_grads_gb_str = fmt(grads_gb, precision=0, commas=False)
|
||
admin_opt_gb_str = fmt(opt_gb, precision=0, commas=False)
|
||
admin_act_gb_str = fmt(act_gb, precision=1, commas=False)
|
||
admin_total_gb_str = fmt(total_gb, precision=0, commas=False)
|
||
admin_tax_gb_str = fmt(tax_gb, precision=0, commas=False)
|
||
admin_batch_str = fmt(batch_size, precision=0, commas=False)
|
||
admin_layers_str = fmt(layers, precision=0, commas=False)
|
||
admin_width_str = fmt(width, precision=0, commas=False)
|
||
|
||
admin_act_calc_md = md_math(f"{admin_batch_str} \\times {admin_layers_str} \\times {admin_width_str}^{{2}} \\times 2 \\approx \\mathbf{{{admin_act_gb_str} \\text{{ GB}}}}")
|
||
```
|
||
|
||
::: {.callout-notebook title="The Administrative Tax"}
|
||
|
||
The memory breakdown for ResNet-50 in @sec-ml-frameworks-principle-2-memorycompute-tradeoff-19f2 showed a concrete ~`{python} ResNetMemory.resnet_training_ratio_str`$\times$ ratio between training and inference memory. Here we generalize that analysis to reveal the full administrative overhead at billion-parameter scale.
|
||
|
||
**Problem**: Why does GPU utilization drop when training small models?
|
||
|
||
**The Math (The Hidden Tax)**:
|
||
|
||
1. **Model Weights**: `{python} AdminTax.admin_weights_gb_str` GB.
|
||
2. **Gradients**: `{python} AdminTax.admin_grads_gb_str` GB (same size as weights).
|
||
3. **Optimizer States (Adam)**: `{python} AdminTax.admin_opt_gb_str` GB ($2 \times \text{weights}$ for momentum and velocity in FP32).
|
||
4. **Activations**: For a batch size of `{python} AdminTax.admin_batch_str` and a `{python} AdminTax.admin_layers_str`-layer network, the framework must store every intermediate layer output for the backward pass.
|
||
|
||
$$ \text{Activations} \approx \text{Batch} \times \text{Layers} \times \text{Width}^{2} \times 2 \text{ bytes} $$
|
||
For a `{python} AdminTax.admin_width_str`-width model: `{python} AdminTax.admin_act_calc_md`. (Each layer's activation is a `Width$\times$ Width` matrix per sample---appropriate for transformer-style models where intermediate projections scale with hidden dimension squared.)
|
||
|
||
**The Systems Conclusion**: A `{python} AdminTax.admin_weights_gb_str` GB model has an **"Administrative Tax"** of ~`{python} AdminTax.admin_tax_gb_str` GB (`{python} AdminTax.admin_grads_gb_str` GB gradients + `{python} AdminTax.admin_opt_gb_str` GB optimizer + `{python} AdminTax.admin_act_gb_str` GB activations) before the first batch is even processed. During training, **Data Movement** includes saving and retrieving these activations, which is why training is often 3--4$\times$ slower than pure inference.
|
||
|
||
:::
|
||
|
||
#### Tensor Structure and Dimensions {#sec-ml-frameworks-tensor-structure-dimensions-4a14}
|
||
|
||
\index{Tensor!rank hierarchy}\index{Broadcasting!shape compatibility}
|
||
A tensor generalizes scalars, vectors, and matrices to arbitrary dimensions. The hierarchy is straightforward: a scalar is a rank-0 tensor (single value), a vector is rank-1 (sequence of values), and a matrix is rank-2 (rows and columns). Higher ranks extend this pattern through nesting, so a rank-3 tensor is a stack of matrices---compare all four ranks side by side in @fig-tensor-data-structure-a to see how each level adds a new axis of organization.
|
||
|
||
::: {#fig-tensor-data-structure-a fig-env="figure" fig-pos="htb" fig-cap="**Tensor Rank Hierarchy.** Four shapes illustrating tensor ranks from left to right: a single value (rank 0, scalar), a column of values (rank 1, vector), a grid of values (rank 2, matrix), and a cube of values (rank 3, three-dimensional tensor)." fig-alt="Four shapes showing tensor ranks left to right: single box labeled Rank 0, vertical column of numbers labeled Rank 1, 2D grid of numbers labeled Rank 2, and 3D cube labeled Rank 3."}
|
||
|
||
```{.tikz}
|
||
\scalebox{0.8}{%
|
||
\begin{tikzpicture}[font=\small\usefont{T1}{phv}{m}{n}]
|
||
\begin{scope}
|
||
\pgfmathsetmacro{\cubex}{2.5}
|
||
\pgfmathsetmacro{\cubey}{2.5}
|
||
\pgfmathsetmacro{\cubez}{2.5}
|
||
\draw[BrownLine,fill=BrownL!40] (0,0,0) -- ++(-\cubex,0,0) -- ++(0,-\cubey,0) -- ++(\cubex,0,0) -- cycle;
|
||
\draw[BrownLine,fill=BrownL] (0,0,0) -- ++(0,0,-\cubez)coordinate(G) -- ++(0,-\cubey,0) -- ++(0,0,\cubez) -- cycle;
|
||
\draw[BrownLine,fill=BrownL!70] (0,0,0) -- ++(-\cubex,0,0) -- ++(0,0,-\cubez) -- ++(\cubex,0,0) -- cycle;
|
||
\path[red] (-\cubex,-\cubey,0)coordinate(A) -- (0,-\cubey,0)coordinate(B);
|
||
\node[below=0.3of $(A)!0.5!(B)$]{Rank 3};
|
||
\end{scope}
|
||
|
||
\begin{scope}[shift={(-5.5,-0.77)}]
|
||
\node[draw=BrownLine,fill=BrownL!40,rectangle,%anchor=north west,
|
||
minimum width=98,minimum height=98](R){};
|
||
\node[right=2pt of $(R.north west)!0.1!(R.south west)$]{1 \ldots ~2};
|
||
\node[right=2pt of $(R.north west)!0.24!(R.south west)$]{3 \ldots ~5};
|
||
\node[right=2pt of $(R.north west)!0.39!(R.south west)$]{5 \phantom{\ldots} 3};
|
||
\node[right=2pt of $(R.north west)!0.58!(R.south west)$]{$\vdots$ \phantom{\ldots~} $\vdots$};
|
||
\node[right=2pt of $(R.north west)!0.9!(R.south west)$]{3 \phantom{\ldots} 3};
|
||
\node[below=0.3of $(R.south west)!0.5!(R.south east)$]{Rank 2};
|
||
\end{scope}
|
||
|
||
\begin{scope}[shift={(-8.75,-0.77)}]
|
||
\node[draw=BrownLine,fill=BrownL!40,rectangle,%anchor=north west,
|
||
minimum width=18,minimum height=98](R){};
|
||
\node[right=2pt of $(R.north west)!0.1!(R.south west)$]{1};
|
||
\node[right=2pt of $(R.north west)!0.24!(R.south west)$]{3};
|
||
\node[right=2pt of $(R.north west)!0.39!(R.south west)$]{5};
|
||
\node[right=2pt of $(R.north west)!0.58!(R.south west)$]{$\vdots$};
|
||
\node[right=2pt of $(R.north west)!0.9!(R.south west)$]{3};
|
||
\node[below=0.3of $(R.south west)!0.5!(R.south east)$](R1){Rank 1};
|
||
\end{scope}
|
||
|
||
\begin{scope}[shift={(-10.5,-0.77)}]
|
||
\node[draw=BrownLine,fill=BrownL!40,rectangle,%anchor=north west,
|
||
minimum width=18,minimum height=18](3R){0};
|
||
\end{scope}
|
||
\path[red](R1)-|coordinate(P)(3R);
|
||
\node[]at(P){Rank 0};
|
||
\end{tikzpicture}}
|
||
```
|
||
|
||
:::
|
||
|
||
This rank hierarchy maps directly onto ML data. A color image is a rank-3 tensor: height x width x 3 channels (red, green, blue). @fig-tensor-data-structure-b breaks this apart, stacking the three color channels illustrating how a single photograph becomes a three-layer numerical grid. Stacking a batch of $N$ images adds a fourth dimension, producing a rank-4 tensor of shape $[N, 3, H, W]$. Every convolutional layer in a vision model consumes and produces tensors of exactly this shape, which is why the tensor abstraction is so central to framework design.
|
||
|
||
::: {#fig-tensor-data-structure-b fig-env="figure" fig-pos="htb" fig-cap="**Image as RGB Tensor.** Three stacked grids representing the red, green, and blue color channels of an image, with dimension labels showing width, height, and channel depth forming a rank-3 tensor. *Credit: Niklas Lang [https://towardsdatascience.com/what-are-tensors-in-machine-learning-5671814646ff](https://towardsdatascience.com/what-are-tensors-in-machine-learning-5671814646ff)*." fig-alt="Three stacked $3\times3$ grids in red, green, and blue representing RGB color channels. Dimension labels show width 3 pixels, height 3 pixels, and 3 color channels forming a 3D tensor for image data."}
|
||
|
||
```{.tikz}
|
||
\scalebox{0.7}{%
|
||
\begin{tikzpicture}[font=\usefont{T1}{phv}{m}{n}\Large]
|
||
%
|
||
\tikzset{
|
||
Line/.style={line width=1.0pt,black!70,font=\usefont{T1}{phv}{m}{n}\footnotesize
|
||
},
|
||
Box/.style={align=flush center,
|
||
inner xsep=4pt,
|
||
node distance=0,
|
||
draw=white,
|
||
line width=0.75pt,
|
||
fill=red!80,
|
||
minimum width=10mm,
|
||
minimum height=10mm
|
||
},
|
||
}
|
||
\node[Box](B1){\textbf{6}};
|
||
\node[Box,right=of B1](B2){\textbf{2}};
|
||
\node[Box,right=of B2](B3){\textbf{5}};
|
||
\node[Box,below=of B1](B4){\textbf{32}};
|
||
\node[Box,right=of B4](B5){\textbf{15}};
|
||
\node[Box,right=of B5](B6){\textbf{4}};
|
||
\node[Box,below=of B4](B7){\textbf{1}};
|
||
\node[Box,right=of B7](B8){\textbf{8}};
|
||
\node[Box,right=of B8](B9){\textbf{3}};
|
||
%%
|
||
\node[Box,fill= OliveLine, draw= white,above=of B2](2B1){\textbf{8}};
|
||
\node[Box,fill= OliveLine, draw= white,right=of 2B1](2B2){\textbf{7}};
|
||
\node[Box,fill= OliveLine, draw= white,right=of 2B2](2B3){\textbf{5}};
|
||
\node[Box,fill= OliveLine, draw= white,below=of 2B3](2B4){\textbf{1}};
|
||
\node[Box,fill= OliveLine, draw= white,below=of 2B4](2B5){\textbf{2}};
|
||
%%
|
||
\node[Box,fill= BlueLine!80, draw= white,above=of 2B2](3B1){\textbf{2}};
|
||
\node[Box,fill= BlueLine!80, draw= white,right=of 3B1](3B2){\textbf{1}};
|
||
\node[Box,fill= BlueLine!80, draw= white,right=of 3B2](3B3){\textbf{9}};
|
||
\node[Box,fill= BlueLine!80, draw= white,below=of 3B3](3B4){\textbf{4}};
|
||
\node[Box,fill= BlueLine!80, draw= white,below=of 3B4](3B5){\textbf{3}};
|
||
%
|
||
\draw[dashed,Line,latex-latex]([yshift=-3mm]B7.south west)--
|
||
node[below=1mm]{Width: 3 Pixel}([yshift=-3mm]B9.south east);
|
||
\draw[dashed,Line,latex-latex]([xshift=-4mm]B7.south west)--
|
||
node[left]{Height: 3 Pixel}([xshift=-4mm]B1.north west);
|
||
\draw[dashed,Line,latex-latex,shorten <=2mm]([xshift=-4mm]B1.north west)--
|
||
node[left=3mm,pos=0.6]{3 Color Channels}([xshift=-4mm]3B1.north west);
|
||
\end{tikzpicture}}
|
||
```
|
||
|
||
:::
|
||
|
||
Framework tensors carry more than raw numbers. Each tensor stores metadata that the runtime uses to validate operations and select fast execution paths: a *shape* tuple (e.g., `[64, 3, 224, 224]` for a batch of images), a *dtype* (float32, float16, int8), and a *device* tag (CPU, cuda:0). A matrix multiplication, for instance, checks shape compatibility at dispatch time and uses the dtype to route to the correct hardware kernel, whether a standard FP32 GEMM or a Tensor Core FP16 path.
|
||
|
||
\index{Memory Layout!stride patterns}
|
||
Memory layout implementation introduces distinct challenges in tensor design. While tensors provide an abstraction of multi-dimensional data, physical computer memory remains linear. Stride patterns\index{Tensor!stride patterns}\index{Memory Layout!row-major vs. column-major} address this disparity by creating mappings between multi-dimensional tensor indices and linear memory addresses. These patterns significantly impact computational performance by determining memory access patterns during tensor operations. @fig-tensor-memory-layout makes this concrete with a $2\times3$ tensor: follow the same six values as they map into two different linear orderings---row-major and column-major---and note how the stride values change to compensate.
|
||
|
||
::: {#fig-tensor-memory-layout fig-env="figure" fig-pos="htb" fig-cap="**Tensor Memory Layout**: A $2\times3$ tensor can be stored in linear memory using either row-major (C-style) or column-major (Fortran-style) ordering. Strides define the number of elements to skip in each dimension when moving through memory, enabling frameworks to calculate memory addresses for tensor[i,j] as base_address + i$\times$ stride[0] + j$\times$ stride[1]. The choice of memory layout significantly impacts cache performance and computational efficiency." fig-alt="Left: $2\times3$ tensor grid with values 1-6. Right: two linear arrays showing row-major layout (1,2,3,4,5,6) and column-major layout (1,4,2,5,3,6). Below: stride calculations for row-major [3,1] and column-major [1,2]."}
|
||
|
||
```{.tikz}
|
||
\begin{tikzpicture}[font=\footnotesize\usefont{T1}{phv}{m}{n}]
|
||
% Define colors
|
||
\definecolor{col1}{RGB}{135, 206, 250}
|
||
\definecolor{col2}{RGB}{255, 182, 193}
|
||
\definecolor{col3}{RGB}{152, 251, 152}
|
||
% 2x3 tensor visualization (LEFT SIDE)
|
||
\foreach \row in {0,1} {
|
||
\foreach \col in {0,1,2} {
|
||
\pgfmathsetmacro{\val}{\row * 3 + \col + 1}
|
||
\node[draw, minimum width=15mm, minimum height=10mm,
|
||
fill=col1!50](B\row\col) at (\col*1.7, 1-\row*1.2) {\val};
|
||
}
|
||
}
|
||
\node[above=2pt of B01]{\textbf{2D Tensor ($2\times3$)}};
|
||
\path[red](B02.north east)--++(1.35,0)coordinate(CR);
|
||
\path[red](B12.340)--++(1.35,0)coordinate(ZE);
|
||
% Row-major memory layout (RIGHT SIDE)
|
||
\foreach \i in {0,1,2,3,4,5} {
|
||
\pgfmathsetmacro{\val}{\i + 1}
|
||
\node[draw, minimum width=10mm, minimum height=8mm,
|
||
anchor=north west,fill=col2!50](CB\i) at ($(CR)+(\i*1.1, 0)$) {\val};
|
||
\node[below=0pt of CB\i, font=\tiny\usefont{T1}{phv}{m}{n}] {[\i]};
|
||
}
|
||
\node[above=2pt of CB2.north east]{\textbf{Row-Major Layout}};
|
||
% Column-major memory layout (RIGHT SIDE)
|
||
\foreach \i in {0,1,2,3,4,5} {
|
||
\pgfmathsetmacro{\val}{int(mod(\i,2)*3 + int(\i/2) + 1)}
|
||
\node[draw, minimum width=10mm, minimum height=8mm,
|
||
anchor=north west,fill=col3!50](ZE\i) at ($(ZE)+(\i*1.1, 0)$) {\val.0};
|
||
\node[below=0pt of ZE\i, font=\tiny\usefont{T1}{phv}{m}{n}] {[\i]};
|
||
}
|
||
\node[above=2pt of ZE2.north east]{\textbf{Column-Major Layout}};
|
||
% Strides explanation (BOTTOM)
|
||
\node[anchor=north west,align=left,inner sep=0pt] at ($(B10.south west)+(0,-0.2)$) {%
|
||
\textbf{Stride Calculation:}\\
|
||
Row-major strides: [3, 1]\\
|
||
Column-major strides: [1, 2]\\
|
||
Element [i,j] offset = i$\times$ stride[0] + j$\times$ stride[1]
|
||
};
|
||
\end{tikzpicture}
|
||
```
|
||
|
||
:::
|
||
|
||
These memory layout patterns are crucial for framework performance optimization. Row-major layout (used by NumPy, PyTorch) stores elements row by row, making row-wise operations more cache-friendly. Column-major layout (used by some BLAS libraries) stores elements column by column, optimizing column-wise access patterns. The stride values encode this layout information: in row-major layout for a $2\times3$ tensor, moving to the next row requires skipping 3 elements (stride[0]=3), while moving to the next column requires skipping 1 element (stride[1]=1).
|
||
|
||
These memory layout details have direct performance implications. When a convolution kernel accesses weight values, row-major layout means consecutive weights along the output channel dimension are contiguous in memory---enabling efficient vectorized loads. Column-major layout would scatter those same weights across memory, forcing slower gather operations. Careful alignment of stride patterns with hardware memory hierarchies maximizes cache efficiency and memory throughput, with optimal layouts achieving 80--90% of theoretical memory bandwidth (1.5--3.0 TB/s on modern data-center GPUs like the A100 and H100) compared to suboptimal patterns that may achieve only 20--30% utilization.
|
||
|
||
\index{Tensor!data types (dtypes)}
|
||
Tensor implementations use type systems to control numerical precision and memory consumption. The standard choice in machine learning has been 32-bit floating-point numbers (`float32`), offering a balance of precision and efficiency. Modern frameworks extend this with multiple numeric types for different needs. Integer types support indexing and embedding operations. Reduced-precision types like 16-bit floating-point numbers enable efficient mobile deployment. 8-bit integers allow fast inference on specialized hardware.
|
||
|
||
The choice of numeric type affects both model behavior and computational efficiency.
|
||
\index{Quantization!inference precision}
|
||
Neural network training typically requires float32 precision to maintain stable gradient computations. Inference tasks can often use lower precision (`int8` or even `int4`), reducing memory usage and increasing processing speed. Mixed-precision training approaches combine these benefits by using float32 for critical accumulations while performing most computations at lower precision.
|
||
|
||
Type conversions between different numeric representations require careful management. Operating on tensors with different types demands explicit conversion rules to preserve numerical correctness. These conversions introduce computational costs and risk precision loss. Frameworks provide type casting capabilities but rely on developers to maintain numerical precision across operations.
|
||
|
||
Tensors answer the first question---*what is the data?*---by encoding shape, layout, and precision into a single abstraction. A perfectly shaped tensor on the wrong device, however, or one that must cross a 60$\times$ bandwidth gap to reach the GPU, can erase every layout optimization. The next question is *where* data lives and *how* it moves.
|
||
|
||
#### Device and Memory Management {#sec-ml-frameworks-device-memory-management-9404}
|
||
|
||
```{python}
|
||
#| label: device-bandwidth-hierarchy
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ DEVICE BANDWIDTH HIERARCHY
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: Device and Memory Management section introducing bandwidth costs
|
||
# │
|
||
# │ Goal: Quantify the bandwidth hierarchy across system interconnects.
|
||
# │ Show: The 60× bandwidth gap between PCIe (32 GB/s) and HBM (2 TB/s).
|
||
# │ How: Compare transfer rates for standard interfaces and on-chip memory.
|
||
# │
|
||
# │ Imports: mlsysim.core.constants (PCIE_GEN4_BW, NVLINK_A100_BW, A100_MEM_BW, A100_MEM_CAPACITY,
|
||
# │ A100_FLOPS_FP16_TENSOR), mlsysim.book (fmt)
|
||
# │ Exports: pcie4_gbs_str, pcie4_bidir_gbs_str, nvlink_a100_gbs_str, a100_bw_gbs_str,
|
||
# │ a100_bw_tbs_str, a100_tflops_fp16_str, pcie4_4mb_ms_str, nvlink_4mb_ms_str,
|
||
# │ hbm_4mb_ms_str, pcie4_1gb_ms_str, pcie4_1gb_equiv_ops_str, a100_mem_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsysim.core.constants import (PCIE_GEN4_BW, NVLINK_A100_BW, A100_MEM_BW, A100_MEM_CAPACITY,
|
||
A100_FLOPS_FP16_TENSOR, GB, TB, GiB, TFLOPs, flop, byte, second,
|
||
BILLION, MILLION, THOUSAND)
|
||
from mlsysim.fmt import fmt, check
|
||
|
||
# ┌── LEGO ───────────────────────────────────────────────
|
||
class DeviceBandwidthHierarchy:
|
||
"""
|
||
Namespace for Device Bandwidth Hierarchy.
|
||
Scenario: Comparing PCIe vs NVLink vs HBM speeds.
|
||
"""
|
||
|
||
# ┌── 1. LOAD (Constants) ───────────────────────────────────────────────
|
||
tensor_4mb = 4 * MILLION
|
||
|
||
# ┌── 2. EXECUTE (The Compute) ─────────────────────────────────────────
|
||
pcie4_gbs = PCIE_GEN4_BW.m_as(GB/second)
|
||
nvlink_a100_gbs = NVLINK_A100_BW.m_as(GB/second)
|
||
a100_bw_gbs = A100_MEM_BW.m_as(GB/second)
|
||
a100_bw_tbs = A100_MEM_BW.m_as(TB/second)
|
||
a100_mem = A100_MEM_CAPACITY.m_as(GiB)
|
||
a100_tflops_fp16 = A100_FLOPS_FP16_TENSOR.m_as(TFLOPs/second)
|
||
|
||
# Transfer times for 4 MB tensor
|
||
pcie4_4mb_ms = (tensor_4mb / PCIE_GEN4_BW.m_as(byte/second)) * THOUSAND
|
||
nvlink_4mb_ms = (tensor_4mb / NVLINK_A100_BW.m_as(byte/second)) * THOUSAND
|
||
hbm_4mb_ms = (tensor_4mb / A100_MEM_BW.m_as(byte/second)) * THOUSAND
|
||
|
||
# 1 GB transfer cost analysis
|
||
# 1 GB / BW (B/s) * 1000 ms
|
||
pcie4_1gb_ms = (BILLION / PCIE_GEN4_BW.m_as(byte/second)) * THOUSAND
|
||
# Equiv Ops: (ms / 1000) * FLOPS
|
||
pcie4_1gb_equiv_ops = (pcie4_1gb_ms / THOUSAND) * A100_FLOPS_FP16_TENSOR.m_as(flop/second)
|
||
|
||
# ┌── 4. OUTPUT (Formatting) ──────────────────────────────────────────────
|
||
pcie4_gbs_str = fmt(pcie4_gbs, precision=0, commas=False)
|
||
pcie4_bidir_gbs_str = fmt(pcie4_gbs * 2, precision=0, commas=False)
|
||
nvlink_a100_gbs_str = fmt(nvlink_a100_gbs, precision=0, commas=False)
|
||
a100_bw_gbs_str = fmt(a100_bw_gbs, precision=0, commas=False)
|
||
a100_bw_tbs_str = fmt(a100_bw_tbs, precision=1, commas=False)
|
||
a100_tflops_fp16_str = fmt(a100_tflops_fp16, precision=0, commas=False)
|
||
a100_mem_str = fmt(a100_mem, precision=0, commas=False)
|
||
pcie4_4mb_ms_str = fmt(pcie4_4mb_ms, precision=3, commas=False)
|
||
nvlink_4mb_ms_str = fmt(nvlink_4mb_ms, precision=3, commas=False)
|
||
hbm_4mb_ms_str = fmt(hbm_4mb_ms, precision=3, commas=False)
|
||
pcie4_1gb_ms_str = fmt(pcie4_1gb_ms, precision=0, commas=False)
|
||
pcie4_1gb_equiv_ops_str = fmt((pcie4_1gb_equiv_ops * flop).m_as(TFLOPs), precision=1, commas=False)
|
||
```
|
||
|
||
\index{Device Management!CPU-GPU transfers}
|
||
\index{Memory Management!device placement}
|
||
Tensors and their memory layouts establish *what* the framework computes with. Where that data physically resides, and how it moves between locations, determines whether computation happens at full speed or crawls.
|
||
|
||
### Frameworks as the Operating System Interface {#sec-ml-frameworks-os-interface}
|
||
|
||
\index{CUDA Runtime!OS layer}\index{PCIe DMA!data movement interface}While the high-level API focuses on math, the framework's backend functions as the **Operating System** of the Single-Machine Stack. It manages the two critical resources of a single node: compute scheduling and data movement.
|
||
|
||
The **CUDA Runtime** serves as this OS layer, providing the low-level primitives for launching kernels and managing device memory. The framework coordinates with this runtime to implement **Direct Memory Access (DMA)** over the PCIe bus. As established in @sec-hardware-acceleration, the bandwidth gap between the host (CPU) and device (GPU) is the primary "Data Loading Bottleneck." Frameworks mitigate this through **pinned memory** (page-locked memory) that allows the GPU to read directly from CPU RAM via DMA without interrupting the processor. This "HW/OS" interface is what makes high-throughput training loops possible on a single machine.
|
||
|
||
Every tensor resides on a specific device
|
||
, and cross-device operations incur transfer costs that can dominate execution time. PCIe 4.0 delivers `{python} DeviceBandwidthHierarchy.pcie4_gbs_str` GB/s between CPU and GPU, while HBM2e provides `{python} MemoryWallSpecs.a100_bw_tbs_str` TB/s within the GPU. This bandwidth gap, exceeding 60$\times$, means a single misplaced tensor transfer can erase the entire speedup from GPU acceleration.
|
||
|
||
Why does this matter for framework design? Because the framework must track where every tensor lives and enforce that operations only combine tensors on the same device. When data must move, the framework must decide whether to block execution or overlap the transfer with other work. These decisions, invisible to most users, determine whether a training loop achieves 30% or 80% of theoretical hardware throughput.
|
||
|
||
Three systems principles govern effective device and memory management: understanding the bandwidth hierarchy that constrains data movement, overlapping computation with communication to hide transfer latency, and using fine-grained synchronization to maintain correctness without sacrificing concurrency. The remainder of this section develops each principle, with quantitative analysis grounded in the **Iron Law's** data movement term.
|
||
|
||
#### Principle 1: The Device Bandwidth Hierarchy { .unnumbered}
|
||
|
||
The cost of moving data between devices varies by orders of magnitude depending on the interconnect.[^fn-nvlink-bandwidth-hierarchy] Before examining optimization strategies, we need to understand these costs quantitatively. @tbl-device-transfer-overhead shows transfer times for a $1000\times1000$ float32 tensor (4 MB)---roughly the size of a typical activation tensor in a moderately sized model. The numbers reveal why careless device placement can erase any speedup from GPU acceleration:
|
||
|
||
[^fn-nvlink-bandwidth-hierarchy]: **NVLink**: NVIDIA's high-bandwidth GPU-to-GPU interconnect (see @sec-hardware-acceleration), providing `{python} DeviceBandwidthHierarchy.nvlink_a100_gbs_str` GB/s bidirectional bandwidth (NVLink 3.0 on A100) compared to `{python} DeviceBandwidthHierarchy.pcie4_bidir_gbs_str` GB/s for PCIe 4.0 x16. This ~10$\times$ bandwidth advantage determines whether tensor parallelism is practical for a given model size: splitting a model across GPUs connected by PCIe can make the $D_{\text{vol}}/BW$ communication term dominate total training time, erasing the benefit of additional compute. \index{NVLink!bandwidth hierarchy}
|
||
|
||
| **Interconnect** | **Bandwidth** | **Transfer Time** | **Relative to Compute** |
|
||
|:-----------------|---------------------------------------------------------------------------:|---------------------------------------------------------:|:-----------------------------------|
|
||
| **PCIe 3.0 x16** | 16 GB/s | 0.25 ms | 10$\times$ slower than GPU compute |
|
||
| **PCIe 4.0 x16** | `{python} DeviceBandwidthHierarchy.pcie4_gbs_str` GB/s | `{python} DeviceBandwidthHierarchy.pcie4_4mb_ms_str` ms | 5$\times$ slower than GPU compute |
|
||
| **NVLink 3.0** | `{python} DeviceBandwidthHierarchy.nvlink_a100_gbs_str` GB/s bidirectional | `{python} DeviceBandwidthHierarchy.nvlink_4mb_ms_str` ms | Comparable to GPU compute |
|
||
| **GPU Memory** | `{python} DeviceBandwidthHierarchy.a100_bw_gbs_str` GB/s | `{python} DeviceBandwidthHierarchy.hbm_4mb_ms_str` ms | Optimal |
|
||
|
||
: **Device Transfer Overhead.** Transfer time for a 4 MB tensor across different interconnects. PCIe bandwidth shown is unidirectional (typical for GPU transfers), with full-duplex operation providing 2$\times$ total bandwidth. NVLink bandwidth is bidirectional (300 GB/s per direction). Transfer times dominate for small operations, making device placement critical for performance. {#tbl-device-transfer-overhead}
|
||
|
||
These numbers connect directly to the **Iron Law** of performance. Every cross-device transfer inflates the data movement term ($D_{\text{vol}}/BW$) at a fraction of the available on-device bandwidth. A PCIe 4.0 transfer at `{python} DeviceBandwidthHierarchy.pcie4_gbs_str` GB/s means moving a 1 GB activation tensor adds approximately `{python} DeviceBandwidthHierarchy.pcie4_1gb_ms_str` ms to the data movement cost, equivalent to roughly `{python} DeviceBandwidthHierarchy.pcie4_1gb_equiv_ops_str` trillion operations on a GPU delivering `{python} A100BLAS.dense_tflops_str` TFLOPS. For a model forward pass taking 0.5 ms on GPU, transferring inputs and outputs over PCIe 3.0 doubles the total latency. When batches are small or models are lightweight, transfer overhead can exceed computation time entirely.
|
||
|
||
The systems implication is clear: every tensor should reside on the device where it will be consumed, and transfers should occur only when unavoidable. Frameworks track device placement for every tensor and raise errors when operations attempt to combine tensors from different devices, enforcing this discipline at the API level.
|
||
|
||
##### Principle 2: Overlapping Computation and Communication { .unnumbered}
|
||
|
||
\index{Execution Streams!overlapping computation}
|
||
\index{Asynchronous Execution!hiding transfer latency}
|
||
When transfers are unavoidable, the next optimization is to hide their latency by executing them concurrently with computation. Modern GPUs contain independent hardware units for computation (SM clusters) and data transfer (copy engines), enabling true simultaneous execution. The framework abstraction that exposes this hardware parallelism is the *CUDA stream*\index{Execution Streams!definition}: an independent execution queue where operations execute sequentially within a stream but concurrently across streams.
|
||
|
||
Without explicit concurrency control, the GPU serializes all operations on a single default stream, leaving execution units idle while data transfers complete. By placing data transfers on one stream and computation on another, the effective latency approaches the theoretical minimum of $\max(\text{compute\_time}, \text{transfer\_time})$ rather than their sum. Stream-based overlap effectively hides the $D_{\text{vol}}/BW$ penalty when computation is the longer operation (see @lst-overlap-compute-transfer):
|
||
|
||
::: {#lst-overlap-compute-transfer lst-cap="**Overlapping Computation and Transfer**: Use separate streams for data transfer and computation to hide transfer latency. Pinned memory enables truly asynchronous non-blocking transfers."}
|
||
|
||
```{.python}
|
||
compute_stream = torch.cuda.Stream()
|
||
transfer_stream = torch.cuda.Stream()
|
||
|
||
# Transfer next batch while computing current batch
|
||
with torch.cuda.stream(transfer_stream):
|
||
next_batch = next_batch_cpu.to("cuda", non_blocking=True)
|
||
|
||
with torch.cuda.stream(compute_stream):
|
||
output = model(current_batch)
|
||
loss = criterion(output, labels)
|
||
|
||
# Pinned memory enables non_blocking transfers
|
||
x_pinned = torch.randn(1000, 1000).pin_memory()
|
||
x_gpu = x_pinned.to("cuda", non_blocking=True) # Asynchronous
|
||
|
||
# Regular memory requires blocking transfer
|
||
y_regular = torch.randn(1000, 1000)
|
||
y_gpu = y_regular.to("cuda", non_blocking=True) # Still blocks
|
||
```
|
||
|
||
:::
|
||
|
||
\index{Pinned Memory!definition}
|
||
\index{DMA Transfer!GPU copy engine}
|
||
The `non_blocking=True` flag enables asynchronous transfers that return immediately without waiting for completion. This works only when the source tensor uses *pinned memory*\index{Pinned Memory!DMA transfers} (page-locked memory that enables DMA transfers). Without pinned memory, the transfer blocks even when `non_blocking=True` is specified, because the GPU's copy engine cannot initiate a DMA transfer from pageable host memory.
|
||
|
||
This overlap principle extends naturally to pipeline parallelism within a single node. Different model stages on separate GPUs can process different microbatches concurrently, with each stage's computation overlapping the next stage's data reception (see @lst-pipeline-parallelism-streams):
|
||
|
||
::: {#lst-pipeline-parallelism-streams lst-cap="**Pipeline Parallelism with Streams**: Overlap multiple model stages across microbatches using streams and events for inter-stage synchronization."}
|
||
|
||
```{.python}
|
||
# Pipeline parallelism: overlap stages across microbatches
|
||
stages = [Stage1().cuda(), Stage2().cuda(), Stage3().cuda()]
|
||
streams = [torch.cuda.Stream() for _ in stages]
|
||
events = [
|
||
[torch.cuda.Event() for _ in range(num_microbatches)]
|
||
for _ in stages
|
||
]
|
||
|
||
for mb in range(num_microbatches):
|
||
for stage_idx, (stage, stream) in enumerate(zip(stages, streams)):
|
||
with torch.cuda.stream(stream):
|
||
if stage_idx > 0:
|
||
# Wait for previous stage to complete this microbatch
|
||
events[stage_idx - 1][mb].wait()
|
||
|
||
output = stage(inputs[stage_idx][mb])
|
||
events[stage_idx][mb].record()
|
||
```
|
||
|
||
:::
|
||
|
||
Extending this pattern across multiple machines requires distributed training techniques that constitute an advanced topic, but the single-node implementation above illustrates the core synchronization principles that underlie all pipeline-parallel systems.
|
||
|
||
With computation and communication overlapping effectively, the remaining challenge is ensuring correctness when operations complete out of order.
|
||
|
||
##### Principle 3: Synchronization and Correctness { .unnumbered}
|
||
|
||
\index{Synchronization Events!inter-stream}
|
||
\index{Synchronization!device vs. stream}
|
||
Concurrent execution introduces ordering constraints. When one stream's output becomes another stream's input, the system must enforce a happens-before relationship without unnecessarily serializing independent work. Two synchronization mechanisms exist, with dramatically different performance implications.
|
||
|
||
Full device synchronization (`torch.cuda.synchronize()`) blocks all streams and the CPU until every queued operation completes. This creates a global serialization point that eliminates all overlap benefits. CUDA events\index{Synchronization Events!fine-grained} provide the alternative: fine-grained synchronization that blocks only the dependent stream, allowing other streams and the CPU to continue execution (see @lst-cuda-events):
|
||
|
||
::: {#lst-cuda-events lst-cap="**CUDA Events for Synchronization**: Events enable fine-grained producer-consumer patterns between streams without blocking the entire device."}
|
||
|
||
```{.python}
|
||
# Create streams and event
|
||
stream1 = torch.cuda.Stream()
|
||
stream2 = torch.cuda.Stream()
|
||
event = torch.cuda.Event()
|
||
|
||
# Stream 1: producer
|
||
with torch.cuda.stream(stream1):
|
||
result1 = expensive_computation(data1)
|
||
event.record() # Mark completion point
|
||
|
||
# Stream 2: consumer (waits only for stream1's event)
|
||
with torch.cuda.stream(stream2):
|
||
event.wait() # Block stream2 until event is recorded
|
||
result2 = dependent_computation(result1) # Safe to use result1
|
||
```
|
||
|
||
:::
|
||
|
||
The performance difference between these approaches is not incremental but categorical. Full synchronization after every operation converts a concurrent pipeline into a sequential one, entirely negating the hardware parallelism that streams expose. Event-based synchronization preserves the concurrent execution model while enforcing only the dependencies that correctness requires.
|
||
|
||
###### Device Placement Discipline {.unnumbered}
|
||
|
||
Every tensor carries a `device` attribute, and frameworks enforce a strict rule: operations can only combine tensors on the *same* device. A `RuntimeError` results from mixing `cuda:0` and `cuda:1` tensors, preventing silent cross-device transfers. The `.to()` method moves tensors between devices with copy-on-write semantics---calling `.to("cuda")` on a tensor already on the GPU returns the same object without copying. Module `.to()` recursively moves all parameters and buffers, ensuring the entire model hierarchy lands on a single device. Three placement principles prevent transfer bottlenecks: (1) allocate tensors on the target device from the start rather than creating on CPU and transferring, (2) reuse GPU memory across iterations rather than re-allocating, and (3) colocate all inputs, labels, and model parameters on the same device to eliminate implicit transfers. Violating any of these principles inserts PCIe transfers into the critical path, which at `{python} DeviceBandwidthHierarchy.pcie4_gbs_str` GB/s can dominate a training iteration that otherwise runs at `{python} MemoryWallSpecs.a100_bw_tbs_str` TB/s on-device.
|
||
|
||
###### Synchronization Patterns {.unnumbered}
|
||
|
||
As @lst-cuda-events demonstrated, event-based synchronization preserves parallelism by enforcing only the dependencies that correctness requires. A common mistake in production code is inserting `torch.cuda.synchronize()` calls for debugging and forgetting to remove them, silently converting an overlapped pipeline into a serialized one.
|
||
|
||
###### Profiling Transfer Bottlenecks {.unnumbered}
|
||
|
||
When overlap is insufficient, profiling reveals where time is lost. NVIDIA provides two complementary tools: **Nsight Systems** (`nsys profile`) captures system-wide timelines correlating CPU activity, GPU kernel execution, and memory transfers, identifying *which* kernels dominate runtime. **Nsight Compute** (`ncu`) provides kernel-level analysis with hardware counters, revealing *why* those kernels underperform. @tbl-nsight-metrics lists the key metrics to examine when optimizing ML kernels.
|
||
|
||
| **Metric** | **Meaning** | **Optimization Target** |
|
||
|:-----------------------|:-----------------------------|:-----------------------------------|
|
||
| **SM Occupancy** | Active warps / maximum warps | Increase parallelism if low |
|
||
| **Memory Throughput** | Achieved / peak bandwidth | Optimize memory access patterns |
|
||
| **Compute Throughput** | Achieved / peak FLOPS | Reduce memory bottlenecks |
|
||
| **Tensor Core Active** | Time in Tensor Core ops | Verify mixed-precision utilization |
|
||
|
||
: **Nsight Compute Metrics.** Key metrics for ML kernel optimization. Low values indicate specific optimization opportunities. Nsight Systems identifies which kernels dominate runtime, and Nsight Compute reveals why those kernels underperform. {#tbl-nsight-metrics}
|
||
|
||
#### Data Pipelines and Loading {#sec-ml-frameworks-domainspecific-data-organizations-48d9}
|
||
|
||
```{python}
|
||
#| label: dataloader-throughput-calc
|
||
#| echo: false
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ DATALOADER THROUGHPUT CALCULATIONS
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: Domain-Specific Data Organizations section on sustaining GPU throughput
|
||
# │
|
||
# │ Goal: Quantify the data ingestion requirements for high-speed training.
|
||
# │ Show: That sustaining 1000 images/s requires 150 MB/s continuous throughput.
|
||
# │ How: Calculate bandwidth from image resolution, batch size, and rate.
|
||
# │
|
||
# │ Imports: mlsysim.core.constants (PCIE_GEN4_BW, BYTES_FP32), mlsysim.book (fmt)
|
||
# │ Exports: img_res, dataloader_mbs_str, batch_mb_str, batch_transfer_ms_str, DeviceBandwidthHierarchy.pcie4_gbs_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsysim.core.constants import PCIE_GEN4_BW, BYTES_FP32, MB, GB, byte, second, MS_PER_SEC
|
||
from mlsysim.fmt import fmt, check
|
||
|
||
# ┌── LEGO ───────────────────────────────────────────────
|
||
class DataloaderThroughput:
|
||
"""
|
||
Namespace for Dataloader Throughput.
|
||
Scenario: GPU data ingestion requirements.
|
||
"""
|
||
|
||
# ┌── 1. LOAD (Constants) ───────────────────────────────────────────────
|
||
img_per_sec = 1000
|
||
img_res = 224
|
||
img_channels = 3
|
||
batch_size = 64
|
||
|
||
# ┌── 2. EXECUTE (The Compute) ─────────────────────────────────────────
|
||
# Step 1: Throughput requirement
|
||
throughput_bytes_sec = img_per_sec * img_res * img_res * img_channels * byte
|
||
dataloader_mbs = throughput_bytes_sec.m_as(MB)
|
||
|
||
# Step 2: Batch transfer
|
||
batch_bytes = batch_size * img_res * img_res * img_channels * BYTES_FP32
|
||
batch_mb = batch_bytes.m_as(MB)
|
||
batch_transfer_ms = (batch_bytes / PCIE_GEN4_BW).m_as(second) * MS_PER_SEC
|
||
|
||
# Step 3: PCIe Ref
|
||
pcie4_gbs = PCIE_GEN4_BW.m_as(GB/second)
|
||
|
||
# ┌── 3. GUARD (Invariants) ───────────────────────────────────────────
|
||
check(dataloader_mbs > 100, f"Throughput requirement ({dataloader_mbs:.1f} MB/s) too low.")
|
||
|
||
# ┌── 4. OUTPUT (Formatting) ──────────────────────────────────────────────
|
||
dataloader_mbs_str = fmt(dataloader_mbs, precision=0, commas=False)
|
||
batch_mb_str = fmt(batch_mb, precision=0, commas=False)
|
||
batch_transfer_ms_str = fmt(batch_transfer_ms, precision=1, commas=False)
|
||
pcie4_gbs_str = fmt(pcie4_gbs, precision=0, commas=False)
|
||
```
|
||
|
||
\index{Data Pipeline!throughput optimization}
|
||
Streams and events answer the second question---*where does data live, and how does it move?*---by overlapping transfers with computation so that the GPU rarely stalls on a single tensor. Scheduling alone, however, cannot help if data arrives too slowly in the first place. The third question is *how does data arrive fast enough?* The core systems principle is straightforward: the data pipeline must sustain the accelerator's consumption rate. A GPU processing 1,000 images per second at `{python} DataloaderThroughput.img_res`$\times$ `{python} DataloaderThroughput.img_res` resolution requires approximately `{python} DataloaderThroughput.dataloader_mbs_str` MB/s of sustained data throughput. If the pipeline cannot maintain this rate, the accelerator idles and the effective utilization term in the **Iron Law** drops below 1.
|
||
|
||
Frameworks address this throughput requirement through three mechanisms. The first is *parallel worker processes*: the DataLoader spawns multiple CPU processes, each independently loading and preprocessing samples. Because data loading involves disk I/O and CPU-bound transformations (decoding, augmentation, normalization), a single process cannot saturate a modern GPU. Multiple workers overlap I/O wait times with preprocessing computation, collectively sustaining throughput that no single process could achieve. When `num_workers > 0`, the DataLoader distributes sample indices across workers through a shared queue, and workers push completed samples to a data queue that the main process assembles into batches.
|
||
|
||
The second mechanism is *prefetching*. The `prefetch_factor` parameter (default 2) controls how many batches each worker prepares in advance. With 4 workers and `prefetch_factor=2`, the pipeline maintains 8 batches in flight, ensuring the GPU never stalls waiting for data. While the model processes batch $N$ on the GPU, workers simultaneously load and preprocess batch $N+1$ through $N+8$ on CPUs, effectively hiding data loading latency behind computation. The cost is memory consumption proportional to batch size times prefetch depth.
|
||
|
||
The third mechanism is *pinned memory for DMA transfers*. The `pin_memory=True` option allocates batch data in page-locked (pinned) host memory rather than pageable memory. Pageable memory can be swapped to disk by the operating system, forcing the CUDA runtime to first copy data to a temporary pinned buffer before initiating the GPU transfer. Pinned memory bypasses this intermediate copy, enabling direct memory access (DMA) transfers where the GPU's memory controller reads directly from host memory while the CPU continues other work. For a batch of 64 images at 224$\times$ $224\times3$ in FP32 (`{python} DataloaderThroughput.batch_mb_str` MB), pinned memory transfer takes approximately `{python} DataloaderThroughput.batch_transfer_ms_str` ms over PCIe 4.0 x16 (`{python} DataloaderThroughput.pcie4_gbs_str` GB/s) compared to ~3.0 ms with pageable memory, a 2--3$\times$ speedup. The cost is reduced available system memory, as pinned pages cannot be swapped.
|
||
|
||
These three mechanisms appear together in the DataLoader configuration. Understanding how each parameter connects to the underlying systems principle helps practitioners diagnose data pipeline bottlenecks. @lst-dataloader-throughput shows a typical setup where `num_workers` enables parallel loading, `prefetch_factor` controls pipeline depth, and `pin_memory` enables DMA transfers:
|
||
|
||
::: {#lst-dataloader-throughput lst-cap="**DataLoader Throughput Configuration**: Each parameter addresses a specific throughput bottleneck. num_workers parallelizes I/O and preprocessing across CPU cores, prefetch_factor controls pipeline depth, and pin_memory enables DMA transfers to the GPU."}
|
||
|
||
```{.python}
|
||
from torch.utils.data import DataLoader
|
||
|
||
loader = DataLoader(
|
||
dataset,
|
||
batch_size=64,
|
||
shuffle=True,
|
||
num_workers=4, # Parallel worker processes (mechanism 1)
|
||
prefetch_factor=2, # Batches prepared ahead per worker (mechanism 2)
|
||
pin_memory=True, # Page-locked memory for DMA (mechanism 3)
|
||
worker_init_fn=seed_worker, # Reproducible augmentation per worker
|
||
)
|
||
|
||
# Pipeline effect: while GPU processes batch N,
|
||
# 4 workers load batches N+1..N+8 into pinned memory,
|
||
# ready for DMA transfer when the GPU finishes.
|
||
```
|
||
|
||
:::
|
||
|
||
A practical starting point is setting `num_workers` equal to the number of available CPU cores. The optimal value depends on whether loading is I/O-bound or CPU-bound. For I/O-bound workloads such as reading images from network storage, more workers overlap disk latency and improve throughput. For CPU-bound workloads involving heavy augmentation, the benefit saturates once all cores are in use. Too many workers waste memory, since each maintains a copy of the Dataset object.
|
||
|
||
Worker process management introduces several subtle issues. Because workers are separate processes, random number generators used in data augmentation must be explicitly seeded per worker via `worker_init_fn` to ensure reproducibility. Without proper seeding, workers may produce identical augmentation sequences, reducing effective data diversity. Shared state between workers presents a separate challenge: each worker has its own memory space, so modifications to global variables in one worker do not propagate to others or to the main process. For large datasets where caching matters, memory-mapped files or shared memory regions that persist across processes are the standard solution.
|
||
|
||
The DataLoader wraps a Dataset object that defines how individual samples are accessed. PyTorch supports two dataset paradigms. *Map-style* datasets implement `__len__` and `__getitem__`, enabling random access to samples by index---this pattern works well for datasets that fit in memory or support efficient random access on disk. *Iterable-style* datasets implement `__iter__` instead, yielding samples sequentially for streaming data sources where random access is impractical. The choice between paradigms determines whether the DataLoader can shuffle samples (map-style only) or must process them in arrival order (iterable-style).
|
||
|
||
A final detail is *collation*: the `collate_fn` parameter determines how individual samples are combined into batches. The default collation stacks tensors along a new batch dimension, which works when all samples have identical shapes. For variable-length data such as text sequences, custom collation handles padding, sorting by length, or creating attention masks---directly affecting both memory usage and training throughput.
|
||
|
||
```{python}
|
||
#| label: gpt3-parameter-structures
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ GPT-3 PARAMETER STRUCTURES
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: @sec-ml-frameworks-abstraction-problem-37a5 "Parameter Structures"
|
||
# │ subsection; values used in two prose paragraphs and @fig-3d-parallelism
|
||
# │
|
||
# │ Goal: Quantify GPT-3 FP16 storage (350 GB) to motivate parameter sharding
|
||
# │ across multiple devices and multi-GPU parallelism strategies.
|
||
# │ Show: "175 B parameters require 350 GB in FP16" — inline in prose before
|
||
# │ @fig-3d-parallelism and in "Parameter Structures" sub-paragraph.
|
||
# │ How: model_memory() formula; m_as(Bparam) for param count extraction.
|
||
# │
|
||
# │ Note: PERSISTENT — GPT3MemoryFootprint.gpt3_params_b_str / GPT3MemoryFootprint.gpt3_fp16_gb_str reused again at
|
||
# │ @sec-ml-frameworks-nnmodule-abstraction-2622 (line ~3225) and Fallacies (line ~4082).
|
||
# │ Produces identical values to gpt3-memory-footprint cell (line ~1651).
|
||
# │
|
||
# │ Imports: mlsysim.core.constants (GPT3_PARAMS, BYTES_FP16, GB, Bparam),
|
||
# │ mlsysim.formulas (model_memory), mlsysim.book (fmt)
|
||
# │ Exports: GPT3MemoryFootprint.gpt3_params_b_str, GPT3MemoryFootprint.gpt3_fp16_gb_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsysim.core.constants import GPT3_PARAMS, BYTES_FP16, GB, Bparam
|
||
from mlsysim.core.formulas import model_memory
|
||
from mlsysim.fmt import fmt, check
|
||
|
||
# ┌── LEGO ───────────────────────────────────────────────
|
||
class GPT3ParameterStructures:
|
||
"""GPT-3 FP16 storage to motivate parameter sharding across devices."""
|
||
|
||
# ┌── 1. LOAD (Constants) ──────────────────────────────────────────────
|
||
params_b = GPT3_PARAMS.m_as(Bparam) # 175 billion
|
||
|
||
# ┌── 2. EXECUTE (The Compute) ────────────────────────────────────────
|
||
fp16_gb = model_memory(GPT3_PARAMS, BYTES_FP16, GB) # 350 GB
|
||
|
||
# ┌── 4. OUTPUT (Formatting) ─────────────────────────────────────────────
|
||
gpt3_params_b_str = fmt(params_b, precision=0, commas=False) # e.g. "175"
|
||
gpt3_fp16_gb_str = fmt(fp16_gb, precision=0, commas=False) # e.g. "350"
|
||
```
|
||
|
||
DataLoaders, Datasets, and collation functions answer the third question---*how does data arrive fast enough?*---by sustaining accelerator-rate throughput through parallelism, prefetching, and DMA. These structures, however, handle only *ephemeral* data: samples flow through the pipeline once per epoch and are discarded. The fourth question asks how frameworks manage data that *persists*---the model's own weights---especially when those weights exceed the memory of any single device.
|
||
|
||
##### Parameter Structures {.unnumbered}
|
||
|
||
A GPT-3 scale model stores `{python} GPT3Context.gpt3_params_b_str` billion parameters, occupying `{python} GPT3ParameterStructures.gpt3_fp16_gb_str` GB in FP16. Managing these parameters across devices, keeping gradients synchronized, and maintaining optimizer state (which can triple the memory footprint, as the Administrative Tax notebook showed) is a core framework responsibility.
|
||
|
||
Because parameters persist throughout training and inference, frameworks organize them into compact structures that minimize memory while enabling fast read and write access [@li2014communication]. During multi-GPU training, frameworks may replicate parameters across devices for parallel computation while keeping a synchronized master copy. Synchronizing multi-billion parameter models can require transferring tens of GB of gradients per step, which is why frameworks implement gradient compression and efficient communication patterns like ring all-reduce.
|
||
|
||
Parameter structures must also adapt to varying precision requirements. Training typically uses FP32 for gradient stability, but inference and large-scale training increasingly use FP16 or INT8. Frameworks implement type casting and mixed-precision management to enable these optimizations without compromising numerical accuracy.
|
||
|
||
##### Distributed Execution Contexts {.unnumbered}
|
||
|
||
\index{Distributed Training!execution contexts}
|
||
The computational graph defines *what* to compute, but *where* and *how* that computation runs across devices is the job of execution contexts. On a single node, execution contexts manage CUDA streams and events (discussed earlier in this chapter) to overlap computation and data transfer across GPUs.
|
||
|
||
When training scales beyond a single machine, these same abstractions extend to manage process groups and communication primitives. Frameworks use constructs like `ProcessGroup` (PyTorch) or `Mesh` (JAX) to define how devices communicate, maintaining state for collective operations such as AllReduce that synchronize gradients across thousands of GPUs. This includes partitioning computational graphs, synchronizing gradients, and redistributing data as needed.
|
||
|
||
We introduce these concepts here because they shape framework API design even for single-node code. The implementation details of distributed training---including gradient compression, communication topologies, and fault tolerance---constitute advanced topics that build on these single-node foundations.
|
||
|
||
\index{Data Parallelism!definition}
|
||
\index{Pipeline Parallelism!definition}
|
||
When models exceed single-device memory, frameworks combine multiple parallelism strategies simultaneously. A GPT-3 scale model, for instance, cannot fit on a single GPU---its `{python} GPT3Context.gpt3_params_b_str` B parameters alone require `{python} GPT3ParameterStructures.gpt3_fp16_gb_str` GB in FP16, far exceeding any GPU's memory. How do practitioners train such models? By distributing computation across multiple devices using three complementary strategies. @fig-3d-parallelism lays out how large-scale training distributes computation across three orthogonal dimensions to overcome this constraint. In the figure, look for how each dimension addresses a different scaling need: **Data Parallelism**\index{Data Parallelism!model replication} (replicating the model across columns) scales throughput by processing different batches in parallel; **Pipeline Parallelism**\index{Pipeline Parallelism!layer splitting} (splitting layers across rows) distributes a single model's depth across devices; and **Model Parallelism**\index{Model Parallelism!tensor sharding} (sharding tensors within each cluster) partitions individual layers that are too large for one device. This "3D" approach allows frameworks to scale beyond the memory limits of any single device. @sec-model-training examines these parallelism strategies in depth, including their implementation trade-offs and communication patterns.
|
||
|
||
::: {#fig-3d-parallelism fig-env="figure" fig-pos="htb" fig-cap="**3D Parallelism.** A grid of eight accelerator clusters arranged in two rows and four columns, each containing stacked computational units. Distinct colors encode the three parallelism dimensions: data parallelism across columns, pipeline parallelism across rows, and model parallelism within each cluster." fig-alt="Grid of 8 GPU clusters in 2 rows and 4 columns. Each cluster contains 4 stacked cubes. Colors vary: blue, red, green, orange in bottom row; olive, yellow, brown, pink in top row."}
|
||
|
||
```{.tikz}
|
||
\resizebox{0.70\textwidth}{!}{
|
||
\begin{tikzpicture}[line cap=round,line join=round,font=\small\usefont{T1}{phv}{m}{n}]
|
||
\tikzset{
|
||
pics/square/.style = {
|
||
code = {
|
||
\pgfkeys{/channel/.cd, #1}
|
||
\begin{scope}[local bounding box=SQUARE,scale=\scalefac,every node/.append style={transform shape}]
|
||
% Right Face
|
||
\draw[fill=\channelcolor!70,line width=\Linewidth]
|
||
(\Depth,0,0)coordinate(\picname-ZDD)--(\Depth,\Width,0)--(\Depth,\Width,\Height)--(\Depth,0,\Height)--cycle;
|
||
% Front Face
|
||
\draw[fill=\channelcolor!40,line width=\Linewidth]
|
||
(0,0,\Height)coordinate(\picname-DL)--(0,\Width,\Height)coordinate(\picname-GL)--
|
||
(\Depth,\Width,\Height)coordinate(\picname-GD)--(\Depth,0,\Height)coordinate(\picname-DD)--(0,0,\Height);
|
||
% Top Face
|
||
\draw[fill=\channelcolor!20,line width=\Linewidth]
|
||
(0,\Width,0)coordinate(\picname-ZGL)--(0,\Width,\Height)coordinate(\picname-ZGL)--
|
||
(\Depth,\Width,\Height)--(\Depth,\Width,0)coordinate(\picname-ZGD)--cycle;
|
||
\end{scope}
|
||
}
|
||
}
|
||
}
|
||
\pgfkeys{
|
||
/channel/.cd,
|
||
Depth/.store in=\Depth,
|
||
Height/.store in=\Height,
|
||
Width/.store in=\Width,
|
||
channelcolor/.store in=\channelcolor,
|
||
drawchannelcolor/.store in=\drawchannelcolor,
|
||
scalefac/.store in=\scalefac,
|
||
Linewidth/.store in=\Linewidth,
|
||
picname/.store in=\picname,
|
||
Depth=1.6,
|
||
Height=1.1,
|
||
Width=1.4,
|
||
channelcolor=BrownLine,
|
||
drawchannelcolor=BrownLine,
|
||
scalefac=1,
|
||
Linewidth=1.0pt,
|
||
picname=C
|
||
}
|
||
\def\ras{0.95}
|
||
\def\dis{2.2}
|
||
\begin{scope}[local bounding box=BELOW,shift={($(0,0)+(0,0)$)},scale=1,every node/.append style={transform shape}]
|
||
\begin{scope}[local bounding box=GPU0,shift={($(0,0)+(0,0)$)},scale=1,every node/.append style={transform shape}]
|
||
\foreach \i in {1,...,4} {
|
||
\pic[shift={(0,0)}] at ({-\i*\ras}, {-\ras*\i}) {square={scalefac=1,picname=4,channelcolor=BlueLine,Linewidth=0.7pt}};
|
||
}
|
||
\end{scope}
|
||
\begin{scope}[local bounding box=GPU8,shift={($(0,0)+(\dis,0)$)},scale=1,every node/.append style={transform shape}]
|
||
\foreach \i in {1,...,4} {
|
||
\pic[shift={(0,0)}] at ({-\i*\ras}, {-\ras*\i}) {square={scalefac=1,picname=12,channelcolor=RedLine,Linewidth=0.7pt}};
|
||
}
|
||
\end{scope}
|
||
\begin{scope}[local bounding box=GPU16,shift={($(0,0)+(2*\dis,0)$)},scale=1,every node/.append style={transform shape}]
|
||
\foreach \i in {1,...,4} {
|
||
\pic[shift={(0,0)}] at ({-\i*\ras}, {-\ras*\i}) {square={scalefac=1,picname=20,channelcolor=GreenLine,Linewidth=0.7pt}};
|
||
}
|
||
\end{scope}
|
||
\begin{scope}[local bounding box=GPU16,shift={($(0,0)+(3*\dis,0)$)},scale=1,every node/.append style={transform shape}]
|
||
\foreach \i in {1,...,4} {
|
||
\pic[shift={(0,0)}] at ({-\i*\ras}, {-\ras*\i}) {square={scalefac=1,picname=28-\i,channelcolor=OrangeLine,Linewidth=0.7pt}};
|
||
}
|
||
\end{scope}
|
||
\end{scope}
|
||
%%%%ABOVE
|
||
\begin{scope}[local bounding box=ABOVE,shift={($(0,0)+(0,2.2)$)},scale=1,every node/.append style={transform shape}]
|
||
\begin{scope}[local bounding box=GPU0,shift={($(0,0)+(0,0)$)},scale=1,every node/.append style={transform shape}]
|
||
\foreach \i in {1,...,4} {
|
||
\pic[shift={(0,0)}] at ({-\i*\ras}, {-\ras*\i}) {square={scalefac=1,picname=0,channelcolor=OliveLine,Linewidth=0.7pt}};
|
||
}
|
||
\end{scope}
|
||
\begin{scope}[local bounding box=GPU8,shift={($(0,0)+(1*\dis,0)$)},scale=1,every node/.append style={transform shape}]
|
||
\foreach \i in {1,...,4} {
|
||
\pic[shift={(0,0)}] at ({-\i*\ras}, {-\ras*\i}) {square={scalefac=1,picname=8,channelcolor=pink,Linewidth=0.7pt}};
|
||
}
|
||
\end{scope}
|
||
\begin{scope}[local bounding box=GPU16,shift={($(0,0)+(2*\dis,0)$)},scale=1,every node/.append style={transform shape}]
|
||
\foreach \i in {1,...,4} {
|
||
\pic[shift={(0,0)}] at ({-\i*\ras}, {-\ras*\i}) {square={scalefac=1,picname=16,channelcolor=green!70!,Linewidth=0.7pt}};
|
||
}
|
||
\end{scope}
|
||
\begin{scope}[local bounding box=GPU16,shift={($(0,0)+(3*\dis,0)$)},scale=1,every node/.append style={transform shape}]
|
||
\foreach \i in {1,...,4} {
|
||
\pic[shift={(0,0)}] at ({-\i*\ras}, {-\ras*\i}) {square={scalefac=1,picname=24,channelcolor=red,Linewidth=0.7pt}};
|
||
}
|
||
\end{scope}
|
||
\end{scope}
|
||
\node[]at($(28-4-GL)!0.5!(28-4-DD)$){GPU 28};
|
||
%
|
||
\foreach \i in {0,8,16,24,4,12,20} {
|
||
\node[]at($(\i-GL)!0.5!(\i-DD)$){GPU \i};
|
||
}
|
||
\draw[thick,decoration={brace,amplitude=5pt,mirror},decorate]([yshift=-2mm]4-DL)--
|
||
([yshift=-2mm]28-4-DD) node [midway,below=2mm] {Pipeline Parallel};
|
||
\draw[thick,decoration={brace,amplitude=5pt},decorate]([xshift=-2mm]4-DL)--
|
||
([xshift=-2mm]0-GL) node [midway,above=5mm, sloped,pos=0.9,anchor=east] {Zero Data Parallel};
|
||
\draw[thick,decoration={brace,amplitude=5pt,mirror},decorate]([xshift=2mm]28-4-DD)--
|
||
([xshift=2mm]28-1-ZDD)node[midway, below=4mm, anchor=west, sloped,pos=0.25] {Model Parallel};
|
||
\end{tikzpicture}}
|
||
```
|
||
|
||
:::
|
||
|
||
The data structures examined so far---tensors, device managers, data pipelines, parameter structures, and distributed execution contexts---define *what* data a framework manages and *where* it lives. What remains is the final question of *what actually runs on the hardware*.
|
||
|
||
### Core Operations {#sec-ml-frameworks-core-operations-914f}
|
||
|
||
\index{Operator Kernels!dispatch hierarchy}
|
||
When an engineer writes `y = torch.matmul(x, w)`, the gap between Python and the GPU is larger than it appears. The gap between a single line of Python and thousands of parallel GPU threads is bridged by three distinct layers working in coordination. @fig-mlfm-core-ops breaks this bridge into three distinct layers---read from bottom to top to follow the path from hardware to application: hardware abstraction operations manage computing platform complexity, basic numerical operations implement mathematical computations, and system-level operations coordinate resources and execution.
|
||
|
||
::: {#fig-mlfm-core-ops fig-env="figure" fig-pos="htb" fig-cap="**Core Operations Stack.** Three grouped layers showing how frameworks bridge Python code to hardware. The top layer contains system-level operations (scheduling, memory management, resource optimization), the middle layer holds numerical operations (GEMM, BLAS, element-wise), and the bottom layer provides hardware abstraction (kernel management, memory abstraction, execution control)." fig-alt="Three grouped boxes connected by arrows. System-Level: Scheduling, Memory Management, Resource Optimization. Numerical: GEMM, BLAS, Element-wise Operations. Hardware: Kernel Management, Memory Abstraction, Execution Control."}
|
||
|
||
```{.tikz}
|
||
\begin{tikzpicture}[font=\usefont{T1}{phv}{m}{n}\small]
|
||
%
|
||
\tikzset{Line/.style={line width=1.0pt,black!50
|
||
},
|
||
Box/.style={align=flush center,
|
||
inner xsep=2pt,
|
||
node distance=0.3,
|
||
draw=BlueLine,
|
||
line width=0.75pt,
|
||
fill=BlueL,
|
||
text width=34mm,
|
||
minimum width=30mm,
|
||
minimum height=10mm
|
||
},
|
||
}
|
||
\begin{scope}[local bounding box=box1]
|
||
\node[Box,](B1){Scheduling};
|
||
\node[Box,below=of B1](B2){Memory Management};
|
||
\node[Box,below=of B2](B3){Resource Optimization};
|
||
%
|
||
\scoped[on background layer]
|
||
\node[draw=BackLine,inner xsep=4mm,inner ysep=5mm,yshift=3mm,
|
||
fill=BackColor,fit=(B1)(B2)(B3),line width=0.75pt](BB1){};
|
||
\node[below=2pt of BB1.north,anchor=north]{System-Level Operations};
|
||
\end{scope}
|
||
|
||
\begin{scope}[local bounding box=box2,shift={(5.5,0)}]
|
||
\node[Box,fill=BrownL,draw=BrownLine,](B1){GEMM Operations};
|
||
\node[Box,fill=BrownL,draw=BrownLine,below=of B1](B2){BLAS Operations};
|
||
\node[Box,fill=BrownL,draw=BrownLine,below=of B2](B3){Element-wise Operations};
|
||
%
|
||
\scoped[on background layer]
|
||
\node[draw=BackLine,inner xsep=4mm,inner ysep=5mm,yshift=3mm,
|
||
fill=BackColor,fit=(B1)(B2)(B3),line width=0.75pt](BB2){};
|
||
\node[below=2pt of BB2.north,anchor=north]{Basic Numerical Operations};
|
||
\end{scope}
|
||
|
||
\begin{scope}[local bounding box=box3,shift={(11,0)}]
|
||
\node[Box,fill=OrangeL,draw=OrangeLine,](B1){Compute Kernel Management};
|
||
\node[Box,fill=OrangeL,draw=OrangeLine,below=of B1](B2){Memory Abstraction};
|
||
\node[Box,fill=OrangeL,draw=OrangeLine,below=of B2](B3){Execution Control};
|
||
%
|
||
\scoped[on background layer]
|
||
\node[draw=BackLine,inner xsep=4mm,inner ysep=5mm,yshift=3mm,
|
||
fill=BackColor,fit=(B1)(B2)(B3),line width=0.75pt](BB3){};
|
||
\node[below=2pt of BB3.north,anchor=north]{Hardware Operations};
|
||
\end{scope}
|
||
|
||
\foreach \x/\y in{1/2,2/3}
|
||
\draw[-latex,Line](box\x)--(box\y);
|
||
\end{tikzpicture}
|
||
```
|
||
|
||
:::
|
||
|
||
#### Hardware Abstraction Operations {#sec-ml-frameworks-hardware-abstraction-operations-1204}
|
||
|
||
The hardware abstraction layer isolates framework code from platform-specific details. It solves three concrete problems: selecting the right compute kernel, moving data through the memory hierarchy, and coordinating execution across processing units.
|
||
|
||
##### Compute Kernel Management {.unnumbered}
|
||
|
||
\index{Kernel Dispatch!hardware selection}
|
||
The kernel manager dispatches each operation to the fastest available implementation for the current hardware. When a framework encounters a matrix multiplication, it selects among AVX-512 vector instructions on modern CPUs, [cuBLAS](https://developer.nvidia.com/cublas) on NVIDIA GPUs, or dedicated tensor processing instructions on AI accelerators. The dispatch decision depends on input dimensions, data layout, and hardware capabilities. A $4096\times4096$ GEMM on an A100 GPU routes to cuBLAS Tensor Core kernels that sustain up to `{python} A100BLAS.dense_tflops_str` TFLOPS in FP16, while the same operation on a CPU falls back to an AVX-512 path at roughly 2 TFLOPS. When no specialized kernel exists, the manager falls back to a generic implementation rather than failing.
|
||
|
||
##### Memory System Abstraction {.unnumbered}
|
||
|
||
\index{Memory Layout!NCHW vs. NHWC}
|
||
The memory abstraction layer moves tensors between device types (CPU registered memory, GPU pinned memory, unified memory) and transforms data layouts to match hardware preferences. A convolutional layer, for example, may store activations in NCHW format (batch, channels, height, width) on NVIDIA GPUs but convert to NHWC for Apple's Metal backend. Alignment requirements vary from 4 bytes on CPUs to 128 bytes on some accelerators, and misaligned access can halve effective memory bandwidth. The layer also enforces cache coherency when multiple execution units read and write the same tensor, preventing silent data corruption during concurrent operations.
|
||
|
||
##### Execution Control {.unnumbered}
|
||
|
||
The execution controller coordinates work across multiple processing units and memory spaces. On a modern GPU, this means managing dozens of concurrent CUDA streams: when two independent convolutions are both ready to execute, the controller launches them on separate streams so they overlap on the GPU's streaming multiprocessors, improving utilization from as low as 40% (sequential) to over 80% (concurrent). The controller inserts synchronization barriers only where true data dependencies exist, tracks event completions to trigger dependent operations, and routes hardware errors (ECC failures, timeout watchdogs) to the framework's error handling path.
|
||
|
||
```{python}
|
||
#| label: resnet50-gflops
|
||
#| echo: false
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ RESNET-50 GFLOPS FOR GEMM DISCUSSION
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: Basic Numerical Operations section discussing GEMM dominance
|
||
# │
|
||
# │ Goal: Demonstrate the computational dominance of GEMM in vision models.
|
||
# │ Show: That ResNet-50 performs 8.2 GFLOPs per forward pass.
|
||
# │ How: Retrieve ResNet-50 GFLOPs from mlsysim.core.constants.
|
||
# │
|
||
# │ Imports: mlsysim.core.constants (RESNET50_FLOPs), mlsysim.book (fmt)
|
||
# │ Exports: resnet_gflops_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsysim.core.constants import RESNET50_FLOPs, GFLOPs
|
||
from mlsysim.fmt import fmt, check
|
||
|
||
# ┌── LEGO ───────────────────────────────────────────────
|
||
class ResNetGFLOPS:
|
||
"""
|
||
Namespace for ResNet GFLOPS.
|
||
Scenario: Compute intensity check.
|
||
"""
|
||
|
||
# ┌── 1. LOAD (Constants) ───────────────────────────────────────────────
|
||
flops = RESNET50_FLOPs
|
||
|
||
# ┌── 2. EXECUTE (The Compute) ─────────────────────────────────────────
|
||
gflops = flops.m_as(GFLOPs)
|
||
|
||
# ┌── 4. OUTPUT (Formatting) ──────────────────────────────────────────────
|
||
resnet_gflops_str = fmt(gflops, precision=1, commas=False)
|
||
```
|
||
|
||
#### Basic Numerical Operations {#sec-ml-frameworks-basic-numerical-operations-5516}
|
||
|
||
\index{GEMM!arithmetic intensity}
|
||
With hardware abstraction managing the platform-specific details, frameworks build a layer of mathematical operations on top. General Matrix Multiply (GEMM)\index{GEMM!matrix multiplication}\index{Tensor Operations!GEMM} dominates ML computation (see @sec-algorithm-foundations-general-matrix-multiply-gemm-b55d for arithmetic intensity analysis and the roofline implications). The operation C = $\alpha$AB + $\beta$C accounts for the vast majority of arithmetic in neural networks: a single ResNet-50 forward pass performs approximately `{python} ResNetGFLOPS.resnet_gflops_str` billion floating-point operations, nearly all of which reduce to GEMM. Frameworks optimize GEMM through cache-aware tiling (splitting matrices into blocks that fit in L1/L2 cache), loop unrolling for instruction-level parallelism, and shape-specific kernels. Fully connected layers use standard dense GEMM, while convolutional layers use im2col transformations that reshape input patches into matrix columns, converting convolution into GEMM.
|
||
|
||
Beyond GEMM, frameworks implement BLAS operations\index{BLAS!vector and matrix operations} (AXPY for vector addition, GEMV for matrix-vector products) and element-wise operations\index{Element-wise Operations!memory bandwidth} (activation functions, normalization). Element-wise operations are individually cheap but collectively expensive due to memory bandwidth. Each operation reads and writes the full tensor, so a sequence of five element-wise operations on a 100 MB tensor moves 1 GB of data. Fusing those five operations into a single kernel reduces memory traffic to 200 MB, a 5$\times$ bandwidth savings that directly translates to faster execution.
|
||
|
||
Numerical precision adds another dimension. Training in FP32 uses 4 bytes per parameter; quantizing to INT8 reduces this to 1 byte, cutting memory by 4$\times$ and enabling 2--4$\times$ throughput improvements on hardware with INT8 acceleration. Training typically requires FP32 for gradient stability, while inference runs at FP16 or INT8 with minimal accuracy loss. Frameworks maintain separate kernel implementations for each precision format and handle mixed-precision workflows where different layers operate at different bit widths within a single forward pass.
|
||
|
||
#### System-Level Operations {#sec-ml-frameworks-systemlevel-operations-6a1c}
|
||
|
||
Hardware abstraction and numerical operations provide the building blocks; system-level operations orchestrate them. The system layer ties scheduling, memory management, and resource optimization into a coherent execution engine.
|
||
|
||
The operation scheduler analyzes the computational graph to find parallelism while respecting data dependencies. In a static graph, the scheduler sees the full dependency structure before execution begins and can plan an optimal ordering. In a dynamic graph, dependencies emerge at runtime, forcing the scheduler to make greedy decisions. Concretely, when a ResNet block produces two independent branch outputs, the scheduler launches both branches simultaneously rather than serializing them, reducing idle cycles on the GPU's streaming multiprocessors.
|
||
|
||
The memory manager allocates and reclaims GPU memory across the computational graph's lifetime. Model parameters (a 7B-parameter model consumes approximately `{python} Model7B.model_7b_fp16_gb_str` GB in FP16) persist for the entire training run, while activation tensors live only until the backward pass consumes them. PyTorch's caching allocator maintains a memory pool, subdividing and reusing freed blocks without returning them to CUDA, which avoids the 1 ms overhead of `cudaMalloc` calls. For models that exceed GPU memory, the manager applies gradient checkpointing: discarding selected activations during the forward pass and recomputing them during the backward pass, trading roughly 20--33% additional compute for 60% or more memory savings (with optimal checkpoint placement).
|
||
|
||
The resource optimizer integrates these scheduling and memory decisions. When two matrix multiplications with different shapes are ready to execute, it selects the algorithm variant (Winograd, Strassen, or standard tiled GEMM) that best fits each shape and the current memory pressure. A poorly scheduled graph wastes compute; a poorly managed memory pool triggers out-of-memory errors on hardware that theoretically has capacity to spare.
|
||
|
||
The preceding sections examined what happens *below* the API surface: tensors manage data layout, streams overlap computation with communication, and kernel dispatch routes operations to hardware. These mechanisms operate at the level of individual tensors and operations---the raw materials of machine learning computation. Practitioners, however, rarely write code at this level. A ResNet-50 has `{python} ResNetMemory.resnet_params_m_str` million parameters organized into dozens of layers; manually tracking each tensor, registering it with an optimizer, and handling device placement would be error-prone and tedious. The abstraction problem is not fully solved by hardware-level mechanisms alone; it also requires a *programming model* that organizes these low-level primitives into the clean APIs that practitioners actually use.
|
||
|
||
::: {.callout-checkpoint title="Hardware Abstraction"}
|
||
|
||
The abstraction problem is the bridge between *portable code* and *efficient execution*.
|
||
|
||
- [ ] **Two dimensions**: Can you distinguish **data representation** (layout, dtype, placement) from **execution mapping** (kernel selection, scheduling), and explain how they constrain each other?
|
||
- [ ] **Kernel dispatch**: Can you explain why the same high-level operation (e.g., GEMM) needs multiple implementations (CPU vector path vs GPU Tensor Core path) and how shapes/dtypes affect the choice?
|
||
- [ ] **Memory abstraction**: Can you explain why frameworks use caching allocators and layout transforms (NCHW↔NHWC) rather than calling device allocators on every tensor?
|
||
- [ ] **Execution control**: Can you describe what the runtime is doing when it overlaps independent work (streams) and inserts synchronization only where dependencies require it?
|
||
|
||
:::
|
||
|
||
Individual operations---matrix multiplications, activations, normalizations---are the atoms of deep learning computation. Building models from individual operations, however, would be like building a house from individual atoms. Frameworks need an organizational abstraction that lets engineers compose operations into reusable, nestable building blocks. That abstraction is the *module*.
|
||
|
||
## nn.Module Abstraction {#sec-ml-frameworks-nnmodule-abstraction-2622}
|
||
|
||
\index{Module Abstraction!PyTorch nn.Module}
|
||
\index{Module Abstraction!hierarchical composition}
|
||
The hardware-facing half of the abstraction problem---tensors, kernels, streams, and memory managers---makes individual operations fast on diverse silicon. A ResNet-50, however, contains fifty layers, each with multiple parameter tensors, buffers, and mode-dependent behaviors. Manually wiring each tensor to the correct device, registering it with an optimizer, toggling dropout behavior between training and inference, and serializing state for checkpointing---for every layer---would drown practitioners in bookkeeping that has nothing to do with model design. The upper layer of the abstraction problem is organizational: composing thousands of low-level primitives into the clean, composable APIs that practitioners actually use.
|
||
|
||
Every major framework answers this question through a *module abstraction* that bundles parameters, forward computation, and state management into a single reusable unit. PyTorch's `nn.Module`[^fn-nn-module-composition] provides an instructive case study because its design patterns recur across frameworks: Keras uses similar layer abstractions, JAX's Flax employs analogous module structures, and TensorFlow's functional API shares conceptual parallels. Rather than catalog its API, we extract three enduring design principles that every framework must address regardless of its syntax or programming paradigm.
|
||
|
||
[^fn-nn-module-composition]: **nn.Module**: The "design patterns recur" claim holds because `nn.Module` solves a universal organizational problem: it automatically registers any assigned submodule or parameter into a hierarchical tree, enabling a single `.to('cuda')` call to recursively place millions of parameters onto a GPU. Keras layers, JAX Flax modules, and TensorFlow's `tf.Module` all implement the same tree-walking pattern. Without it, managing model state would require manual bookkeeping that scales linearly with architectural depth, a cost that grows prohibitive for models with hundreds of layers. \index{nn.Module!parameter management}
|
||
|
||
### Principle 1: Automatic Parameter Discovery { .unnumbered}
|
||
|
||
\index{nn.Module!parameter discovery}
|
||
\index{Parameter Registration!automatic tracking}
|
||
A modern neural network may contain millions of trainable parameters spread across dozens of layers. Without automation, a programmer would need to enumerate every parameter tensor and pass it to the optimizer manually, an error-prone process that scales poorly with model complexity. Frameworks solve this through *automatic parameter discovery*: the system walks the module tree, collecting every parameter tensor so the optimizer can update them in a single call.
|
||
|
||
This is a graph traversal problem at its core. When a developer assigns an `nn.Parameter` as a class attribute, the framework's metaclass machinery intercepts the assignment and registers the tensor in an internal dictionary. A call to `.parameters()` then performs a recursive depth-first traversal of the module tree, yielding every registered parameter. The same pattern appears in every major framework: Keras layers maintain a `trainable_weights` list, JAX's Flax modules use `init()` to return a nested parameter dictionary, and TensorFlow's `tf.Module` provides `trainable_variables`. The mechanism differs but the principle is universal.
|
||
|
||
The systems consequence is significant. Automatic parameter discovery enables `optimizer.step()` to update millions of parameters in a single vectorized operation, keeping the operations-per-parameter term efficient by avoiding per-parameter Python dispatch. Without this abstraction, each parameter update would require a separate Python function call, and the interpreter overhead alone would dominate training time for large models. @lst-parameter_registration demonstrates the core mechanism: attribute assignment triggers registration, and `.parameters()` returns all discovered tensors.
|
||
|
||
::: {#lst-parameter_registration lst-cap="**Parameter Registration**: Automatic parameter tracking through attribute assignment enables optimizer access to all trainable weights without manual enumeration."}
|
||
|
||
```{.python}
|
||
import torch
|
||
import torch.nn as nn
|
||
|
||
|
||
class CustomLayer(nn.Module):
|
||
def __init__(self, input_size, output_size):
|
||
super().__init__()
|
||
self.weight = nn.Parameter(
|
||
torch.randn(output_size, input_size)
|
||
)
|
||
self.bias = nn.Parameter(torch.randn(output_size))
|
||
self.register_buffer("running_mean", torch.zeros(output_size))
|
||
|
||
def forward(self, x):
|
||
return torch.matmul(x, self.weight.t()) + self.bias
|
||
|
||
|
||
layer = CustomLayer(10, 20)
|
||
# Framework discovers both parameters automatically:
|
||
for name, param in layer.named_parameters():
|
||
print(f"{name}: shape {param.shape}")
|
||
```
|
||
|
||
:::
|
||
|
||
The distinction between *parameters* and *buffers* illustrates a subtlety of the discovery mechanism. Parameters carry `requires_grad=True` and participate in gradient computation. Buffers, registered through `register_buffer()`, travel with the model during device transfers but remain excluded from gradient updates. This separation is essential for normalization layers, where running statistics must persist across batches but must not receive gradients. The same dual-track design appears in Keras (via `non_trainable_weights`) and Flax (via `state` versus `params`).
|
||
|
||
::: {.callout-perspective title="Cross-Framework Parameter Discovery" collapse="true"}
|
||
|
||
The same principle manifests differently across frameworks:
|
||
|
||
| **Framework** | **Parameter Access** | **Non-Trainable State** |
|
||
|:---------------|:------------------------------|:---------------------------------|
|
||
| **PyTorch** | `model.parameters()` | `register_buffer()` |
|
||
| **Keras** | `layer.trainable_weights` | `layer.non_trainable_weights` |
|
||
| **JAX/Flax** | `params = model.init(key, x)` | Separate `state` dict |
|
||
| **TensorFlow** | `module.trainable_variables` | `module.non_trainable_variables` |
|
||
|
||
Despite syntactic differences, all frameworks solve the same problem: enabling optimizers to discover and update trainable parameters while preserving non-trainable state across forward passes.
|
||
|
||
:::
|
||
|
||
### Principle 2: Mode-Dependent Behavior { .unnumbered}
|
||
|
||
\index{nn.Module!train vs. eval modes}
|
||
\index{Dropout!training vs. inference}
|
||
\index{BatchNormalization!mode-dependent behavior}
|
||
Training and inference require different computational behavior from the same model graph. During training, dropout layers randomly zero elements with probability $p$ to regularize the network, while during inference those same layers must perform identity mapping to produce deterministic outputs. Batch normalization uses per-batch statistics during training but switches to accumulated running statistics during inference. If these behavioral changes are left to the programmer, forgetting a single mode switch produces silently incorrect predictions in production.
|
||
|
||
Frameworks solve this with a *state flag* that propagates through the module hierarchy. A single call to `.eval()` on the root module recursively sets `self.training = False` on every descendant, and each layer queries this flag to select its behavior. This is an instance of a broader systems principle: the same computation graph must produce different execution behavior depending on context. Compilers face the same challenge when the same source code must produce debug builds (with bounds checking and symbol tables) versus release builds (with aggressive optimization). The flag-propagation pattern ensures correctness by centralizing the mode decision at the root rather than requiring per-layer coordination.
|
||
|
||
This principle extends to parameter freezing for transfer learning. Setting `requires_grad=False` on specific parameters excludes them from gradient computation, effectively creating a third behavioral mode where some parameters train while others remain fixed. Selective freezing achieves computational savings by pruning the backward pass graph: frozen parameters need no gradient storage, reducing memory consumption proportionally.
|
||
|
||
### Principle 3: Hierarchical Composition and Serialization { .unnumbered}
|
||
|
||
\index{nn.Module!hierarchical composition}
|
||
\index{Model Serialization!state dictionary pattern}
|
||
Complex models compose from reusable submodules, creating a tree structure. A ResNet is not implemented as a monolithic block of operations but as a hierarchy: the root module contains a sequence of residual blocks, each block contains convolution layers and normalization layers, and each layer contains parameter tensors. This hierarchical composition must support two critical operations: *recursive parameter collection* for training and *state serialization* for checkpointing and deployment.
|
||
|
||
Hierarchical composition mirrors the hardware memory hierarchy in a systems-relevant way: each submodule's parameters can be loaded independently, enabling model parallelism across devices. When a model is too large for a single GPU, the framework can assign different subtrees of the module hierarchy to different devices, with the tree structure providing natural partition boundaries.
|
||
|
||
The state dictionary mechanism provides the serialization half of this principle. The `state_dict()` method produces a flat key-value mapping of the full module tree, where dotted path names (e.g., `blocks.0.conv1.weight`) encode the hierarchy. This flat structure enables efficient serialization: a 7B-parameter model's approximately `{python} Model7B.model_7b_fp16_gb_str` GB FP16 checkpoint can be written as a sequential byte stream, maximizing storage bandwidth utilization. The inverse operation, `load_state_dict()`, reconstructs the hierarchy from the flat mapping, enabling checkpoint recovery and cross-framework model exchange via formats like ONNX. @lst-nested_modules demonstrates how the module tree enables both recursive parameter access and hierarchical state serialization.
|
||
|
||
::: {#lst-nested_modules lst-cap="**Nested Module Composition**: Hierarchical module composition enables recursive parameter collection and flat state serialization across the module tree."}
|
||
|
||
```{.python}
|
||
import torch
|
||
import torch.nn as nn
|
||
|
||
|
||
class ResidualBlock(nn.Module):
|
||
def __init__(self, channels):
|
||
super().__init__()
|
||
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
|
||
self.bn1 = nn.BatchNorm2d(channels)
|
||
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
|
||
self.bn2 = nn.BatchNorm2d(channels)
|
||
|
||
def forward(self, x):
|
||
residual = x
|
||
x = torch.relu(self.bn1(self.conv1(x)))
|
||
x = self.bn2(self.conv2(x))
|
||
return torch.relu(x + residual)
|
||
|
||
|
||
class ResNet(nn.Module):
|
||
def __init__(self, num_blocks, channels=64):
|
||
super().__init__()
|
||
self.conv_in = nn.Conv2d(3, channels, 7, padding=3)
|
||
self.blocks = nn.ModuleList(
|
||
[ResidualBlock(channels) for _ in range(num_blocks)]
|
||
)
|
||
self.fc = nn.Linear(channels, 10)
|
||
|
||
def forward(self, x):
|
||
x = self.conv_in(x)
|
||
for block in self.blocks:
|
||
x = block(x)
|
||
x = x.mean(dim=[2, 3]) # Global average pooling
|
||
return self.fc(x)
|
||
|
||
|
||
model = ResNet(num_blocks=4)
|
||
total = sum(p.numel() for p in model.parameters())
|
||
print(f"Total parameters: {total}")
|
||
# state_dict() flattens the tree: 'blocks.0.conv1.weight', etc.
|
||
print(list(model.state_dict().keys())[:4])
|
||
```
|
||
|
||
:::
|
||
|
||
The hierarchical structure also enables module-level traversal for systematic operations. Methods like `.named_modules()` iterate the entire tree, supporting bulk transformations such as replacing all BatchNorm layers with GroupNorm or applying Xavier initialization to every Linear layer. These traversal operations depend on the same tree structure that enables parameter discovery, illustrating how a single design decision propagates benefits across multiple use cases.
|
||
|
||
These three principles, automatic parameter discovery, mode-dependent behavior, and hierarchical composition with serialization, are not PyTorch-specific. Every framework must solve them. Keras layers, JAX's Flax modules, and even functional approaches all address the same problems of parameter management, state tracking, and compositional design. The differences lie not in *what* problems they solve but in *how* they prioritize among competing solutions. Two practical patterns built on these principles deserve attention: *selective parameter freezing* for transfer learning (@lst-parameter_freezing) and *module hooks* for non-invasive inspection (@lst-module_hooks).
|
||
|
||
::: {#lst-parameter_freezing lst-cap="**Parameter Freezing**: Demonstrates selective parameter freezing for transfer learning, where pretrained layers remain fixed while new layers train."}
|
||
|
||
```{.python}
|
||
# Freeze all parameters in a pretrained model
|
||
pretrained_model = torch.hub.load(
|
||
"pytorch/vision", "resnet18", pretrained=True
|
||
)
|
||
|
||
for param in pretrained_model.parameters():
|
||
param.requires_grad = False
|
||
|
||
# Replace final layer with trainable parameters
|
||
pretrained_model.fc = nn.Linear(512, 10) # New layer is trainable
|
||
|
||
# Only fc.parameters() will receive gradients during training
|
||
optimizer = torch.optim.Adam(
|
||
filter(lambda p: p.requires_grad, pretrained_model.parameters()),
|
||
lr=0.001,
|
||
)
|
||
```
|
||
|
||
:::
|
||
|
||
\index{Module Hooks!forward and backward}
|
||
Forward and backward hooks intercept intermediate computations without modifying model code, enabling gradient flow diagnosis and activation monitoring. @lst-module_hooks illustrates both hook types.
|
||
|
||
::: {#lst-module_hooks lst-cap="**Module Hooks**: Shows forward and backward hooks for inspecting activations and gradients during training."}
|
||
|
||
```{.python}
|
||
import torch
|
||
import torch.nn as nn
|
||
|
||
model = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5))
|
||
|
||
|
||
# Forward hook to inspect activations
|
||
def forward_hook(module, input, output):
|
||
print(
|
||
f"Layer: {module.__class__.__name__}, "
|
||
f"Output shape: {output.shape}, "
|
||
f"mean={output.mean():.3f}, "
|
||
f"std={output.std():.3f}"
|
||
)
|
||
|
||
|
||
# Backward hook to inspect gradients
|
||
def backward_hook(module, grad_input, grad_output):
|
||
print(f"Gradient norm: {grad_output[0].norm():.3f}")
|
||
|
||
|
||
# Register hooks on specific layer
|
||
handle_fwd = model[0].register_forward_hook(forward_hook)
|
||
handle_bwd = model[0].register_full_backward_hook(backward_hook)
|
||
|
||
# Execute forward and backward pass
|
||
x = torch.randn(32, 10)
|
||
y = model(x)
|
||
loss = y.sum()
|
||
loss.backward()
|
||
|
||
# Remove hooks when done
|
||
handle_fwd.remove()
|
||
handle_bwd.remove()
|
||
```
|
||
|
||
:::
|
||
|
||
Together, these patterns---parameter discovery, freezing, and hooks---demonstrate how the three principles translate into practical APIs. The `nn.Module` patterns above illustrate PyTorch's approach to the abstraction problem. PyTorch, however, is only one of several major frameworks, and its choices (mutable state, class inheritance, eager execution by default) are not the only valid design points. TensorFlow centralizes state differently, and JAX avoids mutable state entirely. These are not superficial API differences; they reflect deeply different answers to the three problems we examined at the chapter's start.
|
||
|
||
## Framework Platform Analysis {#sec-ml-frameworks-major-framework-platform-analysis-fe96}
|
||
|
||
\index{ML Framework!platform comparison}
|
||
Each major framework represents a distinct point in the design space defined by the three core problems: TensorFlow prioritizes the Abstraction Problem through its comprehensive deployment ecosystem, PyTorch prioritizes the Execution Problem through its dynamic graph approach, and JAX reframes the Differentiation Problem through composable function transformations. These differences are architectural, reflecting fundamental capability trade-offs that determine what each framework can and cannot do well.
|
||
|
||
### TensorFlow: The Graph-First Production Machine {#sec-ml-frameworks-tensorflow-ecosystem-063c}
|
||
|
||
\index{TensorFlow!graph-first architecture}
|
||
\index{TensorFlow!production deployment}
|
||
TensorFlow's architecture reflects a comprehensive solution to the Abstraction Problem: targeting diverse hardware, from cloud TPUs to microcontrollers, through a single interface. Google's production environment demanded this breadth because the same model often needed to serve predictions on TPU pods in the datacenter, on Android phones via TensorFlow Lite, and in web browsers through TensorFlow.js. This deployment diversity drove the choice of a **Static Graph** (or "Define-and-Run") design. By requiring the model to be represented as a complete computational graph before execution, TensorFlow enables ahead-of-time (AOT) compilation and optimization for each target platform.
|
||
|
||
The graph-first approach prioritizes the **Deployment Spectrum**: because the framework sees the entire graph, it can perform aggressive optimizations like constant folding, operator fusion, and memory layout optimization before the first byte of data is processed. TensorFlow's dominance in complex production ecosystems traces directly to this ahead-of-time optimization capability. @fig-tensorflow-architecture maps the full training-to-deployment pipeline---trace how a model flows from data preprocessing through distributed training on the left, then fans out to serving, mobile (TF Lite), browser (TF.js), and language bindings on the right.
|
||
|
||
::: {#fig-tensorflow-architecture fig-env="figure" fig-pos="htb" fig-cap="**TensorFlow Training-to-Deployment Pipeline.** Two-column diagram showing the training path (left) from data preprocessing through tf.keras and distribution strategy across CPU, GPU, and TPU, and the deployment path (right) from SavedModel export to TensorFlow Serving, Lite, JS, and language bindings. Source: [TensorFlow.](https://blog.tensorflow.org/2019/01/whats-coming-in-tensorflow-2-0.html)." fig-alt="Two-column diagram. Training: data preprocessing, tf.keras, TensorFlow Hub, Premade Estimators, Distribution Strategy across CPU/GPU/TPU. Deployment via SavedModel to TensorFlow Serving, Lite, JS, and language bindings."}
|
||
|
||
```{.tikz}
|
||
\begin{tikzpicture}[font=\usefont{T1}{phv}{m}{n}\small]
|
||
%
|
||
\tikzset{Line/.style={line width=1.0pt,black!50
|
||
},
|
||
Box/.style={align=flush center,
|
||
inner xsep=4pt,
|
||
node distance=0.8,
|
||
draw=BlueLine,
|
||
line width=0.75pt,
|
||
fill=BlueL,,
|
||
minimum height=11mm
|
||
},
|
||
}
|
||
|
||
\node[Box,text width=70mm,fill= BrownL,
|
||
draw= BrownLine](B1){\textbf{Read \& Preprocess Data}\\ tf.data, feature columns};
|
||
\node[Box,fill= BrownL,draw= BrownLine,below=of B1.south west,minimum width=20mm,
|
||
anchor=north west](B2){\textbf{tf.keras}};
|
||
\node[Box,fill= BrownL,draw= BrownLine,below=of B1.south east,,minimum width=20mm,
|
||
anchor=north east](B3){\textbf{Premade}\\\textbf{Estimators}};
|
||
\node[Box,fill= BrownL,draw= BrownLine,
|
||
minimum width=20mm](B4)at($(B2.east)!0.5!(B3.west)$){\textbf{TensorFlow}\\\textbf{Hub}};
|
||
%
|
||
\node[Box,text width=70mm,fill= BrownL,below=of B4,
|
||
draw= BrownLine](B5){\textbf{Distribution Strategy}};
|
||
\node[Box,fill= BrownL,draw= BrownLine,below=of B5.south west,minimum width=18mm,
|
||
anchor=north west](B6){\textbf{CPU}};
|
||
\node[Box,fill= BrownL,draw= BrownLine,below=of B5.south east,minimum width=18mm,
|
||
anchor=north east](B7){\textbf{TPU}};
|
||
\node[Box,fill= BrownL,draw= BrownLine,minimum width=18mm](B8)at($(B6.east)!0.5!(B7.west)$){\textbf{GPU}};
|
||
%
|
||
\node[Box,fill= BlueL,draw= BlueLine,right=1.0 of $(B1.east)!0.5!(B7.east)$](B9){\textbf{SavedModel}};
|
||
%
|
||
\def\di{4.35}
|
||
\node[Box,text width=50mm,fill= RedL,right=\di of B1,
|
||
draw= RedLine](L1){\textbf{TensorFlow Serving}\\ Cloud, on-prem};
|
||
\node[Box,text width=50mm,fill= RedL,right=\di of B3,
|
||
draw= RedLine](L2){\textbf{TensorFlow Lite}\\ Android, iOS, Raspberry Pi};
|
||
\node[Box,text width=50mm,fill= RedL,right=\di of B5,
|
||
draw= RedLine](L3){\textbf{TensorFlow.js}\\ Browser and Node Server};
|
||
\node[Box,text width=50mm,fill= RedL,right=\di of B7,
|
||
draw= RedLine](L4){\textbf{Other Language Bindings}\\ C, Java, Go, C\#, Rust, R,\ldots};
|
||
%
|
||
\node[above=2mm of B1]{\textbf{TRAINING}};
|
||
\node[above=2mm of L1]{\textbf{DEPLOYMENT}};
|
||
%
|
||
\draw[latex-,Line](B2)--(B1.south-|B2);
|
||
\draw[latex-,Line](B3)--(B1.south-|B3);
|
||
\draw[-latex,Line](B4)--(B2);
|
||
\draw[-latex,Line](B4)--(B3);
|
||
\draw[-latex,Line](B2)--(B5.north-|B2);
|
||
\draw[-latex,Line](B3)--(B5.north-|B3);
|
||
\draw[latex-,Line](B6)--(B5.south-|B6);
|
||
\draw[latex-,Line](B7)--(B5.south-|B7);
|
||
\draw[latex-,Line](B8)--(B5.south-|B8);
|
||
\draw[Line](B6)--++(270:1)-|(B7);
|
||
\draw[-latex,Line](B8)-++(270:1.35)-|(B9);
|
||
\foreach \x in {1,2,3,4}
|
||
\draw[-latex,Line](B9.east)--(L\x.west);
|
||
\end{tikzpicture}
|
||
```
|
||
|
||
:::
|
||
|
||
While TensorFlow 2.0 introduced eager execution to bridge the gap between research and production, its core strength remains the robust, compiled path from research to global-scale deployment. @sec-model-training examines how TensorFlow's distribution strategies enable large-scale training, while @sec-model-serving covers its production serving infrastructure.
|
||
|
||
### PyTorch: The Eager Research Standard {#sec-ml-frameworks-pytorch-85cd}
|
||
|
||
\index{PyTorch!eager execution default}
|
||
\index{PyTorch!research-first design}
|
||
Where TensorFlow's graph-first approach prioritizes production optimization, PyTorch makes the opposite trade-off: it prioritizes *developer experience*. PyTorch's architecture represents a sharply different answer to the Execution Problem, built on **Dynamic Graphs** (or "Define-by-Run"). Instead of building a blueprint before execution, PyTorch builds the computational graph on-the-fly as the code runs. Facebook AI Research (FAIR) adopted this design because researchers need immediate feedback when experimenting with novel architectures; the define-then-run cycle of static graphs introduced a compilation delay that slowed the rapid prototyping essential to research workflows.
|
||
|
||
PyTorch's approach won the broader research community for the same reason: it treats deep learning as standard Python programming. Developers can use Python loops, conditionals, and debuggers (like `pdb`) directly within a model's forward pass, with no special syntax, no separate compilation step, and no waiting to see if the code works. Eager execution enables rapid iteration and intuitive model design, which is essential for the trial-and-error nature of frontier AI research.
|
||
|
||
PyTorch's answer to the Differentiation Problem is the tape-based autograd system examined in @sec-ml-frameworks-pytorch-autograd-internals-4fa0: flexible and debuggable, but harder to optimize globally because the tape is rebuilt each iteration. Its answer to the Abstraction Problem is more pragmatic than comprehensive: strong GPU support through cuBLAS and cuDNN, but deployment to mobile, edge, and browser environments requires exporting through ONNX or specialized runtimes rather than a native path.
|
||
|
||
The trade-off is therefore a more fragmented deployment path. Because the graph is dynamic, the framework cannot easily perform global optimizations before execution. A model that works perfectly in development may hit performance walls in production when dispatch overhead dominates small operations. To bridge this research-to-production gap, PyTorch introduced **TorchScript** and **PyTorch 2.0 (with `torch.compile`)**, which allow developers to capture a dynamic model and turn it into an optimized, static representation for deployment. This evolution shows PyTorch moving toward the production end of the compilation continuum while preserving the eager experience that made it dominant in research.
|
||
|
||
### JAX: The Functional Transformation Engine {#sec-ml-frameworks-jax-functional-transformation-engine-242e}
|
||
|
||
\index{JAX!functional transformations}
|
||
\index{JAX!composable program transformations}
|
||
\index{JAX!pure functions requirement}
|
||
\index{JAX!transformation composition}
|
||
PyTorch's eager execution and TensorFlow's graph compilation represent two points on a spectrum, yet both share an imperative programming heritage where computation proceeds as a sequence of stateful operations. JAX represents a radically different approach, one built on functional programming principles and composable program transformations rather than computational graphs [@jax2018github]. Developed by Google Research, JAX has gained significant traction in research settings, particularly for work requiring custom differentiation, advanced optimization research, and large-scale distributed training.
|
||
|
||
JAX's architecture reframes the Differentiation Problem entirely. Google Research built JAX on a key observation: if functions are pure (no side effects, no mutable state), the compiler can safely reorder, fuse, and parallelize any operation, because outputs depend only on inputs. This constraint, borrowed from functional programming, is what makes JAX's composable transformations possible. Rather than implementing automatic differentiation as a tape-based system (PyTorch) or a graph transformation pass (TensorFlow), JAX treats differentiation as one of several *composable function transformations*. The `jax.grad` function does not compute gradients directly; it returns a *new function* that computes gradients. This subtle distinction enables arbitrary compositions: differentiating a differentiated function yields higher-order derivatives, vectorizing a gradient computation (`vmap(grad(f))`) parallelizes across examples, and compiling a vectorized gradient to XLA (`jit(vmap(grad(f)))`) eliminates Python overhead entirely.
|
||
|
||
JAX's functional paradigm requires a genuine mental shift from "tracking state through objects" to "transforming pure functions." The conceptual introduction here covers JAX's core design; transformation composition, pytree handling, and XLA tracing mechanics each warrant dedicated study for production use.
|
||
|
||
#### Transformations over State {#sec-ml-frameworks-transformations-state-3b46}
|
||
|
||
While PyTorch and TensorFlow build computational graphs (dynamically or statically), JAX transforms functions. The core insight is that automatic differentiation, vectorization, and JIT compilation are all *program transformations* that can compose. @lst-jax-transformations demonstrates this composable approach.
|
||
|
||
::: {#lst-jax-transformations lst-cap="**JAX Function Transformations**: JAX treats differentiation, vectorization, and compilation as composable function transformations rather than graph operations."}
|
||
|
||
```{.python}
|
||
import jax
|
||
import jax.numpy as jnp
|
||
|
||
|
||
def loss_fn(params, x, y):
|
||
pred = jnp.dot(x, params["w"]) + params["b"]
|
||
return jnp.mean((pred - y) ** 2)
|
||
|
||
|
||
# Transform: compute gradients
|
||
grad_fn = jax.grad(loss_fn)
|
||
|
||
# Transform: vectorize over batch dimension
|
||
batched_grad = jax.vmap(grad_fn, in_axes=(None, 0, 0))
|
||
|
||
# Transform: compile to XLA
|
||
fast_batched_grad = jax.jit(batched_grad)
|
||
|
||
# Compose all three: fast, batched gradient computation
|
||
```
|
||
|
||
:::
|
||
|
||
This functional approach requires **pure functions** (no side effects) and **immutable data** (arrays cannot be modified in place). These constraints may seem restrictive coming from PyTorch's mutable object model, but they enable formal guarantees: the compiler can safely reorder, fuse, and parallelize operations because function outputs depend only on inputs. The restriction is the feature; purity is what makes transformation composition possible.
|
||
|
||
#### Key Transformations {#sec-ml-frameworks-key-transformations-4105}
|
||
|
||
JAX's power emerges from composition. Start with a loss function `f` and apply `jax.grad`\index{JAX!gradient transformation} to obtain a new function that computes gradients---unlike PyTorch's tape-based autograd, `grad` returns a *function*, not a value, supporting both forward-mode (`jacfwd`) and reverse-mode (`jacrev`) differentiation. Wrap that gradient function in `jax.jit`\index{JAX!JIT compilation} and JAX traces it once, compiles to optimized XLA machine code, caches the result, and eliminates Python overhead on subsequent calls. Apply `jax.vmap`\index{JAX!automatic vectorization}\index{Vectorization!automatic batching} to the compiled gradient function and it automatically vectorizes across a batch dimension, transforming single-example code into batched code without manual reshaping. Finally, `jax.pmap`\index{JAX!device-parallel mapping} maps the vectorized, compiled gradient function across multiple GPUs or TPUs, automatically handling inter-device communication. The result---`pmap(jit(vmap(grad(f))))`---expresses distributed, compiled, batched gradient computation as a single composed expression. No other framework offers this level of compositional power.
|
||
|
||
#### Ecosystem and Libraries {#sec-ml-frameworks-ecosystem-libraries-5dae}
|
||
|
||
JAX's minimalist core delegates neural network abstractions to companion libraries (Flax, Haiku, Equinox) and optimization to Optax. This separation reflects the functional philosophy: the core provides transformations, while libraries build conventional abstractions on top. The ecosystem is younger and smaller than PyTorch's or TensorFlow's, which affects the availability of pre-built components for production use.
|
||
|
||
#### Trade-offs and Use Cases {#sec-ml-frameworks-tradeoffs-use-cases-6453}
|
||
|
||
\index{JAX!XLA compilation}
|
||
The functional constraints that JAX imposes become advantages in specific domains. Custom differentiation---higher-order gradients, custom VJP/JVP rules---composes cleanly because pure functions make differentiation rules predictable. Research on optimization algorithms benefits from transformations that let researchers manipulate gradient computation as naturally as they manipulate data. Large-scale distributed training, particularly on TPUs, uses XLA compilation to extract maximum hardware utilization. Scientific computing with AD requirements benefits from functional purity that enables mathematical reasoning about code. JAX requires more upfront investment than PyTorch: the functional paradigm has a learning curve, state management requires explicit patterns, and debugging compiled code is harder than eager execution. Teams should choose JAX when its strengths align with project requirements, not as a default.
|
||
|
||
### Quantitative Platform Performance Analysis {#sec-ml-frameworks-quantitative-platform-performance-analysis-816d}
|
||
|
||
The preceding sections described each framework's design philosophy in qualitative terms: graph-first versus eager-first, stateful versus functional. Design philosophy claims, however, are only meaningful when backed by measurement. @tbl-mlfm-comparison quantifies how the architectural choices of TensorFlow, PyTorch, and JAX translate to system characteristics. When examining this comparison, note particularly the differences in execution mode, compilation optimization potential, and distributed scalability---these dimensions most directly impact production deployment decisions.
|
||
|
||
| **Aspect** | **TensorFlow** | **PyTorch** | **JAX** |
|
||
|:------------------------------|:---------------------------------|:-----------------------|:---------------------------|
|
||
| **Graph Type** | Static (1.x), Dynamic (2.x) | Dynamic | Functional transformations |
|
||
| **Programming Model** | Imperative (2.x), Symbolic (1.x) | Imperative | Functional |
|
||
| **Core Data Structure** | Tensor (mutable) | Tensor (mutable) | Array (immutable) |
|
||
| **Execution Mode** | Eager (2.x default), Graph | Eager | Just-in-time compilation |
|
||
| **Automatic Differentiation** | Reverse mode | Reverse mode | Forward and Reverse mode |
|
||
| **Hardware Acceleration** | CPU, GPU, TPU | CPU, GPU | CPU, GPU, TPU |
|
||
| **Compilation Optimization** | XLA: 3--10$\times$ speedup | TorchScript: 2$\times$ | XLA: 3--10$\times$ speedup |
|
||
| **Memory Efficiency** | 70--90% (workload dependent) | 70--90% (varies) | 75--95% (with XLA fusion) |
|
||
| **Distributed Scalability** | High (1024+ GPUs) | High | Very High (1024+ GPUs) |
|
||
|
||
: **Framework Characteristics.** Each column reflects a distinct answer to the three core problems. TensorFlow's static graph roots enable XLA compilation but constrain dynamic control flow; PyTorch's eager default maximizes debugging flexibility but limits ahead-of-time optimization; JAX's functional model enables composable transformations and achieves the highest memory efficiency through XLA fusion. {#tbl-mlfm-comparison}
|
||
|
||
An important caveat applies to these numbers: GPU utilization and compilation speedups vary significantly by model architecture, batch size, and operation mix. JAX/XLA achieves higher utilization for TPU workloads through aggressive fusion, while PyTorch and TensorFlow perform similarly for most deep learning workloads. These framework-level generalizations provide useful orientation but cannot substitute for profiling specific workloads on target hardware.
|
||
|
||
How do these architectural differences look in practice? @lst-framework-hello-world implements the same neural network (a single linear layer mapping 10 inputs to 1 output) across all three frameworks, revealing how design philosophy shapes even the simplest code:
|
||
|
||
::: {#lst-framework-hello-world lst-cap="**Framework Comparison: Hello World**: The same simple neural network implemented in PyTorch (object-oriented), TensorFlow/Keras (declarative), and JAX (functional), illustrating each framework's distinct design philosophy."}
|
||
|
||
```{.python}
|
||
# PyTorch - Dynamic, Pythonic
|
||
import torch.nn as nn
|
||
|
||
|
||
class SimpleNet(nn.Module):
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.fc = nn.Linear(10, 1)
|
||
|
||
def forward(self, x):
|
||
return self.fc(x)
|
||
|
||
|
||
# TensorFlow/Keras - High-level API
|
||
import tensorflow as tf
|
||
|
||
model = tf.keras.Sequential(
|
||
[tf.keras.layers.Dense(1, input_shape=(10,))]
|
||
)
|
||
|
||
# JAX - Functional approach
|
||
import jax.numpy as jnp
|
||
from jax import random
|
||
|
||
|
||
def simple_net(params, x):
|
||
return jnp.dot(x, params["w"]) + params["b"]
|
||
|
||
|
||
key = random.PRNGKey(0)
|
||
params = {
|
||
"w": random.normal(key, (10, 1)),
|
||
"b": random.normal(key, (1,)),
|
||
}
|
||
```
|
||
|
||
:::
|
||
|
||
These three implementations solve the same mathematical problem but reveal distinct answers to the Three Problems. The differences are not cosmetic; they shape debugging workflows, deployment options, and optimization potential.
|
||
|
||
PyTorch binds state and computation together through class inheritance (`nn.Module`), solving the Execution Problem through eager evaluation: the graph builds as Python runs, making standard debuggers and control flow work naturally. The cost is that no optimizer sees the full computation before execution begins.
|
||
|
||
TensorFlow/Keras inverts this priority through the `Sequential` API, which declares structure without executing it, solving the Abstraction Problem first: the same declaration compiles to server GPUs, mobile NPUs, or browser WebGL backends. Eager mode (default in TensorFlow 2.x) recovers some of PyTorch's debugging flexibility, but production deployment still relies on graph capture for optimization.
|
||
|
||
JAX makes the most radical trade-off by treating the model as a pure function[^fn-pure-function-jax] with immutable data and no internal state. This functional purity solves the Differentiation Problem most elegantly: `grad`, `vmap` (automatic vectorization), and `jit` (just-in-time compilation[^fn-jit-compilation-tradeoff]) are composable transformations on stateless functions, not infrastructure bolted onto an object system. The cost is explicit parameter management and a programming model unfamiliar to most engineers.
|
||
|
||
[^fn-jit-compilation-tradeoff]: **Just-in-Time (JIT) Compilation**: Translates high-level code into optimized machine code at runtime, specializing for the actual data shapes and hardware present. The trade-off is compilation latency: the first execution pays a one-time cost (5--30 seconds for transformer models) while subsequent calls with the same shapes execute cached compiled code with microsecond dispatch overhead. Shape changes trigger recompilation, which is why dynamic sequence lengths in language models can degrade JIT performance unless the framework pads to fixed shape buckets. \index{JIT Compilation!shape specialization}
|
||
|
||
[^fn-pure-function-jax]: **Pure Function**: Has no side effects and always returns the same output for the same inputs. In JAX, purity is not a style preference but a compiler requirement: `jax.jit` traces the function once and caches the compiled result, so any side effect (printing, modifying global state, random number generation without explicit key threading) would execute only during the first trace and silently vanish from subsequent calls. This constraint is the cost JAX pays for composable, whole-program optimization. \index{Pure Function!JIT requirement}
|
||
|
||
[^fn-xla-compiler]: **XLA (Accelerated Linear Algebra)**: The "optimized machine code" in the triggering sentence means XLA fuses an entire subgraph into one kernel, eliminating both launch overhead ($L_{\text{lat}}$) and intermediate memory writes ($D_{\text{vol}}$). The 1.5--2$\times$ speedup for transformer blocks is modest because their large GEMM operations are already compute-bound, leaving little overhead for fusion to remove. Memory-bound models see 3--10$\times$ gains, where fusion hides the relative cost of many small, sequential operations. \index{XLA!compilation speedup}
|
||
|
||
[^fn-onnx-portability]: **ONNX (Open Neural Network Exchange)**: The "fragmentation" ONNX addresses is that the best training framework (often PyTorch for research velocity) rarely matches the best serving runtime (often TensorRT for latency, TF Lite for mobile). ONNX defines a hardware-agnostic graph representation that decouples the two, eliminating the engineer-months of manual model conversion that would otherwise be required each time a deployment target changes. The accepted trade-off is that ONNX export can lose framework-specific optimizations or custom operators, requiring fallback implementations. \index{ONNX!framework portability}
|
||
|
||
No framework optimizes all three problems simultaneously; each makes deliberate trade-offs that shape everything from API design to performance characteristics. PyTorch prioritizes the Execution Problem (eager debugging, dynamic graphs) at the cost of optimization potential. TensorFlow prioritizes the Abstraction Problem (unified deployment from cloud to microcontroller) at the cost of development flexibility. JAX reframes the Differentiation Problem (composable function transformations) at the cost of a steeper learning curve. These are the same design tensions examined in the subsections above, now visible even in a ten-line program. Exploratory research favors PyTorch's debugging immediacy, production deployment favors TensorFlow's optimization depth, and algorithmic research favors JAX's composable transformations. Each philosophy shapes code syntax, team workflows, debugging practices, and deployment pipelines, which is why framework migration costs are measured in engineer-months rather than engineer-days.
|
||
|
||
These design differences are not arbitrary; they reflect which term of the Iron Law each framework prioritizes. TensorFlow's graph compilation minimizes the *Overhead* term through ahead-of-time optimization, PyTorch's eager execution minimizes the *developer iteration* overhead at the cost of runtime optimization, and JAX's XLA backend minimizes the *Data Movement* term through aggressive operation fusion.
|
||
|
||
### Quantitative Framework Efficiency Comparison {#sec-ml-frameworks-quantitative-framework-efficiency-comparison-3b77}
|
||
|
||
How large are these differences in practice? @tbl-framework-efficiency-matrix compares major frameworks across efficiency dimensions using benchmark workloads representative of production deployment scenarios.
|
||
|
||
| **Framework** | **Inference** **Latency (ms)** | **Memory** **Usage (MB)** | **Energy** **(mJ/inference)** | **Model Size** **Reduction** | **Hardware** **Utilization (%)** |
|
||
|:--------------------------|-------------------------------:|--------------------------:|------------------------------:|-----------------------------:|---------------------------------:|
|
||
| **TensorFlow** | 45 | 2,100 | 850 | None | 35 |
|
||
| **TensorFlow Lite** | 12 | 180 | 120 | 4$\times$ (quantized) | 65 |
|
||
| **TensorFlow Lite Micro** | 8 | 32 | 45 | 8$\times$ (pruned+quant) | 75 |
|
||
| **PyTorch** | 52 | 1,800 | 920 | None | 32 |
|
||
| **PyTorch Mobile** | 18 | 220 | 180 | 3$\times$ (quantized) | 58 |
|
||
| **ONNX Runtime** | 15 | 340 | 210 | 2$\times$ (optimized) | 72 |
|
||
| **TensorRT** | 3 | 450 | 65 | 2$\times$ (precision opt) | 88 |
|
||
| **Apache TVM** | 6 | 280 | 95 | 3$\times$ (compiled) | 82 |
|
||
|
||
: **Framework Efficiency Comparison.** Quantitative comparison of major ML frameworks across efficiency dimensions using ResNet-50 inference on representative hardware (NVIDIA A100 GPU for server frameworks, ARM Cortex-A78 for mobile). Metrics reflect production workloads with accuracy maintained within 1% of baseline. Hardware utilization represents percentage of theoretical peak performance on typical operations. {#tbl-framework-efficiency-matrix}
|
||
|
||
\index{TensorRT!hardware utilization}
|
||
\index{Apache TVM!ML compiler}
|
||
The efficiency data reveals several important patterns. First, specialized inference frameworks (TensorRT, Apache TVM) achieve 10--15$\times$ lower latency than general-purpose training frameworks (PyTorch, TensorFlow) on identical hardware, demonstrating that framework selection has quantitative performance implications beyond qualitative design preferences. Second, mobile-optimized variants (TF Lite, PyTorch Mobile) reduce memory requirements by 10$\times$ compared to their full counterparts while maintaining accuracy within 1% through quantization and graph optimization. Third, hardware utilization varies dramatically: TensorRT achieves 88% GPU utilization through aggressive kernel fusion while vanilla PyTorch achieves only 32%, a 2.75$\times$ efficiency gap that directly translates to cost differences in production deployment.
|
||
|
||
These efficiency gaps, significant in the data center, become existential as we move beyond the server room. A 17$\times$ latency difference between PyTorch and TensorRT is an optimization opportunity on a cloud GPU; on a microcontroller with 256 KB of RAM, a framework that requires 1.8 GB of memory simply cannot run at all. The question shifts from "which framework is fastest?" to "which framework fits?"
|
||
|
||
## Deployment Targets {#sec-ml-frameworks-deployment-targets-13f1}
|
||
|
||
\index{Deployment Targets!cloud to edge spectrum}
|
||
\index{Edge Deployment!framework selection}
|
||
\index{Edge Deployment!three problems reweighting}
|
||
As ML models move from cloud servers to edge devices, the efficiency gaps measured above transform from optimization opportunities into hard deployment constraints. The three core problems reweight dramatically at the edge. The *execution problem* shifts from "eager vs. graph" to "can we execute at all within 10 ms and 50 KB?" The *differentiation problem* often disappears entirely, since edge devices run inference only. The *abstraction problem* intensifies: targeting ARM vs. x86, mobile NPUs vs. edge TPUs, microcontrollers with kilobytes of memory.
|
||
|
||
@tbl-deployment-frameworks summarizes framework choices by deployment target:
|
||
|
||
| **Environment** | **Primary Frameworks** | **Key Optimizations** | **Typical Constraints** |
|
||
|:--------------------|:---------------------------------|:--------------------------------------|:----------------------------------|
|
||
| **Cloud/Server** | PyTorch, TensorFlow, JAX | Distributed training, mixed precision | Throughput, cost |
|
||
| **Edge** | TensorFlow Lite, ONNX Runtime | Quantization (INT8), static graphs | Latency <10 ms, limited memory |
|
||
| **Mobile** | TF Lite, Core ML, PyTorch Mobile | NPU acceleration, model compression | Battery, thermal, app size limits |
|
||
| **Microcontroller** | TF Lite Micro, | 4-bit quantization, | <256 KB RAM, |
|
||
| **(TinyML)** | uTensor | static allocation | no dynamic memory |
|
||
|
||
: **Framework Selection by Deployment Target.** Recommended frameworks, optimization techniques, and key constraints for each deployment tier, from cloud servers to microcontrollers. {#tbl-deployment-frameworks}
|
||
|
||
@tbl-deployment-frameworks reveals a fragmented landscape: different deployment targets favor different frameworks. The Smart Doorbell KWS model from @sec-ml-frameworks-tinyml-micro-runtimes-2a1b exemplifies the Microcontroller tier, where TF Lite Micro's extreme AOT compilation is the only viable path. This fragmentation creates a practical problem when organizations train in one framework but deploy on a target best served by another.
|
||
|
||
\index{ONNX!definition}
|
||
\index{ONNX!hub-and-spoke interoperability}
|
||
The Open Neural Network Exchange (ONNX)\index{ONNX!cross-framework portability}[^fn-onnx-portability] format addresses this fragmentation by enabling model portability across frameworks: train in PyTorch, deploy via TensorFlow Lite or ONNX Runtime. This standardization eliminates manual conversion when moving between development and production environments. @fig-onnx captures this hub-and-spoke interoperability model---notice how ONNX sits at the center, accepting models from any training framework on the left and dispatching them to specialized runtimes on the right. Detailed deployment optimization (quantization, pruning, hardware-specific compilation) appears in @sec-model-compression and @sec-model-serving.
|
||
|
||
::: {#fig-onnx fig-env="figure" fig-pos="htb" fig-cap="**Framework Interoperability**: ONNX enables model portability across frameworks, allowing training in one framework and deployment in another." fig-alt="Hub diagram with ONNX logo at center. Left side: PyTorch, TensorFlow, Keras with arrows pointing inward. Right side: TF Lite, ONNX Runtime with arrows outward."}
|
||
|
||
{width="70%"}
|
||
|
||
:::
|
||
|
||
ONNX reduces the cost of framework fragmentation, but it does not eliminate the initial selection decision. With the deployment landscape mapped and interoperability options understood, we can now address the practical question: given a specific project's requirements, the question becomes how an engineer should select a framework.
|
||
|
||
## Framework Selection {#sec-ml-frameworks-selecting-framework-2949}
|
||
|
||
\index{Framework Selection!trade-off analysis}
|
||
\index{Framework Selection!constrained optimization}
|
||
Framework selection is a constrained optimization problem across technical capabilities, operational requirements, and organizational factors. No single framework dominates across all criteria, which means the goal is not to find the "best" framework but to find the one whose trade-offs align with the project's constraints.
|
||
|
||
### The Framework Selection Trade-off Space {#sec-ml-frameworks-framework-selection-tradeoff-space-c76e}
|
||
|
||
Framework selection involves three interconnected tensions. The first is between development velocity and production performance: eager execution (PyTorch) prioritizes iteration speed, while graph compilation (TensorFlow/XLA, JAX/JIT) prioritizes runtime optimization. Research teams that need to test ten architecture variants per day cannot afford minutes of compilation between experiments; production teams that deploy a single model for months cannot afford the throughput penalty of eager dispatch. The optimal point shifts as a project moves through its lifecycle.
|
||
|
||
This velocity-performance tension leads directly to the second: flexibility versus optimization depth. Dynamic graphs enable the arbitrary control flow that makes eager development fast, but they limit the compiler's scope. Static graphs constrain expressiveness but enable aggressive fusion and hardware-specific code generation. As @tbl-mlfm-graphs demonstrated, this trade-off cascades through memory management, utilization, and debugging workflows---it is not a single design decision but a system-wide constraint.
|
||
|
||
The flexibility-optimization tension, in turn, exposes a third: ecosystem breadth versus specialization. General-purpose frameworks cover broad operation sets but underperform specialized runtimes. TensorRT achieves 88% GPU utilization versus PyTorch's 32% (@tbl-framework-efficiency-matrix) precisely because it optimizes for a narrower problem. ONNX bridges this gap through standardized interchange, but the underlying trade-off remains: the more a runtime specializes, the faster it runs and the less it supports.
|
||
|
||
::: {.callout-perspective title="Framework Selection Constraints"}
|
||
|
||
Rather than seeking the "best" framework, effective selection identifies the framework that satisfies hard constraints (deployment target, required operations, team expertise) while optimizing soft preferences (performance, development speed, ecosystem). Hard constraints eliminate options; soft preferences rank remaining candidates.
|
||
|
||
:::
|
||
|
||
The TensorFlow ecosystem illustrates how these axes interact concretely. Its three variants (TensorFlow\index{TensorFlow!full framework}, TensorFlow Lite\index{TensorFlow Lite!mobile deployment}, TensorFlow Lite Micro\index{TensorFlow Lite Micro!microcontroller deployment}) trace a single design philosophy across progressively tighter constraints, a pattern that generalizes to any framework family. @tbl-tf-comparison quantifies the trade-offs.
|
||
|
||
| | **TensorFlow** | **TensorFlow Lite** | **TensorFlow Lite for Microcontrollers** |
|
||
|:--------------------------------|:--------------------------------|:----------------------|:-----------------------------------------|
|
||
| **Training** | Yes | No | No |
|
||
| **Inference** | Yes (*but inefficient on edge*) | Yes (*and efficient*) | Yes (*and even more efficient*) |
|
||
| **How Many Ops** | ~1400 | ~130 | ~50 |
|
||
| **Native Quantization Tooling** | No | Yes | Yes |
|
||
|
||
: **TensorFlow Variant Software Comparison.** Design trade-offs across TensorFlow, TensorFlow Lite, and TensorFlow Lite Micro, balancing model expressiveness, binary size, and resource constraints. Supported operations decrease from approximately 1,400 in full TensorFlow to approximately 50 in TensorFlow Lite Micro, reflecting a shift from training capability to efficient edge inference. Native quantization tooling enables further optimization for constrained environments. {#tbl-tf-comparison}
|
||
|
||
The principle is progressive constraint leading to progressive optimization: fewer supported operations enable smaller binaries, tighter memory budgets, and native quantization. Three dimensions structure this analysis: model requirements (what operations must the framework support?), software dependencies (what runtime environment is available?), and hardware constraints (what are the physical limits?).
|
||
|
||
### Framework Selection Criteria {#sec-ml-frameworks-model-requirements-2e01}
|
||
|
||
Three dimensions structure systematic framework evaluation: what the model requires (supported operations and graph semantics), what the software environment provides (OS, memory management, accelerator delegation), and what the hardware physically permits (compute, memory, power). Each dimension acts as a filter—hard constraints eliminate candidates, and soft preferences rank the survivors.
|
||
|
||
#### Model Requirements {#sec-ml-frameworks-model-requirements-ca8b}
|
||
|
||
The first question is whether a framework can express the models a project requires. Examine @tbl-tf-comparison: notice how operator count drops from approximately $10^3$ (full TensorFlow) to $10^2$ (TensorFlow Lite) to $10^1$ (TensorFlow Lite Micro). Each reduction eliminates training capability and general-purpose operations while adding native quantization tooling. The engineering principle is that expressiveness and efficiency trade against each other: fewer supported operations enable tighter code generation, smaller binaries, and hardware-specific optimization paths. This progressive constraint model applies to any framework family, not just TensorFlow. The choice between *dynamic and static computational graphs* further shapes which optimizations each constraint level permits.
|
||
|
||
::: {.callout-perspective title="Dynamic vs Static Computational Graphs"}
|
||
|
||
The static-versus-dynamic graph distinction (examined in @sec-ml-frameworks-execution-problem-e1e1) has direct implications for model requirements analysis. Static graphs constrain which operations are expressible but enable ahead-of-time optimization for deployment. Dynamic graphs support arbitrary Python control flow but require explicit compilation steps (e.g., `torch.compile`, `tf.function`) to recover optimization potential.
|
||
|
||
:::
|
||
|
||
#### Software Dependencies {#sec-ml-frameworks-software-dependencies-5245}
|
||
|
||
Once model requirements are satisfied, the framework must integrate with the target software environment. @tbl-tf-sw-comparison reveals how operating system requirements, memory management, and accelerator support vary across TensorFlow variants.
|
||
|
||
| | **TensorFlow** | **TensorFlow Lite** | **TensorFlow Lite for Microcontrollers** |
|
||
|:-------------------------------|:---------------|:--------------------|:-----------------------------------------|
|
||
| **Needs an OS** | Yes | Yes | No |
|
||
| **Memory Mapping of Models** | No | Yes | Yes |
|
||
| **Delegation to accelerators** | Yes | Yes | No |
|
||
|
||
: **TensorFlow Variant Capability Comparison.** Capabilities of TensorFlow, TensorFlow Lite, and TensorFlow Lite Micro regarding operating system dependence, memory management, and hardware acceleration. Progressive constraint across variants enables selection by deployment context, from full-scale servers to resource-constrained edge devices. {#tbl-tf-sw-comparison}
|
||
|
||
The key distinctions follow the same progressive constraint pattern. TensorFlow Lite Micro eliminates the OS requirement entirely, enabling bare-metal execution on microcontrollers (though it integrates with RTOSes like FreeRTOS and Zephyr when available). Both Lite variants support memory-mapped model access from flash storage, avoiding the RAM overhead of loading full models. Accelerator delegation drops out at the microcontroller tier, where specialized hardware is rarely available. Each software dependency removed is a deployment target gained.
|
||
|
||
#### Hardware Constraints {#sec-ml-frameworks-hardware-constraints-e449}
|
||
|
||
Software compatibility alone does not guarantee deployment; the framework must fit within physical hardware limits. @tbl-tf-hw-comparison quantifies this final constraint dimension.
|
||
|
||
| | **TensorFlow** | **TensorFlow Lite** | **TensorFlow Lite for Microcontrollers** |
|
||
|:----------------------------|:------------------------------------------------------|:-----------------------|:-----------------------------------------|
|
||
| **Base Binary Size** | A few MB (varies by platform and build configuration) | Tens to hundreds of KB | On the order of 10 KB |
|
||
| **Base Memory Footprint** | Several MB (minimum runtime overhead) | Hundreds of KB | Tens of KB |
|
||
| **Optimized Architectures** | X86, TPUs, GPUs | Arm Cortex A, x86 | Arm Cortex M, DSPs, MCUs |
|
||
|
||
: **TensorFlow Hardware Optimization.** Resource requirements (binary size and memory footprint) decrease across TensorFlow variants as they target increasingly constrained hardware, from servers to microcontrollers. Optimized architectures shift from general-purpose CPUs and GPUs to ARM Cortex-M processors and digital signal processors for resource-limited environments. {#tbl-tf-hw-comparison}
|
||
|
||
Binary size spans three orders of magnitude: from MB (full TensorFlow) to tens of KB (TensorFlow Lite Micro). Memory footprint follows the same pattern. Processor architecture support shifts correspondingly from x86/GPU/TPU (data center) through Arm Cortex-A (mobile/edge) to Arm Cortex-M, DSPs, and MCUs (embedded). These are not arbitrary engineering tiers---they mirror the physical constraints (Light Barrier, Power Wall, Memory Wall) that carve the deployment spectrum into distinct paradigms (@sec-ml-systems-deployment-spectrum-71be). The engineering lesson generalizes beyond TensorFlow: every framework family that spans deployment tiers makes analogous trade-offs between capability and resource footprint, and the framework's job is to make those trade-offs navigable rather than invisible.
|
||
|
||
#### Production-Ready Evaluation Factors {#sec-ml-frameworks-productionready-evaluation-factors-e088}
|
||
|
||
The engineering principle underlying production evaluation is that expressiveness and efficiency trade against each other: fewer supported operations enable tighter code generation, smaller binaries, and hardware-specific optimization paths. Technical specifications establish necessary but not sufficient conditions for selection. Production deployments also require evaluating operational factors: migration cost (typically 3--6 engineer-months for production systems), maintenance burden, and deployment success rates.
|
||
|
||
These hardware constraints cascade into performance trade-offs that are tightly coupled. Inference latency (tens of milliseconds for mobile image classification, sub-millisecond for industrial control), memory footprint (MB for full TensorFlow down to tens of KB for TF Lite Micro), power consumption (INT8 inference consuming several-fold less energy than FP32), and hardware utilization (operator fusion improving FLOPS utilization from 10--20% to 60--80% of peak) are not independent dimensions. Quantization simultaneously reduces memory, latency, and energy at the cost of precision, and framework selection determines which of these optimization levers are available in the first place. Scalability introduces a further concern: consistent deployment from microcontrollers to servers, smooth prototype-to-production transitions, and version management across deployed fleets all depend on the framework's deployment toolchain. The three-dimension methodology illustrated here---model requirements, software dependencies, hardware constraints---applies to any framework ecosystem, not just TensorFlow.
|
||
|
||
### Development Support and Long-term Viability Assessment {#sec-ml-frameworks-development-support-longterm-viability-assessment-d1d7}
|
||
|
||
What determines whether a framework remains viable five years into a production deployment? Technical capabilities are necessary but not sufficient. Community composition shapes framework evolution in measurable ways: PyTorch's academic community drives research-oriented features and reproducibility tools, though production tooling (PyTorch Lightning, TorchServe) has historically lagged; TensorFlow's enterprise community emphasizes production reliability through TFX pipelines, TensorBoard visualization, and TensorFlow Model Analysis; JAX's smaller community concentrates on mathematical rigor, producing specialized research tools (composable transformations, custom VJP rules) but with a steeper onboarding curve.
|
||
|
||
A framework's practical utility, however, often depends more on its surrounding ecosystem than on its core capabilities. Hugging Face provides consistent model APIs across all three major frameworks, making pretrained model availability a near-commodity. Cross-framework tools (Weights & Biases, MLflow for experiment tracking; ONNX Runtime for serving) reduce lock-in, while framework-native tools (XLA, TorchScript, TensorFlow Serving) offer deeper optimization at the cost of portability. Cloud ML services (SageMaker, Google AI Platform, Azure ML) provide native integration for specific frameworks, creating operational advantages that compound over time.
|
||
|
||
These compounding effects make framework migration progressively harder. Integration with existing CI/CD pipelines, monitoring infrastructure, and cloud providers creates operational inertia that resists change. The measurable indicators of viability---contributor diversity (single-company dependence is a risk), backward compatibility track record, and hiring pool alignment with organizational needs---should be evaluated before commitment, not after. The mitigation strategy is defensive: use standardized formats (ONNX), maintain framework-agnostic data pipelines, and document framework-specific customizations to preserve future flexibility.
|
||
|
||
We have now examined the three core problems individually, compared how major frameworks resolve them, and established criteria for choosing among alternatives. What remains is to see all three problems interact inside a single execution---to watch the machinery we have studied operate as one integrated system.
|
||
|
||
## Anatomy of a Training Step {#sec-ml-frameworks-putting-together-anatomy-training-step-c7f1}
|
||
|
||
\index{Training Step!anatomy}
|
||
\index{Forward Pass!kernel dispatch}
|
||
\index{Backward Pass!gradient computation}
|
||
The concepts developed throughout this chapter---eager versus graph execution, reverse-mode autodiff, tensor abstractions, kernel dispatch---remain abstract until we see them interact inside a real execution. To solidify understanding, we trace a single training step through the PyTorch stack, revealing how eight lines of Python trigger the execution, differentiation, and abstraction machinery simultaneously.
|
||
|
||
@lst-training-step-anatomy presents a minimal training iteration for a two-layer multilayer perceptron. Though only eight lines of Python, this code exercises the entire framework stack: tensor allocation, kernel dispatch, autograd recording, gradient computation, and parameter updates. Tracing each phase reveals the three problems in action and connects the quantitative principles developed earlier to concrete execution.
|
||
|
||
::: {#lst-training-step-anatomy lst-cap="**Training Step Anatomy**: A minimal training iteration for a two-layer MLP, exercising tensor allocation, kernel dispatch, autograd recording, gradient computation, and parameter updates."}
|
||
|
||
```{.python}
|
||
# Single training step for a 2-layer MLP
|
||
x = torch.randn(32, 784, device="cuda") # Input batch
|
||
y = torch.randint(0, 10, (32,), device="cuda") # Labels
|
||
|
||
# Forward pass
|
||
h = torch.relu(x @ W1 + b1) # Hidden layer
|
||
logits = h @ W2 + b2 # Output layer
|
||
loss = F.cross_entropy(logits, y)
|
||
|
||
# Backward pass
|
||
loss.backward()
|
||
|
||
# Parameter update
|
||
optimizer.step()
|
||
```
|
||
|
||
:::
|
||
|
||
```{python}
|
||
#| echo: false
|
||
#| label: training-step-dims
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ TRAINING STEP DIMENSIONS
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: Phase 1 prose references matrix dimensions inline
|
||
# │
|
||
# │ Goal: Define model dimensions for the training step example.
|
||
# │ Show: The matrix sizes for a 2-layer MLP (784 -> 1024 -> 10).
|
||
# │ How: Set constants for input batch, hidden layers, and output classes.
|
||
# │
|
||
# │ Imports: (none)
|
||
# │ Exports: train_batch, train_input, train_hidden, train_output
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
|
||
# ┌── LEGO ───────────────────────────────────────────────
|
||
class TrainingStepDims:
|
||
"""Model dimensions for the two-layer MLP training step example."""
|
||
|
||
# ┌── 1. LOAD (Constants) ──────────────────────────────────────────────
|
||
train_batch = 32 # batch size
|
||
train_input = 784 # MNIST input (28 × 28)
|
||
train_hidden = 256 # hidden layer
|
||
train_output = 10 # 10 classes
|
||
```
|
||
|
||
### Phase 1: Forward Pass (Solving the Execution Problem) {.unnumbered}
|
||
|
||
When `h = torch.relu(x @ W1 + b1)` executes, PyTorch's eager execution triggers immediate computation:
|
||
|
||
1. **Python Dispatch**\index{Python Dispatch!overhead} (~1μs): Python interpreter calls `torch.matmul`, which routes through PyTorch's dispatcher to select the CUDA backend.
|
||
|
||
2. **Kernel Selection**\index{BLAS Library!kernel selection} (~0.5μs): cuBLAS selects an optimized GEMM kernel based on matrix dimensions (`{python} TrainingStepDims.train_batch`$\times$ `{python} TrainingStepDims.train_input`$\times$ `{python} TrainingStepDims.train_input`$\times$ `{python} TrainingStepDims.train_hidden`). For these dimensions, it might choose a tiled algorithm optimized for L2 cache.
|
||
|
||
3. **Kernel Launch** (~5μs): The selected kernel is queued to the GPU's command buffer. The CPU continues immediately (asynchronous execution).
|
||
|
||
4. **GPU Execution** (~15μs):
|
||
- Load W1 from HBM[^fn-hbm-capacity-constraint] to L2 cache (~200 GB/s effective bandwidth)
|
||
- Perform matrix multiply in tensor cores (if available)
|
||
- Write result to HBM
|
||
|
||
[^fn-hbm-capacity-constraint]: **HBM (High Bandwidth Memory)**: Provides 2--3 TB/s bandwidth on modern GPUs (introduced in @sec-network-architectures). HBM bandwidth determines whether operations are memory-bound or compute-bound, and its `{python} DeviceBandwidthHierarchy.a100_mem_str` GB capacity on an A100 sets the hard ceiling on model size: weights, activations, gradients, and optimizer state must all fit simultaneously during training. When they do not, the framework must resort to offloading, checkpointing, or model parallelism, each adding complexity to what the programmer perceives as a single `loss.backward()` call. \index{HBM!capacity constraint}
|
||
|
||
5. **Autograd Recording**: Simultaneously, PyTorch's autograd engine records a `MmBackward` node on the tape, storing references to `x` and `W1` for gradient computation.
|
||
|
||
The bias addition and ReLU follow similar patterns, each adding a node to the autograd tape.
|
||
|
||
### Phase 2: Backward Pass (Solving the Differentiation Problem) {.unnumbered}
|
||
|
||
When `loss.backward()`\index{Backward Pass!triggering gradient computation} executes:
|
||
|
||
1. **Tape Traversal**\index{Autograd Tape!reverse traversal}: The autograd engine traverses the recorded graph in reverse topological order.
|
||
|
||
2. **Gradient Computation**: For each node, it calls the registered backward function:
|
||
- `CrossEntropyBackward`: Computes $\frac{\partial \mathcal{L}}{\partial \text{logits}}$ using softmax derivative
|
||
- `MmBackward` (W2): Computes $\frac{\partial \mathcal{L}}{\partial W_2} = h^T \cdot \frac{\partial \mathcal{L}}{\partial \text{logits}}$ and $\frac{\partial \mathcal{L}}{\partial h}$
|
||
- `ReluBackward`: Applies ReLU derivative mask (zero where h ≤ 0)
|
||
- `MmBackward` (W1): Computes $\frac{\partial \mathcal{L}}{\partial W_1}$ and $\frac{\partial \mathcal{L}}{\partial x}$
|
||
|
||
3. **Gradient Accumulation**: Gradients are accumulated into `.grad` attributes of leaf tensors.
|
||
|
||
4. **Memory Management**: After each backward node completes, its saved tensors are freed, allowing memory reuse.
|
||
|
||
### Phase 3: Memory Traffic Analysis (The Physics at Work) {.unnumbered}
|
||
|
||
Applying the Dispatch Overhead Equation (@eq-dispatch-overhead) to this step, @tbl-training-step-roofline breaks down the FLOPs, memory traffic, and arithmetic intensity for each operation:
|
||
|
||
```{python}
|
||
#| label: training-step-calc
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ TRAINING STEP FLOPS AND MEMORY ANALYSIS
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: Training Step Anatomy table showing FLOPs and memory per operation
|
||
# │
|
||
# │ Goal: Quantify the arithmetic intensity of each operation in a forward pass.
|
||
# │ MatMul has high AI (~15), ReLU has near-zero AI (0.125)—demonstrating
|
||
# │ why element-wise ops are memory-bound and benefit most from fusion.
|
||
# │
|
||
# │ Imports: mlsysim.core.constants (byte, KB, MB, flop, KFLOPs, MFLOPs), mlsysim.book (fmt)
|
||
# │ Exports: train_mm1_flops_str, train_mm1_mem_str, train_mm1_ai_str,
|
||
# │ train_relu_flops_str, train_relu_mem_str, train_relu_ai_str,
|
||
# │ train_mm2_flops_str, train_mm2_mem_str, train_mm2_ai_str, train_ce_flops_str,
|
||
# │ train_ce_mem_str, train_ce_ai_str, train_bwd_flops_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsysim.fmt import fmt, check
|
||
from mlsysim.core.constants import byte, KB, MB, flop, KFLOPs, MFLOPs
|
||
|
||
# ┌── LEGO ───────────────────────────────────────────────
|
||
class TrainingStepCalc:
|
||
"""FLOPs and memory analysis for each operation in a two-layer MLP training step."""
|
||
|
||
# ┌── 1. LOAD (Constants) ──────────────────────────────────────────────
|
||
# Dimensions from TrainingStepDims (class attributes)
|
||
_batch = TrainingStepDims.train_batch
|
||
_input = TrainingStepDims.train_input
|
||
_hidden = TrainingStepDims.train_hidden
|
||
_output = TrainingStepDims.train_output
|
||
|
||
# ┌── 2. EXECUTE (The Compute) ────────────────────────────────────────
|
||
# Step 1: MatMul 1: [32, 784] @ [784, 256]
|
||
mm1_flops = 2 * _batch * _input * _hidden # ~12.8M FLOPs
|
||
mm1_mem_bytes = (_batch * _input + _input * _hidden + _batch * _hidden) * 4 # FP32
|
||
|
||
# Step 2: ReLU: [32, 256]
|
||
relu_flops = _batch * _hidden # ~8K FLOPs
|
||
relu_mem_bytes = (_batch * _hidden) * 2 * 4 # read + write
|
||
|
||
# Step 3: MatMul 2: [32, 256] @ [256, 10]
|
||
mm2_flops = 2 * _batch * _hidden * _output # ~164K FLOPs
|
||
mm2_mem_bytes = (_batch * _hidden + _hidden * _output + _batch * _output) * 4
|
||
|
||
# Step 4: Cross Entropy: [32, 10]
|
||
ce_flops = _batch * _output * 3 # ~1K FLOPs (exp + sum + div)
|
||
ce_mem_bytes = (_batch * _output) * 2 * 4
|
||
|
||
# Step 5: Backward pass (2x forward FLOPs)
|
||
bwd_flops = 2 * (mm1_flops + relu_flops + mm2_flops + ce_flops)
|
||
bwd_mem_str = "~3.2 MB" # approximation
|
||
bwd_ai_str = "~8.0" # approximation
|
||
|
||
# ┌── 3. GUARD (Invariants) ──────────────────────────────────────────
|
||
check(mm1_flops / mm1_mem_bytes > relu_flops / relu_mem_bytes,
|
||
"MatMul must have higher arithmetic intensity than ReLU.")
|
||
|
||
# ┌── 4. OUTPUT (Formatting) ─────────────────────────────────────────────
|
||
train_mm1_flops_str = fmt((mm1_flops * flop).m_as(MFLOPs), precision=1, commas=False) + "M"
|
||
train_mm1_mem_str = fmt((mm1_mem_bytes * byte).m_as(MB), precision=1, commas=False) + " MB"
|
||
train_mm1_ai_str = fmt(mm1_flops / mm1_mem_bytes, precision=1, commas=False)
|
||
|
||
train_relu_flops_str = fmt((relu_flops * flop).m_as(KFLOPs), precision=0, commas=False) + "K"
|
||
train_relu_mem_str = fmt((relu_mem_bytes * byte).m_as(KB), precision=0, commas=False) + " KB"
|
||
train_relu_ai_str = fmt(relu_flops / relu_mem_bytes, precision=3, commas=False)
|
||
|
||
train_mm2_flops_str = fmt((mm2_flops * flop).m_as(KFLOPs), precision=0, commas=False) + "K"
|
||
train_mm2_mem_str = fmt((mm2_mem_bytes * byte).m_as(KB), precision=0, commas=False) + " KB"
|
||
train_mm2_ai_str = fmt(mm2_flops / mm2_mem_bytes, precision=1, commas=False)
|
||
|
||
train_ce_flops_str = fmt((ce_flops * flop).m_as(KFLOPs), precision=0, commas=False) + "K"
|
||
train_ce_mem_str = fmt((ce_mem_bytes * byte).m_as(KB), precision=0, commas=False) + " KB"
|
||
train_ce_ai_str = fmt(ce_flops / ce_mem_bytes, precision=1, commas=False)
|
||
|
||
train_bwd_flops_str = fmt((bwd_flops * flop).m_as(MFLOPs), precision=0, commas=False) + "M"
|
||
train_bwd_mem_str = bwd_mem_str
|
||
train_bwd_ai_str = bwd_ai_str
|
||
```
|
||
|
||
| **Component** | **FLOPs** | **Memory Traffic** | **Arithmetic Intensity** |
|
||
|:---------------------------------|--------------------------------------------------------------------------------------:|-----------------------------------------------:|----------------------------------------------:|
|
||
| **MatMul (x @ W1)** | 2$\times$ 32$\times$ $784\times256$ = `{python} TrainingStepCalc.train_mm1_flops_str` | `{python} TrainingStepCalc.train_mm1_mem_str` | `{python} TrainingStepCalc.train_mm1_ai_str` |
|
||
| **ReLU** | $32\times256$ = `{python} TrainingStepCalc.train_relu_flops_str` | `{python} TrainingStepCalc.train_relu_mem_str` | `{python} TrainingStepCalc.train_relu_ai_str` |
|
||
| **MatMul (h @ W2)** | 2$\times$ 32$\times$ $256\times10$ = `{python} TrainingStepCalc.train_mm2_flops_str` | `{python} TrainingStepCalc.train_mm2_mem_str` | `{python} TrainingStepCalc.train_mm2_ai_str` |
|
||
| **Cross-entropy** | ~`{python} TrainingStepCalc.train_ce_flops_str` | `{python} TrainingStepCalc.train_ce_mem_str` | `{python} TrainingStepCalc.train_ce_ai_str` |
|
||
| **Backward (2$\times$ forward)** | ~`{python} TrainingStepCalc.train_bwd_flops_str` | `{python} TrainingStepCalc.train_bwd_mem_str` | `{python} TrainingStepCalc.train_bwd_ai_str` |
|
||
|
||
: **Per-Operation Roofline Analysis.** FLOPs, memory traffic, and arithmetic intensity for each operation in a two-layer MLP training step. MatMul operations achieve arithmetic intensity above 1.0 (compute-bound on most hardware), while ReLU and cross-entropy are far below 1.0 (memory-bound), quantifying why kernel fusion targets element-wise operations. {#tbl-training-step-roofline}
|
||
|
||
```{python}
|
||
#| label: mnist-training-step-calc
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ MNIST TRAINING STEP ROOFLINE ANALYSIS
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: Showing that small model training is overhead-bound, not compute-bound
|
||
# │
|
||
# │ Goal: On an A100, 40M FLOPs takes 0.1 μs and 5 MB memory takes 2.5 μs, but
|
||
# │ 6 ops × 5 μs dispatch = 30 μs overhead. The step is overhead-bound!
|
||
# │ This explains why torch.compile provides 2-3x speedup for small models.
|
||
# │
|
||
# │ Imports: mlsysim.core.constants (A100_FLOPS_FP16_TENSOR, A100_MEM_BW), mlsysim.book (fmt)
|
||
# │ Exports: mnist_total_flops_str, mnist_mem_traffic_str, a100_peak_tflops_str,
|
||
# │ a100_mem_bw_tbs_str, mnist_t_compute_us_str, mnist_t_memory_us_str,
|
||
# │ mnist_n_ops_str, mnist_us_per_op_str, mnist_t_overhead_us_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsysim.core.constants import A100_FLOPS_FP16_TENSOR, A100_MEM_BW, TRILLION, flop, byte, second
|
||
from mlsysim.fmt import fmt, check
|
||
|
||
class MnistTrainingStepCalc:
|
||
"""MNIST overhead-bound analysis: dispatch dominates compute and memory on A100."""
|
||
# ┌── 1. LOAD (Constants) ──────────────────────────────────────────────
|
||
total_flops = 40e6 # ~40M FLOPs total
|
||
mem_traffic_bytes = 5e6 # ~5 MB traffic
|
||
n_ops = 6 # 6 kernel launches
|
||
us_per_op = 5 # 5 μs dispatch/op
|
||
peak_flops = A100_FLOPS_FP16_TENSOR.m_as(flop/second) # 312e12 FLOPS
|
||
mem_bw_bytes = A100_MEM_BW.m_as(byte/second) # ~2e12 B/s
|
||
# ┌── 2. EXECUTE (The Compute) ────────────────────────────────────────
|
||
t_compute_us = total_flops / peak_flops * MILLION # ~0.1 μs
|
||
t_memory_us = mem_traffic_bytes / mem_bw_bytes * MILLION # ~2.5 μs
|
||
t_overhead_us = n_ops * us_per_op # 30 μs (dominant!)
|
||
# ┌── 4. OUTPUT (Formatting) ─────────────────────────────────────────────
|
||
mnist_t_compute_us_str = fmt(t_compute_us, precision=1, commas=False)
|
||
mnist_t_memory_us_str = fmt(t_memory_us, precision=1, commas=False)
|
||
mnist_t_overhead_us_str = fmt(t_overhead_us, precision=0, commas=False)
|
||
mnist_total_flops_str = f"{total_flops/MILLION:.0f}M"
|
||
mnist_mem_traffic_str = f"{mem_traffic_bytes/MILLION:.0f}MB"
|
||
mnist_n_ops_str = fmt(n_ops, precision=0, commas=False)
|
||
mnist_us_per_op_str = fmt(us_per_op, precision=0, commas=False)
|
||
a100_peak_tflops_str = fmt(peak_flops/TRILLION, precision=0, commas=False)
|
||
a100_mem_bw_tbs_str = fmt(mem_bw_bytes/TRILLION, precision=0, commas=False)
|
||
```
|
||
|
||
Total: ~`{python} MnistTrainingStepCalc.mnist_total_flops_str` FLOPs, ~`{python} MnistTrainingStepCalc.mnist_mem_traffic_str` memory traffic. On an A100:
|
||
|
||
- Tcompute ≈ `{python} MnistTrainingStepCalc.mnist_total_flops_str` / `{python} MnistTrainingStepCalc.a100_peak_tflops_str` TFLOPS ≈ `{python} MnistTrainingStepCalc.mnist_t_compute_us_str`µs
|
||
- Tmemory ≈ `{python} MnistTrainingStepCalc.mnist_mem_traffic_str` / `{python} MnistTrainingStepCalc.a100_mem_bw_tbs_str` TB/s ≈ `{python} MnistTrainingStepCalc.mnist_t_memory_us_str`µs
|
||
- Toverhead ≈ `{python} MnistTrainingStepCalc.mnist_n_ops_str` ops$\times$ `{python} MnistTrainingStepCalc.mnist_us_per_op_str` μs ≈ `{python} MnistTrainingStepCalc.mnist_t_overhead_us_str` μs
|
||
|
||
The training step is overhead-bound.\index{Overhead-bound!small model training} For small models, Python dispatch and kernel launch dominate. This explains why:
|
||
|
||
- `torch.compile` provides 2--3$\times$ speedup by fusing operations and reducing kernel launches
|
||
- Batch size increases help amortize per-batch overhead
|
||
- Production training uses much larger models where compute dominates
|
||
|
||
### Phase 4: Hardware Abstraction (Solving the Abstraction Problem) {.unnumbered}
|
||
|
||
The same Python code runs on different hardware through abstraction layers:
|
||
|
||
- **CUDA GPU**\index{GPU Backend!CUDA implementation}: cuBLAS GEMM kernels, CUDA streams for async execution
|
||
- **CPU**\index{CPU Backend!optimized BLAS libraries}: Intel MKL or OpenBLAS, OpenMP for parallelism
|
||
- **TPU**\index{TPU!XLA compilation}: XLA compilation to TPU-specific HLO operations
|
||
- **Apple Silicon**: Metal Performance Shaders via MPS backend
|
||
|
||
Each backend implements the same tensor operations with hardware-specific optimizations. The framework's abstraction layer (@sec-ml-frameworks-abstraction-problem-37a5) ensures identical numerical results (within floating-point tolerance) across platforms. This is the abstraction problem solved: a single `loss.backward()` call triggers completely different code paths depending on hardware, yet produces mathematically equivalent gradients.
|
||
|
||
::: {.callout-perspective title="The Three Problems in Action"}
|
||
|
||
This trace reveals the three problems in concrete terms:
|
||
|
||
- **Execution**: Eager mode enables line-by-line debugging but incurs dispatch overhead
|
||
- **Differentiation**: Autograd tape records operations during forward, replays in reverse during backward
|
||
- **Abstraction**: Same code runs on GPU/CPU/TPU through backend-specific kernel implementations
|
||
|
||
Understanding this flow enables informed optimization: fuse operations to reduce overhead, use appropriate batch sizes, and match model scale to hardware capabilities.
|
||
|
||
:::
|
||
|
||
This detailed trace through a single training step demonstrates how deeply the three core problems interact. Even simple code exercises the full framework stack, and seemingly minor decisions---device placement, batch size, compilation mode---cascade through execution, differentiation, and abstraction layers in ways that are difficult to predict without systems-level understanding. The following section catalogs the most common misconceptions that arise when engineers lack this understanding.
|
||
|
||
## Fallacies and Pitfalls {#sec-ml-frameworks-fallacies-pitfalls-61ef}
|
||
|
||
Framework selection involves subtle trade-offs where intuitions from conventional software engineering fail. The memory wall, kernel fusion constraints, and deployment target diversity create pitfalls that waste months of engineering effort and cause production systems to miss latency targets by 10$\times$ or more.
|
||
|
||
```{python}
|
||
#| label: framework-gaps-calc
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ FRAMEWORK PERFORMANCE GAPS
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: @sec-ml-frameworks-fallacies-pitfalls-61ef — first two Fallacy/Pitfall
|
||
# │ items on framework equivalence and deployment memory footprint
|
||
# │
|
||
# │ Goal: Quantify the latency gap (PyTorch vs TensorRT: 52 ms vs 3 ms = 17×)
|
||
# │ and runtime memory gap (PyTorch Mobile vs TFLite Micro: 6875×).
|
||
# │ Show: "17× performance gap" and "6875× memory difference" — inline in both
|
||
# │ Fallacy and Pitfall prose paragraphs.
|
||
# │ How: Simple ratio arithmetic on hardcoded benchmark values; no pint units
|
||
# │ (pure latency ms and memory MB/KB scalars from @tbl-framework-efficiency-matrix).
|
||
# │
|
||
# │ Imports: mlsysim.book (fmt, check)
|
||
# │ Exports: pytorch_ms_str, tensorrt_ms_str, perf_gap_str, pytorch_mobile_mb_str,
|
||
# │ tflite_micro_kb_str, memory_ratio_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsysim.fmt import fmt, check
|
||
|
||
class FrameworkGapsCalc:
|
||
"""PyTorch vs TensorRT latency gap and PyTorch Mobile vs TFLite Micro memory gap."""
|
||
# ┌── 1. LOAD (Constants) ──────────────────────────────────────────────
|
||
pytorch_ms = 52 # 52 ms inference
|
||
tensorrt_ms = 3 # 3 ms inference
|
||
pytorch_mobile_mb = 220 # 220 MB runtime
|
||
tflite_micro_kb = 32 # 32 KB runtime
|
||
# ┌── 2. EXECUTE (The Compute) ────────────────────────────────────────
|
||
perf_gap = pytorch_ms / tensorrt_ms # ~17x gap
|
||
memory_ratio = pytorch_mobile_mb * 1000 / tflite_micro_kb # ~6875x gap (decimal SI: 1 MB = 1000 KB)
|
||
# ┌── 4. OUTPUT (Formatting) ─────────────────────────────────────────────
|
||
pytorch_ms_str = fmt(pytorch_ms, precision=0, commas=False) # e.g. "52"
|
||
tensorrt_ms_str = fmt(tensorrt_ms, precision=0, commas=False) # e.g. "3"
|
||
pytorch_mobile_mb_str = fmt(pytorch_mobile_mb, precision=0, commas=False) # e.g. "220"
|
||
tflite_micro_kb_str = fmt(tflite_micro_kb, precision=0, commas=False) # e.g. "32"
|
||
perf_gap_str = fmt(perf_gap, precision=0, commas=False) # e.g. "17"
|
||
memory_ratio_str = fmt(memory_ratio, precision=0, commas=False) # e.g. "6875"
|
||
```
|
||
|
||
**Fallacy:** *"All frameworks provide equivalent performance for the same model architecture."*
|
||
|
||
Engineers assume that ResNet-50 yields identical performance across frameworks since the mathematics is the same. In production, framework implementation matters enormously. @tbl-framework-efficiency-matrix shows PyTorch achieves `{python} FrameworkGapsCalc.pytorch_ms_str` ms inference at 32% hardware utilization while TensorRT delivers `{python} FrameworkGapsCalc.tensorrt_ms_str` ms at 88% utilization---a **`{python} FrameworkGapsCalc.perf_gap_str`$\times$ performance gap** on identical hardware. The difference arises from kernel fusion depth, graph optimization strategies, and memory access patterns that vary dramatically between frameworks. Organizations that assume equivalence miss latency SLAs and require costly last-minute framework migrations.
|
||
|
||
**Pitfall:** *Choosing frameworks based on popularity rather than project requirements.*
|
||
|
||
Engineers assume the most popular framework works for any project. In reality, deployment constraints dominate. @tbl-framework-efficiency-matrix shows PyTorch Mobile requires `{python} FrameworkGapsCalc.pytorch_mobile_mb_str` MB memory while TensorFlow Lite Micro runs in `{python} FrameworkGapsCalc.tflite_micro_kb_str` KB---a **`{python} FrameworkGapsCalc.memory_ratio_str`$\times$ difference**. Teams that prototype edge applications with PyTorch face either memory bloat that exceeds device capacity or 2--3 month framework migrations after development completes. Evaluate deployment targets per @sec-ml-frameworks-deployment-targets-13f1 *before* selecting a training framework.
|
||
|
||
**Fallacy:** *"Framework abstractions eliminate the need for systems knowledge."*
|
||
|
||
Engineers assume high-level APIs handle all optimization automatically. The **Roofline Model** (@sec-machine-foundations-roofline-model-2529) proves otherwise: element-wise operations like ReLU achieve arithmetic intensity of 0.125 FLOPs/byte, using **under 0.1%** of an A100's peak compute regardless of framework sophistication. @sec-ml-frameworks-execution-strategy-matters-memory-wall-1ce8 explains why: memory bandwidth, not compute, is the bottleneck for most operations. Engineers who lack this understanding leave 80--90% of hardware capacity unused, directly translating to 5--10$\times$ higher inference costs at production scale.
|
||
|
||
**Pitfall:** *Ignoring vendor lock-in from framework-specific formats.*
|
||
|
||
Engineers assume framework migration is straightforward since models are "just math." Converting TensorFlow SavedModel to PyTorch requires rewriting custom operations, validating numerical equivalence across 10,000+ test cases, and retraining when operations lack exact equivalents---typically **3--6 engineer-months** for production systems. ONNX (@sec-ml-frameworks-deployment-targets-13f1) provides portability but supports only 80--85% of operations. Organizations that ignore this during initial framework selection face costly migrations when deployment requirements change or better frameworks emerge.
|
||
|
||
**Pitfall:** *Selecting development frameworks without evaluating production infrastructure.*
|
||
|
||
Engineers assume training framework choice is independent of deployment infrastructure. In practice, framework-infrastructure mismatches impose substantial operational overhead. TensorFlow Serving provides atomic model swaps with zero downtime; PyTorch deployments often require container restarts imposing 30--60 second outages. TensorFlow integrates natively with monitoring tools; PyTorch requires custom instrumentation adding 2--4 weeks of development. Per @sec-ml-frameworks-major-framework-platform-analysis-fe96, evaluate the *complete deployment stack* during framework selection, including serving infrastructure, monitoring, and operational tooling.
|
||
|
||
```{python}
|
||
#| label: model-7b-memory
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ 7B MODEL MEMORY FOOTPRINT
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: Fallacy about batch size being free optimization
|
||
# │
|
||
# │ Goal: Dispel the myth that "small" models leave infinite room for batching.
|
||
# │ Show: That 14 GB of weights consumes 18% of an A100 before training starts.
|
||
# │ How: Calculate weight memory for 7B parameters in FP16.
|
||
# │
|
||
# │ Imports: mlsysim.core.constants (A100_MEM_CAPACITY, BYTES_FP16), mlsysim.formulas (model_memory)
|
||
# │ Exports: model_7b_fp16_gb_str, a100_remaining_7b_gb_str, DeviceBandwidthHierarchy.a100_mem_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsysim.core.constants import A100_MEM_CAPACITY, BYTES_FP16, GB, GiB
|
||
from mlsysim.core.formulas import model_memory
|
||
from mlsysim.fmt import fmt, check
|
||
|
||
class Model7BMemory:
|
||
"""7B-parameter FP16 weight memory to show batch-size myth under capacity constraints."""
|
||
# ┌── 1. LOAD (Constants) ──────────────────────────────────────────────
|
||
params = 7e9 # 7 billion parameters
|
||
a100_mem = A100_MEM_CAPACITY.m_as(GiB) # 80 GB
|
||
# ┌── 2. EXECUTE (The Compute) ────────────────────────────────────────
|
||
fp16_gb = model_memory(params, BYTES_FP16, GB) # 14 GB
|
||
# ┌── 4. OUTPUT (Formatting) ─────────────────────────────────────────────
|
||
model_7b_fp16_gb_str = fmt(fp16_gb, precision=0, commas=False) # e.g. "14"
|
||
a100_remaining_7b_gb_str = fmt(a100_mem - fp16_gb, precision=0, commas=False) # e.g. "66"
|
||
a100_mem_str = fmt(a100_mem, precision=0, commas=False) # e.g. "80"
|
||
```
|
||
|
||
**Fallacy:** *"Increasing batch size is a free throughput optimization within framework memory limits."*
|
||
|
||
Engineers assume that if memory is available, larger batches always improve throughput. The Dispatch Overhead Equation (@eq-dispatch-overhead) reveals hidden costs. A 7B parameter model in FP16 consumes `{python} Model7B.model_7b_fp16_gb_str` GB, leaving `{python} Model7BMemory.a100_remaining_7b_gb_str` GB on an A100-`{python} Model7BMemory.a100_mem_str` GB. Increasing batch size from 8 to 32 quadruples activation memory for transformers due to attention's $O(S^2)$ scaling, potentially triggering recomputation strategies that **reduce throughput by 20--30%** despite the larger batch. Teams that blindly maximize batch size often achieve *lower* throughput than smaller batches that avoid these memory management pathways.
|
||
|
||
**Pitfall:** *Treating compilation overhead as negligible.*
|
||
|
||
```{python}
|
||
#| label: compilation-overhead-calc
|
||
#| echo: false
|
||
|
||
# ┌─────────────────────────────────────────────────────────────────────────────
|
||
# │ COMPILATION OVERHEAD BREAK-EVEN ANALYSIS
|
||
# ├─────────────────────────────────────────────────────────────────────────────
|
||
# │ Context: Pitfall about treating compilation overhead as negligible
|
||
# │
|
||
# │ Goal: Demonstrate that compilation costs can outweigh execution benefits.
|
||
# │ Show: That frequent recompilation makes eager mode 40× faster than compiled.
|
||
# │ How: Model total time = (compile_time * recompiles) + (exec_time * samples).
|
||
# │
|
||
# │ Imports: mlsysim.book (fmt)
|
||
# │ Exports: n_images_str, n_recompilations_str, eager_throughput_str,
|
||
# │ compiled_throughput_str, compilation_time_s_str, eager_total_str,
|
||
# │ compiled_total_str
|
||
# └─────────────────────────────────────────────────────────────────────────────
|
||
from mlsysim.fmt import fmt, check
|
||
|
||
class CompilationOverheadCalc:
|
||
"""Break-even analysis showing frequent recompilation negates compiled-mode throughput gains."""
|
||
# ┌── 1. LOAD (Constants) ──────────────────────────────────────────────
|
||
n_images = 10_000 # small experiment
|
||
eager_throughput = 1_450 # images/sec (eager)
|
||
compiled_throughput = 2_150 # images/sec (compiled)
|
||
n_recompilations = 10 # code changes
|
||
compilation_time_s = 30 # seconds per compile
|
||
# ┌── 2. EXECUTE (The Compute) ────────────────────────────────────────
|
||
eager_total = n_images / eager_throughput # ~6.9 s
|
||
compiled_total = (n_images / compiled_throughput +
|
||
n_recompilations * compilation_time_s) # ~304.7 s
|
||
# ┌── 4. OUTPUT (Formatting) ─────────────────────────────────────────────
|
||
n_images_str = f"{n_images:,}" # e.g. "10,000"
|
||
n_recompilations_str = fmt(n_recompilations, precision=0, commas=False) # e.g. "10"
|
||
eager_throughput_str = f"{eager_throughput:,}" # e.g. "1,450"
|
||
compiled_throughput_str = f"{compiled_throughput:,}" # e.g. "2,150"
|
||
compilation_time_s_str = fmt(compilation_time_s, precision=0, commas=False) # e.g. "30"
|
||
eager_total_str = fmt(eager_total, precision=1, commas=False) # e.g. "6.9"
|
||
compiled_total_str = fmt(compiled_total, precision=1, commas=False) # e.g. "304.7"
|
||
```
|
||
|
||
Engineers assume compilation overhead is a one-time cost that pays off quickly. @tbl-training-benchmark shows torch.compile achieves 48% higher ResNet-50 throughput but incurs 15--60 seconds compilation overhead per graph change. For a `{python} CompilationOverheadCalc.n_images_str`-image experiment with `{python} CompilationOverheadCalc.n_recompilations_str` code changes: Eager completes in `{python} CompilationOverheadCalc.eager_total_str` seconds while Compiled requires `{python} CompilationOverheadCalc.compiled_total_str` seconds (including `{python} CompilationOverheadCalc.n_recompilations_str`$\times$ `{python} CompilationOverheadCalc.compilation_time_s_str` s recompilation overhead). Teams that enable compilation during rapid prototyping waste hours waiting for recompilations that negate any throughput gains.
|
||
|
||
## Summary {#sec-ml-frameworks-summary-07f0}
|
||
|
||
Machine learning frameworks exist to solve three fundamental problems that would otherwise make deep learning impractical:
|
||
|
||
1. **The Execution Problem**: When and how should computation happen? Frameworks navigate the trade-off between eager execution (immediate, debuggable, flexible) and graph execution (deferred, optimizable, deployable). Modern hybrid approaches like `torch.compile` attempt to provide both flexibility during development and optimization for production.
|
||
|
||
2. **The Differentiation Problem**: How do we compute gradients automatically? Frameworks implement reverse-mode automatic differentiation that computes exact gradients for arbitrary operation compositions. This transforms the mathematical chain rule into a software primitive, enabling training on billions of parameters with a single `loss.backward()` call.
|
||
|
||
3. **The Abstraction Problem**: How do we target diverse hardware from a single interface? Frameworks provide tensor abstractions, intermediate representations, and runtime systems that hide hardware complexity while enabling efficient utilization across CPUs, GPUs, TPUs, and specialized accelerators.
|
||
|
||
These problems are interconnected and constrained by the **Iron Law** of performance (@sec-introduction-iron-law-ml-systems-c32a): execution strategy determines dispatch overhead ($L_{\text{lat}}$), differentiation determines memory traffic ($D_{\text{vol}}$), and abstraction determines hardware utilization ($\eta$). The memory wall makes data movement often more expensive than computation, explaining why frameworks invest in kernel fusion, activation checkpointing, mixed-precision training, and compilation pipelines.
|
||
|
||
::: {.callout-takeaways title="The Layer Between Math and Hardware"}
|
||
|
||
* **Three problems define every framework**: Execution (how to run), differentiation (how to train), and abstraction (how to express). TensorFlow prioritizes abstraction for deployment breadth, PyTorch prioritizes execution for research velocity, and JAX reframes differentiation through composable function transformations. These are infrastructure commitments, not tooling preferences.
|
||
* **The memory wall drives optimization**: Compute has grown approximately 1000$\times$ faster than memory bandwidth. Kernel fusion, activation checkpointing, mixed-precision training, and data layout optimizations all target the data movement term ($D_{\text{vol}}$) in the Iron Law, not the compute term.
|
||
* **Compilation pays off only at scale**: The Compilation Continuum principle (@eq-compilation-benefit) quantifies when compilation benefits exceed costs. Research prototyping favors eager mode; production training and inference favor progressive compilation from JIT to AOT. The Dispatch Overhead Law (@eq-dispatch-overhead) explains why small models benefit disproportionately.
|
||
* **The nn.Module pattern is widely adopted**: Automatic parameter discovery, mode-dependent behavior, and hierarchical composition with serialization appear across major frameworks, enabling million-parameter optimization in a single `optimizer.step()` call regardless of API syntax.
|
||
* **Framework choice constrains deployment by orders of magnitude**: A 17$\times$ latency gap (PyTorch vs. TensorRT) and 7,040$\times$ memory gap (PyTorch Mobile vs. TFLite Micro) on identical models demonstrate that frameworks are not interchangeable. Deployment target must be evaluated before framework selection.
|
||
|
||
:::
|
||
|
||
Understanding framework internals transforms how practitioners approach performance debugging and optimization. When a training job runs slower than expected, engineers who understand execution graphs can identify whether the bottleneck lies in eager-mode overhead, insufficient kernel fusion, or suboptimal memory layout. When deployment fails on target hardware, the compilation pipeline reveals whether the issue is operator support, quantization compatibility, or runtime configuration. This knowledge is essential for diagnosing and resolving performance issues in production systems.
|
||
|
||
::: {.callout-chapter-connection title="From Control Room to Power Plant"}
|
||
|
||
Frameworks are the software substrate that translates abstract architectures into executable kernels. Computational graphs, autograd tapes, and kernel dispatch pipelines are the control room instruments---they give engineers visibility into and control over the training process. A control room without a source of energy, however, is just a room with glowing lights. @sec-model-training puts this machinery to work, scaling mixed-precision training, gradient checkpointing, compilation pipelines, and distributed execution contexts from single-device examples to the massive multi-GPU and multi-node orchestration that powers modern AI.
|
||
|
||
:::
|
||
|
||
::: { .quiz-end }
|
||
:::
|