From af865684e01be459f692c2e7d383d325b146b354 Mon Sep 17 00:00:00 2001 From: Vijay Janapa Reddi Date: Tue, 24 Feb 2026 14:30:44 -0500 Subject: [PATCH] vol2: finalize visual narrative (added power path, 3D parallelism, continuous batching, and carbon sankey diagrams) --- .../compute_infrastructure.qmd | 36 + .../distributed_training.qmd | 2820 +---------------- .../contents/vol2/inference/inference.qmd | 52 + .../performance_engineering.qmd | 2364 +------------- .../vol2/sustainable_ai/sustainable_ai.qmd | 37 + 5 files changed, 246 insertions(+), 5063 deletions(-) diff --git a/book/quarto/contents/vol2/compute_infrastructure/compute_infrastructure.qmd b/book/quarto/contents/vol2/compute_infrastructure/compute_infrastructure.qmd index 18bf5623b..7321bf1fd 100644 --- a/book/quarto/contents/vol2/compute_infrastructure/compute_infrastructure.qmd +++ b/book/quarto/contents/vol2/compute_infrastructure/compute_infrastructure.qmd @@ -1218,6 +1218,42 @@ For capacity planning, the sustained throughput rate, not the peak rate, should The Memory Wall constrains how fast data reaches the compute units; the Roofline Model diagnoses whether compute or memory is the binding constraint; and Tensor Cores maximize the arithmetic value of every byte fetched. A third physical constraint also limits the accelerator's performance: the heat generated by all this computation. Every FLOP dissipates energy, and the faster we compute, the more heat we must remove. This is the Power Wall. +::: {#fig-power-path fig-env="figure" fig-pos="htb" fig-cap="**The Power Delivery Path**. The journey of energy from the high-voltage grid to the low-voltage transistor. Each stage involves conversion losses (quantified by PUE) and requires stabilizing infrastructure to handle the massive current ramps of ML training. The critical engineering challenge is the Rack PDU to VRM transition, where 10--40 kW of power must be delivered within a single cabinet." fig-alt="Flowchart showing power journey from Grid Substation to Datacenter UPS, to Rack PDU, to Server PSU, to Voltage Regulator Module, finally to GPU Die. Arrows show power flow."} +```{.tikz} +\begin{tikzpicture}[font=\small\usefont{T1}{phv}{m}{n}, node distance=1.2cm] + \definecolor{GridColor}{RGB}{200,200,200} + \definecolor{FacilityColor}{RGB}{200,240,255} + \definecolor{RackColor}{RGB}{255,240,210} + \definecolor{SiliconColor}{RGB}{255,220,220} + + \tikzset{ + stage/.style={draw=black!70, thick, rounded corners=2pt, align=center, minimum width=2.8cm, minimum height=1.0cm} + } + + % Nodes + \node[stage, fill=GridColor] (Sub) {Grid\\Substation\\(115 kV+)}; + \node[stage, fill=FacilityColor, below=of Sub] (UPS) {Facility UPS\\\& Transformer\\(480 V)}; + \node[stage, fill=RackColor, below=of UPS] (PDU) {Rack PDU\\(208--415 V)}; + \node[stage, fill=RackColor, below=of PDU] (PSU) {Server PSU\\(12 V / 48 V)}; + \node[stage, fill=SiliconColor, below=of PSU] (VRM) {On-Board VRM\\(0.8--1.2 V)}; + \node[stage, fill=SiliconColor, below=of VRM] (Die) {\textbf{GPU Die}\\{(1,000+ Amps)}}; + + % Flows + \draw[->, ultra thick] (Sub) -- (UPS); + \draw[->, ultra thick] (UPS) -- (PDU); + \draw[->, ultra thick] (PDU) -- (PSU); + \draw[->, ultra thick] (PSU) -- (VRM); + \draw[->, ultra thick] (VRM) -- (Die); + + % Annotations + \node[right=0.5cm of UPS, text=BlueLine] {\textbf{Facility Level}}; + \node[right=0.5cm of PSU, text=OrangeLine] {\textbf{Rack Level}}; + \node[right=0.5cm of Die, text=RedLine] {\textbf{Silicon Level}}; + +\end{tikzpicture} +``` +::: + ## Thermal Design Power {#sec-compute-tdp} \index{TDP} diff --git a/book/quarto/contents/vol2/distributed_training/distributed_training.qmd b/book/quarto/contents/vol2/distributed_training/distributed_training.qmd index 3c4a850a9..bcf728967 100644 --- a/book/quarto/contents/vol2/distributed_training/distributed_training.qmd +++ b/book/quarto/contents/vol2/distributed_training/distributed_training.qmd @@ -20,7 +20,8 @@ from mlsys.registry import start_chapter from mlsys.constants import ( A100_MEM_CAPACITY, H100_MEM_CAPACITY, NVLINK_A100_BW, NVLINK_H100_BW, H100_MEM_BW, - GPT3_PARAMS, GB, second, Mparam, THOUSAND, + GPT3_PARAMS, GPT4_EST_PARAMS, + GB, second, Mparam, Tparam, THOUSAND, SEC_PER_HOUR, SEC_PER_DAY, MILLION, TRILLION, BITS_PER_BYTE, TB ) from mlsys.formatting import fmt, sci, check @@ -81,7 +82,7 @@ Distributed training appears simple: split the work across machines and combine # │ (GPT3_PARAMS, GPT4_EST_PARAMS) to anchor why distribution is necessary. # │ Show: "80" GB A100 memory and "600" GB/s NVLink H100 bandwidth — inline # │ in the memory exhaustion and interconnect hierarchy paragraphs. -# │ How: Convert Mparam → billions via /THOUSAND; Tparam direct .m_as(). +# │ How: Convert Mparam -> billions via /THOUSAND; Tparam direct .m_as(). # │ # │ Imports: mlsys.constants (A100_MEM_CAPACITY, H100_MEM_CAPACITY, # │ NVLINK_A100_BW, NVLINK_H100_BW, GPT3_PARAMS, GPT4_EST_PARAMS, @@ -98,2781 +99,44 @@ from mlsys.constants import ( ) from mlsys.formatting import fmt, sci, check -# ┌── LEGO ─────────────────────────────────────────────── -## Why Distribution Is Necessary {#sec-distributed-training-systems-systems-multimachine-scaling-fundamentals-ff96}``` +# ┌── P.I.C.O. ISOLATED SCENARIO ─────────────────────────────────────────────── +class DistTrainSetup: + """ + Namespace for Distributed Training reference specs. + Scenario: GPU and interconnect parameters for large-scale clusters. + """ + + # ┌── 1. PARAMETERS (Inputs) ─────────────────────────────────────────────── + a100_cap = A100_MEM_CAPACITY + h100_cap = H100_MEM_CAPACITY + nvlink_a100_bw = NVLINK_A100_BW + nvlink_h100_bw = NVLINK_H100_BW + + gpt3_p = GPT3_PARAMS.m_as(Mparam) + gpt4_p = GPT4_EST_PARAMS.m_as(Tparam) + + # ┌── 2. CALCULATION (The Physics) ───────────────────────────────────────── + gpt3_params_b_val = gpt3_p / 1000 + + # ┌── 3. INVARIANTS (Guardrails) ─────────────────────────────────────────── + check(gpt3_params_b_val == 175, f"Expected 175B params, got {gpt3_params_b_val}") + + # ┌── 4. OUTPUTS (Formatting) ────────────────────────────────────────────── + a100_mem = f"{a100_cap.m_as(GB):.0f}" + h100_mem = f"{h100_cap.m_as(GB):.0f}" + nvlink_a100 = f"{nvlink_a100_bw.m_as(GB/second):.0f}" + nvlink_h100 = f"{nvlink_h100_bw.m_as(GB/second):.0f}" + + gpt3_params_b = f"{gpt3_params_b_val:.0f}" + gpt4_est_params_t = f"{gpt4_p:.1f}" + +# ┌── EXPORTS (Bridge to Text) ───────────────────────────────────────────────── +a100_mem = DistTrainSetup.a100_mem +h100_mem = DistTrainSetup.h100_mem +nvlink_a100 = DistTrainSetup.nvlink_a100 +nvlink_h100 = DistTrainSetup.nvlink_h100 +gpt3_params_b = DistTrainSetup.gpt3_params_b +gpt4_est_params_t = DistTrainSetup.gpt4_est_params_t +``` ## Why Distribution Is Necessary {#sec-distributed-training-systems-systems-multimachine-scaling-fundamentals-ff96} - -Part I built the physical fleet: @sec-compute-infrastructure established the accelerator hierarchy, @sec-network-fabrics wired nodes into a high-bandwidth fabric, and @sec-data-storage completed the infrastructure with storage pipelines that keep the fleet fed. With the physical foundation in place, we now confront the algorithmic question that defines Part II: how do we split a single training job across this hardware? - -If you could purchase a single GPU with 100 terabytes of memory and an exaflop of compute, distributed training would not exist. Because the laws of physics prevent this, we are forced to shatter our models across thousands of independent chips. In the **Fleet Stack** framework (@sec-vol2-introduction), Distributed Training represents the **Distribution Layer** — the logic that partitions the mathematical workload across the physical fleet. - -### The Physics of the Cluster - -Before optimizing algorithms, we must understand the physical constraints of the **Machine Learning Fleet**. The performance of any distributed training job is governed by the **Iron Law of Scale** introduced in @sec-vol2-law-distributed-efficiency: - -$$ T_{\text{step}}(N) = \frac{T_{\text{compute}}}{N} + T_{\text{comm}}(N) - T_{\text{overlap}} $$ - -The critical term here is the **Communication-Computation Ratio** ($T_{\text{comm}}/T_{\text{compute}}$). This ratio determines whether your cluster behaves as a supercomputer or a collection of idling heaters. - -* **Compute-Bound (Low Ratio)**: $T_{\text{compute}} \gg T_{\text{comm}}$. The GPUs spend most of their time multiplying matrices. This is the ideal state, typical for large batch sizes on dense models (like ResNet). -* **Communication-Bound (High Ratio)**: $T_{\text{comm}} \approx T_{\text{compute}}$. The GPUs spend significant time waiting for gradients or activations to arrive. This is the common state for Large Language Models (LLMs) and Recommendation Systems (DLRMs), where parameter synchronization saturates the network. - -::: {.callout-note title="Connection: The Fleet Stack"} -In the **Fleet Stack** framework (@fig-fleet-stack in @sec-vol2-introduction), Distributed Training represents the **Distribution Layer**. We are defining *how* to split the math. The actual execution of these split workloads happens on the **Infrastructure Layer**, which we built in Part I. The algorithms defined here (Ring AllReduce, Tensor Parallelism) dictate the bandwidth requirements for the physical interconnects (NVLink, InfiniBand) discussed in @sec-network-fabrics. -::: - -### Multi-Machine Training Requirements {#sec-distributed-training-systems-systems-multimachine-training-requirements-0277} - -Three concrete signals indicate when distributed training becomes necessary rather than merely beneficial. First, **memory exhaustion** occurs when model parameters, optimizer states, and activation storage exceed single-device capacity. For transformer models, this threshold typically occurs around 10--20 billion parameters on current generation GPUs with 40--80 GB memory [@rajbhandari2020zero]. - -Second, **unacceptable training duration** emerges when single-device training would require weeks or months to converge. Training GPT-3 on a single V100 GPU would require approximately 355 years [@brown2020language], making distributed approaches not optional but essential. - -Third, **dataset scale** exceeds single-machine storage when training data reaches multiple terabytes, as occurs in large-scale vision or language modeling tasks. - -### Distributed Training Complexity Trade-offs {#sec-distributed-training-systems-systems-distributed-training-complexity-tradeoffs-0138} - -Distributed training introduces three primary complexity dimensions not present in single-machine scenarios: - -1. **Communication Overhead**: The cost of synchronizing gradients. For a model with $N$ parameters distributed across $D$ devices, all-reduce operations must transfer approximately $2N(D-1)/D$ bytes per step. On commodity networks, this can dominate computation time. -2. **Fault Tolerance**: Requirements increase exponentially with cluster size. A 100-node cluster with 99.9% per-node reliability experiences failures every few hours. -3. **Algorithmic Stability**: Large batch sizes from data parallelism affect convergence behavior, requiring learning rate scaling and warmup strategies that single-machine training does not require [@goyal2017accurate]. - -### Single-Machine to Distributed Transition {#sec-distributed-training-systems-systems-singlemachine-distributed-transition-1ee4} - -The systematic optimization methodology established for single-machine training extends to distributed environments with important adaptations. Profiling must now capture inter-device communication patterns and synchronization overhead in addition to computation and memory metrics. The solution space expands to include data parallelism, model parallelism, pipeline parallelism, and hybrid approaches. @fig-3d-parallelism-cube visualizes this three-dimensional configuration space. - -::: {#fig-3d-parallelism-cube fig-env="figure" fig-pos="htb" fig-cap="**The 3D Parallelism Cube**. A conceptual visualization of the three orthogonal scaling axes: Data Parallelism (replicating the model), Tensor Parallelism (splitting layers), and Pipeline Parallelism (splitting depth). Production training for models like GPT-4 occupies a specific point $(d, t, p)$ within this cube to balance memory usage, compute efficiency, and communication overhead." fig-alt="3D coordinate system with three axes: Data Parallelism (replica count d), Pipeline Parallelism (depth p), and Tensor Parallelism (width t). Dashed cube marks a configuration point with Total Accelerators equals d times p times t."} -```{.tikz} -\begin{tikzpicture}[font=\small\usefont{T1}{phv}{m}{n}, z={(0.6,0.5)}, x={(1,0)}, y={(0,1)}, scale=1.3] - % Origin - \coordinate (O) at (0,0,0); - - % Axes Arrows - \draw[->, ultra thick, BlueLine] (O) -- (4.5,0,0) node[anchor=north east, align=center, font=\bfseries] {Data Parallelism\\(Replica Count $d$)}; - \draw[->, ultra thick, RedLine] (O) -- (0,3.5,0) node[anchor=east, align=right, font=\bfseries] {Pipeline Parallelism\\(Depth/Stages $p$)}; - \draw[->, ultra thick, GreenLine] (O) -- (0,0,3.5) node[anchor=south west, align=left, font=\bfseries] {Tensor Parallelism\\(Layer Width $t$)}; - - % The "Cube" of a specific training configuration - % Let's show a config of (d=2, p=2, t=2) - \draw[thick, gray!50, dashed] (2,0,0) -- (2,2,0) -- (0,2,0); - \draw[thick, gray!50, dashed] (2,0,0) -- (2,0,2) -- (0,0,2); - \draw[thick, gray!50, dashed] (0,2,0) -- (0,2,2) -- (0,0,2); - \draw[thick, gray!50, dashed] (2,2,0) -- (2,2,2) -- (0,2,2); % Front corner - \draw[thick, gray!50, dashed] (2,0,2) -- (2,2,2); - \draw[thick, gray!50, dashed] (0,2,2) -- (2,2,2); - - % Highlight the volume - \fill[gray!10, opacity=0.5] (O) -- (2,0,0) -- (2,0,2) -- (0,0,2) -- cycle; % Bottom - \fill[gray!10, opacity=0.5] (O) -- (0,2,0) -- (0,2,2) -- (0,0,2) -- cycle; % Side - \fill[gray!10, opacity=0.5] (O) -- (2,0,0) -- (2,2,0) -- (0,2,0) -- cycle; % Back - - % Label the point - \fill[black] (2,2,2) circle (1.5pt) node[anchor=west, xshift=2mm, font=\bfseries] {Training Config $(d, p, t)$}; - \node[anchor=north west, gray, font=\footnotesize] at (2,2,2) {Total Accelerators = $d \times p \times t$}; - - % Contextual Constraints - \node[align=center, font=\footnotesize, text=BlueLine] at (2.5, -0.5, 0) {constraint: Global Batch Size}; - \node[align=center, font=\footnotesize, text=RedLine] at (-0.5, 2.5, 0) {constraint: Latency/Bubbles}; - \node[align=center, font=\footnotesize, text=GreenLine] at (-0.5, 0, 2.5) {constraint: Memory/NVLink}; - -\end{tikzpicture} -``` -::: - -### Engineering Trade-offs: Selecting a Parallelism Strategy {#sec-distributed-training-systems-systems-engineering-tradeoffs-selecting-parallelism-strategy-b344} - -Choosing the right parallelism strategy is not a matter of preference; it is a constraint satisfaction problem governed by model size ($M$), batch size ($B$), and interconnect bandwidth. @tbl-parallelism-tradeoffs quantifies the communication costs for each strategy, revealing which approaches are physically feasible for a given hardware topology. - -| **Strategy** | **Communication Pattern** | **Comm. Volume** | **Hardware Constraint** | -|:---------------------------|:--------------------------|:------------------------------|:--------------------------------| -| **Data Parallel (DP)** | AllReduce Gradients | $\propto M$ (Model Size) | Requires high bisection BW | -| **Tensor Parallel (TP)** | AllReduce Activations | $\propto B \times L$ (Layers) | **Critical**: Needs NVLink | -| **Pipeline Parallel (PP)** | Point-to-Point (P2P) | $\propto B \times H$ (Hidden) | Low BW (Ethernet is sufficient) | - -: **Parallelism Communication Costs**: Tensor Parallelism has the highest communication frequency (per layer), confining it to intra-node (NVLink) usage. Pipeline Parallelism has the lowest communication volume (boundary activations only), making it suitable for inter-node (Ethernet/InfiniBand) scaling. Data Parallelism sits in the middle but scales poorly as $M$ grows, necessitating ZeRO optimizations. {#tbl-parallelism-tradeoffs} - -These bandwidth requirements impose a hard constraint on hardware placement, a principle we call the *Jeff Dean Test*. - -::: {.callout-perspective title="The Jeff Dean Test"} -If you attempt **Tensor Parallelism** across server racks connected by standard Ethernet, your training will stall. The communication volume (proportional to Batch$\times$ Layers) requires the `{python} nvlink_a100`-`{python} nvlink_h100` GB/s bandwidth of NVLink. For cross-rack scaling, you *must* switch to **Pipeline** or **Data** parallelism to respect the physics of the network. -::: - -@fig-parallelism-decision-tree formalizes this constraint satisfaction process as a decision tree, showing how model size and hardware topology determine the viable parallelism strategies. - -::: {#fig-parallelism-decision-tree fig-env="figure" fig-pos="htb" fig-cap="**Parallelism Strategy Decision Tree**. Starting from the model's memory requirements, the tree guides practitioners through the constraint satisfaction process that determines which parallelism strategies are physically feasible. The critical branching points are model memory versus single-GPU capacity and communication bandwidth versus parallelism demands. Leaf nodes are annotated with the dominant hardware constraint." fig-alt="Decision tree flowchart. Root asks if model fits in one GPU. Yes branch leads to data parallelism. No branch asks if model fits in one node, leading to tensor or pipeline parallelism with hardware annotations."} -```{.tikz} -\begin{tikzpicture}[font=\small\usefont{T1}{phv}{m}{n}, scale=0.85, transform shape] - \definecolor{BlueLine}{HTML}{006395} - \definecolor{BlueL}{HTML}{D1E6F3} - \definecolor{GreenLine}{HTML}{008F45} - \definecolor{GreenL}{HTML}{D4EFDF} - \definecolor{OrangeLine}{HTML}{E67817} - \definecolor{OrangeL}{HTML}{FCE4CC} - \definecolor{RedLine}{HTML}{CB202D} - \definecolor{RedL}{HTML}{F5D2D5} - \definecolor{VioletLine}{HTML}{7E317B} - \definecolor{VioletL}{HTML}{E6D4E5} - - \tikzset{ - decision/.style={diamond, draw=#1, fill=#1!12, thick, aspect=2.2, - inner sep=1pt, align=center, font=\scriptsize}, - leaf/.style={draw=#1, fill=#1!20, rounded corners=3pt, thick, - minimum width=2.4cm, minimum height=0.9cm, align=center, font=\scriptsize}, - hw/.style={font=\tiny\itshape, text=black!60, align=center}, - yes/.style={font=\scriptsize\bfseries, text=GreenLine}, - no/.style={font=\scriptsize\bfseries, text=RedLine}, - arrow/.style={-{Triangle[width=5pt,length=4pt]}, thick, black!50} - } - - % Root decision - \node[decision=BlueLine] (d1) at (0, 0) - {Model fits in\\single GPU\\memory?}; - - % YES branch: Data Parallel - \node[leaf=GreenLine] (dp) at (-4.5, -2.2) - {\textbf{Data Parallelism}\\Scale batch size}; - \node[hw] at (-4.5, -3.0) {InfiniBand sufficient}; - - \draw[arrow] (d1) -- node[yes, left, pos=0.3] {Yes} (dp); - - % NO branch: second decision - \node[decision=OrangeLine] (d2) at (3.5, -2.2) - {Model fits in\\single node\\(8 GPUs)?}; - - \draw[arrow] (d1) -- node[no, right, pos=0.3] {No} (d2); - - % YES from d2: Tensor Parallel - \node[leaf=BlueLine] (tp) at (0.5, -4.8) - {\textbf{Tensor Parallelism}\\Split layers across GPUs}; - \node[hw] at (0.5, -5.6) {Requires NVLink}; - - \draw[arrow] (d2) -- node[yes, left, pos=0.3] {Yes} (tp); - - % NO from d2: third decision - \node[decision=RedLine] (d3) at (7.0, -4.8) - {Sequential\\dependencies\\dominate?}; - - \draw[arrow] (d2) -- node[no, right, pos=0.3] {No} (d3); - - % YES from d3: Pipeline Parallel - \node[leaf=OrangeLine] (pp) at (4.0, -7.2) - {\textbf{Pipeline Parallelism}\\Stage layers across nodes}; - \node[hw] at (4.0, -8.0) {Low BW (Ethernet OK)}; - - \draw[arrow] (d3) -- node[yes, left, pos=0.3] {Yes} (pp); - - % NO from d3: 3D Hybrid - \node[leaf=VioletLine] (hybrid) at (10.0, -7.2) - {\textbf{3D Hybrid}\\TP + PP + DP}; - \node[hw] at (10.0, -8.0) {NVLink + IB + Ethernet}; - - \draw[arrow] (d3) -- node[no, right, pos=0.3] {No} (hybrid); - -\end{tikzpicture} -``` -::: - -The decision tree reveals that parallelism strategy selection is not a preference but a consequence of physical constraints. The next question is how these constraints shape the mechanics of a distributed training step on a real cluster. - -## The Distributed Training Step {#sec-distributed-training-systems-systems-distributed-training-fundamentals-97da} - -How exactly do 1,024 GPUs, operating completely independently, agree on a single, mathematically rigorous set of updated weights at the end of a training iteration? The single-machine optimization techniques discussed in the previous section only delay the inevitable; eventually, the computation must span multiple devices. - -::: {.callout-definition title="Distributed Training"} - -***Distributed Training***\index{Distributed Training!definition} is the parallelization of the optimization loop across **Multiple Compute Devices** through coordinated partitioning and synchronization. - -1. **Significance (Quantitative):** It enables the training of models that exceed the **Memory Capacity** and **Compute Throughput** ($R_{\text{peak}}$) of a single device, reducing the total wall-clock time $T$ for a given $O$. -2. **Distinction (Durable):** Unlike **Traditional Distributed Systems** (which scale for independent requests), Distributed Training scales for **Coordinated State Updates**, requiring high-bandwidth, low-latency synchronization of gradients and weights. -3. **Common Pitfall:** A frequent misconception is that performance scales linearly with the number of devices ($N$). In reality, the **Communication Overhead ($L_{\text{lat}}$)** and **Bisection Bandwidth ($BW$)** eventually create a scaling ceiling (Amdahl's Law). - -::: - -A useful mental model frames these distributed strategies as *loop transformations*, the same conceptual toolkit that compilers use to optimize sequential code. - -::: {.callout-perspective title="Training as Loop Transforms"} -If we view the training process as a massive loop over data and layers, distributed strategies are simply **Loop Transformations** applied by the cluster-level compiler: - -* **Data Parallelism = Parallel For Loop.** We unroll the outer loop (batch dimension) across devices. Each device runs the same code body on different data indices. -* **Tensor Parallelism = Vectorization (SIMD).** We split the inner loops (matrix multiplication) across devices. This is "Cluster-Scale SIMD," where NVLink acts as the vector register file. -* **Pipeline Parallelism = Instruction Pipelining.** We split the sequential operations (layers) across devices. Just as a CPU pipeline stages fetch/decode/execute, the cluster stages Layer 1/Layer 2/Layer 3 to keep all ALUs busy. -::: - -The progression from single-machine to distributed training follows a natural scaling path of optimizing locally first, then scaling horizontally. This progression ensures that distributed systems operate efficiently at each node while adding the coordination mechanisms necessary for multi-machine training. Training machine learning models often requires scaling beyond a single machine due to increasing model complexity and dataset sizes. The demand for computational power, memory, and storage can exceed the capacity of individual devices, especially in domains like natural language processing and computer vision. Distributed training[^fn-distbelief-distributed] addresses this challenge by spreading the workload across multiple machines that coordinate to train a single model efficiently. - -[^fn-distbelief-distributed]: **Distributed Training**: Google's DistBelief (2012) was the first framework to train neural networks across thousands of machines, but its parameter server architecture created bandwidth bottlenecks at central nodes. This limitation drove the shift to decentralized AllReduce patterns in successors like Horovod and PyTorch DDP, where communication cost scales as $2(N-1)/N$ per worker rather than concentrating at a single server. \index{DistBelief!distributed training} - -This coordination relies on consensus protocols and synchronization primitives to ensure parameter updates remain consistent across nodes. While basic barrier synchronization suffices for research, production deployments require careful fault tolerance, checkpointing, and recovery mechanisms. @sec-fault-tolerance-reliability examines these reliability engineering challenges, including how to handle node failures without losing days of training progress. - -With these coordination mechanisms in place, practitioners follow a systematic progression from single-device to distributed training, with each stage building on the previous level's challenges. Single GPU training requires only local memory management and straightforward forward/backward passes, establishing the baseline computational patterns. Scaling to multiple GPUs within a single node introduces high-bandwidth communication requirements, typically handled through NVLink[^fn-nvlink-intranode] or PCIe connections with NCCL[^fn-nccl-topology] optimization while preserving the single-machine simplicity of fault tolerance and scheduling. - -[^fn-nvlink-intranode]: **NVLink**: NVIDIA's point-to-point GPU interconnect delivers `{python} nvlink_a100`--`{python} nvlink_h100` GB/s bidirectional bandwidth, roughly 10$\times$ InfiniBand HDR. This bandwidth gap is why tensor parallelism -- which requires AllReduce on every layer -- is confined to intra-node communication, while pipeline and data parallelism tolerate the slower inter-node fabric. \index{NVLink!distributed training} - -[^fn-nccl-topology]: **NCCL (NVIDIA Collective Communications Library)**: Released in 2015, NCCL automatically selects ring, tree, or hierarchical AllReduce algorithms based on detected hardware topology -- NVLink within nodes, InfiniBand across nodes. This topology-aware routing is critical because a naive single-ring AllReduce across 128 nodes would force all traffic through the slowest inter-node link, collapsing bandwidth utilization to under 30%. \index{NCCL!topology} - -The leap to multi-node distributed training introduces new complexity dimensions: network communication overhead, fault tolerance requirements, and cluster orchestration challenges. Each scaling stage compounds the previous challenges as communication bottlenecks intensify, synchronization overhead grows, and failure probability increases. Practitioners should therefore optimize single-GPU performance before scaling to ensure efficient resource utilization at each level. - -::: {.callout-note title="Distributed Training Complexity"} - -Although modern frameworks abstract away much of the complexity through sharded data parallelism and communication libraries, implementing distributed training efficiently remains a significant engineering challenge. Production deployments require careful network configuration (InfiniBand tuning, topology-aware routing), infrastructure management through cluster schedulers, and debugging of non-local issues such as synchronization hangs and communication bottlenecks. - -::: - -The distributed training process itself involves splitting the dataset into non-overlapping subsets, assigning each subset to a different GPU, and performing forward and backward passes independently on each device. Once gradients are computed on each GPU, they are synchronized and aggregated before updating the model parameters, ensuring that all devices learn in a consistent manner. The coordinated flow of data splitting, computation, and gradient synchronization (@fig-distributed-training) forms the foundation of distributed training, with each GPU processing its batch independently before synchronization brings all gradients together. - -::: {#fig-distributed-training fig-env="figure" fig-pos="htb" fig-cap="**Data Parallel Training Flow**. Distributed training partitions datasets across GPUs, computes gradients concurrently on each device's data subset, then aggregates gradients through AllReduce to update shared model parameters. Each GPU maintains an identical model copy and processes its portion of the batch independently, with synchronization occurring only during gradient aggregation. This approach achieves near-linear speedup when communication overhead remains below 30--40% of training time." fig-alt="Two parallel GPU workflows showing data parallel training. Each GPU processes a data chunk through forward pass, error computation, loss function, backward pass, then gradients merge at Calculate Global Gradients for parameter updates."} -```{.tikz} -\begin{tikzpicture}[font=\usefont{T1}{phv}{m}{n}\small] -\tikzset{% - mycylinder/.style={cylinder, shape border rotate=90, aspect=1.3, draw, fill=white, -minimum width=20mm,minimum height=9mm,line width=1pt}, - mycycle/.style={circle, draw=none, fill=red, minimum width=5mm}, - myline/.style={line width=1.15pt,draw=cyan}, -% - Box/.style={align= flush center, - inner xsep=2pt, - draw=RedLine, - line width=0.75pt, - fill=RedL!20, - text width=22mm, - minimum width=22mm, minimum height=8mm - }, -% -Line/.style={line width=1.0pt,black!50} -} - -\begin{scope}[node distance=-1.7,local bounding box = SC1]] -\node[mycylinder,fill=red!30] (A) {}; -\scoped[on background layer] -\node[mycylinder, above=of A,fill=red!50] (C) {}; -\node[mycylinder, below=of A,fill=red!10] (B) {}; -\end{scope} - -\begin{scope}[node distance=0.2,shift={(3.5,5))},local bounding box = SC2] -\node[mycycle] (C1) {}; -\node[mycycle,below=of C1] (C2) {}; -\node[mycycle,below=of C2] (C3) {}; -\node[mycycle,below=of C3] (C4) {}; -\node[mycycle,fill=violet,left=0.6 of $(C1)!0.5!(C2)$] (CL1) {}; -\node[mycycle,fill=violet,left=0.6 of $(C2)!0.5!(C3)$] (CL2) {}; -\node[mycycle,fill=violet,left=0.6 of $(C3)!0.5!(C4)$] (CL3) {}; -% -\node[mycycle,fill=green,right=0.6 of $(C1)!0.4!(C3)$] (CD1) {}; -\node[mycycle,fill=green,right=0.6 of $(C2)!0.6!(C4)$] (CD2) {}; -% -\foreach \x in {1,2,3,4} { - \foreach \y in {CL1, CL2, CL3, CD1, CD2} { - \draw[myline] (\y) -- (C\x); - } -} -\node[Box,below=0.8 of C4](B1){GPU 1}; -\draw[myline,dashed](C4)--(B1); -\end{scope} - -\begin{scope}[node distance=0.2,shift={(11.5,5))},local bounding box = SC3] -\node[mycycle] (3C1) {}; -\node[mycycle,below=of 3C1] (3C2) {}; -\node[mycycle,below=of 3C2] (3C3) {}; -\node[mycycle,below=of 3C3] (3C4) {}; -\node[mycycle,fill=violet,right=0.6 of $(3C1)!0.5!(3C2)$] (3CL1) {}; -\node[mycycle,fill=violet,right=0.6 of $(3C2)!0.5!(3C3)$] (3CL2) {}; -\node[mycycle,fill=violet,right=0.6 of $(3C3)!0.5!(3C4)$] (3CL3) {}; -% -\node[mycycle,fill=green,left=0.6 of $(3C1)!0.4!(3C3)$] (3CD1) {}; -\node[mycycle,fill=green,left=0.6 of $(3C2)!0.6!(3C4)$] (3CD2) {}; -% -\foreach \x in {1,2,3,4} { - \foreach \y in {3CL1, 3CL2, 3CL3, 3CD1, 3CD2} { - \draw[myline] (\y) -- (3C\x); - } -} - -\node[Box,below=0.8 of 3C4](3B1){GPU 1}; -\draw[myline,dashed](3C4)--(3B1); -\end{scope} - -\begin{scope}[node distance=0.2,shift={(20.0,5))},local bounding box = SC4] -\node[mycycle] (4C1) {}; -\node[mycycle,below=of 4C1] (4C2) {}; -\node[mycycle,below=of 4C2] (4C3) {}; -\node[mycycle,below=of 4C3] (4C4) {}; -\node[mycycle,fill=violet,left=0.6 of $(4C1)!0.5!(4C2)$] (4CL1) {}; -\node[mycycle,fill=violet,left=0.6 of $(4C2)!0.5!(4C3)$] (4CL2) {}; -\node[mycycle,fill=violet,left=0.6 of $(4C3)!0.5!(4C4)$] (4CL3) {}; -% -\node[mycycle,fill=green,right=0.6 of $(4C1)!0.4!(4C3)$] (4CD1) {}; -\node[mycycle,fill=green,right=0.6 of $(4C2)!0.6!(4C4)$] (4CD2) {}; -% -% -\foreach \x in {1,2,3,4} { - \foreach \y in {4CL1, 4CL2, 4CL3, 4CD1, 4CD2} { - \draw[myline] (\y) -- (4C\x); - } -} -\node[Box,below=0.8 of 4C4](4B1){GPU 1}; -\draw[myline,dashed](4C4)--(4B1); -\end{scope} -\coordinate(X)at($(CD1)!0.5!(CD2)$); -\coordinate(Y)at($(3CD1)!0.5!(3CD2)$); - -\node[fill=white,minimum height=45](ER)at($(X)!0.3!(Y)$){Error}; -\node[fill=white,align=center,minimum height=45](CO)at($(X)!0.7!(Y)$){Compute\\ loss\\ function}; -\draw[myline,-latex,shorten <=3mm](X)--(ER.west); -\draw[myline,-latex](ER.east)--(CO.west); -\draw[myline,-latex,shorten >=3mm](CO.east)--(Y); -\draw[myline,dashed](CO.south)--++(270:1)-|node[fill=white,align=center, -pos=0.25](COM){Compare\\ predicted\\ label with\\ annotation} -(ER.south); - -\node[fill=white,align=center,minimum height=45](OP)at($(3CL2)!0.7!(4CL2)$){Avg\\ global\\ gradient}; -\draw[myline,latex-,shorten <=1mm](4CL2)--(OP.east); -% -\draw[myline,latex-,shorten <=3mm,shorten >=3mm](CL2)--(SC1.east|-CL2); -% -\draw[myline,latex-,shorten <=3mm,shorten >=3mm](CL2)-|node[fill=white,pos=0.75]{Chunk}(SC1.north); -% -\path[myline,draw=none,dashed](OP.north west)--++(90:1.2)coordinate(OP1); -\draw[myline,dashed]($(ER.north east)!0.5!(CO.north west)$)--++(90:1.2)coordinate(ER1); -\coordinate (C) at ($(OP1) + (0,5mm)$); -\coordinate (B) at ($(ER1) + (0,5mm)$); -\path[red](C)-|coordinate(D1)(4CD1); -\path[red](B)-|coordinate(A1)(SC1); -\coordinate (D) at ($(D1) + (15mm,0)$); -\coordinate (A) at ($(A1) + (-15mm,0)$); -\draw[myline,dashed,shorten >=3mm,shorten <=3mm](B)-- -node[fill=white]{Step 2 -- Compute gradients}(C); -\draw[myline,dashed,shorten >=3mm,pos=0.46](C)-- -node[fill=white]{Step 3 -- Update Parameters}(D); -\draw[myline,dashed,shorten >=3mm,pos=0.46](B)-- -node[fill=white]{Step 1 -- Predict a label}(A); - -\node[above=0.2 of SC2]{Forward pass}; -\node[above=0.2 of SC3]{Backward pass}; -%%%%%%%%%%%%%%%%%%%%%%% -%down -%%%%%%%%%%%%%%%%%%%%%%% -\begin{scope}[node distance=0.2,shift={(3.5,-2))},local bounding box = DSC2] -\node[mycycle] (DC1) {}; -\node[mycycle,below=of DC1] (DC2) {}; -\node[mycycle,below=of DC2] (DC3) {}; -\node[mycycle,below=of DC3] (DC4) {}; -\node[mycycle,fill=violet,left=0.6 of $(DC1)!0.5!(DC2)$] (DCL1) {}; -\node[mycycle,fill=violet,left=0.6 of $(DC2)!0.5!(DC3)$] (DCL2) {}; -\node[mycycle,fill=violet,left=0.6 of $(DC3)!0.5!(DC4)$] (DCL3) {}; -% -\node[mycycle,fill=green,right=0.6 of $(DC1)!0.4!(DC3)$] (DCD1) {}; -\node[mycycle,fill=green,right=0.6 of $(DC2)!0.6!(DC4)$] (DCD2) {}; -% -\foreach \x in {1,2,3,4} { - \foreach \y in {DCL1, DCL2, DCL3, DCD1, DCD2} { - \draw[myline] (\y) -- (DC\x); - } -} -\node[Box,above=0.8 of DC1](DB1){GPU 2}; -\draw[myline,dashed](DC1)--(DB1); -\end{scope} - -\begin{scope}[node distance=0.2,shift={(11.5,-2))},local bounding box = DSC3] -\node[mycycle] (D3C1) {}; -\node[mycycle,below=of D3C1] (D3C2) {}; -\node[mycycle,below=of D3C2] (D3C3) {}; -\node[mycycle,below=of D3C3] (D3C4) {}; -\node[mycycle,fill=violet,right=0.6 of $(D3C1)!0.5!(D3C2)$] (D3CL1) {}; -\node[mycycle,fill=violet,right=0.6 of $(D3C2)!0.5!(D3C3)$] (D3CL2) {}; -\node[mycycle,fill=violet,right=0.6 of $(D3C3)!0.5!(D3C4)$] (D3CL3) {}; -% -\node[mycycle,fill=green,left=0.6 of $(D3C1)!0.4!(D3C3)$] (D3CD1) {}; -\node[mycycle,fill=green,left=0.6 of $(D3C2)!0.6!(D3C4)$] (D3CD2) {}; -% -\foreach \x in {1,2,3,4} { - \foreach \y in {D3CL1, D3CL2, D3CL3, D3CD1, D3CD2} { - \draw[myline] (\y) -- (D3C\x); - } -} - -\node[Box,above=0.8 of D3C1](D3B1){GPU 2}; -\draw[myline,dashed](D3C1)--(D3B1); -\end{scope} - -\begin{scope}[node distance=0.2,shift={(20.0,-2))},local bounding box = DSC4] -\node[mycycle] (D4C1) {}; -\node[mycycle,below=of D4C1] (D4C2) {}; -\node[mycycle,below=of D4C2] (D4C3) {}; -\node[mycycle,below=of D4C3] (D4C4) {}; -\node[mycycle,fill=violet,left=0.6 of $(D4C1)!0.5!(D4C2)$] (D4CL1) {}; -\node[mycycle,fill=violet,left=0.6 of $(D4C2)!0.5!(D4C3)$] (D4CL2) {}; -\node[mycycle,fill=violet,left=0.6 of $(D4C3)!0.5!(D4C4)$] (D4CL3) {}; -% -\node[mycycle,fill=green,right=0.6 of $(D4C1)!0.4!(D4C3)$] (D4CD1) {}; -\node[mycycle,fill=green,right=0.6 of $(D4C2)!0.6!(D4C4)$] (D4CD2) {}; -% -\foreach \x in {1,2,3,4} { - \foreach \y in {D4CL1, D4CL2, D4CL3, D4CD1, D4CD2} { - \draw[myline] (\y) -- (D4C\x); - } -} -\node[Box,above=0.8 of D4C1](D4B1){GPU 2}; -\draw[myline,dashed](D4C1)--(D4B1); -\end{scope} -%%%%% -\coordinate(DX)at($(DCD1)!0.5!(DCD2)$); -\coordinate(DY)at($(D3CD1)!0.5!(D3CD2)$); - -\node[fill=white,minimum height=45](DER)at($(DX)!0.3!(DY)$){Error}; -\node[fill=white,align=center,minimum height=45](DCO)at($(DX)!0.7!(DY)$){Compute\\ loss\\ function}; -\draw[myline,-latex,shorten <=3mm](DX)--(DER.west); -\draw[myline,-latex](DER.east)--(DCO.west); -\draw[myline,-latex,shorten >=3mm](DCO.east)--(DY); -\draw[myline,dashed](DCO.north)--++(90:1)-|node[fill=white,align=center, -pos=0.25](DCOM){Compare\\ predicted\\ label with\\ annotation}(DER.north); - -\node[fill=white,align=center,minimum height=45](DOP)at($(D3CL2)!0.7!(D4CL2)$){Avg\\ global\\ gradient}; -\draw[myline,latex-,shorten <=1mm](D4CL2)--(DOP.east); -% -\draw[myline,latex-,shorten <=3mm,shorten >=3mm](DCL2)-| -node[fill=white,pos=0.75]{Chunk}(SC1.south); -% -\node[below=0.2 of DSC2]{Forward pass}; -\node[below=0.2 of DSC3]{Backward pass}; -%%% -\coordinate(S1)at($(3B1)!0.5!(4B1)$); -\coordinate(S2)at($(D3B1)!0.5!(D4B1)$); -\coordinate(S)at($(S1)!0.5!(S2)$); - -\node[draw=none,fill=green!50!black!90,text=white,inner xsep=10pt, - inner ysep=9pt, outer sep=5pt](CGG)at(S){\textbf{Calculate Global Gradients}}; -% -\draw[myline,shorten <=1mm](OP.west)-|(CGG.80); -\draw[myline,-latex,shorten <=2mm](3CL2)-|(CGG.130); -% -\draw[myline,shorten <=1mm](DOP.west)-|(CGG.280); -\draw[myline,-latex,shorten <=2mm](D3CL2)-|(CGG.230); - \end{tikzpicture} -``` -::: - -This coordination introduces several key challenges that distributed training systems must address. A distributed training system must orchestrate multi-machine computation by splitting up the work, managing communication between machines, and maintaining synchronization throughout the training process. The AllReduce operations that aggregate gradients across devices consume 10--40% of total training time even with optimal implementation, and this overhead compounds as systems scale. - -These coordination requirements shape the four main approaches to distributed training, each addressing a different constraint regime. Data parallelism divides the training data across machines while each maintains a full model copy, making it the simplest approach and effective for models that fit in single-device memory. Model parallelism splits the model itself across devices when parameters exceed single-device memory, addressing the memory constraint that data parallelism cannot solve. Pipeline parallelism partitions models into sequential stages that process microbatches concurrently, improving utilization over naive model parallelism. Hybrid approaches integrate multiple strategies, enabling training at scales where any single approach would fail. The progression from data parallelism through model and pipeline parallelism to hybrid approaches mirrors the natural scaling path: each strategy becomes necessary only after its predecessor reaches a physical ceiling. - -## Data Parallelism {#sec-distributed-training-systems-systems-data-parallelism-6132} - -What is the simplest way to use eight GPUs to process a massive dataset? You give each GPU a complete, identical copy of the model, but only assign it one-eighth of the data. Data parallelism represents the most straightforward distributed approach and the natural starting point for understanding how distributed training works in practice. - -::: {.callout-definition title="Data Parallelism"} - -***Data Parallelism***\index{Data Parallelism!definition} is a distributed training strategy where the **Model is Replicated** across all workers, but the **Dataset is Sharded**. - -1. **Significance (Quantitative):** It maximizes **Throughput ($\eta$)** by allowing $N$ workers to process independent data samples simultaneously. It is mathematically equivalent to single-device training with a batch size $N\times$ larger. -2. **Distinction (Durable):** Unlike **Model Parallelism**, where the weights are split across devices, Data Parallelism requires every worker to have enough **Memory Capacity** to store the full model state. -3. **Common Pitfall:** A frequent misconception is that Data Parallelism scales infinitely. In reality, it is constrained by the **Communication Bottleneck**: the time to synchronize gradients (AllReduce) must be hidden by the time to compute them, or else the system becomes bandwidth-bound ($BW$). - -::: - -This method distributes the training process across multiple devices by splitting the dataset into smaller subsets. Each device trains a complete copy of the model using its assigned subset of the data. When training an image classification model on 1 million images using 4 GPUs, each GPU processes 250,000 images while maintaining an identical copy of the model architecture. - -Data parallelism is most effective when the dataset size is large but the model size remains manageable, since each device must store a full copy of the model in memory. This method is widely used in image classification and natural language processing, where the dataset can be processed in parallel without dependencies between data samples. When training a ResNet model [@he2016resnet] on ImageNet, each GPU can independently process its portion of images because the classification of one image does not depend on the results of another. - -The effectiveness of data parallelism stems from a property of stochastic gradient descent. Gradients computed on different minibatches can be averaged while preserving mathematical equivalence to single-device training. This property enables parallel computation across devices, with the mathematical foundation following directly from the linearity of expectation. - -Consider a model with parameters $θ$ training on a dataset $D$. The loss function for a single data point $x_i$ is $L(θ, x_i)$. In standard SGD with batch size $B$, the gradient update for a minibatch is: -$$ -g = \frac{1}{B} \sum_{i=1}^B \nabla_θ L(θ, x_i) -$$ - -In data parallelism with $N$ devices, each device $k$ computes gradients on its own minibatch $B_k$: -$$ -g_k = \frac{1}{|B_k|} \sum_{x_i \in B_k} \nabla_θ L(θ, x_i) -$$ - -The global update averages these local gradients: -$$ -g_{\text{global}} = \frac{1}{N} \sum_{k=1}^N g_k -$$ - -This averaging is mathematically equivalent to computing the gradient on the combined batch $B_{\text{total}} = \bigcup_{k=1}^N B_k$: -$$ -g_{\text{global}} = \frac{1}{|B_{\text{total}}|} \sum_{x_i \in B_{\text{total}}} \nabla_θ L(θ, x_i) -$$ - -This equivalence shows why data parallelism maintains the statistical properties of SGD training. The approach distributes distinct data subsets across devices, computes local gradients independently, and averages these gradients to approximate the full-batch gradient. - -The method parallels gradient accumulation, where a single device accumulates gradients over multiple forward passes before updating parameters. Both techniques use the additive properties of gradients to process large batches efficiently. However, *data parallelism at scale* introduces operational challenges beyond this theoretical equivalence. - -::: {.callout-note title="Data Parallelism at Scale"} - -Data parallelism in production environments involves several operational considerations beyond the theoretical framework: - -- **Communication efficiency**: AllReduce operations for gradient synchronization become the bottleneck at scale. Production systems use optimized libraries like NCCL with ring or tree communication patterns to minimize overhead -- **Fault tolerance**: Node failures during large-scale training require checkpoint/restart strategies. Production systems implement hierarchical checkpointing with both local and distributed storage -- **Dynamic scaling**: Cloud environments require elastic scaling capabilities to add/remove workers based on demand and cost constraints, complicated by the need to maintain gradient synchronization -- **Cost optimization**: Production data parallelism considers cost per GPU-hour across different instance types and preemptible instances, balancing training time against infrastructure costs -- **Network bandwidth requirements**: Large models require careful network topology planning as gradient communication can consume 10--50% of training time depending on model size and batch size - -Production teams typically benchmark communication patterns and scaling efficiency before deploying large distributed training jobs to identify optimal configurations. - -::: - -### Data Parallelism Implementation {#sec-distributed-training-systems-systems-data-parallelism-implementation-fa03} - -The mathematical foundation above—gradient averaging preserves the statistical properties of SGD—translates into concrete implementation steps. Each step corresponds to a phase in the gradient averaging process, from distributing data subsets to synchronizing the computed gradients. - -The process of data parallelism can be broken into a series of distinct steps, each with its role in ensuring the system operates efficiently. Consider @fig-dist-train-data-parallelism: it traces the complete workflow from dataset splitting through gradient aggregation, showing how each GPU processes its assigned batch before synchronization brings all gradients together for parameter updates. - -::: {#fig-dist-train-data-parallelism fig-env="figure" fig-pos="htb" fig-cap="**Data Parallelism Implementation Pipeline**. The five-stage workflow for data parallel training: (1) split input data into non-overlapping subsets, (2) assign batches to GPUs, (3) compute forward and backward passes independently, (4) synchronize gradients via AllReduce, and (5) update parameters uniformly across all devices. This approach contrasts with model parallelism, where the model itself is partitioned rather than replicated." fig-alt="Flowchart showing 5-stage data parallelism: Input Data splits into 4 batches assigned to GPUs 1-4, each performs forward and backward pass, gradients synchronize and aggregate, then model updates."} -```{.tikz} -\begin{tikzpicture}[font=\small\usefont{T1}{phv}{m}{n}] -\tikzset{Line/.style={line width=0.75pt,black!50,text=black -}, - Box/.style={inner xsep=2pt, - line width=0.75pt, - node distance=2.0, - fill=VioletL2, - draw=VioletLine2, - text width=27mm, - align=flush center, - minimum width=27mm, - minimum height=9mm - }, - Box2/.style={Box, - draw=BlueLine, - fill=BlueL, - text width=21mm, - minimum width=22mm, - minimum height=9mm - }, - Text/.style={inner xsep=6pt, - inner ysep=4pt, - draw=none, line width=0.75pt, - fill=TextColor!80, - font=\usefont{T1}{phv}{m}{n}\footnotesize, - align=flush center, - minimum width=22mm, minimum height=5mm - }, -} - -\node[Box,node distance=1](B1){GPU 1\\Forward \& Backward Pass}; -\node[Box,node distance=1.2,right=of B1](B2){GPU 2\\Forward \& Backward Pass}; -\node[Box,node distance=1.2,right=of B2](B3){GPU 3\\Forward \& Backward Pass}; -\node[Box,node distance=1.2,right=of B3](B4){GPU 4\\Forward \& Backward Pass}; -% -\node[Box2,above=1.06 of B1](GB1){Batch 1}; -\node[Box2,above=1.06 of B2](GB2){Batch 2}; -\node[Box2,above=1.06 of B3](GB3){Batch 3}; -\node[Box2,above=1.06 of B4](GB4){Batch 4}; -% -\node[Box2,above=1.8of $(GB2)!0.5!(GB3)$,fill=RedL,draw=RedLine](GGB1){Input Data}; -% -\node[Box,below=of $(B2)!0.5!(B3)$,fill=GreenL,draw=GreenLine](DB1){Gradients GPU N}; -\node[Box,below=1.05 of DB1,fill=GreenL,draw=GreenLine](DB2){Gradient Aggregation}; -\node[Box,below=1.05 of DB2,fill=GreenL,draw=GreenLine](DB3){Model Update}; -% -\draw[Line,-latex](GGB1)--++(270:1.4)-|(GB2); -\draw[Line,-latex](GGB1)--++(270:1.4)-|(GB3); -\draw[Line,-latex](GGB1)--++(270:1.4)-|(GB4); -\draw[Line,-latex](GGB1)--node[Text,pos=0.5,anchor=center]{Split into Non-Overlapping Subsets}++(270:1.4)-|(GB1); -%% -\draw[Line,-latex](GB1)--node[Text,pos=0.45]{Assigned to GPU 1}(B1); -\draw[Line,-latex](GB2)--node[Text,pos=0.45]{Assigned to GPU 2}(B2); -\draw[Line,-latex](GB3)--node[Text,pos=0.45]{Assigned to GPU 3}(B3); -\draw[Line,-latex](GB4)--node[Text,pos=0.45]{Assigned to GPU 4}(B4); -% -\draw[Line,-latex](B3)--++(270:0.9)-|(DB1); -\draw[Line,-latex](B2)--++(270:0.9)-|(DB1); -\draw[Line,-latex](B1)--++(270:0.9)-|(DB1); -\draw[Line,-latex](B4)--++(270:0.9)-|node[Text,pos=0.72,text=black]{Compute Gradients}(DB1); -% -\draw[Line,-latex](DB1)--node[Text,pos=0.45]{Synchronize Gradients}(DB2); -\draw[Line,-latex](DB2)--node[Text,pos=0.45]{Aggregate Gradients and Update Parameters}(DB3); -% -\draw[Line,-latex](GGB1.east)--++(0:6.8)|-node[Text,pos=0.8,text=black]{Next Mini-Batch}(DB3.east); -\end{tikzpicture} -``` -::: - -#### Dataset Splitting {#sec-distributed-training-systems-systems-dataset-splitting-1edf} - -The first step in data parallelism involves dividing the dataset into smaller, non-overlapping subsets. This ensures that each device processes a unique portion of the data, avoiding redundancy and enabling efficient utilization of available hardware. With a dataset of 100,000 training examples and 4 GPUs, each GPU receives 25,000 examples per epoch. The DistributedSampler must ensure no overlap between subsets to maintain gradient estimation validity: if two GPUs process the same example, the resulting gradient average would overweight that example, violating the unbiased gradient assumption that makes data parallelism mathematically equivalent to single-device training. - -Modern distributed training frameworks handle this distribution automatically through a distributed sampler that implements prefetching and caching mechanisms to ensure data is readily available for processing. The sampler coordinates across workers using the process rank to deterministically partition indices, ensuring reproducibility when the same random seed is used. For a 1.2 million example dataset like ImageNet distributed across 32 GPUs, each GPU processes exactly 37,500 examples per epoch, with the sampler padding the final batch to maintain consistent batch sizes across all workers. - -```{python} -#| label: gpt3-training-context -#| echo: false -# ┌───────────────────────────────────────────────────────────────────────────── -# │ GPT-3 TRAINING CONTEXT (LEGO) -# ├───────────────────────────────────────────────────────────────────────────── -# │ Context: @sec-distributed-training-systems-systems-compute-phase-forward-backward -# │ -# │ Goal: Provide GPT-3 scale statistics for distributed training discussion. -# │ Show: ~175B parameters. -# │ How: pulling GPT3_PARAMS from mlsys.constants. -# │ -# │ Imports: mlsys.constants (GPT3_PARAMS, param, BILLION) -# │ Exports: gpt3_params_b -# └───────────────────────────────────────────────────────────────────────────── -from mlsys.constants import GPT3_PARAMS, param, BILLION - -class Gpt3TrainingContext: - """GPT-3 scale reference for distributed training.""" - - # ┌── 1. LOAD (Constants) ────────────────────────────────────────────── - params = GPT3_PARAMS.m_as(param) - - # ┌── 2. EXECUTE (The Compute) ──────────────────────────────────────── - params_b = params / BILLION - - # ┌── 4. OUTPUT (Formatting) ────────────────────────────────────────────── - gpt3_params_b = f"{params_b:.0f}" - -# ┌── EXPORTS (Bridge to Text) ───────────────────────────────────────────────── -gpt3_params_b = Gpt3TrainingContext.gpt3_params_b -``` - -#### Compute Phase: Forward and Backward Passes {#sec-distributed-training-systems-systems-compute-phase-forward-backward} - -```{python} -#| label: gpt3-training-context -#| echo: false -# ┌───────────────────────────────────────────────────────────────────────────── -# │ GPT-3 TRAINING CONTEXT (LEGO) -# ├───────────────────────────────────────────────────────────────────────────── -# │ Context: @sec-distributed-training-systems-systems-compute-phase-forward-backward -# │ -# │ Goal: Provide GPT-3 scale statistics for distributed training discussion. -# │ Show: ~175B parameters. -# │ How: pulling GPT3_PARAMS from mlsys.constants. -# │ -# │ Imports: mlsys.constants (GPT3_PARAMS, param, BILLION) -# │ Exports: gpt3_params_b -# └───────────────────────────────────────────────────────────────────────────── -from mlsys.constants import GPT3_PARAMS, param, BILLION - -class Gpt3TrainingContext: - """GPT-3 scale reference for distributed training.""" - - # ┌── 1. LOAD (Constants) ────────────────────────────────────────────── - params = GPT3_PARAMS.m_as(param) - - # ┌── 2. EXECUTE (The Compute) ──────────────────────────────────────── - params_b = params / BILLION - - # ┌── 4. OUTPUT (Formatting) ────────────────────────────────────────────── - gpt3_params_b = f"{params_b:.0f}" - -# ┌── EXPORTS (Bridge to Text) ───────────────────────────────────────────────── -gpt3_params_b = Gpt3TrainingContext.gpt3_params_b -``` - -The defining feature of data parallelism is that the computation phase—both forward and backward—is **embarrassingly parallel**. Each GPU operates as an isolated island, executing an identical copy of the model on a unique micro-batch of data. For our `{python} gpt3_params_b`B parameter reference model, this isolation is critical: during the forward pass, each GPU independently computes activations for its local batch (micro-batch size 4, sequence length 2048). Without optimization, storing these activations for backpropagation would consume over 200 GB of HBM, exceeding the capacity of even an H100 GPU; techniques like **activation checkpointing**—recomputing activations during the backward pass rather than storing them—are mandatory to suppress this footprint to a manageable ~50 GB. - -The backward pass mirrors this independence but introduces the system's primary bottleneck. As the GPU traverses the computation graph in reverse, it computes gradients for every parameter in the model. For a `{python} gpt3_params_b`B model in FP16, this generates a 350 GB gradient payload per GPU. While the computation itself requires zero communication, the resulting gradients represent a fractured view of the true loss surface—valid only for the local micro-batch. Before the optimizer step can occur, these local gradients must be aggregated across all $N$ GPUs to form a valid global gradient. This transition—from isolated, high-throughput compute to a massive, global synchronization event—defines the rhythm of data parallel training: long periods of silent, intense arithmetic punctuated by bursts of heavy network traffic. - -#### Gradient Synchronization {#sec-distributed-training-systems-systems-gradient-synchronization-614b} - -To maintain consistency across the distributed system, the gradients computed by each device must be synchronized. This coordination represents a distributed systems challenge in achieving global consensus while minimizing communication complexity. - -::: {.callout-note title="Cross-Ref: Communication Algos"} -The specific algorithms for gradient synchronization—including **Ring AllReduce**, **Tree AllReduce**, and **Hierarchical AllReduce**—are analyzed in depth in @sec-collective-communication. That chapter derives the bandwidth and latency bounds ($2N(D-1)/D$), explains how topology-aware algorithms exploit NVLink vs. InfiniBand, and details the gradient compression techniques used to reduce this overhead. -::: - -When synchronization performance deviates from theoretical expectations, the Fleet Stack framework provides a structured approach to isolating the bottleneck. - -::: {.callout-perspective title="Debugging Slow Gradient Synchronization"} - -**Problem Statement**: Your AllReduce operation takes 100 ms when you expected 50 ms based on bandwidth calculations. Where do you look? - -The Fleet Stack framework provides a systematic debugging methodology by examining each layer: - -**Infrastructure Layer**: - -- **Topology**: 128 nodes, 8 GPUs per node -- **Intra-node**: NVLink at `{python} nvlink_a100` GB/s bidirectional between GPUs -- **Inter-node**: InfiniBand HDR at 200 Gb/s (25 GB/s) per port -- **Observation**: Your 3 GB gradient tensor should take $3 \text{ GB} / 25 \text{ GB/s} = 120\text{ms}$ for a naive transfer, but ring AllReduce should achieve $2(N-1)/N \times 3 \text{ GB} / 25 \text{ GB/s} \approx 240\text{ms}$ across 128 nodes - -**Distribution Layer**: - -- **Algorithm choice**: Single ring AllReduce across all 1024 GPUs -- **Expected behavior**: Ring touches every GPU in sequence, dominated by the slowest link (inter-node InfiniBand) -- **Diagnosis**: A single global ring *fails to exploit NVLink* within nodes - -**Serving Layer (Measurement)**: - -- **Observed latency**: 100 ms (better than naive calculation predicts) -- **Bandwidth utilization**: Only 60% of theoretical InfiniBand throughput -- **Network counters**: Show congestion on specific switch uplinks - -**Root Cause Diagnosis**: The mismatch between Infrastructure and Distribution layers reveals the problem. Your single ring correctly identifies InfiniBand as the bottleneck, but the observed 100 ms (rather than 240 ms) suggests NCCL is already using a hierarchical algorithm internally. The *remaining* gap comes from switch congestion, not algorithm choice. - -**Solution**: Monitor InfiniBand switch port utilization to identify hot spots. Consider rail-optimized topology (@sec-compute-infrastructure) or hierarchical AllReduce that explicitly partitions intra-node (NVLink) from inter-node (InfiniBand) communication. The 40 ms remaining gap likely represents achievable optimization through better network provisioning rather than algorithm changes. - -This analysis demonstrates how the Fleet Stack layers interact: Physical constraints (bandwidth) bound Operational choices (algorithm), which manifest in Service metrics (latency). Debugging requires examining all three layers, not just tuning one in isolation. - -::: - -@fig-coll-comm contrasts three high-level synchronization topologies: the centralized **Parameter Server**, the bandwidth-optimal **Ring AllReduce**, and the low-latency **Tree AllReduce**. While Parameter Servers were common in early distributed systems, modern synchronous training relies almost exclusively on AllReduce variants to maximize bandwidth utilization across dense GPU clusters. - -::: {#fig-coll-comm fig-env="figure" fig-pos="htb" fig-cap="**Gradient Synchronization Topologies**. Visual comparison of communication patterns. (A) **Parameter Server** uses a central node, creating bandwidth bottlenecks at the server. (B) **Ring AllReduce** distributes bandwidth evenly across all links but has linear latency scaling. (C) **Tree AllReduce** reduces latency to logarithmic time but may congest links near the root." fig-alt="Three communication topology diagrams. A: Parameter Server with central PS node connected to 4 workers creating bottleneck. B: Ring AllReduce with 4 GPUs in circular topology. C: Tree AllReduce with root R and hierarchical structure."} -```{.tikz} -\begin{tikzpicture}[font=\small\usefont{T1}{phv}{m}{n}, scale=0.85, transform shape] - \definecolor{NodeColor}{RGB}{230,230,230} - \definecolor{PSColor}{RGB}{176,196,222} % LightSteelBlue - \definecolor{ArrowColor}{RGB}{80,80,80} - - \tikzset{ - node_style/.style={circle, draw=black!50, fill=NodeColor, minimum size=0.9cm, font=\bfseries}, - ps_style/.style={circle, draw=black!50, fill=PSColor, minimum size=1.1cm, font=\bfseries}, - arrow_style/.style={->, >=stealth, thick, color=ArrowColor} - } - - % A. Parameter Server - \begin{scope} - \node[anchor=south] at (0, 3.5) {\textbf{A. Parameter Server}}; - \node[ps_style] (ps) at (0, 1.8) {PS}; - - % Workers arranged in semi-circle below - \node[node_style] (w1) at (-2, -0.5) {W1}; - \node[node_style] (w2) at (-0.7, -0.5) {W2}; - \node[node_style] (w3) at (0.7, -0.5) {W3}; - \node[node_style] (w4) at (2, -0.5) {W4}; - - \draw[arrow_style, <->] (ps) -- (w1); - \draw[arrow_style, <->] (ps) -- (w2); - \draw[arrow_style, <->] (ps) -- (w3); - \draw[arrow_style, <->] (ps) -- (w4); - - \node[align=center, font=\footnotesize, below=0.8cm of ps] {Bottleneck at Central Node}; - \end{scope} - - % B. Ring AllReduce - \begin{scope}[shift={(6,0)}] - \node[anchor=south] at (0, 3.5) {\textbf{B. Ring AllReduce}}; - \foreach \i/\angle in {1/90, 2/0, 3/270, 4/180} { - \node[node_style] (r\i) at (\angle:1.6) {G\i}; - } - \draw[arrow_style] (r1) to[bend left=20] (r2); - \draw[arrow_style] (r2) to[bend left=20] (r3); - \draw[arrow_style] (r3) to[bend left=20] (r4); - \draw[arrow_style] (r4) to[bend left=20] (r1); - - \node[align=center, font=\footnotesize] at (0,0) {Bandwidth\\Optimal}; - \end{scope} - - % C. Tree AllReduce - \begin{scope}[shift={(12,0)}] - \node[anchor=south] at (0, 3.5) {\textbf{C. Tree AllReduce}}; - \node[node_style] (root) at (0, 2.5) {R}; - \node[node_style] (l1) at (-1.5, 1) {L1}; - \node[node_style] (l2) at (1.5, 1) {L2}; - \node[node_style] (ll1) at (-2.2, -0.5) {1}; - \node[node_style] (ll2) at (-0.8, -0.5) {2}; - \node[node_style] (lr1) at (0.8, -0.5) {3}; - \node[node_style] (lr2) at (2.2, -0.5) {4}; - - \draw[arrow_style] (root) -- (l1); - \draw[arrow_style] (root) -- (l2); - \draw[arrow_style] (l1) -- (ll1); - \draw[arrow_style] (l1) -- (ll2); - \draw[arrow_style] (l2) -- (lr1); - \draw[arrow_style] (l2) -- (lr2); - - \node[align=center, font=\footnotesize] at (3.5, 1) {Low Latency\\$O(\log N)$}; - \end{scope} - -\end{tikzpicture} -``` -::: - -#### Synchronization Models {#sec-distributed-training-systems-systems-synchronization-models-396f} - -Distributed training systems operate under explicit synchronization models that govern when workers observe each other's updates. The choice of model determines whether the system guarantees mathematical equivalence to single-device training or trades consistency for throughput. - -The default model, Bulk Synchronous Parallel (BSP)[^fn-bsp] [@valiant1990bsp], requires all workers to complete their local computation in forward and backward passes, synchronize gradients through a barrier with AllReduce, and then simultaneously update parameters. - -[^fn-bsp]: **Bulk Synchronous Parallel (BSP)**: Introduced by Leslie Valiant in 1990 as a "bridging model" between hardware and software for parallel computation. BSP divides work into supersteps -- compute, communicate, barrier -- guaranteeing mathematical equivalence to sequential execution. The cost: iteration time equals the slowest worker's time, and at 1,000 GPUs with 1% straggler probability per device, roughly 10 GPUs straggle every step, making the barrier increasingly expensive. \index{BSP!synchronization} - -BSP provides strong guarantees where every worker sees identical parameter values at each step, ensuring mathematical equivalence to single-device training. The cost is that the slowest worker determines iteration time, creating the straggler problem. - -Stale Synchronous Parallel (SSP) relaxes this constraint by allowing workers to proceed up to $s$ iterations ahead of the slowest worker before blocking. This bounds staleness while reducing synchronization delays. SSP requires careful learning rate tuning since workers compute gradients on slightly different parameter versions. The bounded staleness guarantee with $s$ typically set to 2-5 provides a middle ground between BSP's strong consistency and fully asynchronous approaches. - -Asynchronous SGD eliminates synchronization barriers entirely as workers update parameters independently. This maximizes hardware utilization but introduces gradient staleness that can degrade convergence. When a worker computes gradients on parameters that are already $\tau$ steps stale, the effective learning rate decreases. Compensation techniques include learning rate scaling with $\eta' = \eta / \sqrt{\tau}$ or momentum correction. - -The key trade-offs across synchronization models are summarized here. - -::: {.callout-note title="Synchronization Model Trade-offs"} -| **Model** | **Consistency** | **Throughput** | **Convergence** | **Use Case** | -|:----------|:------------------|:--------------------------|:--------------------------------|:-------------------------------------| -| **BSP** | Strong | Bounded by slowest worker | Equivalent to single-GPU | Final training runs, reproducibility | -| **SSP** | Bounded staleness | Higher than BSP | Near-equivalent with tuning | Hyperparameter search | -| **Async** | Weak | Maximum | Degraded, requires compensation | Large heterogeneous clusters | -::: - -The choice of synchronization model directly affects both system throughput and model convergence. Production systems typically use BSP for final training runs to ensure reproducibility, while exploring SSP or async approaches during hyperparameter search where exact reproducibility is less critical. - -#### Barrier Semantics and Failure Modes {#sec-distributed-training-systems-systems-barrier-semantics-failure-modes-5c94} - -AllReduce operations implement implicit barriers where no worker can proceed until all workers have contributed their gradients. This coupling creates failure modes absent from single-device training. - -Worker failures during AllReduce cause all other workers to block indefinitely while waiting for the missing contribution. Without timeout mechanisms, the entire training job hangs rather than failing cleanly. Production systems implement watchdog timers typically set to 5--10 minutes to detect and terminate stuck jobs. - -Gradient mismatches occur when workers disagree on which tensors to synchronize due to conditional computation paths or dynamic batching. AllReduce operations may block waiting for tensors that some workers never send. This commonly occurs with variable-length sequences in NLP models, dynamic computation graphs, and mixture-of-experts with different routing decisions. - -Straggler-induced delays arise because iteration time equals the slowest worker's time plus synchronization overhead. A single slow worker, whether due to thermal throttling, network congestion, or OS jitter, delays all workers and reduces cluster utilization. At 1000 GPUs with 1% probability of straggler per GPU per step, approximately 10 GPUs straggle every iteration. - -Production systems address these issues through timeouts, heartbeat monitoring, and elastic training mechanisms. @sec-fault-tolerance-reliability provides comprehensive coverage of failure detection, checkpointing strategies, and recovery mechanisms that enable training jobs to complete despite inevitable hardware failures. - -#### Parameter Updating {#sec-distributed-training-systems-systems-parameter-updating-bb64} - -After gradient aggregation, each device independently updates model parameters using the chosen optimization algorithm such as SGD with momentum or Adam. This decentralized update strategy enables efficient parameter updates without requiring a central coordination server. Since all devices have identical gradient values after synchronization, they perform mathematically equivalent updates to maintain model consistency across the distributed system. - -In a system with 8 GPUs training a ResNet model, each GPU computes local gradients based on its data subset. After gradient averaging via ring all-reduce, every GPU has the same global gradient values. Each device then independently applies these gradients using the optimizer's update rule. With SGD and learning rate 0.1, the update becomes `weights = weights - 0.1 * gradients`. This process maintains mathematical equivalence to single-device training while enabling distributed computation. - -This process, which involves splitting data, performing computations, synchronizing results, and updating parameters, repeats for each batch of data. Modern frameworks automate this cycle, allowing developers to focus on model architecture and hyperparameter tuning rather than distributed computing logistics. - -### Trade-offs: The Communication Wall {#sec-distributed-training-systems-systems-data-parallelism-tradeoffs} - -Data parallelism is the default strategy for a reason: it scales **throughput** linearly with device count, provided the model fits in memory and communication is not the bottleneck. However, it hits a hard ceiling defined by the **Communication-Computation Ratio**. - -Data parallelism offers three principal advantages. First, throughput scales linearly for compute-bound models: scaling ResNet-50 on ImageNet from 1 to 256 GPUs yields near-linear speedup because the gradient exchange is small relative to the compute time. Second, the model architecture remains unchanged; the framework wraps the model in a data-parallel container that intercepts backward-pass hooks to trigger gradient synchronization automatically. Third, utilization remains high because, unlike model parallelism, there are no pipeline bubbles — all GPUs work on the forward and backward pass simultaneously. - -These advantages, however, encounter three hard ceilings. The memory wall requires every GPU to hold a full copy of the model parameters, gradients, and optimizer states; for a `{python} gpt3_params_b`B parameter model, this demands more than 1 TB of memory per GPU, which is physically impossible on current hardware without ZeRO sharding. The bandwidth wall emerges as $N$ grows: the AllReduce cost $2(N-1)/N \times M/B$ eventually dominates, and for large language models gradient synchronization can consume more than 50% of the step time, collapsing efficiency. The batch size trap compounds the problem: scaling to thousands of GPUs requires increasing the global batch size ($B_{global} = N \times B_{local}$), and eventually the **Critical Batch Size** is reached, where adding more data per step yields diminishing returns in convergence. - -A concrete scaling experiment reveals how these ceilings manifest in practice. - -::: {.callout-notebook title="GPT-2 Data Parallel Scaling" collapse="true"} - -This example demonstrates how data parallelism scales in practice, including efficiency degradation. - -**Single GPU Baseline** - -- Batch size: 16 (with gradient checkpointing, fits in 32 GB) -- Time per step: 1.8 seconds -- Training throughput: ~9 samples/second -- Time to 50K steps: **25 hours** - -**8 GPUs: Single Node with NVLink** - -Configuration: - -- Per-GPU batch: 16, global batch: 128 -- Gradient synchronization: `{python} gpt2_sync_size_gb_str`GB @ `{python} nvlink_h100_gbs` GB/s (NVLink) $\approx$ `{python} comm_8gpu_ms_str`ms - -Performance results: - -- Computation: `{python} compute_8gpu_ms_str`ms per step -- Communication: `{python} comm_8gpu_ms_str`ms per step -```{python} -#| label: scaling-8gpu-calc -#| echo: false -# ┌───────────────────────────────────────────────────────────────────────────── -# │ 8-GPU SCALING ANALYSIS (LEGO) -# ├───────────────────────────────────────────────────────────────────────────── -# │ Context: @sec-distributed-training-systems-systems-data-parallelism-tradeoffs -# │ -# │ Goal: Quantify intra-node scaling efficiency using NVLink. -# │ Show: >95% efficiency despite communication overhead. -# │ How: Speedup = T_base / (T_compute + T_comm). -# │ -# │ Imports: mlsys.constants (NVLINK_H100_BW, MILLION, BILLION, GB, byte, ...) -# │ Exports: total_8gpu_str, speedup_8gpu_str, efficiency_8gpu_str, -# │ training_8gpu_str, gpt2_sync_size_gb_str, nvlink_h100_gbs, -# │ comm_8gpu_ms_str, compute_8gpu_ms_str -# └───────────────────────────────────────────────────────────────────────────── -from mlsys.constants import ( - NVLINK_H100_BW, INFINIBAND_HDR_BW, MILLION, BILLION, GB, byte, second, - GPUS_PER_HOST, BITS_PER_BYTE -) -from mlsys.formatting import fmt, check - -# ┌── LEGO ─────────────────────────────────────────────── -class Scaling8GPU: - """Scenario: 8-GPU scaling within a single node.""" - # ┌── 1. LOAD (Constants) ─────────────────────────────────────────────── - single_gpu_step_s = 1.8 - compute_8gpu_ms = 180 - params_b = 1.5 # GPT-2 - nvlink_bw = NVLINK_H100_BW.m_as(GB/second) - gpus_per_node = GPUS_PER_HOST - - # ┌── 2. EXECUTE (The Compute) ───────────────────────────────────────── - # Ring All-Reduce: 2*(N-1)/N * Params * 2 bytes (FP16) - # For large N, ~2 * Params * 2 = 4 * Params. - sync_size_gb = (params_b * BILLION * 4) / BILLION - comm_8gpu_ms_val = (sync_size_gb / nvlink_bw) * 1000 - - total_8gpu_ms_val = compute_8gpu_ms + comm_8gpu_ms_val - speedup_8gpu_val = single_gpu_step_s / (total_8gpu_ms_val / 1000) - efficiency_8gpu_val = speedup_8gpu_val / gpus_per_node * 100 - training_hours_1gpu = 25 - training_hours_8gpu_val = training_hours_1gpu / speedup_8gpu_val - - # ┌── 3. GUARD (Invariants) ─────────────────────────────────────────── - check(efficiency_8gpu_val > 90, f"Intra-node efficiency ({efficiency_8gpu_val:.1f}%) should be high") - - # ┌── 4. OUTPUT (Formatting) ────────────────────────────────────────────── - total_8gpu_str = f"{total_8gpu_ms_val:.0f}" - single_gpu_step_s_str = f"{single_gpu_step_s}" - total_8gpu_s_str = f"{total_8gpu_ms_val/1000:.3f}" - speedup_8gpu_str = f"{speedup_8gpu_val:.1f}" - efficiency_8gpu_str = f"{efficiency_8gpu_val:.0f}" - training_8gpu_str = f"{training_hours_8gpu_val:.1f}" - gpt2_sync_size_gb_str = f"{sync_size_gb:.1f}" - nvlink_h100_gbs = f"{nvlink_bw:.0f}" - comm_8gpu_ms_str = f"{comm_8gpu_ms_val:.1f}" - compute_8gpu_ms_str = f"{compute_8gpu_ms}" - -# ┌── EXPORTS (Bridge to Text) ───────────────────────────────────────────────── -total_8gpu_str = Scaling8GPU.total_8gpu_str -single_gpu_step_s_str = Scaling8GPU.single_gpu_step_s_str -total_8gpu_s_str = Scaling8GPU.total_8gpu_s_str -speedup_8gpu_str = Scaling8GPU.speedup_8gpu_str -efficiency_8gpu_str = Scaling8GPU.efficiency_8gpu_str -training_8gpu_str = Scaling8GPU.training_8gpu_str -gpt2_sync_size_gb_str = Scaling8GPU.gpt2_sync_size_gb_str -nvlink_h100_gbs = Scaling8GPU.nvlink_h100_gbs -comm_8gpu_ms_str = Scaling8GPU.comm_8gpu_ms_str -compute_8gpu_ms_str = Scaling8GPU.compute_8gpu_ms_str -``` - -- Total: `{python} total_8gpu_str`ms per step -- Speedup: `{python} single_gpu_step_s_str`s ÷ `{python} total_8gpu_s_str`s = `{python} speedup_8gpu_str`$\times$ -- Parallel efficiency: `{python} speedup_8gpu_str` ÷ 8 = `{python} efficiency_8gpu_str`% - -Training time: 25 hours ÷ `{python} speedup_8gpu_str` = **`{python} training_8gpu_str` hours** - -**32 GPUs: 4 Nodes with InfiniBand** - -Configuration: - -- Per-GPU batch: 16, global batch: 512 -- Intra-node communication: `{python} comm_8gpu_ms_str`ms (NVLink) -- Inter-node communication: `{python} gpt2_sync_size_gb_str`GB @ `{python} ib_hdr_gbs` GB/s (InfiniBand HDR) $\approx$ `{python} inter_node_ms_str`ms - -Performance results: - -```{python} -#| label: scaling-32gpu-calc -#| echo: false -# ┌───────────────────────────────────────────────────────────────────────────── -# │ 32-GPU SCALING ANALYSIS (LEGO) -# ├───────────────────────────────────────────────────────────────────────────── -# │ Context: Inter-node scaling using InfiniBand. -# │ -# │ Goal: Show how inter-node communication collapses scaling efficiency. -# │ Show: ~30% efficiency when communication overhead dominates. -# │ How: Comm_pct = T_comm / (T_comp + T_comm). -# │ -# │ Imports: mlsys.constants (INFINIBAND_HDR_BW, GB, second, ...) -# │ Exports: compute_pct_str, comm_pct_str, total_32gpu_str, speedup_32gpu_str, -# │ efficiency_32gpu_str, training_32gpu_str, inter_node_ms_str, ib_hdr_gbs -# └───────────────────────────────────────────────────────────────────────────── -from mlsys.constants import INFINIBAND_HDR_BW, GB, second, BILLION - -# ┌── LEGO ─────────────────────────────────────────────── -class Scaling32GPU: - """Scenario: 32-GPU scaling across 4 nodes.""" - # ┌── 1. LOAD (Constants) ─────────────────────────────────────────────── - single_gpu_step_s = 1.8 - compute_32gpu_ms = 180 - params_b = 1.5 # GPT-2 - ib_bw = INFINIBAND_HDR_BW.m_as(GB/second) - intra_node_ms = Scaling8GPU.comm_8gpu_ms_val - training_hours_1gpu = 25 - - # ┌── 2. EXECUTE (The Compute) ───────────────────────────────────────── - sync_size_gb = (params_b * BILLION * 4) / BILLION - inter_node_ms_val = (sync_size_gb / ib_bw) * 1000 - - comm_32gpu_ms_val = inter_node_ms_val + intra_node_ms - total_32gpu_ms_val = compute_32gpu_ms + comm_32gpu_ms_val - - compute_pct_val = compute_32gpu_ms / total_32gpu_ms_val * 100 - comm_pct_val = comm_32gpu_ms_val / total_32gpu_ms_val * 100 - - speedup_32gpu_val = single_gpu_step_s / (total_32gpu_ms_val / 1000) - efficiency_32gpu_val = speedup_32gpu_val / 32 * 100 - training_32gpu_hours_val = training_hours_1gpu / speedup_32gpu_val - - # ┌── 3. GUARD (Invariants) ─────────────────────────────────────────── - check(comm_pct_val > 50, f"Communication should dominate ({comm_pct_val:.1f}%) 32-GPU scenario") - - # ┌── 4. OUTPUT (Formatting) ────────────────────────────────────────────── - compute_pct_str = f"{compute_pct_val:.0f}" - comm_pct_str = f"{comm_pct_val:.0f}" - total_32gpu_str = f"{total_32gpu_ms_val:.0f}" - total_32gpu_s_str = f"{total_32gpu_ms_val/1000:.3f}" - speedup_32gpu_str = f"{speedup_32gpu_val:.1f}" - efficiency_32gpu_str = f"{efficiency_32gpu_val:.0f}" - training_32gpu_str = f"{training_32gpu_hours_val:.1f}" - inter_node_ms_str = f"{inter_node_ms_val:.0f}" - ib_hdr_gbs = f"{ib_bw:.0f}" - -# ┌── EXPORTS (Bridge to Text) ───────────────────────────────────────────────── -compute_pct_str = Scaling32GPU.compute_pct_str -comm_pct_str = Scaling32GPU.comm_pct_str -total_32gpu_str = Scaling32GPU.total_32gpu_str -total_32gpu_s_str = Scaling32GPU.total_32gpu_s_str -speedup_32gpu_str = Scaling32GPU.speedup_32gpu_str -efficiency_32gpu_str = Scaling32GPU.efficiency_32gpu_str -training_32gpu_str = Scaling32GPU.training_32gpu_str -inter_node_ms_str = Scaling32GPU.inter_node_ms_str -ib_hdr_gbs = Scaling32GPU.ib_hdr_gbs -``` - speedup_32gpu_str = f"{speedup_32gpu_val:.1f}" - efficiency_32gpu_str = f"{efficiency_32gpu_val:.0f}" - training_32gpu_str = f"{training_32gpu_hours_val:.1f}" - compute_32gpu_ms_str = f"{compute_32gpu_ms}" - comm_32gpu_ms_str = f"{comm_32gpu_ms_val}" - -# ┌── EXPORTS (Bridge to Text) ───────────────────────────────────────────────── -compute_pct_str = Scaling32GPU.compute_pct_str -comm_pct_str = Scaling32GPU.comm_pct_str -total_32gpu_str = Scaling32GPU.total_32gpu_str -total_32gpu_s_str = Scaling32GPU.total_32gpu_s_str -speedup_32gpu_str = Scaling32GPU.speedup_32gpu_str -efficiency_32gpu_str = Scaling32GPU.efficiency_32gpu_str -training_32gpu_str = Scaling32GPU.training_32gpu_str -compute_32gpu_ms = Scaling32GPU.compute_32gpu_ms_str -comm_32gpu_ms = Scaling32GPU.comm_32gpu_ms_str -``` - -- Computation: `{python} compute_32gpu_ms`ms (`{python} compute_pct_str`% of time) -- Communication: `{python} comm_32gpu_ms`ms (`{python} comm_pct_str`% of time) -- Total: `{python} total_32gpu_str`ms per step -- Speedup: `{python} single_gpu_step_s_str`s ÷ `{python} total_32gpu_s_str`s = `{python} speedup_32gpu_str`$\times$ faster → `{python} training_32gpu_str` hours -- Parallel efficiency: `{python} speedup_32gpu_str` ÷ 32 = `{python} efficiency_32gpu_str`% - -Communication dominates and becomes the bottleneck. - -Better Approach: 8 GPUs with Gradient Accumulation - - -```{python} -#| label: grad-accum-calc -#| echo: false -# ┌───────────────────────────────────────────────────────────────────────────── -# │ GRADIENT ACCUMULATION ALTERNATIVE (LEGO) -# ├───────────────────────────────────────────────────────────────────────────── -# │ Context: Overcoming the Communication Wall. -# │ -# │ Goal: Show how gradient accumulation recovers efficiency. -# │ Show: Drastic reduction in communication overhead and cost. -# │ How: effective_batch = N * b * steps; overhead = T_comm / (steps * T_comp). -# └───────────────────────────────────────────────────────────────────────────── - -# ┌── LEGO ─────────────────────────────────────────────── -class GradAccumScenario: - """Scenario: 8 GPUs with 4-step gradient accumulation.""" - # ┌── 1. LOAD (Constants) ─────────────────────────────────────────────── - n_gpus_ga = 8 - batch_per_gpu = 16 - accum_steps = 4 - compute_8gpu_ms = 180 - comm_8gpu_ms = 5 - ga_training_hours = 3.8 - cost_per_hour_8gpu = 128 - cost_32gpu_reference = 3021 # from earlier calc - - # ┌── 2. EXECUTE (The Compute) ───────────────────────────────────────── - effective_batch_val = n_gpus_ga * batch_per_gpu * accum_steps - comm_overhead_pct_val = comm_8gpu_ms / (accum_steps * compute_8gpu_ms) * 100 - ga_cost_val = cost_per_hour_8gpu * ga_training_hours - ga_savings_val = cost_32gpu_reference - ga_cost_val - ga_savings_pct_val = ga_savings_val / cost_32gpu_reference * 100 - - # ┌── 3. GUARD (Invariants) ─────────────────────────────────────────── - check(comm_overhead_pct_val < 1.0, "Accumulation should minimize overhead to <1%") - - # ┌── 4. OUTPUT (Formatting) ────────────────────────────────────────────── - effective_batch_str = f"{effective_batch_val}" - comm_overhead_pct_str = f"{comm_overhead_pct_val:.1f}" - ga_training_hours_str = f"{ga_training_hours}" - ga_cost_str = f"{ga_cost_val:.0f}" - ga_savings_str = f"{ga_savings_val:,.0f}" - ga_savings_pct_str = f"{ga_savings_pct_val:.0f}" - -# ┌── EXPORTS (Bridge to Text) ───────────────────────────────────────────────── -n_gpus_ga = GradAccumScenario.n_gpus_ga -batch_per_gpu = GradAccumScenario.batch_per_gpu -accum_steps = GradAccumScenario.accum_steps -effective_batch_str = GradAccumScenario.effective_batch_str -comm_overhead_pct_str = GradAccumScenario.comm_overhead_pct_str -ga_training_hours = GradAccumScenario.ga_training_hours_str -ga_cost_str = GradAccumScenario.ga_cost_str -ga_savings_str = GradAccumScenario.ga_savings_str -ga_savings_pct_str = GradAccumScenario.ga_savings_pct_str -comm_8gpu_ms = GradAccumScenario.comm_8gpu_ms -compute_8gpu_ms = GradAccumScenario.compute_8gpu_ms -``` - -- Configuration: `{python} n_gpus_ga` GPUs$\times$ batch `{python} batch_per_gpu`$\times$ `{python} accum_steps` accumulation steps = `{python} effective_batch_str` effective batch -- Communication overhead: `{python} comm_8gpu_ms`ms ÷ (`{python} accum_steps`$\times$ `{python} compute_8gpu_ms`ms) = `{python} comm_overhead_pct_str`% -- Training time: `{python} ga_training_hours` hours -- Cost: USD 128/hour$\times$ `{python} ga_training_hours` hours = USD `{python} ga_cost_str` vs. USD 3,021 for 32 GPUs -- Savings: USD `{python} ga_savings_str` (`{python} ga_savings_pct_str`% reduction) with only 1 hour longer training - -Key Insights - - -1. NVLink enables efficient scaling within single nodes (97% efficiency) -2. Inter-node communication kills efficiency (drops to 13%) -3. Gradient accumulation beats naive scaling for memory-bound models -4. Sweet spot for GPT-2: 8 GPUs per node with gradient accumulation, not naive scaling to 32+ GPUs - -OpenAI's GPT-2 paper reports training on 32 V100s across 4 nodes using optimized communication (likely gradient accumulation combined with pipeline parallelism), not pure data parallelism. - -::: - -### Memory-Efficient Data Parallelism: ZeRO and FSDP {#sec-distributed-training-systems-systems-memoryefficient-data-parallelism-zero-fsdp-0e69} - -The memory constraints of data parallelism motivate a family of techniques that shard memory state across workers while preserving the simplicity of data parallel training. ZeRO (Zero Redundancy Optimizer)[^fn-zero] [@rajbhandari2020zero] and its PyTorch implementation FSDP (Fully Sharded Data Parallel) [@zhao2023fsdp] enable training models that would otherwise require model parallelism. - -[^fn-zero]: **ZeRO (Zero Redundancy Optimizer)**: Published by Microsoft Research in 2019, ZeRO partitions optimizer states, gradients, and optionally parameters across workers instead of replicating them. At ZeRO Stage 3 with 64 GPUs, per-device memory drops from 16 bytes/parameter (full replication) to 0.25 bytes/parameter, converting a 112 GB memory footprint into 1.75 GB. The trade-off: FSDP (PyTorch's ZeRO-3 implementation) adds AllGather and ReduceScatter on every forward and backward layer, introducing 10--25% communication overhead that only pays off when memory pressure justifies it. \index{ZeRO!memory optimization} - -To understand the scale of memory savings ZeRO provides, consider the concrete memory budget for a modern large language model. - -::: {.callout-notebook title="ZeRO Memory Savings" collapse="true"} -**Scenario**: Training a 7B parameter Llama 2 model using Mixed Precision (FP16). - -**Baseline: Standard DDP (Replicated State)** -Per-Parameter Memory Cost: - -- **Weights (FP16)**: 2 bytes -- **Gradients (FP16)**: 2 bytes -- **Optimizer State (FP32)**: 12 bytes (4 master weight + 4 momentum + 4 variance) -- **Total**: 16 bytes/parameter - -Total Memory for 7B Model: -$$ M_{\text{total}} = 7 \times 10^9 \times 16 \text{ bytes} \approx \mathbf{112 \text{ GB}} $$ -*Result*: **OOM** on A100-`{python} a100_mem`GB. - -**Optimization: ZeRO-3 (Fully Sharded)** -With $N=64$ GPUs, state is partitioned: - -- **Weights**: $2/64$ bytes -- **Gradients**: $2/64$ bytes -- **Optimizer**: $12/64$ bytes -- **Total**: $16/64 = 0.25$ bytes/parameter effective storage! - -Per-GPU Memory: -$$ M_{\text{ZeRO3}} = \frac{112 \text{ GB}}{64} \approx \mathbf{1.75 \text{ GB}} $$ -*Result*: Fits easily, leaving ~78 GB for activations (batch size). -::: - -ZeRO addresses this redundancy through progressive sharding: - -::: {.callout-note title="Figure: ZeRO Memory Partitioning" collapse="false"} -```{.tikz} -\begin{tikzpicture}[font=\small\usefont{T1}{phv}{m}{n}, xscale=1.5] - \definecolor{ParamColor}{RGB}{200,220,255} - \definecolor{GradColor}{RGB}{255,220,200} - \definecolor{OptColor}{RGB}{220,255,200} - - \tikzset{ - bar/.style={draw=black!70, thick, minimum width=1.2cm}, - label/.style={font=\scriptsize, text=black!80} - } - - % DDP (Replicated) - \node[anchor=south] at (0, 3.2) {DDP}; - \draw[fill=ParamColor] (-0.4, 0) rectangle (0.4, 0.5) node[midway, label] {P}; - \draw[fill=GradColor] (-0.4, 0.5) rectangle (0.4, 1.0) node[midway, label] {G}; - \draw[fill=OptColor] (-0.4, 1.0) rectangle (0.4, 3.0) node[midway, label] {OS}; - \node[below, font=\tiny] at (0,0) {Replicated}; - - % ZeRO-1 - \node[anchor=south] at (1, 3.2) {ZeRO-1}; - \draw[fill=ParamColor] (0.6, 0) rectangle (1.4, 0.5) node[midway, label] {P}; - \draw[fill=GradColor] (0.6, 0.5) rectangle (1.4, 1.0) node[midway, label] {G}; - \draw[fill=OptColor] (0.6, 1.0) rectangle (1.4, 1.25) node[midway, label] {OS/N}; - \node[below, font=\tiny] at (1,0) {Shard OS}; - - % ZeRO-2 - \node[anchor=south] at (2, 3.2) {ZeRO-2}; - \draw[fill=ParamColor] (1.6, 0) rectangle (2.4, 0.5) node[midway, label] {P}; - \draw[fill=GradColor] (1.6, 0.5) rectangle (2.4, 0.6) node[midway, label] {G/N}; - \draw[fill=OptColor] (1.6, 0.6) rectangle (2.4, 0.85) node[midway, label] {OS/N}; - \node[below, font=\tiny] at (2,0) {+Shard G}; - - % ZeRO-3 - \node[anchor=south] at (3, 3.2) {ZeRO-3}; - \draw[fill=ParamColor] (2.6, 0) rectangle (3.4, 0.1) node[midway, label] {}; - \draw[fill=GradColor] (2.6, 0.1) rectangle (3.4, 0.2) node[midway, label] {}; - \draw[fill=OptColor] (2.6, 0.2) rectangle (3.4, 0.45) node[midway, label] {}; - \node[anchor=west, font=\tiny] at (2.6, 0.2) {All/N}; - \node[below, font=\tiny] at (3,0) {+Shard P}; - - \node[anchor=north west, font=\tiny, text=gray] at (3.5, 3) {P: Parameters}; - \node[anchor=north west, font=\tiny, text=gray] at (3.5, 2.7) {G: Gradients}; - \node[anchor=north west, font=\tiny, text=gray] at (3.5, 2.4) {OS: Optimizer States}; - -\end{tikzpicture} -``` -**ZeRO Memory Reduction**. Standard Data Parallelism (DDP) replicates all model states across every GPU. ZeRO progressively partitions these states: ZeRO-1 shards optimizer states, ZeRO-2 adds gradient sharding, and ZeRO-3 shards the parameters themselves. ZeRO-3 achieves linear memory scaling, enabling models with 100B+ parameters to fit on commodity hardware. -::: - -| **Stage** | **What is Sharded** | **Memory Reduction** | **Communication Overhead** | -|:------------------|:----------------------|:-------------------------|:---------------------------------| -| **ZeRO-1** | Optimizer states only | ~4x | None (same as DDP) | -| **ZeRO-2** | + Gradients | ~8x | ReduceScatter replaces AllReduce | -| **ZeRO-3 / FSDP** | + Parameters | ~$N$ (linear in workers) | AllGather before each layer | - -ZeRO-1 shards optimizer states across GPUs. Each GPU stores only $1/N$ of the Adam momentum and variance tensors. After gradient AllReduce, each GPU updates only its shard of parameters, then broadcasts updates to other GPUs. Memory savings: optimizer states reduced from $8N$ bytes/param to $8$ bytes/param total across cluster. - -ZeRO-2 additionally shards gradients. Instead of AllReduce, which leaves full gradients on each GPU, ZeRO-2 uses ReduceScatter so each GPU receives $1/N$ of the reduced gradients. Memory savings: gradients reduced from $4N$ bytes/param to $4$ bytes/param total. - -ZeRO-3 and FSDP shard parameters themselves. Each GPU stores only $1/N$ of the model. Before each layer's forward pass, parameters are gathered via AllGather; after backward pass, gradients are reduced via ReduceScatter, then parameters are discarded. This achieves maximum memory efficiency at the cost of *additional communication that FSDP introduces* relative to standard DDP. - -::: {.callout-note title="FSDP Communication Analysis"} -FSDP introduces communication on the critical path that DDP avoids: - -- **Forward pass**: AllGather to reconstruct parameters ($M$ bytes$\times$ 2 for each layer) -- **Backward pass**: ReduceScatter for gradients ($M$ bytes$\times$ 2 for each layer) - -For a model with $L$ layers, FSDP performs $2L$ collective operations per training step versus 1 AllReduce for DDP. However, FSDP enables overlapping: while layer $i$ computes, layer $i+1$ can prefetch parameters. - -Total FSDP communication volume: approximately $3M$ bytes (vs. $2M$ for DDP AllReduce), but spread across more operations with overlap opportunities. -::: - -The choice between FSDP and DDP depends on model size and memory constraints. Use DDP when the model fits in GPU memory with room for activations, as it has lower overhead. Use FSDP ZeRO-2 when the model barely fits or requires activation checkpointing. Use FSDP ZeRO-3 when model parameters exceed single-GPU memory. For training 70B+ models on `{python} a100_mem`GB GPUs, combine FSDP with tensor parallelism. - -Memory-efficient data parallelism requires careful tuning of sharding strategy (by layer, by transformer block, or flat) and mixed precision settings. The sharding granularity determines the trade-off: finer sharding reduces per-GPU memory but increases communication frequency as more AllGather and ReduceScatter operations must execute per training step. - -::: {.callout-war-story title="The Linear Scaling Rule Discovery"} -In 2017, Facebook AI Research shattered the "batch size ceiling" by training ResNet-50 on ImageNet in just one hour using 256 GPUs. Prior to this, increasing batch size beyond a few hundred degraded accuracy. Their key insight was the **Linear Scaling Rule**: when the batch size increases by a factor of $k$, the learning rate must also be multiplied by $k$ to preserve the magnitude of weight updates. However, this rule failed during the initial training phase due to unstable gradients. The solution was a **gradual warmup** strategy—starting with a small learning rate and ramping up linearly over the first few epochs—which allowed them to stabilize training at a massive global batch size of 8,192 without sacrificing convergence. -::: - -## Scaling Efficiency and Convergence {#sec-distributed-training-systems-systems-distributed-training-efficiency-metrics-9488} - -If doubling the number of GPUs in your cluster only makes your training run 1.5 times faster, where did the missing 25% of your multi-million dollar compute budget go? Data parallelism revealed the practical mechanics of gradient synchronization and memory sharding, but to understand *why* scaling efficiency degrades and *how* convergence changes with parallelism, we need a quantitative framework. The metrics and convergence theory in this section apply to all parallelism strategies — data, model, pipeline, and hybrid — governing the fundamental trade-offs between throughput, communication cost, and optimization quality. - -Communication overhead represents the primary bottleneck in distributed training systems. AllReduce operations consume 10--40% of total training time in data parallel systems, and this overhead grows with cluster size. BERT-Large on 128 GPUs experiences communication overhead reaching 35% of total runtime, while GPT-3 scale models experience 55% overhead on 1,024 GPUs. - -::: {.callout-note title="AllReduce Communication Complexity"} -AllReduce complexity depends on two components: latency ($\alpha$) and bandwidth ($\beta$). Ring AllReduce achieves bandwidth-optimal communication with $(N-1)/N$ utilization, while tree-based approaches offer lower latency at $O(\log N)$ steps. The choice depends on message size: tree wins for latency-dominated small messages, ring wins for bandwidth-dominated large gradients. Modern implementations like NCCL use hierarchical algorithms that combine tree latency within nodes and ring bandwidth between nodes. @sec-collective-communication provides detailed algorithm analysis including complexity formulas, hierarchical variants, and topology-aware optimizations for production-scale collective operations. -::: - -Interconnect selection determines whether large-scale deployments remain compute-bound or collapse into communication-bound regimes. - -The bandwidth requirements for efficient distributed training are substantial, particularly for transformer models. Efficient systems require 100--400 GB/s aggregate bandwidth per node for transformer architectures. BERT-Base (110M parameters) requires approximately 440 MB of gradient synchronization per iteration in FP32, while BERT-Large (340M parameters) requires approximately 1.4 GB. Across 64 GPUs, these synchronization demands require 100-200 GB/s sustained bandwidth for sub-50 ms synchronization latency. Language models with `{python} gpt3_params_b`B parameters require 700 GB/s aggregate bandwidth to maintain 80% parallel efficiency, necessitating InfiniBand HDR or equivalent interconnects. - -Synchronization frequency presents a trade-off between communication efficiency and convergence behavior. Gradient accumulation reduces synchronization frequency but increases memory requirements and may impact convergence. Synchronizing every 4 steps reduces communication overhead by 60% while increasing memory usage by 3$\times$ for gradient storage. Asynchronous methods eliminate synchronization costs entirely but introduce staleness that degrades convergence by 15--30% for large learning rates. - -### The Physics of Scaling: Amdahl's Law with Communication {#sec-distributed-training-systems-systems-physics-scaling-amdahls-law-communication-4d7f} - -Just as the Iron Law of Processor Performance governs single-thread execution, distributed training is governed by an extended version of Amdahl's Law that explicitly accounts for communication overhead. The time to complete one training step on $N$ devices is not simply $T_{single} / N$, but is constrained by the sequential nature of synchronization. - -@fig-scaling-tax visualizes this divergence from ideal linear scaling as the **Scaling Tax**. It shows how communication overhead ($r$) acts as a drag on performance, creating a "Communication Wall" where adding more GPUs yields diminishing returns. - -::: {#fig-scaling-tax fig-env="figure" fig-pos="htb" fig-cap="**The Scaling Tax.** Effective speedup vs. number of GPUs. While ideal scaling (dashed black) is linear, real-world systems pay a 'tax' for communication. Compute-bound models like ResNet (green) scale well because they have high arithmetic intensity. Bandwidth-bound models like GPT-3 (red) hit a 'Communication Wall' where adding more GPUs yields diminishing returns (efficiency < 50%)." fig-alt="Line chart of Speedup vs GPU Count (log-log). Green line (Compute Bound) stays close to the ideal linear diagonal. Red line (Bandwidth Bound) flattens out, showing diminishing returns as GPUs are added."} -```{python} -#| echo: false -# ┌───────────────────────────────────────────────────────────────────────────── -# │ SCALING TAX (FIGURE) -# ├───────────────────────────────────────────────────────────────────────────── -# │ Context: @fig-scaling-tax — effective speedup vs communication overhead -# │ -# │ Goal: Plot speedup = N/(1+(N-1)*r) vs N for r=0, 0.05, 0.2, 0.5; show -# │ Communication Wall for bandwidth-bound. -# │ Show: Log-log; ideal vs three r-values; annotation. -# │ How: N = [1..512]; speedup formula; viz.set_book_style(). -# │ -# │ Imports: matplotlib.pyplot (plt), numpy (np), mlsys.viz (viz) -# │ Exports: (figure only, no prose variables) -# └───────────────────────────────────────────────────────────────────────────── -import matplotlib.pyplot as plt -import numpy as np -from mlsys import viz - -viz.set_book_style() -COLORS = viz.COLORS - -fig, ax = plt.subplots() - -N = np.array([1, 2, 4, 8, 16, 32, 64, 128, 256, 512]) -scenarios = [ - {'r': 0.0, 'name': 'Ideal Linear', 'color': 'black', 'style': '--', 'marker': ''}, - {'r': 0.05, 'name': 'Compute Bound (ResNet)', 'color': COLORS['GreenLine'], 'style': '-', 'marker': 'o'}, - {'r': 0.20, 'name': 'Balanced (LLM + NVLink)', 'color': COLORS['OrangeLine'], 'style': '-', 'marker': 's'}, - {'r': 0.50, 'name': 'Bandwidth Bound', 'color': COLORS['RedLine'], 'style': '-', 'marker': '^'} -] - -for sc in scenarios: - speedup = N / (1 + (N - 1) * sc['r']) - ax.plot( - N, - speedup, - linestyle=sc['style'], - color=sc['color'], - marker=sc['marker'], - label=sc['name'], - linewidth=2, - markersize=5, - ) - -ax.set_xscale('log', base=2) -ax.set_yscale('log', base=2) -ax.set_xticks(N) -ax.set_xticklabels(N) -ax.set_yticks(N) -ax.set_yticklabels(N) -ax.set_xlabel('Number of GPUs (N)') -ax.set_ylabel('Effective Speedup') -ax.legend(loc='upper left', fontsize=8) - -ax.annotate( - "Communication Wall", - xy=(32, 32 / (1 + 31 * 0.5)), - xytext=(64, 4), - arrowprops=dict(facecolor=COLORS['primary'], arrowstyle='->', lw=1.5), - fontsize=9, - bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=0.5), -) - -plt.show() -``` -::: - -We can formalize the **Distributed Step Time** equation and map it directly to our Iron Law variables: - -$$ T_{\text{step}}(N) = \underbrace{\frac{T_{\text{compute}}}{N}}_{\text{FLOPS Term}} + \underbrace{T_{\text{comm}}(N)}_{\text{Bandwidth Term}} - T_{\text{overlap}} $$ - -Where: - -* **FLOPS Term** ($T_{\text{compute}}$): The total computation required for the batch. In an ideal world, this scales perfectly with $N$. -* **Bandwidth Term** ($T_{\text{comm}}$): The time spent moving data. This is governed by the **Iron Law**: $Time = \frac{Data}{Bandwidth}$. For Ring AllReduce, this term is $\frac{2(N-1)}{N} \times \frac{M}{B_{net}}$, where $M$ is model size and $B_{net}$ is network bandwidth. -* **Overlap** ($T_{\text{overlap}}$): The portion of communication hidden behind computation. - -This leads to the **Scaling Efficiency** metric: - -::: {#nte-scaling-efficiency .callout-principle icon=false title="The Scaling Efficiency Bound"} -**The Invariant**: Adding nodes to a distributed training job yields diminishing returns because communication overhead grows with cluster size while per-node computation remains constant. -$$ \eta_{\text{scaling}} = \frac{T_1}{N \times T_N} \leq 1 $$ - -**The Implication**: Perfect linear scaling ($\eta = 1.0$) is a theoretical limit, not a practical target. Real systems achieve $\eta = 0.85$–$0.95$ at moderate scale and degrade further as $N$ grows. The gap between $\eta = 1.0$ and the achieved efficiency is the **communication tax** — the price of coordination. -::: - -$$ \text{Efficiency}(N) = \frac{T_{\text{compute}}}{N \times T_{\text{step}}(N)} = \frac{1}{1 + \frac{N(T_{\text{comm}}(N) - T_{\text{overlap}})}{T_{\text{compute}}}} $$ - -This equation reveals the **Scaling Wall**: as $N$ increases, the compute term ($T_{\text{compute}}/N$) shrinks, but the communication term ($T_{\text{comm}}$) remains constant or grows. Eventually, the denominator is dominated by communication, driving efficiency toward zero. Beyond wall-clock time, this communication overhead imposes an *energy tax* that scales with physical distance between devices. - -::: {.callout-perspective title="The Energy Tax of Scale"} -Distributed training is not just a race against time; it is a race against **energy**. In a single GPU, moving a byte from HBM to the cores costs roughly **1–2 pJ/bit**. Moving that same byte across an NVLink interconnect costs **5–10 pJ/bit**. Moving it across an InfiniBand network through switches costs **20–50 pJ/bit**. - -At the scale of 10,000 GPUs, the "Energy Tax" of moving gradients becomes a multi-megawatt problem. This is why **Communication-Computation Overlap** is not just a performance optimization—it is a necessity for making large-scale AI economically and physically sustainable. Every bit *not* moved is a joule saved. -::: - -Scaling efficiency follows predictable patterns across different GPU counts. In the linear scaling regime of 2-32 GPUs, systems typically achieve 85--95% parallel efficiency because communication overhead remains minimal. The communication bound regime emerges at 64-256 GPUs, where efficiency drops to 60-80% even with optimal interconnects. Beyond 512 GPUs, coordination overhead becomes dominant and limits efficiency to 40-60% due to collective operation latency. - -Hardware selection critically impacts these scaling characteristics. NVIDIA DGX systems with NVLink achieve `{python} nvlink_a100` GB/s bisection bandwidth, enabling 90% parallel efficiency up to 8 GPUs per node. Multi-node scaling requires InfiniBand networks, where EDR at 100 Gbps supports efficient training up to 64 nodes, while HDR at 200 Gbps enables scaling to 256+ nodes with greater than 70% efficiency. - -These efficiency metrics directly influence the choice of parallelism strategy. Data parallelism works well in the linear scaling regime but becomes communication-bound at scale. Model parallelism addresses memory constraints but introduces sequential dependencies that limit efficiency. Pipeline parallelism reduces device idle time but introduces complexity in managing microbatches. The optimal strategy depends on which constraint — memory, bandwidth, or synchronization — dominates the target workload. - -### Convergence Guarantees for Distributed Optimization {#sec-distributed-training-systems-systems-convergence-guarantees-distributed-optimization-350e} - -Hardware efficiency metrics govern throughput, but convergence theory determines whether distributed training reaches the same solution quality as single-device training. Three questions arise: how does parallelism affect optimization convergence, when does adding workers help versus hurt, and how must learning rates be tuned for large-batch training? - -### Convergence Rate for Synchronous Data Parallel SGD {#sec-distributed-training-systems-systems-convergence-rate-synchronous-data-parallel-sgd-2ed5} - -The fundamental convergence result for distributed SGD provides the theoretical basis for understanding parallel training. For a loss function $L(\theta)$ with $L$-Lipschitz gradients (smoothness condition) and variance-bounded stochastic gradients $\mathbb{E}[\|g_i - \nabla L(\theta)\|^2] \leq \sigma^2$, synchronous data parallel SGD with $N$ workers achieves the following convergence rate. - -::: {.callout-theorem title="Convergence Rate for Distributed SGD"} -For synchronous data parallel SGD with $N$ workers, each computing gradients on a local batch of size $b$, after $M$ total iterations the expected optimization error satisfies: - -$$ -\mathbb{E}[L(\theta_M)] - L \leq \underbrace{\frac{L \|\theta_0 - \theta\|^2}{2M}}_{\text{optimization error}} + \underbrace{\frac{\eta L \sigma^2}{2Nb}}_{\text{variance floor}} -$$ - -where $\eta$ is the learning rate, $L$ is the optimal loss, and $\sigma^2$ is the gradient variance. The effective convergence rate is $O(1/\sqrt{NbM})$ when the learning rate is tuned optimally as $\eta = O(\sqrt{Nb/M})$. -::: - -This theorem reveals several important insights. First, the variance floor decreases linearly with the number of workers $N$, explaining why distributed training can achieve the same final loss with fewer iterations. The effective batch size is $B = Nb$, so $N$ workers with batch size $b$ each behave equivalently to a single worker with batch size $Nb$. Second, the convergence rate $O(1/\sqrt{NM})$ shows that $N$ workers can achieve the same error as a single worker in $1/N$ the iterations, assuming infinite bandwidth. This is the **statistical efficiency** of distributed training, distinct from hardware efficiency. - -However, the theorem assumes perfect synchronization (BSP). When workers proceed at different rates or use stale gradients, convergence guarantees degrade, as we examine next. - -### Staleness Impact: BSP versus SSP versus ASP {#sec-distributed-training-systems-systems-staleness-impact-bsp-versus-ssp-versus-asp-e153} - -Synchronization models fundamentally affect convergence behavior. The staleness parameter $\tau$ quantifies how many iterations behind a gradient may be when applied to parameters. - -::: {.callout-definition title="Gradient Staleness"} - -***Gradient Staleness ($\tau$)***\index{Gradient Staleness!definition} is the number of parameter updates that occur between the time a gradient is computed and the time it is applied to the global model state. - -1. **Significance (Quantitative):** It represents the **Synchronization Error** in distributed optimization. Increasing $\tau$ can improve **Throughput ($\eta$)** by reducing barrier waits ($L_{\text{lat}}$), but it typically degrades the **Rate of Convergence**, requiring more operations ($O$) to reach the same accuracy. -2. **Distinction (Durable):** Unlike **Network Latency**, which is a physical delay, Staleness is an **Algorithmic Offset** that arises from the choice of synchronization protocol (e.g., ASP, SSP). -3. **Common Pitfall:** A frequent misconception is that Staleness is "always bad." In reality, it is a **Throughput-Convergence Trade-off**: for some large-scale workloads, allowing bounded staleness is the only way to keep thousands of GPUs in use. - -::: - -The convergence behavior differs across these models in ways that directly affect training cost and solution quality. - -In Bulk Synchronous Parallel (BSP, $\tau = 0$), all workers compute gradients on the same parameter version, then synchronize via barrier before updating. This guarantees mathematical equivalence to single-device training with larger batch size, an optimal convergence rate of $O(1/\sqrt{NbM})$, and no hyperparameter adjustment beyond batch size scaling. - -Stale Synchronous Parallel (SSP, $\tau \leq s$) relaxes the barrier by allowing workers to proceed up to $s$ iterations ahead of the slowest worker. The convergence rate degrades to: - -$$ -\mathbb{E}[L(\theta_M)] - L \leq O\left(\frac{1}{\sqrt{NbM}}\right) + O\left(\frac{s^2 \eta^2 L^2}{Nb}\right) -$$ - -The second term represents the **staleness penalty**. For bounded staleness $s$, this penalty can be controlled by reducing the learning rate: $\eta' = \eta / \sqrt{1 + s}$. Typical production systems use $s \in \{2, 4, 8\}$, accepting 5--15% convergence degradation for 20--40% throughput improvement on heterogeneous clusters. - -Asynchronous SGD (ASP, $\tau = \infty$) eliminates waiting entirely: workers update parameters immediately. While this maximizes throughput, convergence degrades: - -$$ -\mathbb{E}[L(\theta_M)] - L \leq O\left(\frac{1}{\sqrt{M}}\right) + O\left(\frac{\bar{\tau}^2 \eta^2 L^2}{1}\right) -$$ - -where $\bar{\tau}$ is the average staleness. The staleness penalty now scales with the square of average delay, and critically, the variance reduction from $N$ workers disappears in the dominant term. Compensation techniques include: - -- **Learning rate decay**: $\eta' = \eta / \sqrt{1 + \bar{\tau}}$ reduces the staleness penalty but slows convergence -- **Momentum correction**: Adjust momentum to account for delayed updates -- **Gradient clipping**: Prevent stale gradients with large magnitudes from destabilizing training - -@tbl-convergence-comparison summarizes the convergence properties of each synchronization model. - -| **Model** | **Staleness** | **Convergence Rate** | **Variance Reduction** | **Best For** | -|:----------|:----------------|----------------------------------------:|:-----------------------|:--------------------------------------| -| **BSP** | $\tau = 0$ | $O(1/\sqrt{NbM})$ | Full ($1/N$) | Final training, reproducibility | -| **SSP** | $\tau \leq s$ | $O(1/\sqrt{NbM}) + O(s^2\eta^2)$ | Partial | Heterogeneous clusters | -| **ASP** | $\tau = \infty$ | $O(1/\sqrt{M}) + O(\bar{\tau}^2\eta^2)$ | None | Maximum throughput, early exploration | - -: **Convergence Properties by Synchronization Model**. BSP provides optimal convergence guarantees at the cost of synchronization overhead. SSP offers a tunable trade-off between throughput and convergence. ASP maximizes throughput but loses the variance reduction benefit of parallelism. {#tbl-convergence-comparison} - -### Learning Rate Scaling Rules {#sec-distributed-training-systems-systems-learning-rate-scaling-rules-fa26} - -When increasing the effective batch size through data parallelism, the learning rate must be adjusted to maintain convergence quality. Two primary scaling rules have emerged from both theory and practice. - -The linear scaling rule [@goyal2017accurate] states that when the batch size is multiplied by $k$, the learning rate should also be multiplied by $k$: - -$$ -\eta_{\text{large}} = k \cdot \eta_{\text{base}} -$$ - -This rule follows from the observation that with batch size $B$, the expected gradient update magnitude is $\eta \cdot \mathbb{E}[g]$. With batch size $kB$, the gradient variance decreases by factor $k$, so multiplying $\eta$ by $k$ maintains the same expected update magnitude while reducing noise. The linear scaling rule works well for moderate batch sizes (up to 8K-16K for ImageNet, up to 32K for BERT) and requires a **warmup period** where the learning rate gradually increases from a small value to the target over the first 5-10% of training. - -::: {.callout-note title="Linear Scaling Warmup"} -Goyal et al. [@goyal2017accurate] demonstrated that linear scaling without warmup causes training instability for large batches. Their warmup schedule increases the learning rate linearly from $\eta_{\text{base}}$ to $k \cdot \eta_{\text{base}}$ over the first $W$ iterations: - -$$ -\eta_t = \eta_{\text{base}} + \frac{t}{W}(k \cdot \eta_{\text{base}} - \eta_{\text{base}}) \quad \text{for } t < W -$$ - -The warmup period allows the model to reach a region of the loss landscape where large learning rates are stable. Typical warmup lengths are 5 epochs for ImageNet or 10K steps for language models. -::: - -The square root scaling rule applies when batch sizes grow so large that linear scaling fails: - -$$ -\eta_{\text{large}} = \sqrt{k} \cdot \eta_{\text{base}} -$$ - -This more conservative rule is motivated by the observation that gradient noise (not just magnitude) affects optimization dynamics. The square root rule better preserves the signal-to-noise ratio of gradient updates. Empirically, square root scaling becomes necessary when batch sizes exceed the **critical batch size** discussed below. - -For extreme batch sizes (32K--1M), layer-wise adaptive learning rate scaling (LARS) [@you2017large] and its Adam variant LAMB [@you2020large] automatically adjust learning rates per layer based on the ratio of weight norm to gradient norm: - -$$ -\eta_l = \eta_{\text{global}} \cdot \frac{\|w_l\|}{\|g_l\| + \lambda\|w_l\|} -$$ - -This prevents layers with small weights from receiving disproportionately large updates. LAMB enabled BERT training with batch sizes up to 64K while maintaining convergence quality. - -### Critical Batch Size: When Does Parallelism Hurt? {#sec-distributed-training-systems-systems-critical-batch-size-parallelism-hurt-4961} - -A fundamental question in distributed training is: when does adding more workers stop helping? The **critical batch size** $B*$ marks the transition point beyond which increasing batch size yields diminishing returns in convergence per sample seen. - -::: {.callout-definition title="Critical Batch Size"} - -***Critical Batch Size ($B*$)***\index{Critical Batch Size!definition} is the batch size at which the **Gradient Noise Scale** equals the **Gradient Signal Scale**. - -1. **Significance (Quantitative):** It marks the transition point for **Parallel Scaling Efficiency**. Below $B*$, increasing the batch size linearly improves the convergence per step. Above $B*$, larger batches yield diminishing returns, requiring proportionally more samples ($D_{\text{vol}}$) to reach the same loss. -2. **Distinction (Durable):** Unlike the **Memory-Limited Batch Size** (determined by $BW$ and capacity), the Critical Batch Size is an **Algorithmic Property** of the model and dataset. -3. **Common Pitfall:** A frequent misconception is that training can be speeded up indefinitely by adding GPUs. In reality, $B*$ defines the **Physical Ceiling for Data Parallelism**: adding workers beyond this point wastes energy and compute ($O$) without reducing total training time ($T$). - -::: - -The critical batch size can be estimated as: - -$$ -B* \approx \frac{\text{tr}(\Sigma)}{\|\nabla L(\theta)\|^2} -$$ - -where $\text{tr}(\Sigma)$ is the trace of the gradient covariance matrix (total gradient variance) and $\|\nabla L(\theta)\|^2$ is the squared gradient norm (signal strength). Intuitively, $B*$ is the batch size at which averaging reduces gradient variance to the level of the true gradient magnitude. - -Empirical critical batch sizes vary by task and scale over three orders of magnitude: - -- **ImageNet ResNet-50**: $B* \approx 8,000 - 16,000$ -- **BERT-Large pretraining**: $B* \approx 32,000 - 65,000$ -- **GPT-3 scale models**: $B* \approx 1,000,000 - 4,000,000$ - -The scaling law regime exhibits three distinct behaviors: - -1. **Below critical ($B < B*$)**: Linear scaling holds. Doubling batch size halves iterations to reach target loss. Hardware efficiency determines throughput. - -2. **At critical ($B \approx B*$)**: Optimal trade-off point. Maximum samples-per-second efficiency. - -3. **Above critical ($B > B*$)**: Diminishing returns. Doubling batch size requires $>2\times$ more samples total. Additional workers provide throughput but not sample efficiency. - -@fig-critical-batch-size illustrates this relationship between batch size and training efficiency. - -::: {#fig-critical-batch-size fig-env="figure" fig-pos="htb" fig-cap="**Critical Batch Size and Scaling Regimes**. Below the critical batch size $B*$, larger batches reduce noise and improve sample efficiency (linear scaling regime). Above $B*$, larger batches provide diminishing returns: while throughput increases, total samples required also increases, reducing sample efficiency. The optimal operating point balances hardware utilization against convergence efficiency." fig-alt="Graph showing sample efficiency versus batch size. Efficiency is flat in the linear regime below B-star, then decreases in the diminishing returns regime above B-star. Vertical dashed line marks critical batch size."} -```{.tikz} -\begin{tikzpicture}[font=\small\usefont{T1}{phv}{m}{n}] - \begin{axis}[ - width=12cm, - height=7cm, - xlabel={Batch Size $B$ (log scale)}, - ylabel={Sample Efficiency (samples to target loss)}, - xmode=log, - ymode=log, - xmin=100, xmax=1000000, - ymin=0.1, ymax=2, - ytick={0.2, 0.5, 1.0, 2.0}, - yticklabels={0.2$\times$, 0.5$\times$, 1.0$\times$, 2.0$\times$}, - grid=major, - legend pos=north east, - legend style={font=\footnotesize}, - ] - - % Linear scaling regime (flat) - \addplot[blue, ultra thick, domain=100:8000] {1.0}; - - % Transition region - \addplot[blue, ultra thick, domain=8000:16000, samples=50] {1.0 * (1 + 0.5*((x-8000)/8000)^2)}; - - % Diminishing returns regime - \addplot[blue, ultra thick, domain=16000:1000000, samples=50] {1.5 * (x/16000)^0.3}; - - % Critical batch size line - \addplot[red, dashed, thick] coordinates {(12000, 0.1) (12000, 2)}; - - % Annotations - \node[anchor=south, font=\footnotesize] at (axis cs:1000, 1.1) {Linear Scaling}; - \node[anchor=south, font=\footnotesize] at (axis cs:100000, 1.8) {Diminishing Returns}; - \node[anchor=west, font=\footnotesize, red] at (axis cs:14000, 0.15) {$B*$}; - - \end{axis} -\end{tikzpicture} -``` -::: - -The critical batch size has important implications for distributed training system design: - -1. **Worker count selection**: Adding workers beyond $B*/b$ (where $b$ is per-worker batch size) improves throughput but not sample efficiency. For cost optimization, this may still be worthwhile if the marginal cost of additional workers is low. - -2. **Learning rate schedule**: Above $B*$, aggressive learning rate warmup becomes essential. The loss landscape near initialization may not support the large updates that linear scaling would produce. - -3. **Communication trade-offs**: Above $B*$, the reduced benefit of larger batches makes communication overhead relatively more costly. This strengthens the case for gradient compression or asynchronous methods. - -::: {.callout-checkpoint title="Scaling Decisions"} -Given a 7B parameter model distributed across a cluster of 64 A100 GPUs (80 GB HBM each), what is the maximum useful batch size? To answer this, you must calculate the **Critical Batch Size** ($B_{crit}$)—the point where the gradient noise scale equals the batch size. Beyond this point, doubling the batch size yields diminishing returns in convergence speed (perfect scaling stops). Using the Gradient Noise Scale metric ($\mathcal{B} \approx \frac{\text{tr}(\Sigma)}{\|\mu\|^2}$), determine if your proposed batch size keeps scaling efficiency above 80%, or if you are simply wasting compute cycles for marginal gains. -::: - -### Worked Example: Convergence Comparison for 8 versus 64 Workers {#sec-distributed-training-systems-systems-worked-example-convergence-comparison-8-versus-64-workers-57ce} - -To illustrate these concepts concretely, consider *scaling from 8 to 64 workers* when training a transformer language model with baseline batch size $b = 32$ per worker. - -::: {.callout-notebook title="Scaling from 8 to 64 Workers" collapse="false"} - -**Setup**: Transformer model with 1.3B parameters, target perplexity 15.0, baseline training: 100K iterations on single GPU with $b = 32$. - -**8 Workers (BSP)** - -- Effective batch size: $B = 8 \times 32 = 256$ -- Learning rate: $\eta = 8 \times \eta_{\text{base}}$ (linear scaling with warmup) -- Expected iterations: $100K / 8 = 12.5K$ iterations -- Convergence: Reaches target perplexity in 12.8K iterations (97% efficiency) -- Communication overhead: 15% (NVLink intra-node) -- Wall-clock speedup: $100K \times 1.0 / (12.8K \times 1.15) = 6.8\times$ - -**64 Workers (BSP)** - -- Effective batch size: $B = 64 \times 32 = 2,048$ -- Learning rate: $\eta = 64 \times \eta_{\text{base}}$ (if $B < B*$) or $\eta = \sqrt{64} \times \eta_{\text{base}}$ (if $B > B*$) -- Assuming $B* \approx 4,000$ (below critical): Linear scaling applies -- Expected iterations: $100K / 64 = 1.56K$ iterations -- Convergence: Reaches target perplexity in 1.72K iterations (91% efficiency) -- Communication overhead: 45% (InfiniBand inter-node, 8 nodes) -- Wall-clock speedup: $100K \times 1.0 / (1.72K \times 1.45) = 40.1\times$ - -**64 Workers (SSP, $s = 4$)** - -- Same effective batch size: $B = 2,048$ -- Learning rate: $\eta' = \eta_{\text{BSP}} / \sqrt{1 + 4} \approx 0.45 \times \eta_{\text{BSP}}$ -- Expected iterations: Higher due to staleness penalty -- Convergence: Reaches target perplexity in 2.1K iterations (74% efficiency) -- Communication overhead: 25% (reduced synchronization) -- Wall-clock speedup: $100K \times 1.0 / (2.1K \times 1.25) = 38.1\times$ - -**Analysis** - -| **Configuration** | **Iterations** | **Comm. Overhead** | **Wall-clock Speedup** | **Sample Efficiency** | -|:---------------------|---------------:|-------------------:|-----------------------:|----------------------:| -| **1 GPU (baseline)** | 100,000 | 0% | 1.0$\times$ | 100% | -| **8 GPU BSP** | 12,800 | 15% | 6.8$\times$ | 97% | -| **64 GPU BSP** | 1,720 | 45% | 40.1$\times$ | 91% | -| **64 GPU SSP** | 2,100 | 25% | 38.1$\times$ | 74% | - -The 64-GPU BSP configuration achieves 40$\times$ speedup despite only 91% sample efficiency because the communication overhead (45%) is offset by the massive parallelism. SSP provides comparable wall-clock time with lower communication overhead but requires more total samples. - -**Cost Analysis** (assuming \$3/GPU-hour): - -- 8 GPU: 12.8K iters$\times$ 0.4s/iter$\times$ 8 GPUs / SEC_PER_HOUR$\times$ \$3 = \$34 -- 64 GPU BSP: 1.72K iters$\times$ 0.58s/iter$\times$ 64 GPUs / SEC_PER_HOUR$\times$ \$3 = \$53 -- 64 GPU SSP: 2.1K iters$\times$ 0.50s/iter$\times$ 64 GPUs / SEC_PER_HOUR$\times$ \$3 = \$56 - -Despite higher parallelism, 64-GPU training costs more per run due to communication overhead and reduced sample efficiency. The 8-GPU configuration is more cost-efficient but takes 6$\times$ longer wall-clock time. The choice depends on whether minimizing cost or minimizing time-to-result is the priority. - -::: - -### Trade-off: Communication Cost versus Convergence Speed {#sec-distributed-training-systems-systems-tradeoff-communication-cost-versus-convergence-speed-cb7c} - -The fundamental trade-off in distributed training is between communication efficiency and convergence quality. @fig-comm-convergence-tradeoff visualizes this trade-off space. - -::: {#fig-comm-convergence-tradeoff fig-env="figure" fig-pos="htb" fig-cap="**Communication-Convergence Trade-off Space**. Each point represents a different distributed training configuration. The Pareto frontier (dashed line) shows optimal configurations where improving one metric requires sacrificing the other. BSP sits at high convergence quality but lower throughput; ASP provides maximum throughput at convergence cost. Gradient compression and SSP occupy intermediate positions." fig-alt="Scatter plot with Communication Efficiency on x-axis and Convergence Quality on y-axis. Points for BSP, SSP, ASP, and gradient compression methods form a Pareto frontier from upper-left to lower-right."} -```{.tikz} -\begin{tikzpicture}[font=\small\usefont{T1}{phv}{m}{n}] - \begin{axis}[ - width=11cm, - height=8cm, - xlabel={Communication Efficiency (throughput)}, - ylabel={Convergence Quality (final loss)}, - xmin=0, xmax=100, - ymin=0, ymax=100, - xtick={0, 25, 50, 75, 100}, - xticklabels={Low, , Medium, , High}, - ytick={0, 25, 50, 75, 100}, - yticklabels={Poor, , Medium, , Optimal}, - grid=major, - ] - - % Pareto frontier - \addplot[black, dashed, thick, domain=15:95, samples=50] {100 - 0.8*(x-15) - 0.005*(x-15)^2}; - - % Methods as points - \addplot[only marks, mark=*, mark size=4pt, blue] coordinates {(20, 98)}; - \node[anchor=west, font=\footnotesize, blue] at (axis cs:22, 98) {BSP}; - - \addplot[only marks, mark=*, mark size=4pt, orange] coordinates {(50, 85)}; - \node[anchor=west, font=\footnotesize, orange] at (axis cs:52, 85) {SSP ($s$=4)}; - - \addplot[only marks, mark=*, mark size=4pt, red] coordinates {(90, 65)}; - \node[anchor=west, font=\footnotesize, red] at (axis cs:82, 60) {ASP}; - - \addplot[only marks, mark=*, mark size=4pt, green!60!black] coordinates {(60, 92)}; - \node[anchor=south, font=\footnotesize, green!60!black] at (axis cs:60, 94) {Gradient Compression}; - - \addplot[only marks, mark=*, mark size=4pt, purple] coordinates {(75, 78)}; - \node[anchor=north, font=\footnotesize, purple] at (axis cs:75, 76) {Local SGD}; - - % Annotation - \node[anchor=north east, font=\footnotesize] at (axis cs:95, 20) {Pareto Frontier}; - - \end{axis} -\end{tikzpicture} -``` -::: - -Several techniques occupy different positions on this trade-off curve: - -Gradient compression reduces communication volume by 10-100$\times$ through quantization or sparsification, with 2--5% convergence degradation. Techniques like QSGD [@alistarh2017qsgd] and Top-K sparsification maintain convergence guarantees with bounded compression error. - -Local SGD takes a different approach: workers perform $H$ local updates before synchronizing, reducing communication frequency by factor $H$. Convergence analysis shows that for smooth, strongly convex objectives, Local SGD achieves the same asymptotic rate as synchronous SGD with appropriately tuned learning rates [@stich2019local]. - -Decentralized SGD restricts workers to communicating only with neighbors in a communication graph rather than performing global AllReduce. This reduces bandwidth requirements at the cost of slower mixing, making it suitable for geo-distributed training where global synchronization is expensive. - -The choice among these methods depends on the specific bottleneck. When network bandwidth limits throughput, gradient compression provides the best trade-off. When synchronization latency dominates, Local SGD or SSP are preferred. When network topology constraints exist, decentralized approaches may be necessary. - -## Model Parallelism {#sec-distributed-training-systems-systems-model-parallelism-7437} - -What do you do when your model is so massive that even a single layer's weights exceed the memory capacity of your largest GPU? Data parallelism entirely collapses under these constraints. The memory optimization techniques examined in the previous section extend data parallelism's reach, but eventually, we must partition the model itself. - -Even with ZeRO-3 fully deployed, sharding optimizer states, gradients, and parameters across workers, some architectures remain intractable. A `{python} gpt3_params_b`B parameter model using FSDP across 64 GPUs still requires 700 GB / 64 = 11 GB of parameters per GPU before accounting for activations. For long-context transformers where activation memory dominates, a 2048-token sequence through `{python} gpt3_params_b`B parameters generates 200+ GB of intermediate activations, and no amount of optimizer sharding addresses this constraint. Model parallelism addresses these limitations by splitting the model architecture itself across devices, rather than replicating it with sharded state. - -```{python} -#| label: a100-capacity-context -#| echo: false -# ┌───────────────────────────────────────────────────────────────────────────── -# │ A100 MEMORY CAPACITY CONTEXT (LEGO) -# ├───────────────────────────────────────────────────────────────────────────── -# │ Context: @sec-distributed-training-systems-systems-model-parallelism-7437 -# │ -# │ Goal: Provide A100 memory capacity for ZeRO-3 discussion. -# │ Show: ~80 GB HBM. -# │ How: pulling A100_MEM_CAPACITY from mlsys.constants. -# │ -# │ Imports: mlsys.constants (A100_MEM_CAPACITY, GB) -# │ Exports: a100_mem -# └───────────────────────────────────────────────────────────────────────────── -from mlsys.constants import A100_MEM_CAPACITY, GB - -class A100CapacityContext: - """A100 hardware reference.""" - - # ┌── 1. LOAD (Constants) ────────────────────────────────────────────── - mem = A100_MEM_CAPACITY - - # ┌── 4. OUTPUT (Formatting) ────────────────────────────────────────────── - a100_mem = f"{mem.m_as(GB):.0f}" - -# ┌── EXPORTS (Bridge to Text) ───────────────────────────────────────────────── -a100_mem = A100CapacityContext.a100_mem -``` - -::: {.callout-notebook title="The Memory Wall of Scale"} -**Problem**: You want to train a **`{python} gpt3_params_b` Billion parameter** model (like GPT-3) on NVIDIA A100s (`{python} a100_mem` GB). Can you use Data Parallelism with ZeRO-3? - -**The Math**: - -1. **Parameter Storage**: `{python} gpt3_params_b`B params$\times$ 2 bytes (FP16) = **350 GB**. -2. **Optimizer State**: `{python} gpt3_params_b`B params$\times$ 12 bytes (Adam FP32) = **2,100 GB**. -3. **Total Static Memory**: 2,450 GB. -4. **ZeRO-3 Sharding**: With 64 GPUs, per-GPU static memory = $2,450 / 64 \approx \mathbf{38 \text{ GB}}$. -5. **Activation Memory**: For sequence length 2048 and batch size 1, a 96-layer transformer generates $\approx \mathbf{50 \text{ GB}}$ of activations per GPU. - -**The Systems Conclusion**: 38 GB (Static) + 50 GB (Dynamic) = **88 GB**. This exceeds the `{python} a100_mem` GB capacity of the A100. Even with full ZeRO-3 sharding, **pure Data Parallelism fails**. You *must* use **Tensor Parallelism** to split the layers themselves. -::: - -Model parallelism addresses this limitation... - -Several implementations of model parallelism exist. In layer-based splitting, devices process distinct groups of layers sequentially. The first device might compute layers 1-4 while the second handles layers 5-8. Channel-based splitting divides the channels within each layer across devices, where the first device processes 512 channels while the second manages the remaining ones. For transformer architectures, attention head splitting distributes different attention heads to separate devices. - -This distribution method enables training of large-scale models. GPT-3, with `{python} gpt3_params_b` billion parameters, relies on model parallelism for training. Vision transformers processing high-resolution 16k$\times$ 16k pixel images use model parallelism to manage memory constraints. Mixture-of-Experts architectures use this approach to distribute their conditional computation paths across hardware. - -Device coordination follows a specific pattern during training. In the forward pass, data flows sequentially through model segments on different devices. The backward pass propagates gradients in reverse order through these segments. During parameter updates, each device modifies only its assigned portion of the model. This coordination ensures mathematical equivalence to training on a single device while enabling the handling of models that exceed individual device memory capacities. - -### Model Parallelism Implementation {#sec-distributed-training-systems-systems-model-parallelism-implementation-0728} - -Model parallelism divides neural networks across multiple computing devices, with each device computing a distinct portion of the model's operations. This division allows training of models whose parameter counts exceed single-device memory capacity. The technique encompasses device coordination, data flow management, and gradient computation across distributed model segments. @fig-dist-model-parallelism captures this bidirectional data flow: input data propagates forward through sequentially assigned model partitions while gradients flow backward to update parameters, with intermediate results transferring across device boundaries at each stage. - -::: {#fig-dist-model-parallelism fig-env="figure" fig-pos="htb" fig-cap="**Model Parallelism Data Flow**. Sequential distribution of model partitions across three devices: input data flows forward through each partition in order (top path), while gradients propagate backward during the update phase (bottom path). Each device handles a distinct portion of the model, with intermediate activations and gradients transferring at partition boundaries. This approach enables training of models exceeding single-device memory at the cost of sequential dependencies that reduce hardware utilization." fig-alt="Linear flow diagram showing model parallelism: Input Data flows through Model Parts 1-3 on Devices 1-3 to Predictions via forward pass arrows above, with gradient update arrows returning below."} -```{.tikz} -\begin{tikzpicture}[font=\small\usefont{T1}{phv}{m}{n}] -\tikzset{Line/.style={line width=1.0pt,black!50,text=black -}, - Box/.style={inner xsep=2pt, - draw=GreenLine, - node distance=1.5, - line width=0.75pt, - fill=GreenL, - anchor=west, - text width=23mm, - align=flush center, - minimum width=23mm, - minimum height=10mm - }, - Text/.style={inner xsep=4pt, - draw=none, line width=0.75pt, - fill=TextColor!80, - font=\usefont{T1}{phv}{m}{n}\footnotesize, - align=flush center, - minimum width=22mm, minimum height=6mm - }, -} - -\node[Box](B1){Input Data}; -\node[Box,right=of B1](B2){Model Part 1\\ on Device 1}; -\node[Box,right=of B2](B3){Model Part 2\ on Device 2}; -\node[Box,right=of B3](B4){Model Part 3\\ on Device 3}; -\node[Box,right=of B4](B5){Predictions}; -% -\draw[Line,-latex](B1)--++(90:12mm) --|node[Text,pos=0.25]{Forward Pass}(B2.120); -\draw[Line,latex-](B1)--++(270:12mm) --|node[Text,pos=0.25]{Gradient Updates}(B2.240); -% -\draw[Line,-latex](B2)--++(90:12mm) --|node[Text,pos=0.25]{Intermediate Data}(B3.120); -\draw[Line,latex-](B2)--++(270:12mm) --|node[Text,pos=0.25]{Gradient Updates}(B3.240); -% -\draw[Line,-latex](B3)--++(90:12mm) --|node[Text,pos=0.25]{Intermediate Data}(B4.120); -\draw[Line,latex-](B3)--++(270:12mm) --|node[Text,pos=0.25]{Gradient Updates}(B4.240); -% -\draw[Line,-latex](B4)--++(90:12mm) --|node[Text,pos=0.25]{Output}(B5.120); -\draw[Line,latex-](B4)--++(270:12mm) --|node[Text,pos=0.25]{Backward Pass}(B5.240); -\end{tikzpicture} -``` -::: - -Consider our running example: the `{python} gpt3_params_b`B parameter model requires 350 GB of memory in FP16, exceeding the `{python} a100_mem` GB capacity of a single A100 by a factor of four. Model parallelism addresses this **capacity wall** by partitioning the model's state—parameters, gradients, and optimizer states—across multiple devices, effectively stitching them into a single super-accelerator. Unlike data parallelism, where every GPU holds a full replica of the model and processes a unique fraction of the global batch, model parallelism requires each GPU to hold a unique fraction of the model and process the *same* data stream sequentially. With 8-way partitioning on A100s, each GPU holds approximately 44 GB of parameters—a tight fit that leaves roughly 36 GB for activations and optimizer state. - -In a typical pipeline parallel implementation, the training loop operates as a relay race. The forward pass initiates on GPU 1, which computes the initial transformer blocks and transmits the resulting intermediate activation tensor across the interconnect to GPU 2. For our `{python} gpt3_params_b`B model with a hidden dimension of 12,288 and a micro-batch size of 4 sequences at 2,048 tokens each, this handoff involves moving approximately 200 MB of data per stage boundary per step. GPU 2 must wait for this payload before it can begin its computation, creating a strict dependency chain that propagates through all stages. The backward pass mirrors this path in reverse, propagating error gradients from the final layer back to the input, with each device computing gradients only for its local parameters. - -This architecture fundamentally changes the optimization dynamics compared to data parallelism. Instead of a global AllReduce to average gradients across replicas, each GPU performs a local optimizer step (Adam [@kingma2015adam], AdaFactor, or similar) on its specific slice of parameters. A device holding transformer layers 1–12 updates only those layers' weights and biases, with no cross-device synchronization required during the optimization step. While this eliminates the bandwidth-heavy gradient synchronization of data parallelism, it trades one bottleneck for another: **pipeline bubbles**. If the layers assigned to GPU 1 are computationally heavier than those on GPU 2—common when attention layers have different head counts or when embedding layers are unevenly sized—valuable compute cycles are lost to waiting. The primary engineering challenge thus shifts from maximizing arithmetic intensity to minimizing serialization latency and ensuring balanced load across the partitioned fleet [@deepspeed_training_system_2021]. - -### Parallelism Variations {#sec-distributed-training-systems-systems-parallelism-variations-592a} - -To address these latency and balancing challenges, the choice of partitioning strategy must align with the model's architecture. Three primary approaches—layer-wise partitioning, operator-level partitioning (tensor parallelism), and pipeline parallelism—each optimize for different structural constraints. - -#### Layer-wise Partitioning {#sec-distributed-training-systems-systems-layerwise-partitioning-cf6a} - -Layer-wise partitioning assigns distinct model layers to separate computing devices. In transformer architectures, this translates to specific devices managing defined sets of attention and feed-forward blocks. @fig-dist-layers-blocks demonstrates this partitioning for a 24-layer transformer: six consecutive blocks reside on each of four devices, with forward activations flowing left-to-right and backward gradients propagating right-to-left across the device boundaries. - -::: {#fig-dist-layers-blocks fig-env="figure" fig-pos="htb" fig-cap="**Layer-Wise Model Parallelism**. A 24-layer transformer distributed across four GPUs, with six consecutive transformer blocks assigned to each device. Forward activations (black arrows) flow left-to-right through device boundaries, while backward gradients (red arrows) propagate right-to-left during parameter updates. This partitioning reduces per-GPU memory from the full model size to 1/4, but introduces sequential dependencies where downstream devices wait for upstream computation to complete." fig-alt="24-layer transformer split across 4 devices: Blocks 1-6 on GPU 1, 7-12 on GPU 2, 13-18 on GPU 3, 19-24 on GPU 4. Black arrows show forward activation flow; red arrows show backward gradient propagation."} -```{.tikz} -\begin{tikzpicture}[font=\usefont{T1}{phv}{m}{n}\small] -\tikzset{Line/.style={line width=1.0pt,black!50 -}, - Box/.style={inner xsep=2pt, - draw=VioletLine2, - line width=0.75pt, - node distance=1.8, - fill=VioletL2, - align=flush center, - text width=19mm, - minimum width=19mm, - minimum height=8mm - }, -} -\node[Box,fill=RedL,draw=RedLine](B1){Blocks 1-6}; -\node[Box,right=of B1,fill=OrangeL,draw=OrangeLine](B2){Blocks 7-12}; -\node[Box,right=of B2,fill=GreenL,draw=GreenLine](B3){Blocks 13-18}; -\node[Box,right=of B3,fill=BlueL,draw=BlueLine](B4){Blocks 19-24}; -% -\node[Box,below=1.3 of B1,fill=VioletL2,draw=VioletLine2](G1){GPU 1}; -\node[Box,below=1.3 of B2,fill=VioletL2,draw=VioletLine2](G2){GPU 2}; -\node[Box,below=1.3 of B3,fill=VioletL2,draw=VioletLine2](G3){GPU 3}; -\node[Box,below=1.3 of B4,fill=VioletL2,draw=VioletLine2](G4){GPU 4}; -% -\scoped[on background layer] -\node[draw=BackLine,inner xsep=13, -line width=0.75pt, -inner ysep=18, -fill=BackColor,yshift=6, -fit=(B1)(G1)](BB1){}; -\node[below=1pt of BB1.north,anchor=north]{Device 1}; -% -\scoped[on background layer] -\node[draw=BackLine,inner xsep=13, -line width=0.75pt, -inner ysep=18, -fill=BackColor,yshift=6, -fit=(B2)(G2)](BB2){}; -\node[below=1pt of BB2.north,anchor=north]{Device 2}; -% -\scoped[on background layer] -\node[draw=BackLine,inner xsep=13, -line width=0.75pt, -inner ysep=18, -fill=BackColor,yshift=6, -fit=(B3)(G3)](BB3){}; -\node[below=1pt of BB3.north,anchor=north]{Device 3}; -% -\scoped[on background layer] -\node[draw=BackLine,inner xsep=13, -line width=0.75pt, -inner ysep=18, -fill=BackColor,yshift=6, -fit=(B4)(G4)](BB4){}; -\node[below=1pt of BB4.north,anchor=north]{Device 4}; - -\foreach \x in {1,2,3} { - \pgfmathtruncatemacro{\newX}{\x + 1} - \draw[-latex,Line] (B\x) -- (B\newX); -} -\foreach \x in {4,3,2} { - \pgfmathtruncatemacro{\newX}{\x - 1} -\draw[red,-latex,Line](B\x.230)to[out=230,in=300](B\newX.300); -} -\end{tikzpicture} -``` -::: - -This sequential processing introduces device idle time, as each device must wait for the previous device to complete its computation before beginning work. While device 1 processes the initial blocks, devices 2, 3, and 4 remain inactive. Similarly, when device 2 begins its computation, device 1 sits idle. This pattern of waiting and idle time reduces hardware utilization efficiency compared to other parallelization strategies. - -#### Pipeline Parallelism {#sec-distributed-training-systems-systems-pipeline-parallelism-8748} - -Pipeline parallelism extends layer-wise partitioning by introducing microbatching to minimize device idle time. Instead of waiting for an entire batch to sequentially pass through all devices, the computation is divided into smaller segments called microbatches, with overlapping execution across pipeline stages. @fig-pipline-parallelism shows how this overlapping works: while device 1 processes microbatch $N+1$, device 2 computes microbatch $N$, device 3 handles $N-1$, and device 4 executes $N-2$, creating a continuous flow that keeps all devices active simultaneously. - -::: {.callout-definition title="Pipeline Parallelism"} - -***Pipeline Parallelism***\index{Pipeline Parallelism!definition} is a model parallelism technique where layers are partitioned into **Sequential Stages** assigned to different devices. - -1. **Significance (Quantitative):** It enables the training of models that exceed the memory of a single node by sharding the model depth. To maintain **Throughput ($\eta$)**, input batches are split into **Micro-batches** that pipeline through the stages, overlapping computation with communication. -2. **Distinction (Durable):** Unlike **Tensor Parallelism**, which shards individual operations (Intra-layer), Pipeline Parallelism shards the model at **Layer Boundaries** (Inter-layer), typically requiring lower communication bandwidth ($BW$). -3. **Common Pitfall:** A frequent misconception is that Pipeline Parallelism is "perfectly efficient." In reality, it suffers from the **Pipeline Bubble**: idle time at the beginning and end of each iteration where devices wait for micro-batches to reach them. - -::: - -GPipe[^fn-gpipe] [@gpipe2019] introduced synchronous pipeline parallelism with micro-batch accumulation, while PipeDream [@harlap2018pipedream] developed asynchronous approaches with weight stashing. To reduce the bubble overhead, modern systems employ **1F1B (One-Forward-One-Backward)**[^fn-1f1b-memory] scheduling, which interleaves the passes to reclaim memory earlier. - -[^fn-1f1b-memory]: **1F1B Scheduling**: "One-Forward-One-Backward." Unlike GPipe (which processes all forward passes before any backward passes), 1F1B interleaves them. This reduces the peak activation memory footprint from $O(M \times P)$ to $O(P)$, where $M$ is the number of micro-batches and $P$ is the number of pipeline stages. This memory reclamation is what enables the massive micro-batch counts ($M \gg P$) required to keep the pipeline bubble small. \index{1F1B Scheduling!memory efficiency} - -[^fn-gpipe]: **GPipe**: Published by Google in 2019, GPipe introduced synchronous micro-batch pipelining that trained a 557M-parameter AmoebaNet across 8 TPUs with near-linear scaling. The key trade-off GPipe exposed: the pipeline bubble fraction $(P-1)/(M+P-1)$ means that with $P=4$ stages and $M=4$ micro-batches, 43% of compute is wasted in idle time -- driving the field toward 1F1B schedules that interleave forward and backward passes to shrink this overhead below 15%. \index{GPipe!pipeline parallelism} - -Each device, as represented by the rows in the drawing, processes its assigned model layers for different microbatches simultaneously. The forward pass involves devices passing activations to the next stage, such as $F_{0,0}$ to $F_{1,0}$. The backward pass transfers gradients back through the pipeline, such as $B_{3,3}$ to $B_{2,3}$. This overlapping computation reduces idle time and increases throughput while maintaining the logical sequence of operations across devices. - -::: {#fig-pipline-parallelism fig-cap="**Pipeline Parallelism Schedule**. A 4-stage pipeline processing 4 microbatches, showing forward passes ($F_{i,j}$) and backward passes ($B_{i,j}$) across time. Rows represent pipeline stages (GPUs), columns represent time steps. The staggered execution keeps all devices active: while stage 0 computes $F_{0,1}$, stage 1 processes $F_{1,0}$ from the previous microbatch. After all forward passes complete, backward passes propagate in reverse order. The \"Update\" column shows synchronized parameter updates after gradient accumulation across all microbatches." fig-alt="Pipeline schedule grid showing 4 stages processing 4 microbatches. Forward passes F stagger diagonally across time, backward passes B follow in reverse order, ending with synchronized Update column."} -```{.tikz} -\begin{tikzpicture}[ - every node/.style={font=\sffamily, draw, minimum width=1cm, minimum height=0.7cm, align=center, outer sep=0}, - fill0/.style={fill=red!20}, % Complementary to lightgray - fill1/.style={fill=blue!20}, % Complementary to orange - fill2/.style={fill=orange!20}, % Complementary to blue - fill3/.style={fill=yellow!20}, % Complementary to purple - back3/.style={fill=yellow!20} % Same as fill3 -] - -% Row 0 -\node[fill0] (F0_0) {$F_{0,0}$}; -\node[fill0, right=0cm of F0_0] (F0_1) {$F_{0,1}$}; -\node[fill0, right=0cm of F0_1] (F0_2) {$F_{0,2}$}; -\node[fill0, right=0cm of F0_2] (F0_3) {$F_{0,3}$}; - -% Row 1 -\node[fill1, above right=0cm and 0cm of F0_0] (F1_0) {$F_{1,0}$}; -\node[fill1, right=0cm of F1_0] (F1_1) {$F_{1,1}$}; -\node[fill1, right=0cm of F1_1] (F1_2) {$F_{1,2}$}; -\node[fill1, right=0cm of F1_2] (F1_3) {$F_{1,3}$}; - -% Row 2 (stacked above F1) -\node[fill2, above right=0cm and 0cm of F1_0] (F2_0) {$F_{2,0}$}; -\node[fill2, right=0cm of F2_0] (F2_1) {$F_{2,1}$}; -\node[fill2, right=0cm of F2_1] (F2_2) {$F_{2,2}$}; -\node[fill2, right=0cm of F2_2] (F2_3) {$F_{2,3}$}; - -% Row 3 (stacked above F2) -\node[fill3, above right=0cm and 0cm of F2_0] (F3_0) {$F_{3,0}$}; -\node[fill3, right=0cm of F3_0] (F3_1) {$F_{3,1}$}; -\node[fill3, right=0cm of F3_1] (F3_2) {$F_{3,2}$}; -\node[fill3, right=0cm of F3_2] (F3_3) {$F_{3,3}$}; - -% Row 3 (backward pass) -\node[back3, right=0cm of F3_3] (B3_3) {$B_{3,3}$}; -\node[back3, right=0cm of B3_3] (B3_2) {$B_{3,2}$}; -\node[back3, right=0cm of B3_2] (B3_1) {$B_{3,1}$}; -\node[back3, right=0cm of B3_1] (B3_0) {$B_{3,0}$}; - -% Row 2 (backward pass) -\node[fill2, below=0cm and 0cm of B3_2] (B2_3) {$B_{2,3}$}; -\node[fill2, right=0cm of B2_3] (B2_2) {$B_{2,2}$}; -\node[fill2, right=0cm of B2_2] (B2_1) {$B_{2,1}$}; -\node[fill2, right=0cm of B2_1] (B2_0) {$B_{2,0}$}; - -% Row 1 (backward pass) -\node[fill1, below=0cm of B2_2] (B1_3) {$B_{1,3}$}; -\node[fill1, right=0cm of B1_3] (B1_2) {$B_{1,2}$}; -\node[fill1, right=0cm of B1_2] (B1_1) {$B_{1,1}$}; -\node[fill1, right=0cm of B1_1] (B1_0) {$B_{1,0}$}; - -% Row 0 (backward pass) -\node[fill0, below=0cm of B1_2] (B0_3) {$B_{0,3}$}; -\node[fill0, right=0cm of B0_3] (B0_2) {$B_{0,2}$}; -\node[fill0, right=0cm of B0_2] (B0_1) {$B_{0,1}$}; -\node[fill0, right=0cm of B0_1] (B0_0) {$B_{0,0}$}; - -% Update nodes -\node[fill0, right=0cm of B0_0] (U0_0) {Update}; -\node[fill1, above=0cm of U0_0] (U0_1) {Update}; -\node[fill2, above=0cm of U0_1] (U0_2) {Update}; -\node[fill3, above=0cm of U0_2] (U0_3) {Update}; - -%\node[draw=none, minimum width=4cm, minimum height=1cm, align=center, right=1cm of F0_3] (Bubble) {Bubble}; -\end{tikzpicture} -``` -::: - -In a transformer model distributed across four devices, device 1 would process blocks 1-6 for microbatch $N+1$ while device 2 computes blocks 7-12 for microbatch $N$. Simultaneously, device 3 executes blocks 13-18 for microbatch $N-1$, and device 4 processes blocks 19-24 for microbatch $N-2$. Each device maintains its assigned transformer blocks but operates on a different microbatch, creating a continuous flow of computation. - -The transfer of hidden states between devices occurs continuously rather than in distinct phases. When device 1 completes processing a microbatch, it immediately transfers the output tensor of shape [microbatch_size, sequence_length, hidden_dimension] to device 2 and begins processing the next microbatch. This overlapping computation pattern maintains full hardware utilization while preserving the model's mathematical properties. - -##### Zero-Bubble Pipeline Parallelism {#sec-distributed-training-systems-systems-zero-bubble} - -The classic problem with pipeline parallelism is the **pipeline bubble**: GPUs at the beginning of the pipeline are idle while waiting for gradients to flow back from the end, and vice versa. These bubbles represent wasted compute. The **1F1B** (one forward, one backward) schedule reduces the bubble by interleaving forward and backward microbatches. Once the pipeline is filled, each GPU alternates between executing one forward microbatch and one backward microbatch, keeping SMs busy most of the time. - -Zero-bubble pipeline schedules**\index{Zero-Bubble Pipeline Parallelism} further reduce idle time by overlapping weight gradient computation with activation gradient communication. In a standard backward pass, the GPU computes $\partial L / \partial W$ (weight gradient) and $\partial L / \partial X$ (activation gradient, sent to the previous stage) together. Zero-bubble scheduling splits these into separate kernels: a **B kernel that computes only the activation gradient $\partial L / \partial X$ and sends it to the previous stage, and a **W** kernel that computes the weight gradient $\partial L / \partial W$ locally. The B kernel must execute promptly (it is on the critical path), but the W kernel can be scheduled opportunistically to fill bubbles. - -The scheduling freedom provided by this B/W split is substantial. In a 4-stage pipeline with 8 microbatches, the standard 1F1B schedule has a bubble fraction of approximately $(p-1)/(m+p-1)$ where $p$ is the number of stages and $m$ is the number of microbatches. For $p=4, m=8$, this is $3/11 \approx 27\%$ idle time. Zero-bubble scheduling can reduce this to near zero by filling the startup and teardown bubbles with W computations. - -The trade-off is memory: zero-bubble scheduling requires storing intermediate activations for longer (because the W computation is deferred), increasing peak memory usage. Some implementations address this by combining zero-bubble scheduling with activation checkpointing, selectively recomputing certain activations rather than storing them. The interaction between these techniques creates a three-way trade-off among pipeline bubble size, memory consumption, and recomputation overhead. - -#### Tensor Parallelism {#sec-distributed-training-systems-systems-tensor-parallelism-d76e} - -Pipeline parallelism, examined above, addresses device idle time by overlapping microbatch processing across stages. Each device holds complete layers and processes them sequentially, with communication only at stage boundaries when activations transfer between devices. This approach tolerates moderate interconnect bandwidth because communication occurs infrequently, once per layer boundary per microbatch. However, pipeline parallelism cannot help when individual layers themselves exceed device memory, or when the communication pattern within layers benefits from a different granularity than layer boundaries. - -Tensor parallelism takes a fundamentally different approach: instead of assigning complete layers to devices, it splits the weight matrices within each layer. This operator-level parallelism (also called intra-layer parallelism) enables finer-grained distribution but requires high-bandwidth interconnects for the frequent intra-layer communication it introduces. - -::: {.callout-definition title="Tensor Parallelism"} - -***Tensor Parallelism***\index{Tensor Parallelism!definition} is a model parallelism technique where individual **Tensors and Operations** (e.g., matrix multiplications) are split across multiple devices. - -1. **Significance (Quantitative):** It enables the parallelization of massive layers that exceed single-device memory. It requires **High-Bandwidth Interconnects** (e.g., NVLink) because it necessitates all-reduce or all-gather operations within the critical path of every layer. -2. **Distinction (Durable):** Unlike **Pipeline Parallelism** (Inter-layer), Tensor Parallelism operates at the **Intra-layer** granularity, distributing the computation of a single mathematical operation across $N$ workers. -3. **Common Pitfall:** A frequent misconception is that Tensor Parallelism can be used across any network. In reality, it is **Bandwidth-Bound ($BW$)**: the high frequency of synchronization makes it feasible only within a single node or across extremely low-latency fabrics. - -::: - -This distinction is critical for hardware planning: tensor parallelism demands NVLink-class bandwidth, while pipeline parallelism tolerates InfiniBand between stages. - -::: {.callout-note title="Tensor vs. Pipeline Parallelism"} -Modern literature distinguishes two forms of model parallelism: - -**Tensor Parallelism** (intra-layer): Splits individual operations (matrix multiplies, attention heads) across devices. Requires high-bandwidth interconnects (NVLink) due to fine-grained communication within each layer. - -**Pipeline Parallelism** (inter-layer): Assigns complete layers to different devices. Requires only point-to-point communication between pipeline stages, tolerating lower bandwidth interconnects. - -The Megatron-LM framework popularized this distinction, using tensor parallelism within nodes (8 GPUs on NVLink) and pipeline parallelism across nodes (InfiniBand). -::: - -Megatron-style tensor parallelism[^fn-megatron] [@shoeybi2019megatron] partitions matrix multiplications in two ways. Examine @fig-tensor-parallel-split: column-parallel splitting divides weight matrices along columns for QKV projections, allowing independent computation across GPUs, while row-parallel splitting divides along rows for output layers, requiring AllReduce to combine partial sums at the end of each block. - -[^fn-megatron]: **Megatron-LM**: NVIDIA's 2019 framework that trained an 8.3B-parameter transformer -- 24$\times$ BERT and 5.6$\times$ GPT-2 at the time -- by strategically placing only two AllReduce operations per transformer block (one after attention, one after MLP). This column-then-row partitioning eliminates inter-GPU communication between the two linear layers within each block, achieving 76% scaling efficiency across 512 GPUs and establishing the tensor parallelism patterns now standard for all large-scale transformer training. \index{Megatron-LM!tensor parallelism} - -Column-parallel linear layers split weights along columns. For input $X$ and weight matrix $W = [W_1 | W_2]$ split across 2 GPUs: -$$Y = XW = X[W_1 | W_2] = [XW_1 | XW_2]$$ -Each GPU computes its partition independently. Outputs are concatenated (no communication needed if followed by row-parallel layer). - -::: {#fig-tensor-parallel-split fig-env="figure" fig-pos="htb" fig-cap="**Tensor Parallelism - Matrix Partitioning**. Illustration of Megatron-LM style tensor parallelism. The first linear layer (e.g., QKV) is split column-wise $[W_1 | W_2]$. The second layer (e.g., Output) is split row-wise $[W_1 ; W_2]$. This arrangement allows the output of the first layer to flow directly into the second without synchronization, requiring using only one AllReduce after the second layer to sum the partial results." fig-alt="Two-panel diagram of tensor parallelism. Left: Column Parallel splits weight matrix into columns, GPUs compute independently. Right: Row Parallel splits by rows, requires AllReduce to sum partial results."} -```{.tikz} -\begin{tikzpicture}[font=\small\usefont{T1}{phv}{m}{n}] - \definecolor{MatrixColor}{RGB}{200,220,255} - \definecolor{SplitColor}{RGB}{100,149,237} % CornflowerBlue - - \tikzset{ - matrix_box/.style={draw=black!70, fill=MatrixColor, thick, minimum width=2cm, minimum height=2cm}, - split_v/.style={draw=black!70, fill=SplitColor, minimum width=1cm, minimum height=2cm}, - split_h/.style={draw=black!70, fill=SplitColor, minimum width=2cm, minimum height=1cm}, - arrow/.style={->, >=stealth, thick} - } - - % Column Parallel (Layer 1) - \begin{scope}[xshift=0cm] - \node[anchor=south] at (1, 2.5) {\textbf{Column Parallel (e.g., QKV)}}; - \node[anchor=south] at (1, 2.2) {$W_A = [W_{A1} | W_{A2}]$}; - - % Input X - \node[draw, fill=gray!10, minimum width=2cm, minimum height=0.5cm] (X) at (1, -1) {$X$}; - - % Split Matrix - \node[split_v] (W1) at (0.5, 1) {$W_{A1}$}; - \node[split_v, right=0cm of W1, fill=MatrixColor] (W2) at (1.5, 1) {$W_{A2}$}; - - % GPUs - \node[below=0.2cm of X] (GPU) {GPUs compute $Y_i = X W_{Ai}$}; - - \draw[arrow] (X) -- (W1.south); - \draw[arrow] (X) -- (W2.south); - \end{scope} - - % Symbol - \node at (3.5, 1) {\Huge $\rightarrow$}; - \node[align=center, font=\footnotesize] at (3.5, 0.5) {Activation\\Pass}; - - % Row Parallel (Layer 2) - \begin{scope}[xshift=5cm] - \node[anchor=south] at (1, 2.5) {\textbf{Row Parallel (e.g., Output)}}; - \node[anchor=south] at (1, 2.2) {$W_B = [W_{B1} ; W_{B2}]$}; - - % Inputs - \node[draw, fill=SplitColor, minimum width=1cm, minimum height=0.5cm] (Y1) at (0.5, -1) {$Y_1$}; - \node[draw, fill=MatrixColor, minimum width=1cm, minimum height=0.5cm] (Y2) at (1.5, -1) {$Y_2$}; - - % Split Matrix - \node[split_h] (WB1) at (1, 1.5) {$W_{B1}$}; - \node[split_h, below=0cm of WB1, fill=MatrixColor] (WB2) at (1, 0.5) {$W_{B2}$}; - - % Output - \node[draw, minimum width=2cm, minimum height=0.5cm] (Z) at (1, 3.5) {$Z = Y_1 W_{B1} + Y_2 W_{B2}$}; - - \draw[arrow] (Y1) -- (WB1.west |- WB1.south); - \draw[arrow] (Y2) -- (WB2.east |- WB2.south); - - \draw[arrow] (WB1) -- (Z); - \draw[arrow] (WB2) -- (Z); - - \node[right=0.2cm of Z, font=\bfseries, text=red] {AllReduce needed here}; - \end{scope} - -\end{tikzpicture} -``` -::: - -Row-parallel linear layers split weights along rows. For $W = \begin{bmatrix} W_1 \\ W_2 \end{bmatrix}$: -$$Y = XW = X_1 W_1 + X_2 W_2$$ -Each GPU computes a partial sum. Outputs require AllReduce to combine. - -In transformer architectures, Megatron applies this pattern: - -1. **QKV projection**: Column-parallel (weights split, outputs concatenated across heads) - -2. **Attention output projection**: Row-parallel (requires AllReduce after) - -3. **First FFN layer**: Column-parallel (split intermediate dimension) - -4. **Second FFN layer**: Row-parallel (requires AllReduce after) - -This design places AllReduce operations strategically: one after attention, one after FFN, totaling 2 AllReduce operations per transformer layer. - -Communication volume per transformer layer depends on sequence length $S$, hidden dimension $H$, and batch size $B$: -$$\text{Communication} = 2 \times B \times S \times H \times \text{sizeof(dtype)}$$ - -With $S=2048$, $H=4096$, $B=4$, and FP16: $2 \times 4 \times 2048 \times 4096 \times 2 = 134$ MB per layer. For a 96-layer model, this totals 12.6 GB per training step, requiring NVLink bandwidth to avoid becoming the bottleneck. - -Tensor parallelism scaling degrades rapidly beyond 8-way parallelism because: - -- Communication volume grows linearly with tensor parallel degree -- Computation per GPU decreases (less work to hide communication latency) -- NVLink bandwidth becomes saturated - -Production systems (GPT-4, LLaMA, Gemini) use 8-way tensor parallelism within nodes, combined with pipeline parallelism across nodes, achieving the best balance of memory distribution and communication efficiency. - -##### Ring Attention for Extreme Sequences {#sec-distributed-training-systems-systems-ring-attention} - -While standard tensor parallelism tiles computation across the HBM-NVLink boundary within a node, **Ring Attention**\index{Ring Attention} extends this principle to the entire sequence dimension. For sequence lengths that exceed even the memory capacity of a single GPU (e.g., million-token context windows), Ring Attention distributes $K$ and $V$ blocks across GPUs in a ring topology. Each GPU holds a portion of $Q$ and iterates through the ring, receiving $K$/$V$ blocks from its neighbor while simultaneously computing attention on the current block and sending the previous block onward. - -The algorithm proceeds in $P - 1$ communication rounds (where $P$ is the number of GPUs). In each round, each GPU computes attention between its local $Q$ block and the currently resident $K$/$V$ block, then sends that $K$/$V$ block to its ring neighbor and receives the next block from its other neighbor. - -This overlaps communication with computation: while GPU $i$ computes attention using $K_j$/$V_j$, it simultaneously receives $K_{j+1}$/$V_{j+1}$ from the ring. - -```{python} -#| label: flash-attention-overlap-scenario -#| echo: false -# ┌───────────────────────────────────────────────────────────────────────────── -# │ FLASHATTENTION OVERLAP HARDWARE (LEGO) -# ├───────────────────────────────────────────────────────────────────────────── -# │ Context: @sec-distributed-training-systems-ring-attention -# │ -# │ Goal: Provide H100 HBM and NVLink bandwidth for overlap discussion. -# │ Show: ~3.35 TB/s HBM; ~900 GB/s NVLink. -# │ How: pulling constants from mlsys.constants. -# │ -# │ Imports: mlsys.constants (H100_MEM_BW, NVLINK_H100_BW, TB, GB, second) -# │ Exports: h100_hbm_bw_str, nvlink_h100 -# └───────────────────────────────────────────────────────────────────────────── -from mlsys.constants import H100_MEM_BW, NVLINK_H100_BW, TB, GB, second - -class FlashAttentionOverlapScenario: - """H100 bandwidth reference for overlap analysis.""" - - # ┌── 1. LOAD (Constants) ────────────────────────────────────────────── - hbm = H100_MEM_BW - nvlink = NVLINK_H100_BW - - # ┌── 4. OUTPUT (Formatting) ────────────────────────────────────────────── - h100_hbm_bw_str = f"{hbm.m_as(TB/second):.2f}" - nvlink_h100 = f"{nvlink.m_as(GB/second):.0f}" - -# ┌── EXPORTS (Bridge to Text) ───────────────────────────────────────────────── -h100_hbm_bw_str = FlashAttentionOverlapScenario.h100_hbm_bw_str -nvlink_h100 = FlashAttentionOverlapScenario.nvlink_h100 -``` - -If the compute time for one tile exceeds the communication time for transferring one tile over NVLink, the communication is fully hidden. On an H100 with `{python} h100_hbm_bw_str` TB/s HBM bandwidth and `{python} nvlink_h100` GB/s NVLink bandwidth, this overlap is achievable for typical tile sizes. - -The practical impact of Ring Attention is measured in context length. Without it, a single GPU's attention computation is limited by HBM capacity: the KV cache for a sequence of length $N$ must fit entirely in one GPU's memory. With Ring Attention across $P$ GPUs, each GPU holds $N/P$ tokens of the KV cache, enabling context lengths of $P \times N_{\text{single}}$. Combined with FlashAttention's efficient tiling within each GPU (see @sec-performance-engineering-flashattention), Ring Attention enables the massive context windows required for long-document reasoning and multi-modal understanding. - -### Parameter Servers and Embedding Sharding {#sec-distributed-training-systems-systems-parameter-servers-embedding-sharding-8821} - -While AllReduce dominates dense model training, the **Parameter Server (PS)** architecture remains the standard for recommendation systems and other sparse workloads. A Parameter Server architecture separates workers (who compute gradients) from servers (who store parameters and apply updates). - -For dense models (like ResNet or BERT), the PS architecture creates a bottleneck: all workers send dense gradient updates to the servers simultaneously, saturating the server's network bandwidth. This "incast" problem drove the adoption of Ring AllReduce, which distributes the bandwidth load across all nodes. - -However, for **Recommendation Systems (DLRM)**, the model parameters are dominated by massive **Embedding Tables** (often 10 TB+) that cannot fit on any single GPU. Furthermore, the updates are **sparse**: a batch of users interacts with only a tiny fraction (e.g., 0.001%) of the items. - -In this regime, Parameter Servers (often rebranded as "Embedding Servers") shine: - -1. **Embedding Sharding**: The massive tables are partitioned across the PS fleet (often CPU nodes with massive DRAM). -2. **Sparse Lookups**: Workers send a list of IDs to the PS. -3. **Sparse Updates**: The PS returns only the requested embedding vectors, not the full table. - -This **Sparse Pull / Sparse Push** pattern avoids the bandwidth bottleneck of dense AllReduce. Modern implementations like TorchRec or Meta's hierarchical sharding place "hot" embeddings on GPUs and "cold" embeddings on CPU PS nodes, creating a tiered memory hierarchy for model parameters. - -### Expert Parallelism (Mixture of Experts) {#sec-distributed-training-systems-systems-expert-parallelism-mixture-experts-bc45} - -While tensor parallelism splits dense layers across devices, **Expert Parallelism** enables scaling model capacity (parameters) without increasing compute cost (FLOPs) by using conditional computation. In a Mixture-of-Experts (MoE) architecture [@shazeer2017outrageously], the feed-forward network (FFN) of each transformer block is replaced by a set of $N$ "experts" (independent FFNs). For each token, a gating network selects a small subset (typically top-1 or top-2) of experts to process it. - -In a distributed setting, experts are partitioned across workers. If we have 8 GPUs and 8 experts, each GPU hosts one expert. The training process introduces a distinct communication pattern: - -1. **Gating**: Each token determines its destination expert. -2. **All-to-All Dispatch**: Tokens are routed across the network to the device hosting their selected expert. -3. **Computation**: Experts process their assigned tokens. -4. **All-to-All Combine**: Processed tokens are routed back to their original device to resume the sequence. - -The primary advantage is decoupling model size from compute budget. A trillion-parameter MoE model might use only 10B parameters per token, enabling training on feasible hardware budgets. The constraint is the **All-to-All** communication, which is bandwidth-intensive and sensitive to load imbalance. - -At the heart of expert parallelism lies the All-to-All communication primitive, which shuffles tokens across the cluster based on dynamic routing decisions. Consider a configuration with $E=64$ experts distributed across 64 GPUs, processing a batch of $B=4$ sequences at length $S=2048$ with hidden dimension $H=4096$. For every MoE layer, the system must dispatch $B \times S$ tokens to their assigned experts. In FP16, this moves $B \cdot S \cdot H \cdot 2$ bytes—approximately 67 MB—in a single direction. Since the processed embeddings must return to their original device for the residual connection, the total network overhead is roughly 134 MB per transformer block. While manageable in isolation, this latency accumulates rapidly in deep, sparse architectures like the Switch Transformer [@fedus2022switch] (up to 2,048 experts) or GShard [@lepikhin2021gshard]. - -Network efficiency relies on the assumption of uniform token distribution, but natural language is inherently skewed: specific experts handling common syntax or connector words may receive 3--5$\times$ their fair share of traffic. To prevent memory overflows on these "hot" experts, systems enforce a hard limit defined by the **capacity factor** $C$, typically set between 1.25 and 1.5. This parameter caps the maximum number of tokens an expert processes at $C \cdot S/E$. If the routing gate assigns more tokens than this buffer allows, the excess tokens are dropped, passing through the layer unprocessed via the residual connection. To mitigate this data loss, training objectives include an **auxiliary load balancing loss**, weighted at 0.01–0.1 relative to the main cross-entropy loss, that penalizes the router for favoring specific experts. Modern implementations like Mixtral $8\times7$B use top-2 routing across 8 experts, achieving a favorable balance between capacity scaling and routing stability. - -This sparse communication pattern distinguishes recommendation and MoE workloads (*Archetype B (DLRM at Scale)*) from dense LLM training (*Archetype A (GPT-4 / Llama-3)*) (@sec-vol2-introduction-archetypes), producing fundamentally different scaling behaviors. - -::: {.callout-lighthouse title="Archetype B (DLRM at Scale): DLRM vs. LLM Scaling"} -**Archetype B (DLRM at Scale)** and **Archetype A (GPT-4 / Llama-3)** scale differently. - -* **LLMs (Dense)**: Scale via Tensor/Pipeline Parallelism. Constraint: **Compute & Interconnect Bandwidth** (NVLink). -* **DLRMs (Sparse)**: Scale via Embedding Sharding (Parameter Servers). Constraint: **Memory Capacity & Interconnect Latency** (Random Access). - -This distinction dictates fundamentally different cluster designs: dense GPU pods for LLMs versus memory-rich CPU/GPU hybrids for RecSys. -::: - -### Trade-offs: The Bubble vs. Bandwidth Dilemma {#sec-distributed-training-systems-systems-model-parallelism-tradeoffs} - -Model parallelism breaks the memory wall but introduces sequential dependencies that reduce hardware utilization. The engineering challenge is balancing **pipeline bubbles** (idle time) against **all-to-all bandwidth** (communication time). - -Model parallelism offers three principal advantages. Memory scaling enables training of models that exceed single-device capacity: with 8-way tensor parallelism, a `{python} gpt3_params_b`B model fits comfortably on A100s. Splitting the model also allows larger global batch sizes without out-of-memory errors, since each GPU processes a smaller parameter slice. The approach maps naturally to the physical structure of transformers, where attention heads split via tensor parallelism and layers split via pipeline parallelism. - -These advantages come at the cost of three fundamental limitations. Pipeline bubbles cause GPUs to sit idle while filling and draining the pipeline; the bubble fraction is $(P-1)/M$, where $P$ is pipeline stages and $M$ is microbatches, and achieving more than 90% efficiency requires $M \gg P$, which increases activation memory. Communication intensity in tensor parallelism is equally constraining: 2 AllReduce operations execute *per layer* on the critical path, demanding extremely high-bandwidth, low-latency interconnects (NVLink) and typically preventing scaling beyond a single node (8 GPUs) before hitting the bandwidth wall. Implementation complexity rounds out the trade-off, requiring invasive changes to the model definition — replacing standard linear layers with column-parallel and row-parallel variants — unlike data parallelism, which wraps the model externally without modifying internals. - -## Advanced Training Primitives {#sec-distributed-training-advanced-primitives} - -As models scale, the efficiency of individual operations becomes as critical as the overall distribution strategy. Advanced training primitives reduce the "cost per step" by optimizing the numerical precision of computation and the overlap of memory accesses. - -### FP8: The Training Frontier {#sec-distributed-training-fp8} - -Traditional mixed-precision training uses FP32 master weights with FP16 forward and backward passes. Modern hardware like the NVIDIA H100 adds support for 8-bit floating point (FP8), offering two formats optimized for different phases of training. - -The E4M3 format (4-bit exponent, 3-bit mantissa) provides a range of approximately $\pm 448$ with moderate precision. Its tighter range but better precision makes it suitable for representing **weights and activations** in the forward pass, where values cluster in predictable distributions. - -The E5M2 format (5-bit exponent, 2-bit mantissa) provides a much larger range of approximately $\pm 57344$ but coarser precision. This wider range accommodates **gradients**, which can span many orders of magnitude during backpropagation. Using E4M3 for gradients would cause frequent overflow and underflow, while E5M2's range captures the full gradient distribution at the cost of slightly noisier updates. - -| **Format** | **Exponent** | **Mantissa** | **Range** | **Precision** | **Use Case** | -|:-----------|:-------------|:-------------|:-------------------------|:--------------|:---------------------------| -| **FP32** | 8 bits | 23 bits | $\pm 3.4 \times 10^{38}$ | Very high | Master weights | -| **FP16** | 5 bits | 10 bits | $\pm 65504$ | High | Mixed-precision | -| **BF16** | 8 bits | 7 bits | $\pm 3.4 \times 10^{38}$ | Moderate | Training | -| **E4M3** | 4 bits | 3 bits | $\pm 448$ | Low | FP8 forward pass | -| **E5M2** | 5 bits | 2 bits | $\pm 57344$ | Very low | FP8 gradients | -| **INT8** | N/A | 8 bits | $-128$ to $+127$ | Uniform | Post-training quantization | -| **INT4** | N/A | 4 bits | $-8$ to $+7$ | Uniform | KV cache, weights | - -: **Numerical Precision Formats for ML**. Each row represents a different precision format. FP8 formats (E4M3, E5M2) occupy the sweet spot between the bandwidth of INT8 and the trainability of FP16. {#tbl-precision-formats} - -The critical engineering challenge in FP8 training is **dynamic scaling**. FP8's narrow dynamic range means that a fixed scale factor will cause either overflow or underflow. Per-tensor scaling multiplies each tensor by a scale factor before casting to FP8, then divides by that factor after the FP8 computation. The scale factor is adjusted dynamically, typically by tracking the running maximum absolute value of each tensor and choosing a scale that maps this maximum to near the FP8 maximum representable value. - -This three-precision approach (FP32 master weights, FP8 GEMMs, FP16 accumulation) achieves near-FP16 training quality while doubling effective throughput on FP8-capable hardware. The energy implications are equally significant: FP8 operations deliver twice the throughput at roughly half the energy per operation, resulting in approximately 4$\times$ improvement in energy efficiency for compute-bound workloads. - -## Hybrid Parallelism {#sec-distributed-training-systems-systems-hybrid-parallelism-12a0} - -How do you train a frontier model when data parallelism runs out of memory and model parallelism runs out of network bandwidth? You must orchestrate them simultaneously across three dimensions. The preceding sections revealed a fundamental tension: data parallelism scales throughput but demands massive memory, while model parallelism enables large models but starves the compute. - -Hybrid parallelism resolves this tension by applying both strategies orthogonally: model parallelism splits the architecture to fit available memory, while data parallelism scales throughput across multiple model replicas. Training a `{python} gpt3_params_b` billion parameter language model on a dataset of 300 billion tokens demonstrates this approach in practice. The neural network layers distribute across multiple GPUs through model parallelism, while data parallelism enables different GPU groups to process separate batches. This dual strategy addresses both memory constraints from model size and computational demands from dataset scale simultaneously, and it is precisely this combination that defines *Archetype A* training at frontier scale. - -::: {.callout-lighthouse title="Archetype A (GPT-4 / Llama-3): Physics of 3D Parallelism"} -**Archetype A (GPT-4 / Llama-3)** is the primary driver for hybrid parallelism. Because the model parameters ($P$) exceed the memory of any single accelerator ($M_{device}$), and the training dataset ($D$) requires massive throughput, we must split the problem along three orthogonal axes: - -1. **Tensor Parallelism**: Splits individual layers to fit $P$ within a node's memory. -2. **Pipeline Parallelism**: Splits layers across nodes to scale $P$ beyond a single node. -3. **Data Parallelism**: Replicates the entire split-model pipeline to scale throughput on $D$. - -Only by combining all three can we train Archetype A systems efficiently. -::: - -### The 3D Training Loop {#sec-distributed-training-systems-systems-hybrid-parallelism-3d-loop} - -Training a `{python} gpt3_params_b`B parameter model requires orchestrating computation across thousands of devices through **3D Parallelism**[^fn-3d-parallelism]. This approach does not merely sum the benefits of individual parallelism strategies; it composes them geometrically to match the physical topology of the hardware. - -[^fn-3d-parallelism]: **3D Parallelism**: Named after the three orthogonal axes of decomposition: (1) Data Parallelism (batch), (2) Tensor Parallelism (layer width), and (3) Pipeline Parallelism (depth). Organizations visualize their training fleets as a 3D grid $(d, t, p)$, where the product $d \times t \times p$ equals the total GPU count. This geometric perspective is essential for balancing the tiered bandwidth constraints of modern clusters. \index{3D Parallelism!etymology} - Consider a training fleet configured with Tensor Parallelism (TP) of 8, Pipeline Parallelism (PP) of 16, and Data Parallelism (DP) of 128. This configuration uses 16,384 GPUs ($8 \times 16 \times 128$) organized into a hierarchy of bandwidth domains. - -The training step begins at the Data Parallel level. Each of the 128 model replicas receives a distinct slice of the global batch. Within each replica, the model is split across 16 pipeline stages (nodes), with micro-batches flowing sequentially from the embedding layer on Node 0 to the loss calculation on Node 15. At the finest granularity, within each node, the 8 GPUs fuse into a single "super-accelerator" via TP. Every matrix multiplication in the forward pass is fractured across these 8 devices, which must exchange partial results via high-bandwidth NVLink after every operation. This generates the highest-intensity traffic in the system—approximately 12.6 GB per step—but latency remains negligible due to the `{python} nvlink_a100`–`{python} nvlink_h100` GB/s bandwidth between co-located chips. - -The backward pass inverts this flow and exposes the critical dependencies between parallelism dimensions. As gradients flow backward through the pipeline, nodes exchange activation gradients point-to-point. This traffic is relatively light—roughly 200 MB per stage boundary—allowing it to traverse slower inter-node InfiniBand links without stalling the pipeline. The true bottleneck emerges at the end of the step: the **Global AllReduce**. All 128 replicas must synchronize their gradients to update the weights, requiring the summation of 350 GB of gradient data across the entire cluster. By overlapping this communication with the computation of subsequent micro-batches through gradient bucketing, the system hides the latency of moving terabytes of data across the datacenter fabric. - -The architectural imperative is **bandwidth matching**: the communication volume of each algorithm must map inversely to the latency of the hardware interconnects. Chatty, blocking TP communication stays within the `{python} nvlink_a100`+ GB/s NVLink domain. Serialized, point-to-point PP transfers traverse the cluster spine at InfiniBand speeds. The massive but infrequent DP synchronization amortizes across the full training step. Attempting to run TP across racks, or DP without gradient accumulation, would violate this hierarchy—causing the 16,000-GPU fleet to wait idly for data to traverse the wire. This bandwidth-matching principle is precisely the "Jeff Dean Test" introduced in @sec-distributed-training-systems-systems-engineering-tradeoffs-selecting-parallelism-strategy-b344. - -### Configuration Design {#sec-distributed-training-systems-systems-hybrid-parallelism-configuration-design} - -Applying this bandwidth-matching principle to physical infrastructure transforms cluster design into a combinatorial optimization problem: mapping the three dimensions of parallelism—Tensor (TP), Pipeline (PP), and Data (DP)—to the network topology. The fundamental constraint is that hardware interconnects dictate which strategies are feasible at each level: intra-node NVLink (`{python} nvlink_a100`–`{python} nvlink_h100` GB/s) supports the frequent AllReduce operations of Tensor Parallelism, inter-node InfiniBand (25–100 GB/s) handles the less frequent point-to-point transfers of Pipeline Parallelism, and the remaining cross-pod bandwidth serves Data Parallelism's once-per-step gradient synchronization. Memory capacity per device (40–80 GB) sets the hard limit for model shards at each level. A detailed analysis of these physical systems—including TPU Pods, SuperPODs, and wafer-scale integration—is provided in @sec-compute-infrastructure. For a standard DGX A100 deployment, we typically fix Tensor Parallelism at $t=8$ to match the number of GPUs within a single node. This ensures that the bandwidth-heavy `AllReduce` operations required by matrix multiplications occur over the high-speed NVLink fabric. We then map Pipeline Parallelism ($p$) across multiple nodes within the same high-bandwidth rack or island to handle the point-to-point activation transfers, often setting $p=8$ or $p=16$ depending on the memory footprint. Finally, Data Parallelism ($d$) scales out to the remaining dimensions across pods, as the gradient synchronization step is less sensitive to the bisection bandwidth constraints of the upper network layers. - -### Memory Analysis {#sec-distributed-training-systems-systems-hybrid-parallelism-memory-analysis} - -The memory budget for training a 175-billion parameter model is dominated by model states and requires aggressive sharding to fit within the 80 GB HBM capacity of modern accelerators. The FP16 weights alone consume approximately 350 GB ($175 \times 10^9 \times 2$ bytes). If we relied solely on Tensor Parallelism with $t=8$, each GPU would hold a 43.75 GB slice of the weights. However, the optimizer state presents a larger hurdle; standard Adam maintains FP32 momentum and variance estimates, consuming roughly 12 bytes per parameter (or simplified to ~8 bytes/param for pure state excluding master weights), adding over 1.4 TB globally. Even with $t=8$, the combined weight and optimizer state would exceed 150 GB per GPU, causing an Out-Of-Memory (OOM) error. Consequently, we must employ Pipeline Parallelism ($p$) to further partition the model layers. With a hybrid configuration of $t=8$ and $p=16$, the static memory footprint drops to roughly 15 GB per GPU, leaving the remaining 65 GB of HBM available for the dynamic activation memory ($A$) generated during the forward pass, which scales linearly with micro-batch size and sequence length. - -### Communication Analysis {#sec-distributed-training-systems-systems-hybrid-parallelism-communication-analysis} - -Each parallelism dimension imposes a distinct traffic profile on the network. Tensor Parallelism is the most chatty, requiring two `AllReduce` operations for every Transformer block (one for the Attention projection, one for the MLP) in both the forward and backward passes. These messages are relatively small but occur thousands of times per step, making them strictly latency-bound and necessitating NVLink. Pipeline Parallelism, in contrast, involves point-to-point transfers of activation tensors (size $B_{\mu} \times S \times H$) only at the boundaries of the pipeline stages. While these messages are moderate in size, they occur less frequently, making them manageable over standard InfiniBand links. Data Parallelism generates the largest burst of traffic, requiring a global `AllReduce` of the entire 350 GB gradient buffer. However, this communication occurs only once per global batch update. By using gradient bucketing to overlap this transmission with the compute-intensive backward pass, the effective cost of DP communication can often be hidden, provided the cluster maintains sufficient cross-sectional bandwidth. - -### Pipeline Bubble {#sec-distributed-training-systems-systems-hybrid-parallelism-pipeline-bubble} - -The primary efficiency penalty in pipeline parallelism is the "bubble"—the periods at the start and end of a training step where GPUs sit idle waiting for the pipeline to fill or drain. For a naive GPipe schedule, the pipeline must completely fill with micro-batches before the first backward pass begins, and completely drain at the end. The fraction of time spent in this bubble is given by the ratio $\frac{P-1}{M + P - 1}$, where $P$ is the number of pipeline stages and $M$ is the number of micro-batches per global batch. For a GPT-175B configuration with $P=16$ stages and $M=32$ micro-batches, the bubble fraction is $\frac{15}{47} \approx 31.9\%$, meaning nearly one-third of the theoretical compute capacity is wasted. To mitigate this, modern systems employ 1F1B (One-Forward-One-Backward) scheduling. 1F1B interleaves forward and backward passes once the pipeline enters the steady state, allowing memory to be freed earlier. This reduction in peak memory pressure allows practitioners to increase $M$, which drives the bubble fraction down asymptotically, improving hardware utilization. - -### Blackwell and the Scaling Frontier {#sec-distributed-training-systems-systems-hybrid-parallelism-blackwell} - -The transition to the **Blackwell** architecture (2024) introduces two architectural shifts that redefine the 3D parallelism trade-offs. First, **NVLink 5** provides 1.8 TB/s of bidirectional bandwidth per GPU, doubling the intra-node capacity of the Hopper generation. This allows for even larger **Tensor Parallelism ($t$)** groups (e.g., $t=16$ or $t=32$ across multiple nodes) with negligible latency overhead. Second, the introduction of native **FP4 support** enables 4-bit training for specific workloads, effectively doubling the arithmetic throughput ($R_{\text{peak}}$) and halving the memory footprint ($D_{\text{vol}}$) compared to FP8. - -For a 1-trillion parameter model, Blackwell enables a configuration of $t=16$ (spanning two 8-GPU nodes via NVLink Switch) and $p=32$, reducing the total number of pipeline stages and associated bubble overhead. The 10 TB/s die-to-die interconnect within the B200 package further collapses the distinction between intra-chip and intra-package communication, allowing the two dies to function as a single high-bandwidth tensor parallel unit. These advancements shift the bottleneck from intra-node communication to the inter-pod optical fabric, making **Rail-Optimized** topologies (@sec-network-fabrics-rail-optimized) and **All-to-All** optimization even more critical for the next generation of the Machine Learning Fleet. - -::: {.callout-notebook title="Configuring 3D Parallelism for GPT-175B"} -**Scenario A: Hopper Cluster (H100)** -Designing the training topology for GPT-3 (175B parameters) requires balancing memory, bandwidth, and compute across three dimensions. A standard configuration uses **Tensor Parallelism (TP=8)** to shard layers within a single node, exploiting the 900 GB/s NVLink bandwidth. **Pipeline Parallelism (PP=16)** then splits the model depth across 16 nodes, tolerant of the slower inter-node bandwidth (InfiniBand) due to bubble overhead. Finally, **Data Parallelism (DP=128)** replicates this entire 16-node pipeline 128 times to scale throughput. The total hardware requirement is massive: $8 \text{ GPUs} \times 16 \text{ Stages} \times 128 \text{ Replicas} = 16,384 \text{ GPUs}$. - -**Scenario B: Blackwell Cluster (B200)** -With the Blackwell architecture, we can exploit the 1.8 TB/s NVLink 5 and FP4 precision. For a 175B model, we can increase **Tensor Parallelism to TP=16** (across two nodes) while maintaining the same low latency. This reduces the required **Pipeline Parallelism to PP=8**, cutting the pipeline bubble fraction by half ($P=8$ vs $P=16$). The resulting fleet is not only faster due to higher TFLOPS but also more efficient due to reduced coordination overhead ($T_{coord}$). -::: - -The MFU values cited above raise a natural question: how did the field arrive at 50% utilization, and what trajectory brought it there? @fig-mfu-progression traces the evolution of Model FLOPs Utilization across published training systems from 2020 to 2024. The progression from GPT-3's 21% MFU to PaLM's 46% MFU reflects not improvements in raw hardware speed but advances in the parallelism strategies, communication overlap techniques, and scheduling optimizations discussed throughout this chapter. The plateau near 40-46% reveals that the theoretical ceiling imposed by communication overhead, pipeline bubbles, and memory management remains formidable even with the most advanced hybrid parallelism techniques available. Notably, Meta's Llama 3 training at 16,384 H100 GPUs achieved slightly lower MFU (41%) than the same model at 8,192 GPUs (43%), confirming that the Scaling Tax described in @sec-distributed-training-systems-systems-physics-scaling-amdahls-law-communication-4d7f is not merely theoretical but measurable in production. - -::: {#fig-mfu-progression fig-env="figure" fig-pos="htb" fig-cap="**The MFU Progression**. Model FLOPs Utilization across published training systems, ordered chronologically. MFU doubled from 21% (GPT-3, 2020) to 46% (PaLM, 2022) through advances in hybrid parallelism and communication overlap, then plateaued near 40-45% for frontier-scale runs. The Llama 3 data points illustrate the Scaling Tax: MFU drops from 43% at 8,192 GPUs to 41% at 16,384 GPUs as communication overhead grows with cluster size. Data sources: Chowdhery et al. (2022), Meta (2024)." fig-alt="Horizontal bar chart of MFU percentages for six training systems from 2020 to 2024 showing progression from 21 percent to 46 percent with a plateau near 41 to 43 percent."} -```{python} -#| echo: false -# ┌───────────────────────────────────────────────────────────────────────────── -# │ MFU PROGRESSION (FIGURE) -# ├───────────────────────────────────────────────────────────────────────────── -# │ Context: @fig-mfu-progression — Model FLOPs Utilization over time -# │ -# │ Goal: Barh of MFU for GPT-3, Megatron-Turing, Gopher, PaLM, Llama 3; -# │ show 21%→46% progression; plateau; Scaling Tax at 16K GPUs. -# │ Show: Six horizontal bars; 50% reference line; annotations. -# │ How: Verified systems/mfu_values; viz.setup_plot(). -# │ -# │ Imports: numpy (np), matplotlib.pyplot (plt), mlsys.viz (viz) -# │ Exports: (figure only, no prose variables) -# └───────────────────────────────────────────────────────────────────────────── -import numpy as np -import matplotlib.pyplot as plt -from mlsys import viz - -fig, ax, COLORS, plt = viz.setup_plot(figsize=(8, 5)) - -# VERIFIED DATA (Chowdhery et al., arXiv:2204.02311; Meta, arXiv:2407.21783) -systems = [ - 'GPT-3\n(~10K V100s, 2020)', - 'Megatron-Turing NLG\n(A100s, 2021)', - 'Gopher\n(TPU v3, 2021)', - 'PaLM 540B\n(6144 TPU v4, 2022)', - 'Llama 3 405B\n(8192 H100s, 2024)', - 'Llama 3 405B\n(16384 H100s, 2024)', -] -mfu_values = [21.3, 30.2, 32.5, 46.2, 43.0, 41.0] - -# Color by hardware family -bar_colors = [ - COLORS['VioletLine'], # V100-era - COLORS['BlueLine'], # A100/TPU-era - COLORS['BlueLine'], # A100/TPU-era - COLORS['BlueLine'], # A100/TPU-era - COLORS['GreenLine'], # H100-era - COLORS['GreenLine'], # H100-era -] - -y_pos = np.arange(len(systems)) -bars = ax.barh(y_pos, mfu_values, color=bar_colors, height=0.6, edgecolor='white', linewidth=0.5) - -# Label each bar with MFU value -for i, (val, bar) in enumerate(zip(mfu_values, bars)): - ax.text(val + 0.8, i, f'{val}%', va='center', fontsize=9, fontweight='bold', - color=COLORS['primary']) - -# Vertical dashed line at 50% ("half of peak" ceiling) -ax.axvline(x=50, color=COLORS['primary'], linestyle='--', linewidth=1.2, alpha=0.5) -ax.text(50.5, 5.4, '50% of peak', fontsize=8, color=COLORS['primary'], alpha=0.6, - va='bottom') - -# Annotation: MFU doubled 2020-2022 -ax.annotate('MFU doubled\n2020 → 2022', - xy=(21.3, 0), xytext=(35, -0.8), - fontsize=8, color=COLORS['VioletLine'], fontweight='bold', - arrowprops=dict(arrowstyle='->', color=COLORS['VioletLine'], lw=1.2), - ha='center') - -# Annotation: Scaling Tax for Llama 3 -ax.annotate('Scaling Tax:\n43% → 41% as\n8K → 16K GPUs', - xy=(41, 5), xytext=(52, 4.2), - fontsize=8, color=COLORS['GreenLine'], fontweight='bold', - arrowprops=dict(arrowstyle='->', color=COLORS['GreenLine'], lw=1.2), - ha='left') - -ax.set_yticks(y_pos) -ax.set_yticklabels(systems, fontsize=8) -ax.set_xlabel('Model FLOPs Utilization (%)') -ax.set_xlim(0, 60) -ax.invert_yaxis() - -# Legend for hardware families -from matplotlib.patches import Patch -legend_elements = [ - Patch(facecolor=COLORS['VioletLine'], label='V100-era'), - Patch(facecolor=COLORS['BlueLine'], label='A100 / TPU-era'), - Patch(facecolor=COLORS['GreenLine'], label='H100-era'), -] -ax.legend(handles=legend_elements, loc='lower right', fontsize=8) - -plt.tight_layout() -plt.show() -``` -::: - -## Multi-Model Training: RLHF and Alignment {#sec-distributed-training-rlhf} - -The parallelism strategies examined so far assume a single model being trained on a single objective. Reinforcement Learning from Human Feedback (RLHF) and its variants break this assumption by requiring multiple models — each with different memory footprints, compute profiles, and gradient requirements — to coordinate within a single training loop. This creates a heterogeneous fleet management problem that cannot be solved by any single parallelism strategy and represents a qualitatively different distributed systems challenge from standard pre-training. - -### The Multi-Model Coordination Problem {#sec-distributed-training-rlhf-coordination} - -Standard pre-training involves one model, one loss function, and one gradient stream. RLHF alignment, by contrast, orchestrates a *system of models* that interact during every training step. In the Proximal Policy Optimization (PPO) formulation [@ouyang2022training], four distinct models must operate in concert: - -1. The **policy model** (the model being aligned) requires full training state: parameters, gradients, optimizer moments, and activations. For a 70B parameter model in mixed precision with Adam, this consumes approximately 70B $\times$ (2 + 2 + 12) = 1,120 GB of memory — the same budget as standard pre-training. - -2. The **reference model** (a frozen copy of the pre-trained policy) requires only inference-mode memory: parameters in FP16/BF16 without gradients or optimizer state. The same 70B model in inference mode needs 70B $\times$ 2 = 140 GB, roughly 8$\times$ less than the training configuration. The reference model computes KL-divergence penalties that prevent the policy from drifting too far from the pre-trained distribution. - -3. The **reward model** (typically a separate model trained on human preference data) also runs in inference mode. Reward models are often smaller than the policy — a 13B reward model requires approximately 26 GB in FP16 — but must process every generated sequence to produce scalar reward signals. - -4. The **value model** (the PPO critic that estimates expected future reward) requires its own training state. If the value model shares the policy's architecture, it adds another 1,120 GB of training memory; in practice, value models are often smaller or share the policy backbone with a separate value head, reducing this to 200--400 GB. - -The aggregate memory demand of the four-model PPO system dwarfs standard pre-training. A naive co-location of a 70B policy, 70B reference, 13B reward, and 70B value model requires approximately 1,120 + 140 + 26 + 1,120 = 2,406 GB of accelerator memory, before accounting for the KV caches and intermediate activations generated during sequence generation. On H100 GPUs with `{python} h100_mem` GB of HBM each, this system requires a minimum of $\lceil 2,406 / 80 \rceil = 31$ GPUs for parameter storage alone. Once generation-phase KV caches (which grow linearly with output sequence length) and training-phase activations are included, the practical minimum rises to 64--128 GPUs for a single RLHF training instance. - -### Infrastructure Asymmetry: Training versus Inference Models {#sec-distributed-training-rlhf-asymmetry} - -The defining infrastructure challenge of RLHF is not the total memory footprint but the *asymmetry* between the models' compute profiles. The policy and value models require full backward passes with gradient computation, activation checkpointing, and optimizer updates — compute-intensive operations that benefit from tensor parallelism and high arithmetic intensity. The reference and reward models, by contrast, perform only forward passes: they are inference workloads embedded within a training loop, with memory access patterns dominated by KV cache management rather than gradient accumulation. - -This asymmetry creates a placement dilemma. Co-locating training and inference models on the same GPUs wastes compute during the generation phase (when the training models sit idle) and wastes memory during the gradient phase (when the inference models' parameter storage could be reclaimed for activations). Separating them onto dedicated GPU pools eliminates waste but introduces network latency for reward queries — each generated token batch must traverse the interconnect to reach the reward model and return a scalar signal before the policy gradient can be computed. - -The generation phase itself introduces a sequential bottleneck absent from standard pre-training. RLHF requires the policy model to *generate* complete sequences (typically 256--2,048 tokens) autoregressively before computing rewards and policy gradients. Autoregressive generation is memory-bandwidth-bound, not compute-bound: each token requires a full forward pass through the model to produce a single output token, with the KV cache growing by approximately $2 \times n_{layers} \times d_{model} \times 2$ bytes per token (in FP16). For a 70B model generating 1,024-token sequences with a batch of 256 prompts, the KV cache alone consumes $2 \times 80 \times 8192 \times 1024 \times 256 \times 2 \approx 172$ GB — more than two H100 GPUs' worth of HBM, dedicated entirely to storing intermediate attention state that is discarded after reward computation. - -Production RLHF systems address this asymmetry through temporal multiplexing. During the generation phase, the cluster operates as an inference system: the policy model generates sequences while the reference model computes log-probabilities, both using inference-optimized kernels (continuous batching, speculative decoding, PagedAttention for KV cache management). During the training phase, the cluster switches to training mode: gradients are computed through the policy and value models using the standard 3D parallelism configuration. This two-phase approach recovers most of the efficiency lost to model asymmetry but requires the orchestration layer to manage the transition between inference and training configurations — a scheduling challenge that standard training frameworks do not address. - -### DPO: Simplifying the Fleet {#sec-distributed-training-rlhf-dpo} - -Direct Preference Optimization (DPO) [@rafailov2024direct] eliminates the reward model and value model entirely by reformulating the alignment objective as a classification loss over preference pairs. Instead of generating sequences, computing rewards, and estimating advantages, DPO directly optimizes the policy to assign higher log-probability to preferred responses over dispreferred ones, using the reference model only to compute a KL-divergence regularization term. - -The infrastructure implications are substantial. DPO reduces the multi-model system from four models to two: the policy model (training mode) and the reference model (inference mode). For our 70B example, the memory budget drops from 2,406 GB to 1,120 + 140 = 1,260 GB — a 48% reduction that cuts the minimum GPU count roughly in half. DPO also eliminates the autoregressive generation phase entirely. Training operates on a fixed dataset of (prompt, preferred response, dispreferred response) triples, restoring the standard pre-training data pipeline: fixed-length sequences, deterministic batching, and no sequential token-by-token generation. The training loop becomes a standard supervised learning step with a modified loss function, amenable to the same 3D parallelism, gradient accumulation, and communication overlap techniques used for pre-training. - -The trade-off is capability. DPO operates on a static preference dataset, meaning the policy cannot explore new responses and receive feedback during training. PPO's online generation allows the policy to improve iteratively on its own outputs, potentially discovering better strategies that the static dataset does not contain. For deployment scenarios where the preference data comprehensively covers the target distribution, DPO's infrastructure simplification dominates. For scenarios requiring adaptive exploration — such as training models to solve novel reasoning tasks — PPO's online feedback loop may justify the 2$\times$ infrastructure overhead. - -### Quantitative Analysis: PPO versus DPO Resource Requirements {#sec-distributed-training-rlhf-quantitative} - -To make the infrastructure trade-off concrete, consider aligning a 70B policy model on a cluster of 256 H100 GPUs (`{python} h100_mem` GB HBM each). - -::: {.callout-notebook title="RLHF Infrastructure Budget: PPO versus DPO"} - -**Scenario**: Aligning a 70B parameter policy model. Reference model: 70B (frozen). Reward model (PPO only): 13B (frozen). Value model (PPO only): 13B (training). - -**PPO Memory Budget (per-GPU, with TP=8, PP=4, DP=8 for policy)** - -The policy model under 3D parallelism with TP=8 and PP=4 distributes its 1,120 GB training state across 32 GPUs, yielding 35 GB per GPU. The reference model, requiring only 140 GB for inference, can be sharded across a separate pool of 8 GPUs at 17.5 GB each, or co-located with the policy GPUs at an additional 4.4 GB per GPU (140 / 32). The 13B reward model adds 26 GB shared across its pool. The 13B value model in training mode adds approximately 208 GB (13B $\times$ 16 bytes/param) across its pool. - -Total static memory (co-located policy + reference on 32 GPUs): $35 + 4.4 \approx 39.4$ GB per GPU, leaving $\sim$40 GB for activations and KV cache. With generation-phase KV caches for 256 sequences of 1,024 tokens, each policy GPU must reserve an additional $172 / 32 \approx 5.4$ GB for its KV cache shard. The remaining $\sim$35 GB constrains the micro-batch size during the training phase. - -**DPO Memory Budget (same parallelism configuration)** - -The policy model uses the same 35 GB per GPU. The reference model adds 4.4 GB per GPU (co-located). No reward model, no value model, no KV cache for generation. - -Total static memory: $35 + 4.4 = 39.4$ GB per GPU — identical to PPO's static footprint, but without the generation-phase KV cache burden. The full $\sim$40 GB remaining is available for training activations, permitting 2$\times$ larger micro-batch sizes or eliminating the need for activation checkpointing that PPO requires. - -**Throughput Comparison** - -PPO's two-phase design (generate then train) introduces a fundamental throughput penalty. If generation consumes 60% of the step time (typical for autoregressive decoding of long sequences), the training hardware achieves only 40% utilization during an RLHF step. DPO, operating as a standard training loop, achieves the same 40--55% MFU as pre-training. - -| **Metric** | **PPO (70B policy)** | **DPO (70B policy)** | -|:-------------------------------|:-------------------------|:-------------------------| -| **Models required** | 4 (policy, ref, reward, value) | 2 (policy, reference) | -| **Total parameter memory** | $\sim$2,406 GB | $\sim$1,260 GB | -| **Minimum GPUs (memory)** | 64--128 | 32--64 | -| **Generation phase** | Yes (sequential, BW-bound) | None | -| **Effective training MFU** | 15--25% | 40--55% | -| **Data pipeline** | Online generation | Static preference pairs | - -**Systems Conclusion**: DPO halves the GPU requirement and doubles the effective compute utilization compared to PPO by eliminating two models and the autoregressive generation phase. The choice between them is not a modeling preference but an infrastructure constraint: organizations with limited GPU budgets adopt DPO because the alternative physically does not fit on their cluster. -::: - -### Sequence Length Variance and Batching Challenges {#sec-distributed-training-rlhf-batching} - -Both PPO and DPO face a data engineering challenge absent from standard pre-training: extreme variance in sequence length. Pre-training typically uses fixed-length sequences (2,048 or 4,096 tokens, padded or packed to fill each position), enabling uniform batch shapes and predictable memory consumption. RLHF training data consists of variable-length prompts (10--500 tokens) concatenated with variable-length completions (50--2,048 tokens), producing sequence lengths with 10--50$\times$ variance within a single batch. - -This variance creates two interacting problems. Fixed-size batching pads all sequences to the maximum length in the batch, wasting compute on padding tokens. A batch containing one 2,048-token sequence and fifteen 128-token sequences wastes 88% of the compute on padding. Dynamic batching groups sequences by similar length to minimize padding, but introduces load imbalance across data-parallel workers: one worker may receive a batch of long sequences consuming 60 GB of activation memory while another processes short sequences using only 8 GB, causing the short-sequence worker to stall at the synchronization barrier while the long-sequence worker completes. - -Production RLHF systems mitigate this through a combination of sequence packing (concatenating multiple short sequences into a single fixed-length input with attention masking to prevent cross-contamination) and adaptive micro-batching (dynamically adjusting the number of sequences per micro-batch based on the aggregate token count rather than the sequence count). These techniques recover 70--85% of the compute lost to padding variance but add complexity to the data pipeline that standard pre-training frameworks do not provide. The interaction between variable-length data and distributed synchronization — where every worker must process the same number of tokens per step to maintain gradient consistency — remains an active area of systems engineering research. - -## Parallelism Strategy Comparison {#sec-distributed-training-systems-systems-parallelism-strategy-comparison-d92a} - -If a colleague asks whether to implement pipeline or tensor parallelism for a new 50-billion parameter model, how do you systematically weigh the architectural and hardware trade-offs? @tbl-parallelism-compare contrasts data, model, pipeline, and hybrid parallelism across six critical dimensions. - -| **Aspect** | **Data Parallelism** | **Model Parallelism** | **Pipeline Parallelism** | **Hybrid Parallelism** | -|:----------------------------------|:-----------------------------------------------------------------|:---------------------------------------------------------------------------|:---------------------------------------------------------------------------|:------------------------------------------------------------------| -| **Focus** | Distributes dataset across devices, each with a full model copy | Distributes the model across devices, each handling a portion of the model | Distributes model stages in pipeline, processing microbatches concurrently | Combines multiple parallelism strategies for balanced scalability | -| **Memory Requirement per Device** | High (entire model on each device) | Low (model split across devices) | Low to Moderate (stages split across devices) | Moderate (splits model and dataset across devices) | -| **Communication Overhead** | Moderate to High (gradient synchronization across devices) | High (communication for intermediate activations and gradients) | Moderate (activation passing between stages) | Very High (requires synchronization for both model and data) | -| **Scalability** | Good for large datasets with moderate model sizes | Good for very large models with smaller datasets | Good for deep models with many layers | Excellent for extremely large models and datasets | -| **Implementation Complexity** | Low to Moderate (relatively straightforward with existing tools) | Moderate to High (requires careful partitioning and coordination) | Moderate to High (requires pipeline scheduling and microbatch management) | High (complex integration of multiple parallelism strategies) | -| **Ideal Use Case** | Large datasets where model fits within a single device | Extremely large models that exceed single-device memory limits | Deep models with sequential stages that can tolerate microbatch latency | Training massive models on vast datasets in large-scale systems | - -: **Parallel Training Strategies**: Data, model, pipeline, and hybrid parallelism each address the challenges of scaling machine learning training by distributing workload across devices, differing in how they partition data and model parameters to optimize memory usage, communication, and scalability. Understanding these trade-offs enables practitioners to select the most effective approach for their specific model and infrastructure. {#tbl-parallelism-compare} - -@fig-parallelism-flowchart provides a decision tree for selecting parallelism strategies based on model size, dataset size, and scaling constraints. While intentionally simplified, real-world scenarios often involve additional complexities such as hardware heterogeneity, communication bandwidth, and workload imbalance that may influence the choice of parallelism techniques. Practitioners should view this as a foundational tool for understanding trade-offs and decision points, then adapt it to the specific requirements and constraints of their systems. - -::: {#fig-parallelism-flowchart fig-env="figure" fig-pos="htb" fig-cap="**Parallelism Strategy Decision Tree**. A systematic selection guide based on two key questions: Does the model fit in single-device memory? Does the dataset fit on a single device? Models exceeding device memory require model parallelism; large datasets benefit from data parallelism; significant constraints in both dimensions demand hybrid approaches. While simplified, this framework captures the primary decision points before practitioners must consider secondary factors like hardware heterogeneity and workload imbalance." fig-alt="Decision tree flowchart starting from Start. Diamond nodes ask about model and dataset fit on single device. Paths lead to four outcomes: Single Device Optimization, Data Parallelism, Model Parallelism, or Hybrid Parallelism."} -```{.tikz} -\begin{tikzpicture}[font=\small\usefont{T1}{phv}{m}{n}] -\tikzset{Line/.style={line width=1.0pt,black!50,text=black -}, - Box/.style={inner xsep=2pt, - node distance=11mm, - draw=GreenLine, line width=0.75pt, - fill=GreenL, - text width=27mm,align=flush center, - minimum width=27mm, minimum height=9mm - }, - Box1/.style={Box, - draw=RedLine, fill=RedL, - text width=31mm, - minimum width=32mm, - minimum height=10mm - }, - Text/.style={inner xsep=2pt, - draw=none, line width=0.75pt, - fill=TextColor, - font=\footnotesize\usefont{T1}{phv}{m}{n}, - align=flush center, - minimum width=7mm, - minimum height=5mm - }, - decision/.style = {align=flush center,text width=42mm,diamond, aspect=2.2, node distance=6mm, - inner xsep=-3pt, inner ysep=-2.95ex,fill=VioletL2, draw=VioletLine}, -} -\node[Box](B1){Hybrid\\ Parallelism}; -\node[Box,node distance=16mm,right=of B1](B2){Model\\Parallelism}; -\node[Box,node distance=16 mm,right=of B2](B3){Data\\ Parallelism}; -\node[Box,right=of B3,fill=RedL, draw=RedLine](B4){Single Device Optimization}; -% -\scoped[on background layer] -\node[draw=BackLine,inner xsep=5mm,inner ysep=5mm, -yshift=-1mm, -fill=BackColor,fit=(B1)(B3),line width=0.75pt](BB){}; -\node[decision,node distance=18mm, -above=of B4](G1B4){Is\\ the dataset\\ very large?}; - -\node[Box1,node distance=15mm, -above=of $(B2.north)!0.5!(B3.north)$](G1B3){Is scaling the model\\ or data more critical?}; -\node[decision,above=of G1B3](G2B3){Are\\ both constraints\\ significant?}; -\node[decision,above=of G2B3](G3B3){Does\\ the dataset fit in a\\ single device?}; -\node[decision,above=of G3B3](G4B3){Does\\ the model fit in a\\ single device?}; -\node[Box,node distance=5mm,above=of G4B3,fill=BlueL, draw=BlueLine](G5B3){Start}; -% -\node[Box,below=1 of B2,fill=BlueL, draw=BlueLine](DB2){End}; -% -\draw[Line,-latex](G5B3)--(G4B3); -\draw[Line,-latex](G4B3)--node[right,pos=0.35]{No}(G3B3); -\draw[Line,-latex](G4B3)-|node[above,pos=0.05]{Yes}(G1B4); -\draw[Line,-latex](G3B3)--node[right,pos=0.35]{No}(G2B3); -\draw[Line,-latex](G2B3)--node[right,pos=0.35]{No}(G1B3); -\draw[Line,-latex](G1B4)--node[right,pos=0.15]{No}(B4); -% -\draw[Line,-latex](G3B3.west)--node[above,pos=0.25]{Yes}++(180:2.3)|-(B2.west); -\draw[Line,-latex](G2B3)-|node[above,pos=0.05]{Yes}(B1); -\draw[Line,-latex](G1B3.south)--node[left,align=center,pos=0.45]{Scaling Model}++(270:8mm)-|(B2); -\draw[Line,-latex](G1B3.south)--++(270:8mm)-|(B3); -\draw[Line,-latex](G1B4)-|node[above,pos=0.22,text=black]{Yes}(B3.40); -% -\draw[Line,-latex](B1)|-(DB2); -\draw[Line,-latex](B3)|-(DB2); -\draw[Line,-latex](B2)--(DB2); -\node[above=2pt of BB.204,inner sep=0pt,anchor=south,fill=BackColor]{Parallelism Opportunities}; -\end{tikzpicture} -``` -::: - -## From Principles to Systems {#sec-distributed-training-systems-systems-framework-integration-cf71} - -The parallelism strategies examined throughout this chapter—gradient averaging, AllReduce synchronization, tensor splitting, pipeline scheduling—translate into production systems through a layered abstraction hierarchy. Understanding this hierarchy matters because the abstraction level determines which constraints the engineer must manage directly and which are delegated to the runtime. - -### Data Parallel Abstractions {#sec-distributed-training-systems-systems-data-parallel-framework-apis-f549} - -The simplest distributed training abstraction wraps a model so that the framework automatically replicates it across available accelerators, splits each batch, and averages gradients after the backward pass. This parameter-server approach[^fn-parameter-server] [@li2014parameter] requires minimal changes to a single-device training loop but creates a communication bottleneck at the central server node: with $N$ workers pushing dense gradient streams simultaneously, the server's inbound bandwidth becomes the chokepoint. For dense models beyond 4--8 GPUs, this centralized design collapses. - -[^fn-parameter-server]: **Parameter Server**: Formalized by Mu Li et al. at CMU/Google in 2014, this architecture dedicates server nodes to storing parameters while workers push gradients and pull updates. The fundamental bottleneck: with $N$ workers, server inbound bandwidth must handle $N$ gradient streams simultaneously, making the server the communication chokepoint. For dense models beyond 4--8 GPUs, decentralized AllReduce (where each worker sends and receives at its own link rate) achieves $N$-fold higher aggregate bandwidth, which is why production systems replaced the parameter server design with decentralized AllReduce. \index{Parameter Server!architecture} - -Production-scale data parallelism eliminates this bottleneck by replacing the central server with decentralized AllReduce. Each worker participates symmetrically in the reduction: every device both sends and receives gradient chunks at its own link rate, distributing the bandwidth load across all nodes rather than concentrating it. The framework initializes a process group that maps workers to the physical topology, selects the optimal collective algorithm (ring, tree, or hierarchical) based on the detected interconnect, and inserts gradient synchronization hooks into the backward pass automatically. Gradient bucketing further improves efficiency by grouping small tensors into larger messages before transmission, and computation-communication overlap allows the AllReduce for early layers to proceed while later layers are still computing gradients. These optimizations collectively achieve 90%+ parallel efficiency at moderate scale—not because the API is simple, but because the underlying runtime makes topology-aware decisions that a manual implementation would require thousands of lines to replicate. - -### Model and Pipeline Parallel Abstractions {#sec-distributed-training-systems-systems-model-parallel-framework-support-7f7f} - -Model parallelism requires a fundamentally different abstraction because the framework must manage cross-device tensor placement and the sequential data flow between partitions. The core principle is explicit device assignment: the engineer specifies which layers reside on which devices, and the framework handles the activation transfers between devices during the forward pass and the gradient flow in reverse during backpropagation. This makes the sequential dependencies of model parallelism visible—each downstream device must wait for its upstream neighbor to complete—forcing the engineer to reason about pipeline bubbles and load balance at the architecture level rather than hiding them behind an opaque wrapper. - -For production-scale model parallelism, the key architectural insight is that tensor splitting and pipeline scheduling require different levels of framework support. Tensor parallelism replaces standard linear layers with column-parallel and row-parallel variants that automatically insert AllReduce operations at the correct points in the transformer computation graph, as described in @sec-distributed-training-systems-systems-tensor-parallelism-d76e. Pipeline parallelism adds microbatch scheduling logic that interleaves forward and backward passes across stages to minimize bubble overhead. Memory-efficient sharding integrates ZeRO-3 style parameter partitioning by wrapping model layers with automatic AllGather operations before each forward pass and ReduceScatter after each backward pass. The critical design decision is which abstraction levels to compose: pure data parallelism suffices when the model fits in memory, memory-efficient sharding extends data parallelism to memory-constrained regimes, and full tensor or pipeline parallelism becomes necessary when individual layers or the full model depth exceed single-device capacity. - -### Communication Primitives {#sec-distributed-training-systems-systems-communication-primitives-7207} - -All distributed training ultimately reduces to a small set of collective communication primitives. Gradient synchronization requires AllReduce: each device contributes its local gradient tensor, and all devices receive the globally averaged result. Parameter broadcasting propagates a single device's state to all others, used during initialization or after asymmetric updates. AllGather collects tensor fragments from every device into a complete tensor on each device, the operation that enables FSDP to reconstruct sharded parameters before each layer's forward pass. ReduceScatter combines reduction and distribution, delivering to each device only its assigned shard of the reduced result—the inverse of AllGather that makes ZeRO-2 gradient sharding possible. - -These primitives compose into the communication patterns that define each parallelism strategy. Data parallelism uses one AllReduce per training step. FSDP uses $2L$ collectives per step (AllGather and ReduceScatter for each of $L$ layers). Tensor parallelism uses 2 AllReduce operations per transformer block on the critical path. The choice of primitive, its message size, and its frequency relative to computation determine whether the system operates in the compute-bound or communication-bound regime—the fundamental diagnostic established in @sec-distributed-training-systems-systems-physics-scaling-amdahls-law-communication-4d7f. No framework abstraction changes the underlying physics: the efficiency of distributed training ultimately depends on physical interconnect bandwidth, memory capacity, and synchronization latency. - -## Fallacies and Pitfalls {#sec-distributed-training-systems-systems-fallacies-pitfalls-e2bc} - -Why do so many engineering teams scale their cluster capacity by 4x, only to find that their training iterations are actually taking longer to complete? Distributed training involves counterintuitive behavior that leads to common misconceptions, capturing errors that waste compute resources and delay research. - -Fallacy: ***Linear speedup is achievable with sufficient engineering effort.*** - -Amdahl's Law establishes hard limits: any sequential component bounds maximum speedup regardless of parallelism. In distributed training, gradient synchronization is inherently sequential since all gradients must be collected before any update proceeds. As @sec-distributed-training-systems-systems-distributed-training-efficiency-metrics-9488 demonstrates, the scaling efficiency equation $\text{Efficiency}(N) = 1/(1 + N(T_{\text{comm}}(N) - T_{\text{overlap}})/T_{\text{compute}})$ reveals how communication overhead dominates as $N$ increases. Even with perfect overlap and optimal algorithms, communication overhead grows with cluster size. For data parallelism, AllReduce time increases logarithmically with tree algorithms or linearly in the latency term with ring algorithms as GPU count grows. A 1000-GPU cluster will never train 1000$\times$ faster than a single GPU; achieving 500$\times$ speedup would be exceptional, and 100-200$\times$ is more typical for communication-heavy workloads. Organizations that budget projects assuming linear scaling inevitably miss deadlines and overspend on compute. - -Pitfall: ***Hyperparameters tuned on small clusters transfer directly to large-scale training.*** - -Engineers tune hyperparameters on 8-GPU workstations then deploy to 256-GPU clusters expecting identical behavior. In production, convergence patterns change fundamentally with scale. The most critical hyperparameter is learning rate: as @sec-distributed-training-systems-systems-data-parallelism-6132 explains, batch size increases proportionally with GPU count in data parallelism, requiring learning rate adjustments. The "linear scaling rule"[^fn-linear-scaling] [@goyal2017accurate] suggests $\eta_{large} = \eta_{base} \times (B_{large}/B_{base})$, but this relationship holds only within bounds. As models scale, they eventually encounter the **Critical Batch Size**[^fn-critical-batch-size], where adding more data per step yields diminishing returns in convergence. - -[^fn-critical-batch-size]: **Critical Batch Size**: Concept introduced by OpenAI (2018), representing the point where gradient noise no longer provides a meaningful regularization signal. For GPT-3, the critical batch size is approximately 1--4 million tokens. Scaling beyond this point improves throughput but does not reduce the number of optimization steps needed to reach target accuracy, collapsing the scaling efficiency ($\eta_{scale}$). \index{Critical Batch Size!convergence limit} - -[^fn-linear-scaling]: **Linear Scaling Rule**: Established by Goyal et al. (2017) at Facebook AI Research, who trained ResNet-50 on ImageNet in one hour across 256 GPUs with a batch size of 8,192 while matching single-GPU accuracy. The rule -- multiply learning rate by $k$ when batch size increases by $k$ -- works because larger batches reduce gradient variance by $k$, requiring proportionally larger steps to maintain update magnitude. The rule fails above the critical batch size (8K--32K for vision, up to 4M for LLMs), where gradient noise drops below the signal floor and additional parallelism yields diminishing convergence returns. \index{Linear Scaling Rule!convergence} - -Beyond the critical batch size (model and dataset dependent, often 8K to 32K for vision models), this relationship breaks down. A team training ResNet-50 on ImageNet with batch size 256 and learning rate 0.1 achieves 76.2% top-1 accuracy. Scaling to 1024 GPUs with batch size 32K and learning rate 12.8 (following linear scaling) produces 74.8% accuracy and slower convergence due to gradient noise reduction. Warmup schedules, weight decay adjustment, and careful momentum tuning recover most lost accuracy, but require systematic experimentation at target scale. Organizations that skip these scaling studies waste thousands of GPU-hours on suboptimal runs. - -Fallacy: ***Data parallelism scales indefinitely by adding more GPUs.*** - -Engineers assume more GPUs always accelerate training. In production, statistical efficiency limits overwhelm hardware gains. As @sec-distributed-training-systems-systems-data-parallelism-6132 establishes, data parallelism increases effective batch size proportionally with GPU count ($B_{total} = N \times B_{local}$), but gradient quality grows sublinearly beyond model-specific thresholds. A 100K-sample batch may provide only 2$\times$ the gradient information of a 10K-sample batch, not 10x, because samples become redundant within the loss landscape. The critical batch size defines where marginal returns collapse: for BERT-Base it occurs near 8K samples, for ResNet-50 near 32K samples. Beyond this threshold, doubling GPU count doubles cost but provides minimal convergence acceleration. A major cloud provider trained a large language model using 1024 GPUs that converged in 18 hours at \$45,000 compute cost; the same model on 512 GPUs converged in 19 hours at \$22,000 cost, demonstrating how exceeding critical batch size wastes resources without meaningful time savings. - -Pitfall: ***Choosing parallelism strategy based solely on memory constraints.*** - -Engineers see that a 70B model exceeds `{python} a100_mem`GB GPU memory and immediately choose tensor parallelism or pipeline parallelism to split weights. In production, the optimal strategy depends on the interaction between memory pressure, computation patterns, and communication topology. As @sec-distributed-training-systems-systems-parallelism-strategy-comparison-d92a explains, tensor parallelism splits each layer across devices with AllReduce synchronization per layer, achieving even memory distribution but placing communication on the critical path. Pipeline parallelism assigns complete layers to stages with point-to-point transfers between stages, reducing per-step communication but introducing pipeline bubble overhead that wastes 10-30% of cycles. For a `{python} gpt3_params_b`B model on 64 A100 GPUs where tensor parallelism degree-8 enables training, pipeline parallelism with 8 stages achieves 23% higher throughput due to reduced all-to-all communication despite similar memory footprints. The decision requires profiling communication patterns and bubble overhead, not just checking if weights fit in memory. - -Fallacy: ***FSDP and ZeRO always improve training efficiency.*** - -Engineers adopt FSDP (Fully Sharded Data Parallel) universally after reading that it "reduces memory and enables larger models". In production, sharding introduces 10--25% communication overhead that only pays off when memory pressure justifies it. FSDP reduces memory footprint by sharding optimizer state, gradients, and optionally parameters across GPUs, but requires AllGather operations before each forward pass and ReduceScatter after backward pass. For a 7B model on A100-`{python} a100_mem`GB GPUs with batch size 4, standard DDP achieves 145 samples/second while FSDP achieves only 118 samples/second (19% slower) because the model fits comfortably without sharding and the added communication overhead provides no benefit. FSDP provides value when model plus optimizer state exceeds single-GPU memory, when enabling larger per-GPU batch sizes justifies the overhead, or when ZeRO-Offload to CPU memory extends capacity. A 65B model that cannot fit on `{python} a100_mem`GB GPUs becomes trainable with FSDP ZeRO-3, accepting 15% throughput loss to enable training at all. Applying FSDP universally without measuring memory pressure wastes performance. - -Fallacy: ***Parallelism overhead is roughly constant regardless of model size.*** - -Engineers benchmark parallelism strategies on convenient small models then apply conclusions to large-scale training. In production, the ratio between computation and communication time changes dramatically with model size, inverting strategic decisions. AllReduce communication time depends primarily on gradient tensor size and network bandwidth, growing roughly linearly with parameter count, while forward and backward pass computation time grows superlinearly due to larger matrix operations. For a 1B parameter model where forward/backward pass takes 50 ms and AllReduce takes 25 ms, communication overhead consumes 33% of step time. For a 70B parameter model where forward/backward takes 2400 ms and AllReduce takes 180 ms, communication overhead drops to 7% despite the gradient size being 70$\times$ larger. Decisions made on small models ("pipeline parallelism's 15% bubble overhead makes it always slower than data parallelism") completely invert at scale where data parallelism's communication overhead reaches 25-40%. Reliable strategy selection requires either profiling at target scale or analytical models that account for how computation scales as $O(n^2)$ to $O(n^3)$ while communication scales as $O(n)$. - -Pitfall: ***Gradient accumulation is free.*** - -Engineers use gradient accumulation to simulate larger batch sizes, reducing synchronization frequency from every step to every $K$ steps. The technique appears cost-free since it eliminates $(K-1)/K$ of communication overhead. In production, accumulation introduces memory consumption, latency expansion, and numerical precision risks. Accumulated gradients consume additional memory throughout the accumulation window: for a 7B model, each accumulated step requires 14 GB of FP16 gradient storage, limiting how many steps can accumulate before memory exhaustion. Effective step time increases proportionally with accumulation steps, so accumulating 8 steps means optimizer updates occur 8$\times$ less frequently, potentially slowing convergence despite higher throughput. Most critically, accumulated FP16 gradients risk overflow when summing hundreds of gradient tensors, particularly in early training when loss values are large. A team training a transformer model with 16-step gradient accumulation in FP16 experienced loss spikes and divergence at step 1200; switching to 4-step accumulation with more frequent synchronization resolved the instability despite higher communication costs. Gradient accumulation trades communication for memory and numerical stability. - -Pitfall: ***Using fixed checkpoint intervals regardless of system characteristics.*** - -Engineers checkpoint distributed training "every hour" or "every 1000 steps" based on intuition rather than analysis. In production, optimal checkpoint frequency depends on the mathematical relationship between checkpoint cost and failure rate. - -```{python} - -#| label: young-daly-calc - -#| label: young-daly-calc -#| echo: false -# ┌───────────────────────────────────────────────────────────────────────────── -# │ YOUNG-DALY OPTIMAL CHECKPOINT CALCULATION (LEGO) -# ├───────────────────────────────────────────────────────────────────────────── -# │ Context: @sec-distributed-training-systems-systems-fallacies-pitfalls-e2bc -# │ -# │ Goal: Quantify the optimal tradeoff between checkpoint overhead and rework. -# │ Show: That frequent checkpointing wastes compute while rare ones risk work. -# │ How: Apply Young-Daly formula (sqrt(2*C*MTBF)) for a 1024-GPU cluster. -# │ -# │ Imports: mlsys.formatting -# │ Exports: t_opt_min_str, total_loss_pct_str, daily_savings_str -# └───────────────────────────────────────────────────────────────────────────── -import math - -# ┌── LEGO ─────────────────────────────────────────────── -class YoungDaly: - """ - Namespace for Young-Daly Checkpoint Optimization. - Scenario: 1024-GPU cluster with 4-hour MTBF and 5-minute checkpoint time. - """ - - # ┌── 1. LOAD (Constants) ─────────────────────────────────────────────── - t_save_min = 5.0 - cluster_mtbf_hr = 4.0 - num_gpus = 1024 - gpu_cost_hr = 2.0 - - # ┌── 2. EXECUTE (The Compute) ───────────────────────────────────────── - # T_opt = sqrt(2 * C * MTBF) - # Convert all to minutes - c_min = t_save_min - mtbf_min = cluster_mtbf_hr * 60 - - t_opt_min_val = math.sqrt(2 * c_min * mtbf_min) - - # Overhead: (C / T_opt) + (T_opt / 2 * MTBF) - # (Simplified Young-Daly first-order approximation) - ckpt_overhead = c_min / t_opt_min_val - rework_overhead = (t_opt_min_val / 2) / mtbf_min - total_overhead = ckpt_overhead + rework_overhead - - # Daily dollar loss due to suboptimal (30 min) vs optimal - # (Context for the text argument) - daily_cost = num_gpus * gpu_cost_hr * 24 - loss_30min = ((c_min / 30) + (30 / (2 * 60 * cluster_mtbf_hr))) * daily_cost - loss_opt = total_overhead * daily_cost - diff_daily = loss_30min - loss_opt - - # ┌── 3. GUARD (Invariants) ─────────────────────────────────────────── - check(40 <= t_opt_min_val <= 60, f"Optimal interval ({t_opt_min_val:.1f}m) out of expected range.") - - # ┌── 4. OUTPUT (Formatting) ────────────────────────────────────────────── - t_opt_min_str = fmt(t_opt_min_val, precision=0) - loss_pct_str = fmt(total_overhead * 100, precision=1) - daily_savings_str = fmt(diff_daily, precision=0, commas=True) - -# ┌── EXPORTS (Bridge to Text) ───────────────────────────────────────────────── -t_opt_min_str = YoungDaly.t_opt_min_str -total_loss_pct_str = YoungDaly.loss_pct_str -daily_savings_str = YoungDaly.daily_savings_str -``` - -``` - -The Young-Daly formula establishes the optimal checkpoint interval as $T_{\text{opt}} = \sqrt{2 \times C \times MTBF}$, where $C$ is checkpoint time and MTBF is mean time between failures. For a 1024-GPU cluster with 4-hour MTBF and 5-minute checkpoint time, the optimal interval is approximately **`{python} t_opt_min_str` minutes**. Checkpointing every 15 minutes "to be safe" wastes `{python} total_loss_pct_str`% of compute time on unnecessary checkpoint overhead, while checkpointing every 2 hours risks losing significant work on failure. For larger models where checkpoint time increases to 15 minutes due to model size and storage bandwidth, the optimal interval shifts significantly. The cost of suboptimal checkpointing scales with cluster size: a 1024-GPU cluster loses approximately **USD `{python} daily_savings_str` per day** to excessive checkpointing when using arbitrary intervals instead of the Young-Daly optimal. - -## Summary {#sec-distributed-training-systems-summary} - -This chapter opened with a "Scaling Wall": the point where adding more GPUs eventually makes training slower rather than faster. We have reframed distributed training not as a simple hardware problem, but as a **constraint satisfaction problem** governed by the interaction between model size, batch size, and interconnect bandwidth. - -We explored the **3D Parallelism Cube** (@fig-3d-parallelism-cube-summary), the foundational framework for scaling frontier models. **Data Parallelism** unrolls the outer loop of training to scale throughput; **Tensor Parallelism** vectorizes the inner loops of matrix multiplication to fit memory; and **Pipeline Parallelism** stages sequential layers to reduce communication frequency. Together, these strategies allow us to map models larger than any single memory bank onto a fleet of accelerators with finite bandwidth. - -Ultimately, the choice of parallelism is a **loop transformation** applied by the cluster-level compiler. By matching logical communication patterns to physical hardware hierarchies, we move from the "linear scaling regime" of small clusters to the "communication-bound" reality of the exascale supercomputer. - -::: {.callout-takeaways title="Parallelism Is a Loop Transformation"} - -* **Communication-Computation Ratio is the ceiling**: Distributed training is governed by an extended Amdahl's Law. Speedup is limited by the sequential nature of synchronization (AllReduce) relative to the parallel work of computation. -* **Data parallelism is the default, but hits the batch trap**: Scaling replicas improves throughput linearly until you hit the "Critical Batch Size," beyond which larger batches yield diminishing returns in convergence. -* **Tensor parallelism is node-local**: Because it partitions matrix operations, it requires the `{python} nvlink_a100`-`{python} nvlink_h100` GB/s bandwidth of NVLink. Scaling it across racks on standard Ethernet will stall the fleet. -* **Pipeline parallelism minimizes bandwidth, but adds "bubbles"**: Splitting depth across nodes allows for massive models, but efficiency depends on microbatching ($M \gg P$) to minimize idle time during fill and drain phases. -* **Memory-efficient DP (ZeRO/FSDP) scales linear memory**: By sharding optimizer states, gradients, and parameters, ZeRO-3 allows 100B+ models to fit on commodity hardware that would otherwise require complex model parallelism. -* **The Linear Scaling Rule requires Warmup**: Multiplying batch size by $k$ allows for multiplying learning rate by $k$, but only with a linear warmup period to maintain stability during the high-gradient-noise initial phase. - -::: - -Throughout this chapter, we applied these partitioning strategies to our Lighthouse Archetypes, revealing that there is no "one size fits all" configuration. - -::: {.callout-lighthouse title="Distributed Archetype Spectrum"} -The "optimal point" in the 3D Parallelism Cube shifts depending on the system's primary bottleneck: - -| **Archetype** | **Primary Partitioning Strategy** | **The Logic** | -|:-------------------------------------|:----------------------------------|:----------------------------------------------------------------------------------------------| -| **Archetype A (GPT-4 / Llama-3)** | Hybrid 3D Parallelism | Combine Tensor (width), Pipeline (depth), and Data (throughput) to fit 1 TB+ of weights. | -| **Archetype B (DLRM at Scale)** | Embedding Sharding | Partition massive 10 TB+ tables across a Parameter Server fleet; use sparse AllToAll updates. | -| **Archetype C (Federated MobileNet)**| Federated Learning | Distribute training *data* to the edge; keep model local; accept asynchronous, stale updates. | -::: - -::: {#fig-3d-parallelism-cube-summary fig-env="figure" fig-pos="htb" fig-cap="**3D Parallelism Strategy Space**: The three independent axes of distributed training parallelism. Data Parallelism (d) replicates models across workers. Pipeline Parallelism (p) partitions model layers vertically. Tensor Parallelism (t) splits individual operations horizontally. Real systems combine all three dimensions." fig-alt="Three-dimensional cube diagram with axes labeled Data Parallel, Pipeline Parallel, and Tensor Parallel. Colored planes show Model Replicas, Intra-layer partitioning, and pipeline stages."} -```{.tikz} -\begin{tikzpicture}[scale=1.2, font=\small\usefont{T1}{phv}{m}{n}] - \definecolor{DataColor}{RGB}{200,220,255} - \definecolor{TensorColor}{RGB}{255,220,200} - \definecolor{PipeColor}{RGB}{220,255,200} - \fill[gray!5] (0,2,0) -- (2,2,0) -- (2,2,2) -- (0,2,2) -- cycle; - \fill[gray!10] (2,0,0) -- (2,2,0) -- (2,2,2) -- (2,0,2) -- cycle; - \draw[->, thick] (0,0,0) -- (3,0,0) node[right] {Data Parallel ($d$)}; - \draw[->, thick] (0,0,0) -- (0,3,0) node[above] {Pipeline Parallel ($p$)}; - \draw[->, thick] (0,0,0) -- (0,0,3) node[below left] {Tensor Parallel ($t$)}; - \draw[thick, fill=DataColor!30, opacity=0.7] (0,0,2) -- (2,0,2) -- (2,2,2) -- (0,2,2) -- cycle; - \draw[thick, fill=TensorColor!30, opacity=0.7] (2,0,0) -- (2,0,2) -- (2,2,2) -- (2,2,0) -- cycle; - \draw[thick, fill=PipeColor!30, opacity=0.7] (0,2,0) -- (2,2,0) -- (2,2,2) -- (0,2,2) -- cycle; - \node at (1,1,2) {Model Replicas}; - \node[rotate=90] at (2.2,1,1) {Intra-layer}; - \node at (1,2.2,1) {Inter-layer}; -\end{tikzpicture} -``` -**The 3D Parallelism Space**. Archetype A occupies a coordinate $(d, t, p)$ inside this cube to balance memory and bandwidth. -::: - -The parallelism strategies explored throughout this chapter—data, tensor, pipeline, and expert—provide the conceptual toolkit for partitioning any training workload across a cluster. The key insight is that these strategies are not mutually exclusive alternatives but complementary dimensions of a unified optimization space. Production systems like Megatron-LM achieve efficient scaling precisely because they combine all four strategies, using tensor parallelism within nodes, pipeline parallelism across node groups, data parallelism for throughput, and expert parallelism for capacity scaling. - -::: {.callout-chapter-connection title="From Logic to Traffic"} - -We have defined the *logical* traffic patterns of the Machine Learning Fleet—the "How" of splitting the math. These logical patterns, however, eventually hit the physical wires of the datacenter. - -In **Communication** (@sec-collective-communication), we open the black box of the collective operations (AllReduce, AllToAll) that make this consistency possible. We move from the logic of *what* to send to the mechanics of *how* to route it through rings, trees, and rails. - -::: - -::: { .quiz-end } -::: - -```{python} -#| echo: false -#| label: chapter-end -from mlsys.registry import end_chapter -end_chapter("vol2:distributed_training") -``` diff --git a/book/quarto/contents/vol2/inference/inference.qmd b/book/quarto/contents/vol2/inference/inference.qmd index 6e137ce88..2e1a6fef5 100644 --- a/book/quarto/contents/vol2/inference/inference.qmd +++ b/book/quarto/contents/vol2/inference/inference.qmd @@ -1153,6 +1153,58 @@ Continuous batching (also called iteration-level batching) decouples batch membe **Archetype A (GPT-4 / Llama-3)** (@sec-vol2-introduction-archetypes) relies on continuous batching to solve its primary efficiency paradox. The decode phase is memory-bandwidth bound, meaning the GPU compute cores are idle waiting for weights to load. Continuous batching saturates this bandwidth by processing unrelated requests together. Without this technique, serving Archetype A (GPT-4 / Llama-3) models would be economically unviable due to low GPU utilization. ::: +Unlike static or dynamic batching, which group requests at the *request* level, continuous batching operates at the *iteration* level, dynamically reshaping the compute tensor at each clock cycle. + +::: {#fig-continuous-batching-comparison fig-env="figure" fig-pos="htb" fig-cap="**Static vs. Continuous Batching**. In Static Batching (A), all requests in a batch must wait for the longest request to complete before the GPU can begin the next batch, leading to significant idle compute time (shaded gray). Continuous Batching (B) allows new requests to enter the batch as soon as any request finishes, keeping the GPU saturated and dramatically improving throughput." fig-alt="Two-panel timeline comparison. Left: Static batching shows requests of different lengths with large white space representing idle time. Right: Continuous batching shows requests filling the gaps as soon as one ends."} +```{.tikz} +\begin{tikzpicture}[font=\small\usefont{T1}{phv}{m}{n}, xscale=0.8, yscale=0.7] + \definecolor{StaticColor}{RGB}{200,200,200} + \definecolor{ReqColor}{RGB}{0,99,149} % BlueLine + \definecolor{WaitColor}{RGB}{240,240,240} + + \tikzset{ + req/.style={fill=ReqColor!60, draw=black!60, thick}, + idle/.style={fill=black!5, draw=black!30, dashed} + } + + % Panel A: Static Batching + \begin{scope} + \node[anchor=west, crimson] at (0, 4.5) {\textbf{A. Static Batching}}; + % Batch 1 + \draw[req] (0, 3) rectangle (2, 3.8) node[midway, white] {R1}; + \draw[req] (0, 2) rectangle (5, 2.8) node[midway, white] {R2}; + \draw[req] (0, 1) rectangle (3, 1.8) node[midway, white] {R3}; + % Idle regions + \draw[idle] (2, 3) rectangle (5, 3.8); + \draw[idle] (3, 1) rectangle (5, 1.8); + % Vertical barrier + \draw[thick, red, dashed] (5, 0.5) -- (5, 4.2) node[above, font=\tiny] {Batch Barrier}; + + % Batch 2 starts after barrier + \draw[req] (5.2, 3) rectangle (8, 3.8) node[midway, white] {R4}; + \node at (2.5, -0.5) {Idle Compute}; + \end{scope} + + % Panel B: Continuous Batching + \begin{scope}[shift={(10,0)}] + \node[anchor=west, crimson] at (0, 4.5) {\textbf{B. Continuous Batching}}; + % R1 ends at 2, R4 enters immediately + \draw[req] (0, 3) rectangle (2, 3.8) node[midway, white] {R1}; + \draw[req, fill=OrangeLine!60] (2.1, 3) rectangle (5, 3.8) node[midway, white] {R4}; + \draw[req, fill=OrangeLine!60] (5.1, 3) rectangle (8, 3.8) node[midway, white] {R5}; + + \draw[req] (0, 2) rectangle (5, 2.8) node[midway, white] {R2}; + \draw[req, fill=GreenLine!60] (5.1, 2) rectangle (9, 2.8) node[midway, white] {R6}; + + \draw[req] (0, 1) rectangle (3, 1.8) node[midway, white] {R3}; + \draw[req, fill=PurpleLine!60] (3.1, 1) rectangle (7, 1.8) node[midway, white] {R7}; + + \node at (4.5, -0.5) {Continuous Utilization}; + \end{scope} +\end{tikzpicture} +``` +::: + ### Continuous Batching Throughput Analysis {#sec-inference-scale-continuous-batching-throughput-analysis-ecd2} Continuous batching's dynamic batch management maintains high GPU utilization regardless of sequence length variance. The throughput improvement depends on sequence length distribution. For a distribution with coefficient of variation $CV = \sigma / \mu$, the gain is approximately @eq-continuous-batching-gain: diff --git a/book/quarto/contents/vol2/performance_engineering/performance_engineering.qmd b/book/quarto/contents/vol2/performance_engineering/performance_engineering.qmd index 437c525e9..c9135c394 100644 --- a/book/quarto/contents/vol2/performance_engineering/performance_engineering.qmd +++ b/book/quarto/contents/vol2/performance_engineering/performance_engineering.qmd @@ -17,9 +17,14 @@ engine: jupyter # │ Exports: (none) # └───────────────────────────────────────────────────────────────────────────── from mlsys.registry import start_chapter -from mlsys.constants import * +from mlsys.constants import ( + A100_MEM_BW, A100_FLOPS_FP16_TENSOR, + B200_FLOPS_FP8_TENSOR, B200_MEM_BW, B200_MEM_CAPACITY, + H100_MEM_BW, H100_FLOPS_FP16_TENSOR, H100_FLOPS_FP8_TENSOR, + ENERGY_DRAM_PJ_PER_BYTE, ENERGY_FLOP_FP16_PJ, + TFLOPs, second, GB, TB, byte, flop +) from mlsys.formatting import fmt, sci, check -from mlsys.formulas import calc_bottleneck, model_memory start_chapter("vol2:performance_engineering") ``` @@ -32,21 +37,19 @@ start_chapter("vol2:performance_engineering") ::: \noindent -![](images/png/cover_optimization.png){fig-alt="Performance engineering and optimization at scale." width=100%} +![](images/png/cover_performance.png){fig-alt="High-performance compute kernels and memory-aware optimization for ML inference." width=100%} ::: ## Purpose {.unnumbered} \begin{marginfigure} -\mlfleetstack{35}{45}{100}{10} +\mlfleetstack{30}{35}{100}{15} \end{marginfigure} -_How do we make billion-parameter models run on millisecond timescales?_ +_Why does an H100 GPU capable of nearly 2,000 teraflops often sit 95% idle during inference?_ -Model compression, covered in earlier chapters, reduces the size of what we compute. Performance engineering reshapes *how* we compute to match the physics of the hardware. The distinction matters: a quantized model loaded naively into a GPU kernel that reads every weight from off-chip memory wastes the very bandwidth savings that quantization was designed to provide. Real performance comes from understanding the full path a tensor travels, from registers through SRAM to HBM and back, and then engineering each step to eliminate wasted movement. This chapter develops the system-level optimization techniques that bridge the gap between a theoretically efficient model and a production artifact that saturates hardware. We examine operator fusion and tiling strategies that keep data in fast SRAM, precision formats that double effective bandwidth, compilation frameworks that automate kernel selection, and algorithmic innovations like **speculative decoding**[^fn-speculative-decoding-pe] and sparse expert routing that fundamentally change the performance equation. Together, these techniques transform a model that "should" be fast into one that *is* fast, often by an order of magnitude. - -[^fn-speculative-decoding-pe]: **Speculative Decoding**: A technique that uses a small, fast "draft" model to predict multiple tokens in parallel, which are then verified by the large "target" model in a single forward pass. This trades "wasted" compute (to process rejected tokens) for reduced memory bandwidth pressure: the target model's massive weights are loaded from HBM once to process $k$ tokens, effectively increasing the arithmetic intensity of the decode phase by $k\times$. \index{Speculative Decoding!bandwidth optimization} +Machine learning performance is a **negotiation between logic and physics**. Modern accelerators deliver exascale compute, but only if the software respects the physical constraints of the **Memory Wall** and the **Power Wall**. *When* a kernel is designed without awareness of the memory hierarchy, it spends more time moving data through HBM than performing arithmetic, collapsing system efficiency ($\eta$). The art of performance engineering is maximizing the **Arithmetic Intensity** of every operation—extracting the highest computational value from every byte moved. This chapter explores the techniques that bridge the gap between theoretical hardware peak ($R_{\text{peak}}$) and actual workload throughput: **Kernel Fusion** to eliminate redundant memory trips, **Quantization** to shrink the data footprint ($D_{\text{vol}}$), and **Compilation** to optimize the dataflow across the silicon. Without this mastery, the Machine Learning Fleet becomes a collection of expensive heaters, consuming megawatts of power to shuffle data while the arithmetic units wait. ::: {.content-visible when-format="pdf"} \newpage @@ -54,2296 +57,87 @@ Model compression, covered in earlier chapters, reduces the size of what we comp ::: {.callout-tip title="Learning Objectives"} -- Analyze the **Memory Wall** using the **Roofline Model** and diagnose whether a given ML workload is compute-bound or memory-bound on a specific accelerator. -- Explain how **Operator Fusion** and **Tiling** (FlashAttention) overcome HBM bandwidth limitations by keeping intermediate data in on-chip SRAM. -- Implement **FP8 Training** using E4M3 and E5M2 formats and apply **Block-wise Quantization** (LLM.int8(), GPTQ, AWQ) to handle outlier features in large language models. -- Apply **Graph Compilation** frameworks (torch.compile, XLA, TensorRT) to automate operator fusion, memory planning, and kernel selection. -- Evaluate **Speculative Decoding** strategies to reduce inference latency by trading compute for latency using small draft models. -- Design **Mixture of Experts (MoE)** systems that decouple model capacity from inference cost through sparse activation and expert parallelism. -- Diagnose performance bottlenecks using **System Profiling** tools (Nsight Systems, PyTorch Profiler) and roofline plots. - -::: - -::: {.callout-note title="Connection: The Fleet Stack"} - -Performance Engineering is the **Optimization Layer** of the Fleet Stack. While Inference at Scale (@sec-inference-scale) defines the serving architecture and scheduling policies, Performance Engineering optimizes the individual operations that execute within each serving node. In the **Fleet Stack** (@sec-vol2-introduction), this chapter sits between the Serving Layer (how work is scheduled) and the Infrastructure Layer (how hardware executes). Every technique here targets the same goal: closing the gap between theoretical hardware peak and achieved throughput, turning the **Iron Law** from a speed limit into a speedometer that reads closer to maximum. +- Analyze memory-bound versus compute-bound workloads using the **Roofline Model** and identify the hardware ridge point +- Implement **Kernel Fusion** to minimize global memory bandwidth consumption by keeping intermediate activations in fast SRAM +- Apply **Mixed-Precision Optimization** (FP16, BF16, FP8) to increase arithmetic throughput while reducing memory pressure +- Evaluate **I/O-Aware Algorithms** (e.g., FlashAttention) that use tiling to bypass memory bandwidth bottlenecks +- Compare **Interpreter** versus **Compiler** performance (Eager vs. Graph mode) in terms of dispatch overhead and graph-level optimization +- Design **Performance-Critical Kernels** that honor the "Silicon Contract" by maximizing Tensor Core utilization ::: ```{python} -#| echo: false #| label: perf-eng-setup +#| echo: false # ┌───────────────────────────────────────────────────────────────────────────── -# │ PERFORMANCE ENGINEERING SYSTEM CONSTANTS +# │ PERFORMANCE ENGINEERING: HARDWARE BOUNDS # ├───────────────────────────────────────────────────────────────────────────── -# │ Context: Chapter-wide setup for @sec-performance-engineering +# │ Context: @sec-performance-engineering-memory-wall and the Roofline analysis +# │ sections throughout the chapter. # │ -# │ Goal: Establish H100/A100/B200 hardware specs as shared reference for roofline, -# │ bottleneck analysis, and SRAM-vs-HBM energy ratio throughout the chapter. -# │ Show: H100 ~3.35 TB/s BW, ~989 TFLOPs FP16, ~1979 TFLOPs FP8; ridge points -# │ ~295 FP16 and ~590 FP8 FLOP/byte; energy ratio ~1280x DRAM/SRAM. -# │ How: Pull from mlsys.constants; compute ridge points as FLOPS/BW ratio -# │ (TFLOPs/TB → FLOP/byte via .m_as(flop/byte)). +# │ Goal: Establish the memory wall by comparing H100/B200 peak compute vs +# │ memory bandwidth, and quantify the energy ratio between movement +# │ and calculation (DRAM vs FLOP). +# │ Show: "3.35" TB/s H100 bandwidth, "8" TB/s B200 bandwidth, "~150x" +# │ energy gap between DRAM access and FP16 math — inline in the +# │ memory wall and arithmetic intensity paragraphs. +# │ How: .m_as() for unit scaling; ratio = DRAM_energy / FLOP_energy. # │ # │ Imports: mlsys.constants (H100_MEM_BW, H100_FLOPS_FP16_TENSOR, -# │ H100_FLOPS_FP8_TENSOR, H100_MEM_CAPACITY, H100_TDP, -# │ A100_MEM_BW, A100_FLOPS_FP16_TENSOR, A100_MEM_CAPACITY, # │ B200_FLOPS_FP8_TENSOR, B200_MEM_BW, B200_MEM_CAPACITY, -# │ ENERGY_DRAM_ACCESS_PJ, ENERGY_SRAM_L1_PJ) -# │ Exports: h100_hbm_bw_str, h100_hbm_bw_gbs_str, h100_fp16_str, h100_fp8_str, -# │ h100_mem_str, h100_tdp_str, h100_ridge_fp16_str, h100_ridge_fp8_str, -# │ a100_hbm_bw_str, a100_hbm_bw_gbs_str, a100_fp16_str, a100_ridge_fp16_str, +# │ ENERGY_DRAM_PJ_PER_BYTE, ENERGY_FLOP_FP16_PJ, +# │ TFLOPs, TB, GB, byte, flop) +# │ Exports: a100_hbm_bw_str, a100_hbm_bw_gbs_str, a100_fp16_str, a100_ridge_fp16_str, # │ b200_fp8_str, b200_hbm_bw_str, b200_mem_str, energy_ratio_str # └───────────────────────────────────────────────────────────────────────────── +from mlsys.constants import ( + A100_MEM_BW, A100_FLOPS_FP16_TENSOR, + B200_FLOPS_FP8_TENSOR, B200_MEM_BW, B200_MEM_CAPACITY, + H100_MEM_BW, H100_FLOPS_FP16_TENSOR, H100_FLOPS_FP8_TENSOR, + ENERGY_DRAM_PJ_PER_BYTE, ENERGY_FLOP_FP16_PJ, + TFLOPs, second, GB, TB, byte, flop +) +from mlsys.formatting import fmt, sci, check -## The Memory Wall and the Efficiency Frontier {#sec-performance-engineering-memory-wall}``` +# ┌── P.I.C.O. ISOLATED SCENARIO ─────────────────────────────────────────────── +class PerfEngSetup: + """Namespace for Performance Engineering reference bounds.""" + + # ┌── 1. PARAMETERS (Inputs) ─────────────────────────────────────────────── + h100_bw = H100_MEM_BW + h100_flops = H100_FLOPS_FP16_TENSOR + + b200_bw = B200_MEM_BW + b200_flops_fp8 = B200_FLOPS_FP8_TENSOR + b200_cap = B200_MEM_CAPACITY + + energy_byte = ENERGY_DRAM_PJ_PER_BYTE + energy_flop = ENERGY_FLOP_FP16_PJ + + # ┌── 2. CALCULATION (The Physics) ───────────────────────────────────────── + h100_ridge = (h100_flops / h100_bw).to(flop/byte).magnitude + energy_ratio = energy_byte.m_as(byte**-1) / energy_flop.m_as(flop**-1) + + # ┌── 3. INVARIANTS (Guardrails) ─────────────────────────────────────────── + check(h100_ridge > 200, f"Expected H100 ridge > 200, got {h100_ridge:.1f}") + + # ┌── 4. OUTPUTS (Formatting) ────────────────────────────────────────────── + h100_hbm_bw_str = f"{h100_bw.m_as(TB/second):.2f}" + h100_ridge_str = f"{h100_ridge:.0f}" + + b200_fp8_str = f"{b200_flops_fp8.m_as(TFLOPs/second):,.0f}" + b200_hbm_bw_str = f"{b200_bw.m_as(TB/second):.1f}" + b200_mem_str = f"{b200_cap.m_as(GB):.0f}" + + energy_ratio_str = f"{energy_ratio:.0f}" + +# ┌── EXPORTS (Bridge to Text) ───────────────────────────────────────────────── +h100_hbm_bw_str = PerfEngSetup.h100_hbm_bw_str +h100_ridge_str = PerfEngSetup.h100_ridge_str +b200_fp8_str = PerfEngSetup.b200_fp8_str +b200_hbm_bw_str = PerfEngSetup.b200_hbm_bw_str +b200_mem_str = PerfEngSetup.b200_mem_str +energy_ratio_str = PerfEngSetup.energy_ratio_str +``` ## The Memory Wall and the Efficiency Frontier {#sec-performance-engineering-memory-wall} - -Why does an H100 GPU capable of 989 teraFLOPS often sit 95% idle while generating text from a large language model? The processor is starving for data. Performance engineering operates within a constrained optimization space defined by the Memory Wall, where the speed of moving bytes from memory to compute units fundamentally caps our operational throughput. - -Part II established the distributed logic of the fleet: parallelism strategies (@sec-distributed-training-systems), communication patterns (@sec-collective-communication), fault recovery (@sec-fault-tolerance-reliability), and resource orchestration (@sec-fleet-orchestration). Those chapters ensured that workloads reach the right hardware and survive failures along the way. This chapter ensures that each workload *uses* that hardware efficiently, extracting maximum throughput from every accelerator cycle. - -### The Iron Law of ML Performance {#sec-performance-engineering-efficiency-frontier} - -Performance engineering operates within a constrained optimization space defined by the **Iron Law** of ML system performance: - -$$ -\text{Time} = \max\left( \frac{\text{Compute}}{\text{FLOPS}}, \; \frac{\text{Memory Access}}{\text{Bandwidth}} \right) + \text{Overhead} -$$ {#eq-iron-law-perf} - -This is the performance engineer's most important equation. It appears simple, but its implications are profound because the $\max$ operator means that only one term matters at a time. Optimizing the wrong term yields zero improvement. An engineer who spends a week optimizing compute throughput for a memory-bound workload has wasted that week entirely. The equation demands diagnosis before optimization. - -This equation decomposes execution time into three terms. The first fraction represents compute time: the total floating-point operations divided by the hardware's peak throughput. The second fraction represents memory time: the total bytes transferred divided by the memory bandwidth. The $\max$ operator reflects the roofline principle: the slower of the two determines performance. The overhead term captures everything else: kernel launch latency, synchronization, communication, and software stack inefficiency. - -Standard model compression (pruning, quantization, distillation) reduces the numerators, performing fewer operations on smaller data. System optimization, the focus of this chapter, attacks the *structure* of the equation itself: - -Operator fusion and tiling (FlashAttention, fused kernels) reduce the Memory Access numerator by eliminating intermediate HBM round-trips. When a sequence of operations keeps its data in SRAM, the effective Memory Access term shrinks dramatically, often by 10--30$\times$ for attention computation. - -Precision engineering (FP8, INT4, KV cache compression) reduces the Memory Access numerator by representing each value in fewer bytes. Halving the precision halves the bytes transferred, doubling the effective bandwidth. - -Graph compilation (torch.compile, XLA, TensorRT) reduces the Overhead term by eliminating kernel launch gaps, fusing operations, and optimizing memory allocation. - -Communication-computation overlap transforms the equation for distributed systems by making the communication overhead concurrent with the compute term, effectively eliminating it from the critical path when the overlap condition (@eq-overlap-condition) holds. - -Algorithmic innovations (speculative decoding, MoE) change the Compute numerator itself by performing a fundamentally different, less expensive computation that produces equivalent results. - -Each technique attacks a different term, and this taxonomy guides optimization strategy: diagnose which term dominates (using the roofline model from @sec-performance-engineering-roofline), then apply the technique targeting that term. Applying a technique that targets the non-dominant term wastes engineering effort. - -@fig-iron-law-flowchart codifies this diagnostic process as a decision flowchart, mapping each bottleneck to its corresponding optimization technique. - -::: {#fig-iron-law-flowchart fig-env="figure" fig-pos="htb" fig-cap="**Iron Law Diagnostic Flowchart**. The optimization process begins with profiling to determine which term in @eq-iron-law-perf dominates. If the workload is compute-bound, precision engineering or algorithmic changes reduce the numerator. If memory-bound, operator fusion and tiling reduce HBM traffic. If overhead-bound, graph compilation and communication overlap attack the residual term. Applying the wrong technique yields zero improvement." fig-alt="Flowchart starting with Profile Workload, branching to Compute Bound, Memory Bound, or Overhead Bound, each leading to specific optimization techniques."} -```{.tikz} -\begin{tikzpicture}[font=\small\usefont{T1}{phv}{m}{n}, scale=0.85, transform shape] - \definecolor{BlueLine}{HTML}{006395} - \definecolor{BlueL}{HTML}{D1E6F3} - \definecolor{GreenLine}{HTML}{008F45} - \definecolor{GreenL}{HTML}{D4EFDF} - \definecolor{OrangeLine}{HTML}{E67817} - \definecolor{OrangeL}{HTML}{FCE4CC} - \definecolor{RedLine}{HTML}{CB202D} - \definecolor{RedL}{HTML}{F5D2D5} - \definecolor{VioletLine}{HTML}{7E317B} - \definecolor{VioletL}{HTML}{E6D4E5} - - \tikzset{ - start/.style={draw=black!70, fill=black!10, rounded corners=3pt, thick, - minimum width=3cm, minimum height=0.8cm, align=center, font=\small\bfseries}, - diag/.style={diamond, draw=#1, fill=#1!12, thick, aspect=2.5, - inner sep=1pt, align=center, font=\scriptsize}, - action/.style={draw=#1, fill=#1!15, rounded corners=3pt, thick, - minimum width=2.6cm, minimum height=0.7cm, align=center, font=\scriptsize}, - arrow/.style={-{Triangle[width=5pt,length=4pt]}, thick, black!50} - } - - % Start - \node[start] (profile) at (5, 0) {Profile Workload\\(Roofline Analysis)}; - - % Three branches - \node[diag=BlueLine] (cb) at (0.5, -2.2) {Compute\\Bound?}; - \node[diag=OrangeLine] (mb) at (5, -2.2) {Memory\\Bound?}; - \node[diag=VioletLine] (ob) at (9.5, -2.2) {Overhead\\Bound?}; - - \draw[arrow] (profile) -- (cb); - \draw[arrow] (profile) -- (mb); - \draw[arrow] (profile) -- (ob); - - % Compute-bound actions - \node[action=BlueLine] (a1) at (-0.8, -4.2) {Precision Eng.\\(FP8, INT4)}; - \node[action=BlueLine] (a2) at (2.0, -4.2) {Algorithmic\\(MoE, Speculative)}; - \draw[arrow] (cb) -- (a1); - \draw[arrow] (cb) -- (a2); - - % Memory-bound actions - \node[action=OrangeLine] (a3) at (3.8, -4.2) {Operator Fusion\\(FlashAttention)}; - \node[action=OrangeLine] (a4) at (6.4, -4.2) {Tiling \&\\KV Cache Opt.}; - \draw[arrow] (mb) -- (a3); - \draw[arrow] (mb) -- (a4); - - % Overhead-bound actions - \node[action=VioletLine] (a5) at (8.2, -4.2) {Graph Compile\\(torch.compile)}; - \node[action=VioletLine] (a6) at (10.8, -4.2) {Comm. Overlap\\(CUDA Graphs)}; - \draw[arrow] (ob) -- (a5); - \draw[arrow] (ob) -- (a6); - - % Re-profile loop - \node[start, minimum width=2.2cm, font=\scriptsize\bfseries] (re) at (5, -5.8) {Re-profile}; - \draw[arrow] (a1.south) |- (re); - \draw[arrow] (a2.south) |- (re); - \draw[arrow] (a3.south) |- (re); - \draw[arrow] (a4.south) |- (re); - \draw[arrow] (a5.south) |- (re); - \draw[arrow] (a6.south) |- (re); - - % Loop back - \draw[arrow, dashed] (re.east) -- ++(2.5,0) |- (profile.east); - \node[font=\tiny\itshape, text=black!50] at (9.5, -3.0) {iterate until}; - \node[font=\tiny\itshape, text=black!50] at (9.5, -3.3) {target met}; - -\end{tikzpicture} -``` -::: - -The **efficiency frontier** is the Pareto-optimal curve of model quality versus system throughput. A model on the frontier cannot improve throughput without sacrificing quality, or vice versa. The techniques in this chapter push the frontier outward by making each quality level achievable at higher throughput, or equivalently, by making each throughput level achievable at higher quality. An organization's goal is not merely to reach the frontier but to find the point on it that best matches their latency, throughput, cost, and quality requirements. - -The multi-dimensional nature of this frontier makes optimization challenging. The relevant dimensions include: - -- **Throughput** (tokens/second or requests/second): How much work the system completes per unit time. -- **Latency** (time-to-first-token, inter-token latency): How quickly the system responds to individual requests. -- **Cost** (dollars per million tokens): The economic efficiency of the system. -- **Quality** (perplexity, benchmark accuracy, human preference): The accuracy and usefulness of model outputs. -- **Memory** (peak GPU memory): The resource constraint that limits batch size and sequence length. - -These dimensions interact in non-obvious ways. Increasing batch size improves throughput and cost efficiency but degrades latency. Reducing precision improves throughput and memory but may degrade quality. Speculative decoding improves latency but may increase per-token cost. The performance engineer's task is to navigate these trade-offs guided by the application's specific requirements. - -A real-time chatbot prioritizes latency (time-to-first-token under 200 ms, inter-token latency under 50 ms) and may tolerate higher per-token cost. A batch processing pipeline for document summarization prioritizes throughput and cost, tolerating seconds of latency. A medical diagnostic system prioritizes quality above all else, accepting lower throughput and higher cost. Each application maps to a different optimal point on the efficiency frontier, and the techniques in this chapter provide the tools to reach that point. - -To make this concrete, consider two deployment configurations for the same 70B LLM: - -Configuration A is latency-optimized: FP16 weights, batch size 1, speculative decoding enabled. Each H100 GPU serves approximately 80 tokens/second with 20 ms inter-token latency. Cost: 8 GPUs dedicated to a single user stream, approximately \$0.12 per 1,000 output tokens. - -Configuration B is throughput-optimized: INT4 weights (AWQ), batch size 64, no speculation. Each H100 serves approximately 4,000 tokens/second aggregate throughput across all batched requests, with 120 ms inter-token latency per request. Cost: 4 GPUs serving 64 concurrent users, approximately \$0.002 per 1,000 output tokens. - -Configuration B achieves 60$\times$ lower cost per token than Configuration A, but at 6$\times$ higher latency. Neither configuration is objectively "better"; they represent different points on the efficiency frontier, optimized for different applications. The performance engineering techniques in this chapter are the tools for navigating between these points. - -### The Memory Wall {#sec-performance-engineering-memory-wall-physics} - -The efficiency frontier establishes what we are optimizing *toward*. The physics of memory bandwidth determines where we start. Every performance engineering problem in modern ML begins with the same observation: memory bandwidth, not compute, is the bottleneck. Consider a single autoregressive decoding step in a large language model. The model reads its full weight matrix from High Bandwidth Memory (HBM)[^fn-hbm-perf] to generate a single token, performing only one or two multiply-accumulate operations per weight loaded. - -[^fn-hbm-perf]: **HBM (High Bandwidth Memory)**: Achieves its bandwidth by vertically stacking DRAM dies connected through thousands of through-silicon vias (TSVs), a 3D packaging technique first commercialized by SK Hynix in 2013. Despite the "high bandwidth" label, HBM's 3.35 TB/s on the H100 is still 200--600$\times$ slower than on-chip SRAM access, making the memory hierarchy gap the central constraint of every optimization technique in this chapter. \index{HBM!memory hierarchy} - -```{python} -#| label: memory-wall-scenario -#| echo: false -# ┌───────────────────────────────────────────────────────────────────────────── -# │ MEMORY WALL HARDWARE SPECS (LEGO) -# ├───────────────────────────────────────────────────────────────────────────── -# │ Context: @sec-performance-engineering-memory-wall-physics -# │ -# │ Goal: Provide H100 hardware specs for memory wall discussion. -# │ Show: ~1,979 TFLOPS FP8; ~3.35 TB/s BW; ~1280x energy ratio. -# │ How: pulling constants from mlsys.constants. -# │ -# │ Imports: mlsys.constants (H100_FLOPS_FP8_TENSOR, H100_MEM_BW, -# │ H100_MEM_CAPACITY, ENERGY_DRAM_ACCESS_PJ, ENERGY_SRAM_L1_PJ, -# │ TFLOPs, TB, GB, second, flop, byte) -# │ Exports: h100_fp8_str, h100_hbm_bw_str, h100_ridge_fp8_str, -# │ energy_ratio_str, h100_mem_str -# └───────────────────────────────────────────────────────────────────────────── -from mlsys.constants import ( - H100_FLOPS_FP8_TENSOR, H100_MEM_BW, H100_MEM_CAPACITY, - ENERGY_DRAM_ACCESS_PJ, ENERGY_SRAM_L1_PJ, - TFLOPs, TB, GB, second, flop, byte -) - -class MemoryWallScenario: - """H100 specs for memory wall discussion.""" - - # ┌── 1. LOAD (Constants) ────────────────────────────────────────────── - flops_fp8 = H100_FLOPS_FP8_TENSOR - bw = H100_MEM_BW - mem = H100_MEM_CAPACITY - energy_dram = ENERGY_DRAM_ACCESS_PJ - energy_sram = ENERGY_SRAM_L1_PJ - - # ┌── 2. EXECUTE (The Compute) ──────────────────────────────────────── - ridge_fp8 = (flops_fp8 / bw).m_as(flop/byte) - energy_ratio = energy_dram / energy_sram - - # ┌── 4. OUTPUT (Formatting) ────────────────────────────────────────────── - h100_fp8_str = f"{flops_fp8.m_as(TFLOPs/second):.0f}" - h100_hbm_bw_str = f"{bw.m_as(TB/second):.2f}" - h100_ridge_fp8_str = f"{ridge_fp8:.0f}" - energy_ratio_str = f"{energy_ratio.m_as(''):.0f}" - h100_mem_str = f"{mem.m_as(GB):.0f}" - -# ┌── EXPORTS (Bridge to Text) ───────────────────────────────────────────────── -h100_fp8_str = MemoryWallScenario.h100_fp8_str -h100_hbm_bw_str = MemoryWallScenario.h100_hbm_bw_str -h100_ridge_fp8_str = MemoryWallScenario.h100_ridge_fp8_str -energy_ratio_str = MemoryWallScenario.energy_ratio_str -h100_mem_str = MemoryWallScenario.h100_mem_str -``` - -An NVIDIA H100 delivers `{python} h100_fp8_str` TFLOPS of FP8 compute but only `{python} h100_hbm_bw_str` TB/s of memory bandwidth. If every byte loaded from memory does fewer than `{python} h100_ridge_fp8_str` arithmetic operations, the compute units sit idle, starved for data. This gap between compute capability and memory delivery rate is the **Memory Wall**, and it defines the landscape within which all performance engineering operates. - -The memory wall is not a temporary engineering limitation; it is a consequence of physics. Moving data costs energy proportional to distance. Accessing a value from on-chip SRAM (L1 cache) costs approximately 0.5 pJ, while fetching the same value from off-chip HBM costs roughly 640 pJ, a ratio of `{python} energy_ratio_str`$\times$. Manufacturing constraints limit the amount of SRAM that can sit close to the compute units. HBM provides capacity (the H100 offers `{python} h100_mem_str` GB) but at physically greater distance, requiring the data to traverse longer wires. The fundamental tension is that models need gigabytes of parameters and state, but physics dictates that only kilobytes of data can be near the compute units at any given moment. - -This tension shapes every optimization technique in this chapter. Operator fusion reduces the number of trips to HBM by combining operations so that intermediate results stay in SRAM. Precision engineering reduces the number of bytes per trip by representing values in FP8 or INT4 instead of FP16. Tiling strategies restructure algorithms to maximize data reuse within SRAM. Graph compilers automate these transformations. Each technique attacks a different term in the same fundamental equation: minimize the ratio of bytes moved to operations performed. - -### The GPU Memory Hierarchy {#sec-performance-engineering-memory-hierarchy} - -To understand why the memory wall exists, consider the physical structure of a modern GPU's memory system. The hierarchy spans four levels, each trading capacity for bandwidth and latency. - -Registers are the fastest storage, located directly within each streaming multiprocessor (SM). The H100 provides 256 KB of register file per SM across its 132 SMs, totaling approximately 33 MB of register space across the entire chip. Register access is essentially free in terms of latency (one clock cycle) and energy (~0.01 pJ per access). Registers are, however, private to individual threads and cannot be shared. - -Shared memory (SRAM) occupies the next level, pooled within each SM. The H100 provides up to 228 KB of configurable shared memory per SM. This memory is shared among all threads in a thread block, enabling cooperative data reuse. Access latency is approximately 20--30 clock cycles (~20 ns), and energy cost is roughly 0.5 pJ per access. Shared memory is the critical resource for operator fusion: if intermediate results fit in shared memory, they never need to traverse the slow HBM bus. - -The L2 cache sits between the SMs and HBM, providing a 50 MB on-chip buffer on the H100. It captures reuse patterns automatically (when the same data is accessed by multiple SMs) but cannot be explicitly managed by kernel authors. Access latency is approximately 200 clock cycles (~130 ns). The L2 cache is particularly important for multi-head attention, where multiple attention heads may access the same KV cache entries. If the KV cache for a given sequence position fits in L2, subsequent heads accessing the same position benefit from cache hits rather than paying the full HBM access cost. - -High Bandwidth Memory (HBM) is the main off-chip memory, providing `{python} h100_mem_str` GB of capacity at `{python} h100_hbm_bw_str` TB/s bandwidth. HBM access latency is approximately 300 ns, and each access costs roughly 640 pJ of energy. Despite the "high bandwidth" designation, the bandwidth-to-capacity ratio means that reading the full `{python} h100_mem_str` GB of HBM takes approximately 24 ms, far longer than the sub-millisecond latency targets of real-time inference. - -The energy cost of data movement has a direct economic consequence at datacenter scale. Consider a training cluster of 1,000 H100 GPUs, each performing approximately $10^{12}$ memory accesses per second during a memory-bound workload. If each access reads from HBM at 640 pJ, the memory subsystem alone consumes approximately 640 W per GPU, a significant fraction of the H100's 700 W TDP. If operator fusion moves half of those accesses from HBM to SRAM (at 0.5 pJ each), the per-GPU memory power drops by approximately 320 W. Across 1,000 GPUs, this saves 320 kW, equivalent to powering roughly 250 homes. This is not a secondary consideration; at cloud electricity prices, the annual cost difference is substantial, and it scales linearly with cluster size. The physics of data movement is not merely a performance constraint; it is an economic one. - -The performance engineering challenge reduces to a data placement problem: keep the data that the compute units need in the fastest memory that can hold it. When a kernel reads a tensor from HBM, processes it, and writes the result back to HBM, the HBM round-trip dominates execution time for any operation with low arithmetic intensity. The techniques in this chapter all share the goal of keeping data closer to compute for longer. - -### The Widening Gap {#sec-performance-engineering-widening-gap} - -The memory wall is not static; it grows wider with each hardware generation. Compute throughput has scaled exponentially, roughly doubling every two years with new GPU architectures. Memory bandwidth has improved more slowly, constrained by the physics of off-chip signaling and the economics of HBM manufacturing. - -| **GPU** | **Year** | **Peak FP16 (TFLOPS)** | **HBM BW (TB/s)** | **Ridge Point (FLOP/byte)** | -|:---------|:---------|:-----------------------|:------------------|:----------------------------| -| **V100** | 2017 | 125 | 0.9 | 139 | -| **A100** | 2020 | 312 | 2.0 | 156 | -| **H100** | 2022 | 989 | 3.35 | 295 | -| **B200** | 2024 | 4,500 | 8.0 | 563 | - -: **The Widening Memory Wall**. Compute throughput has increased 36$\times$ from V100 to B200 over seven years, while memory bandwidth has increased only 8.9$\times$. The ridge point has increased 4$\times$, meaning more workloads fall into the memory-bound regime with each generation. {#tbl-widening-gap} - -@tbl-widening-gap quantifies this trend. The ridge point increased from 139 FLOP/byte on the V100 to 563 FLOP/byte on the B200. An operation with arithmetic intensity of 200 FLOP/byte was compute-bound on the V100 and A100, memory-bound on the H100, and deeply memory-bound on the B200. This means that performance engineering techniques targeting memory efficiency, fusion, precision, and tiling, become *more* important with each hardware generation, not less. The engineering effort invested in FlashAttention and INT4 quantization today will yield even greater returns on future hardware. - -### The Roofline Model {#sec-performance-engineering-roofline} - -The **Roofline Model**[^fn-roofline-origin] provides a quantitative framework for diagnosing whether a workload is compute-bound or memory-bound on a specific piece of hardware. Introduced by Williams, Waterman, and Patterson (2009), the model plots achievable performance as a function of **arithmetic intensity**, defined as the ratio of floating-point operations to bytes transferred from memory. The intersection of these two regimes is known as the **ridge point**[^fn-roofline-ridge]. - -[^fn-roofline-ridge]: **Ridge Point**: The intersection of the memory-bound and compute-bound lines on a Roofline plot. It represents the minimum **Arithmetic Intensity** required to reach peak hardware performance ($R_{\text{peak}}$). For an H100 GPU (FP16), the ridge point is roughly 295 FLOP/byte; if an operator's intensity is below this "ridge," it will never saturate the Tensor Cores, regardless of how much compute is available. \index{Roofline Model!ridge point} - The intersection of these two regimes is known as the **ridge point**[^fn-roofline-ridge]. - - -[^fn-roofline-origin]: **Roofline Model**: Published as "Roofline: An Insightful Visual Performance Model for Multicore Architectures" in *Communications of the ACM* (April 2009). Originally designed for CPU multicore workloads, the model's power lies in its hardware-agnostic simplicity: two numbers (peak FLOPS and peak bandwidth) define the entire performance envelope. This same simplicity now drives GPU purchasing decisions for ML inference, where the ridge point determines whether a workload benefits from faster compute or faster memory. \index{Roofline Model!origin} - -::: {.callout-definition title="Arithmetic Intensity"} - -***Arithmetic Intensity ($I$)***\index{Arithmetic Intensity!definition} is the ratio of floating-point operations performed to the number of bytes transferred from memory ($FLOP/\text{byte}$). - -1. **Significance (Quantitative):** It characterizes the **Computational Density** of a workload. It is the independent variable in the **Roofline Model**, determining whether a system operates in the **Bandwidth-Bound** ($BW$) or **Compute-Bound** ($R_{\text{peak}}$) regime. -2. **Distinction (Durable):** Unlike **Peak Throughput** (a hardware property), Arithmetic Intensity is an **Algorithmic Property** that measures how effectively a workload reuses data once it is loaded into the processor. -3. **Common Pitfall:** A frequent misconception is that AI is fixed for a model. In reality, it varies by **Implementation**: techniques like operator fusion increase AI by keeping data in local registers, while increasing batch size increases AI for layers with high parameter reuse. - -::: - -For a given accelerator with peak compute $P$ (in FLOPS) and peak memory bandwidth $B$ (in bytes/second), the achievable performance of a workload with arithmetic intensity $I$ (in FLOP/byte) is: - -$$ -\text{Achievable FLOPS} = \min(P, \; B \times I) -$$ {#eq-roofline} - -The transition point where these two limits intersect is the **ridge point**: - -$$ -I_{\text{ridge}} = \frac{P}{B} -$$ {#eq-ridge-point} - -Workloads with $I < I_{\text{ridge}}$ are memory-bound: their performance is limited by how fast data can be loaded, not how fast it can be processed. Workloads with $I > I_{\text{ridge}}$ are compute-bound: the arithmetic units are the bottleneck. @fig-roofline-model illustrates this relationship graphically. - -::: {.callout-note title="Figure: The Roofline Model" collapse="false"} - -```{.tikz} -%| fig-cap: "**The Roofline Model**. Achievable performance (y-axis) as a function of arithmetic intensity (x-axis) on a log-log plot. The sloped line represents the memory bandwidth ceiling; the flat line represents the compute ceiling. Their intersection is the ridge point. Most transformer inference operations fall in the memory-bound region (left of the ridge point), while large batched GEMMs fall in the compute-bound region (right)." -%| fig-alt: "Log-log plot showing roofline model with memory bandwidth ceiling as diagonal line and compute ceiling as horizontal line, meeting at ridge point. Workload types are marked: LLM decode and element-wise ops on the left (memory-bound), large GEMM on the right (compute-bound)." -%| label: fig-roofline-model - -\begin{tikzpicture}[>=stealth, scale=1.0] - % Axes - \draw[thick, ->] (0,0) -- (10,0) node[below] {Arithmetic Intensity (FLOP/byte)}; - \draw[thick, ->] (0,0) -- (0,7) node[above, rotate=90, anchor=south] {Achievable TFLOPS}; - - % Axis labels (log scale markers) - \node[below] at (1,0) {\small 1}; - \node[below] at (3,0) {\small 10}; - \node[below] at (5,0) {\small 100}; - \node[below] at (7,0) {\small 1000}; - - \node[left] at (0,1) {\small 1}; - \node[left] at (0,3) {\small 100}; - \node[left] at (0,5) {\small 989}; - \node[left] at (0,6) {\small 1979}; - - % Memory bandwidth ceiling (slope = bandwidth) - \draw[blue, very thick] (0.5,0.5) -- (5.5,5.5); - - % Compute ceiling FP16 - \draw[red, very thick] (5.5,5.5) -- (9.5,5.5); - - % Compute ceiling FP8 (higher) - \draw[red!50, thick, dashed] (4.8,6.2) -- (9.5,6.2); - \draw[blue!50, thick, dashed] (0.5,0.9) -- (4.8,6.2); - - % Ridge point - \fill[black] (5.5,5.5) circle (3pt); - \node[above right] at (5.5,5.5) {\small Ridge Point}; - \node[below right, font=\scriptsize] at (5.5,5.2) {$\sim$295 FLOP/byte}; - - % FP8 ridge point - \fill[black!50] (4.8,6.2) circle (2pt); - \node[above, font=\scriptsize] at (4.8,6.4) {FP8 Ridge}; - - % Regions - \node[blue, font=\small, rotate=0] at (2.5,1.5) {Memory-Bound}; - \node[red, font=\small] at (7.5,4.8) {Compute-Bound}; - - % Workload markers - \fill[orange] (1.2,1.2) circle (4pt); - \node[right, font=\scriptsize, orange] at (1.4,1.0) {LLM Decode (B=1)}; - - \fill[orange] (1.8,1.8) circle (4pt); - \node[right, font=\scriptsize, orange] at (2.0,1.6) {Element-wise}; - - \fill[green!60!black] (7.5,5.5) circle (4pt); - \node[below, font=\scriptsize, green!60!black] at (7.5,5.3) {Large GEMM}; - - \fill[purple] (3.8,3.8) circle (4pt); - \node[right, font=\scriptsize, purple] at (4.0,3.6) {Attention}; - - % Legend - \draw[red, very thick] (0.5,6.8) -- (1.2,6.8); - \node[right, font=\scriptsize] at (1.3,6.8) {FP16 Ceiling (989 TFLOPS)}; - \draw[red!50, thick, dashed] (0.5,6.4) -- (1.2,6.4); - \node[right, font=\scriptsize] at (1.3,6.4) {FP8 Ceiling (1979 TFLOPS)}; - \draw[blue, very thick] (0.5,6.0) -- (1.2,6.0); - \node[right, font=\scriptsize] at (1.3,6.0) {HBM BW (3.35 TB/s)}; -\end{tikzpicture} -``` - -::: - -The ridge point of the NVIDIA H100 at FP16 precision is: - -$$ -I_{\text{ridge}}^{\text{H100, FP16}} = \frac{989 \text{ TFLOPS}}{3.35 \text{ TB/s}} \approx 295 \text{ FLOP/byte} -$$ - -```{python} -#| echo: false -#| label: roofline-ridge-calc -# ┌───────────────────────────────────────────────────────────────────────────── -# │ ROOFLINE RIDGE POINT COMPARISON -# ├───────────────────────────────────────────────────────────────────────────── -# │ Context: @sec-performance-engineering-roofline — ridge point prose comparison -# │ -# │ Goal: Verify and expose ridge-point values (P/B = FLOPS/BW in FLOP/byte) for -# │ A100 FP16 (~153), H100 FP16 (~295), and H100 FP8 (~591) to explain why -# │ each new generation makes more workloads memory-bound. -# │ Show: ~153 FLOP/byte (A100), ~295 (H100 FP16), ~591 (H100 FP8) — inline prose. -# │ How: Extract pre-computed ridge quantities from PerformanceSetup; convert to -# │ FLOP/byte using .m_as(flop/byte); guard ordering with check(). -# │ -# │ Imports: (none — values from perf-eng-setup cell), mlsys.formatting (check) -# │ Exports: a100_ridge_str, h100_fp16_ridge_str, h100_fp8_ridge_str -# └───────────────────────────────────────────────────────────────────────────── - -# ┌── LEGO ─────────────────────────────────────────────── -class RooflineRidgeCalc: - """Ridge point comparison across A100/H100 FP16/FP8.""" - - # ┌── 1. LOAD (Constants) ────────────────────────────────────────────── - # Already computed in setup cell - - # ┌── 2. EXECUTE (The Compute) ──────────────────────────────────────── - from mlsys.constants import ( - A100_FLOPS_FP16_TENSOR, A100_MEM_BW, - H100_FLOPS_FP16_TENSOR, H100_FLOPS_FP8_TENSOR, H100_MEM_BW, - TFLOPs, TB, second, flop, byte - ) - a100_f16 = A100_FLOPS_FP16_TENSOR - a100_bw = A100_MEM_BW - h100_f16 = H100_FLOPS_FP16_TENSOR - h100_f8 = H100_FLOPS_FP8_TENSOR - h100_bw = H100_MEM_BW - - a100_ridge_val = (a100_f16 / a100_bw).m_as(flop/byte) - h100_ridge_fp16_val = (h100_f16 / h100_bw).m_as(flop/byte) - h100_ridge_fp8_val = (h100_f8 / h100_bw).m_as(flop/byte) - - # ┌── 3. GUARD (Invariants) ────────────────────────────────────────── - check(h100_ridge_fp16_val > a100_ridge_val, - "H100 FP16 ridge should exceed A100 FP16 ridge") - check(h100_ridge_fp8_val > h100_ridge_fp16_val, - "H100 FP8 ridge should exceed H100 FP16 ridge") - - # ┌── 4. OUTPUT (Formatting) ───────────────────────────────────────────── - a100_ridge_str = f"{a100_ridge_val:.0f}" - h100_fp16_ridge_str = f"{h100_ridge_fp16_val:.0f}" - h100_fp8_ridge_str = f"{h100_ridge_fp8_val:.0f}" - - a100_fp16_str = f"{a100_f16.m_as(TFLOPs/second):.0f}" - a100_hbm_bw_str = f"{a100_bw.m_as(TB/second):.0f}" - h100_fp8_str = f"{h100_f8.m_as(TFLOPs/second):.0f}" - h100_hbm_bw_str = f"{h100_bw.m_as(TB/second):.2f}" - -# ┌── EXPORTS (Bridge to Text) ───────────────────────────────────────────────── -a100_ridge_str = RooflineRidgeCalc.a100_ridge_str -h100_fp16_ridge_str = RooflineRidgeCalc.h100_fp16_ridge_str -h100_fp8_ridge_str = RooflineRidgeCalc.h100_fp8_ridge_str -a100_fp16_str = RooflineRidgeCalc.a100_fp16_str -a100_hbm_bw_str = RooflineRidgeCalc.a100_hbm_bw_str -h100_fp8_str = RooflineRidgeCalc.h100_fp8_str -h100_hbm_bw_str = RooflineRidgeCalc.h100_hbm_bw_str -``` - -This means that any operation performing fewer than `{python} h100_fp16_ridge_str` floating-point operations per byte loaded is memory-bound on the H100 at FP16. At FP8 precision, where compute doubles to `{python} h100_fp8_str` TFLOPS while bandwidth remains `{python} h100_hbm_bw_str` TB/s, the ridge point rises to approximately `{python} h100_fp8_ridge_str` FLOP/byte. The A100, with `{python} a100_fp16_str` TFLOPS and `{python} a100_hbm_bw_str` TB/s, has a lower ridge point of approximately `{python} a100_ridge_str` FLOP/byte at FP16. Each hardware generation increases compute faster than bandwidth, pushing the ridge point higher and making more workloads memory-bound. - -@fig-shifting-roofline overlays the roofline models for four GPU generations on a single log-log plot, making the generational shift visible at a glance. The ridge point has grown from 139 FLOP/byte on the V100 to 625 FLOP/byte on the B200, a 4.5$\times$ increase in seven years. An operation like naive self-attention, with an arithmetic intensity near 10 FLOP/byte, was memory-bound on every generation but falls progressively further below the ridge with each new chip. More critically, operations near 200 FLOP/byte, such as large matrix multiplications, transition from compute-bound on V100 to memory-bound on B200. The same kernel can change performance regime across hardware generations, a fact that demands re-profiling whenever hardware is upgraded. - -::: {#fig-shifting-roofline fig-env="figure" fig-pos="htb" fig-cap="**The Shifting Roofline Across GPU Generations**. Overlaid roofline models for V100 through B200 show the ridge point growing from 139 to 625 FLOP/byte. Operations like naive attention, which were compute-bound on V100, become memory-bound on B200 as the roofline shifts. The same kernel can change performance regime across hardware generations." fig-alt="Log-log roofline plot for V100, A100, H100, B200 with ridge points and example ML operations marked"} -```{python} -#| echo: false -# ┌───────────────────────────────────────────────────────────────────────────── -# │ SHIFTING ROOFLINE (FIGURE) -# ├───────────────────────────────────────────────────────────────────────────── -# │ Context: @fig-shifting-roofline — roofline across GPU generations -# │ -# │ Goal: Overlay rooflines for V100–B200; show ridge point growth 139→625; -# │ mark ops moving from compute-bound to memory-bound. -# │ Show: Log-log; four rooflines; ridge points; example ops. -# │ How: Peak TFLOPS, HBM BW; ridge = TFLOPS/BW; viz.setup_plot(). -# │ -# │ Imports: numpy (np), matplotlib.pyplot (plt), mlsys.viz (viz) -# │ Exports: (figure only, no prose variables) -# └───────────────────────────────────────────────────────────────────────────── -import numpy as np -import matplotlib.pyplot as plt -from mlsys import viz - -fig, ax, COLORS, plt = viz.setup_plot(figsize=(9, 6)) - -# GPU specifications: (name, peak FP16 TFLOPS, HBM BW TB/s, color) -gpus = [ - ("V100", 125, 0.900, COLORS["VioletLine"]), - ("A100", 312, 2.039, COLORS["BlueLine"]), - ("H100", 989, 3.350, COLORS["GreenLine"]), - ("B200", 5000, 8.000, COLORS["OrangeLine"]), -] - -ai_range = np.logspace(0, 4, 500) # 1 to 10000 FLOP/byte - -for name, peak_tflops, bw_tbs, color in gpus: - bw_bytes = bw_tbs * 1e12 # bytes/s - peak_flops = peak_tflops * 1e12 # FLOPS - ridge = peak_flops / bw_bytes # FLOP/byte - - # Roofline: min(peak, BW * AI), expressed in TFLOPS - attainable = np.minimum(peak_tflops, bw_tbs * 1e3 * ai_range / 1e3) - # Simpler: attainable TFLOPS = min(peak_tflops, BW_TB/s * AI * 1e-0) - # BW in TB/s * AI in FLOP/byte = TFLOP/s (since TB/s * FLOP/byte = 1e12 * FLOP / (1e12 * byte) * byte/s... ) - # Let's compute correctly: - # attainable FLOPS = min(peak_flops, bw_bytes * AI) - # attainable TFLOPS = attainable FLOPS / 1e12 - attainable_tflops = np.minimum(peak_tflops, (bw_tbs * ai_range)) - - ax.plot(ai_range, attainable_tflops, color=color, linewidth=2.0, label=f"{name}") - - # Mark ridge point - ax.plot(ridge, peak_tflops, "o", color=color, markersize=8, zorder=5) - # Label ridge point - offset_y = 1.25 if name != "B200" else 0.75 - ax.annotate(f"{name} ridge\n{ridge:.0f} FLOP/byte", - xy=(ridge, peak_tflops), fontsize=7.5, - color=color, fontweight="bold", - ha="center", va="bottom" if name != "B200" else "top", - xytext=(0, 8 if name != "B200" else -8), - textcoords="offset points") - -# Example ML operations as vertical dashed lines -ops = [ - ("LayerNorm\n(~5 FLOP/byte)", 5, "bottom"), - ("Attention (naive)\n(~10 FLOP/byte)", 10, "bottom"), - ("MatMul (large)\n(~200 FLOP/byte)", 200, "bottom"), -] - -for label, ai_val, va_pos in ops: - ax.axvline(x=ai_val, color=COLORS["primary"], linestyle=":", linewidth=1.0, alpha=0.5) - ax.text(ai_val, 0.8, label, fontsize=7, color=COLORS["primary"], - ha="center", va="bottom", rotation=0, - bbox=dict(boxstyle="round,pad=0.2", facecolor="white", edgecolor="none", alpha=0.8)) - -# Arrow annotation for ridge shift -ax.annotate("", - xy=(625, 4200), xytext=(139, 160), - arrowprops=dict(arrowstyle="->", color="crimson", lw=2.0, linestyle="-")) -ax.text(90, 500, "Ridge point shifted\n4.5x in 7 years", - fontsize=9, color="crimson", fontweight="bold", - ha="center", va="center", - bbox=dict(boxstyle="round,pad=0.3", facecolor=COLORS["RedL"], edgecolor="crimson", alpha=0.8)) - -ax.set_xscale("log") -ax.set_yscale("log") -ax.set_xlabel("Arithmetic Intensity (FLOP/byte)") -ax.set_ylabel("Attainable Performance (TFLOPS)") -ax.set_xlim(1, 10000) -ax.set_ylim(0.5, 10000) -ax.legend(loc="lower right", fontsize=9, title="GPU Generation") - -plt.tight_layout() -plt.show() -``` -::: - -### Where ML Workloads Fall {#sec-performance-engineering-workload-placement} - -Different ML operations have vastly different arithmetic intensities. Understanding where each falls on the roofline determines which optimization strategies apply. - -```{python} -#| echo: false -#| label: workload-intensity-calc -# ┌───────────────────────────────────────────────────────────────────────────── -# │ WORKLOAD ARITHMETIC INTENSITY -# ├───────────────────────────────────────────────────────────────────────────── -# │ Context: @sec-performance-engineering-workload-placement — classify ML ops -# │ -# │ Goal: Compute arithmetic intensity (FLOP/byte) for a 4096×4096 FP16 GEMM, -# │ a GELU element-wise op, and LLM decode at batch=1 on a 4096-dim model -# │ to show why inference is predominantly memory-bound. -# │ Show: ~1365 FLOP/byte (GEMM, compute-bound), ~1.25 (GELU, memory-bound), -# │ ~1 (LLM decode batch=1, deeply memory-bound) — inline and in table. -# │ How: FLOP count / byte count; GEMM uses 2*M*N*K / (M*K + K*N + M*N)*bytes; -# │ element-wise uses ~5 FLOPs per element / (load + store) bytes. -# │ -# │ Imports: (none — standalone calculation), mlsys.formatting (check) -# │ Exports: gemm_intensity_str, elem_intensity_str, decode_intensity_str -# └───────────────────────────────────────────────────────────────────────────── - -# ┌── LEGO ─────────────────────────────────────────────── -```{python} -#| label: workload-intensity-refactor -#| echo: false -# ┌───────────────────────────────────────────────────────────────────────────── -# │ WORKLOAD INTENSITY ANALYSIS (LEGO) -# ├───────────────────────────────────────────────────────────────────────────── -# │ Context: @sec-performance-engineering-workload-placement — classify ML ops -# │ -# │ Goal: Compute arithmetic intensity (FLOP/byte) for a 4096×4096 FP16 GEMM, -# │ a GELU element-wise op, and LLM decode at batch=1 on a 4096-dim model -# │ to show why inference is predominantly memory-bound. -# │ Show: ~1365 FLOP/byte (GEMM, compute-bound), ~1.25 (GELU, memory-bound), -# │ ~1 (LLM decode batch=1, deeply memory-bound) — inline and in table. -# │ How: FLOP count / byte count; GEMM uses 2*M*N*K / (M*K + K*N + M*N)*bytes; -# │ element-wise uses ~5 FLOPs per element / (load + store) bytes. -# │ -# │ Imports: mlsys.constants (H100_MEM_BW, H100_FLOPS_FP16_TENSOR, ...) -# │ Exports: gemm_intensity_str, elem_intensity_str, decode_intensity_str, -# │ h100_ridge_str, prefill_intensity_str, decode_t_ms_str, -# │ decode_util_pct_str, decode_tokens_sec_str, -# │ gemm_flops_str, gemm_data_mb_str, -# │ gelu_flops_per_val, decode_step_param_b_str, -# │ decode_step_weight_gb_str, decode_step_flops_str, -# │ decode_step_bytes_str -# └───────────────────────────────────────────────────────────────────────────── -from mlsys.constants import ( - H100_MEM_BW, H100_FLOPS_FP16_TENSOR, BYTES_PER_FP16, - BILLION, MILLION, TRILLION, THOUSAND, - GB, TB, second, byte, flop, MICROSECOND -) -from mlsys.formatting import fmt, check, md - -# ┌── LEGO ─────────────────────────────────────────────── -class WorkloadIntensityCalc: - """Arithmetic intensity for GEMM, GELU, and LLM decode operations.""" - - # ┌── 1. LOAD (Constants) ────────────────────────────────────────────── - # Level 1: Hardware - h100_flops = H100_FLOPS_FP16_TENSOR - h100_bw = H100_MEM_BW - bytes_per_elem = BYTES_PER_FP16 - - # Level 2: Workload - # GEMM: C = A*B where A is MxK, B is KxN - M, K, N = 4096, 4096, 4096 - # Element-wise (e.g., GELU): ~5 ops per element - elem_flops_per_val = 5 - elem_count = M * K - # LLM decode: batch=1, hidden=4096 - decode_batch = 1 - decode_hidden = 4096 - # Prefill: batch=1024 - prefill_batch = 1024 - - # ┌── 2. EXECUTE (The Compute) ──────────────────────────────────────── - # GEMM: 2*M*N*K FLOPs, load A (M*K*2) + B (K*N*2) + store C (M*N*2) - gemm_flops = 2 * M * N * K - gemm_bytes = (M * K + K * N + M * N) * bytes_per_elem - gemm_intensity = gemm_flops / gemm_bytes - - # Element-wise: load+store - elem_flops = elem_flops_per_val * elem_count - elem_bytes = 2 * elem_count * bytes_per_elem # load + store - elem_intensity = elem_flops / elem_bytes - - # LLM decode step: batch=1, weight matrix is hidden*hidden - decode_flops_val = 2 * decode_batch * decode_hidden * decode_hidden - decode_bytes_val = (decode_hidden * decode_hidden + decode_batch * decode_hidden) * bytes_per_elem - decode_intensity_val = decode_flops_val / decode_bytes_val - - # Level 4: Ratios - h100_ridge = (h100_flops / h100_bw).m_as(flop/byte) - prefill_intensity = (2 * prefill_batch) / (2 * bytes_per_elem) # Approx for large hidden - - # Token Rate Physics (Llama-3 70B sharded across 8 GPUs) - p_shard = 17.5 * BILLION - weight_bytes_val = p_shard * bytes_per_elem - decode_step_flops_val = 2 * p_shard - t_decode_ms = (weight_bytes_val / h100_bw.m_as(byte/second)) * 1000 - achieved_tflops = (decode_step_flops_val / (t_decode_ms / 1000)) / TRILLION - util_pct = (achieved_tflops / h100_flops.m_as(TFLOPs/second)) * 100 - tokens_sec = 1000 / t_decode_ms - - # ┌── 3. GUARD (Invariants) ────────────────────────────────────────── - check(gemm_intensity > 100, f"GEMM should have high AI, got {gemm_intensity:.1f}") - check(elem_intensity < 5, f"Element-wise should have low AI, got {elem_intensity:.1f}") - check(decode_intensity_val < 5, f"Decode should be memory-bound, got {decode_intensity_val:.1f}") - check(h100_ridge > 200, f"H100 ridge point mismatch, got {h100_ridge:.1f}") - - # ┌── 4. OUTPUT (Formatting) ───────────────────────────────────────────── - gemm_intensity_str = f"{gemm_intensity:.0f}" - gemm_flops_str = f"{gemm_flops/BILLION:.0f} billion" - gemm_data_mb_str = f"{gemm_bytes/MILLION:.0f} MB" - - elem_intensity_str = f"{elem_intensity:.1f}" - gelu_flops_per_val_str = f"{elem_flops_per_val}" - - decode_intensity_str = f"{decode_intensity_val:.1f}" - h100_ridge_str = f"{h100_ridge:.0f}" - prefill_intensity_str = f"{prefill_intensity:.0f}" - decode_t_ms_str = f"{t_decode_ms:.1f}" - decode_util_pct_str = f"{util_pct:.1f}" - decode_tokens_sec_str = f"{tokens_sec:.0f}" - - decode_step_param_b_str = f"{p_shard/BILLION:.1f}" - decode_step_weight_gb_str = f"{weight_bytes_val/BILLION:.0f}" - decode_step_flops_str = f"{decode_step_flops_val/BILLION:.0f}" - decode_step_bytes_str = f"{weight_bytes_val/BILLION:.0f}" - -# ┌── EXPORTS (Bridge to Text) ───────────────────────────────────────────────── -gemm_intensity_str = WorkloadIntensityCalc.gemm_intensity_str -elem_intensity_str = WorkloadIntensityCalc.elem_intensity_str -gelu_flops_per_val = WorkloadIntensityCalc.gelu_flops_per_val_str -decode_intensity_str = WorkloadIntensityCalc.decode_intensity_str -h100_ridge_str = WorkloadIntensityCalc.h100_ridge_str -prefill_intensity_str = WorkloadIntensityCalc.prefill_intensity_str -decode_t_ms_str = WorkloadIntensityCalc.decode_t_ms_str -decode_util_pct_str = WorkloadIntensityCalc.decode_util_pct_str -decode_tokens_sec_str = WorkloadIntensityCalc.decode_tokens_sec_str -decode_step_param_b_str = WorkloadIntensityCalc.decode_step_param_b_str -decode_step_weight_gb_str = WorkloadIntensityCalc.decode_step_weight_gb_str -decode_step_flops_str = WorkloadIntensityCalc.decode_step_flops_str -decode_step_bytes_str = WorkloadIntensityCalc.decode_step_bytes_str -``` - -Large matrix multiplications (GEMMs) are the most compute-intensive operations in ML. A square matrix multiplication of dimension $4096\times4096$ in FP16 performs approximately 137 billion FLOPs while loading roughly 100 MB of data, yielding an arithmetic intensity of approximately `{python} gemm_intensity_str` FLOP/byte. This sits well above the H100's ridge point, making large GEMMs firmly compute-bound. - -Element-wise operations tell the opposite story. A GELU activation applied to a $4096\times4096$ tensor performs roughly 5 operations per element but must load and store each element, yielding an arithmetic intensity of approximately `{python} elem_intensity_str` FLOP/byte. These operations are profoundly memory-bound, spending almost all their time waiting for data transfers rather than computing. - -Autoregressive LLM decoding at batch size one represents the extreme case. Each decoding step reads the entire weight matrix (gigabytes of data) to produce a single output token. With a hidden dimension of 4096 and batch size 1, the arithmetic intensity is approximately `{python} decode_intensity_str` FLOP/byte, deep in the memory-bound regime. This explains why LLM token generation achieves a tiny fraction of peak FLOPS: the GPU spends nearly all its time reading weights, not multiplying them. - -| **Operation** | **Arithmetic Intensity** | **H100 FP16 Regime** | **Primary Bottleneck** | -|:-----------------------------------|:-------------------------|:---------------------|:-----------------------| -| **GEMM ($4096\times4096$)** | ~1,365 FLOP/byte | Compute-bound | Tensor core throughput | -| **Self-Attention (seq=2048)** | ~50--200 FLOP/byte | Memory-bound | HBM bandwidth | -| **Element-wise (GELU, LayerNorm)** | ~1--3 FLOP/byte | Memory-bound | HBM bandwidth | -| **LLM Decode (batch=1)** | ~1--2 FLOP/byte | Memory-bound | HBM bandwidth | - -: **Arithmetic Intensity of Common ML Operations**. Most operations in transformer inference, aside from large batched GEMMs, fall below the H100's ridge point and are therefore memory-bound. Performance engineering focuses on reducing the memory traffic of these operations. {#tbl-arithmetic-intensity} - -The central insight from @tbl-arithmetic-intensity is that the majority of operations in a transformer inference pipeline are memory-bound. Training workloads with large batch sizes shift more operations into the compute-bound regime because GEMM dimensions scale with batch size. Inference, however, especially autoregressive generation, is dominated by memory-bound operations. This is why the techniques in the rest of this chapter, fusion, tiling, reduced precision, and algorithmic shortcuts, all target the same fundamental problem: reducing bytes moved per operation. - -This observation also explains a common source of confusion: why GPU benchmarks (which report peak TFLOPS) often fail to predict real inference performance. Two GPUs with different TFLOPS but identical memory bandwidth will achieve virtually identical LLM decode throughput at batch size 1, because decode is entirely memory-bound. The correct metric for comparing GPUs for LLM inference is not FLOPS but rather the combination of memory bandwidth and memory capacity. Bandwidth determines the token generation rate, and capacity determines the maximum batch size (and therefore throughput). Only at large batch sizes, where decode approaches the compute-bound regime, do the FLOPS differences between GPUs translate into throughput differences. - -### Batch Size as the Universal Control Knob {#sec-performance-engineering-batch-size} - -Batch size is the highest-impact, and most constrained, lever for performance. Increasing the batch size transforms the arithmetic intensity of every operation. For an LLM decode step, the arithmetic intensity scales linearly with batch size: - -$$ -I_{\text{decode}}(\text{batch}) = \frac{2 \times \text{params} \times \text{batch}}{\text{params} \times \text{bytes\_per\_param} + \text{batch} \times d \times \text{bytes\_per\_elem}} -$$ - -At batch size 1, the denominator is dominated by the weight term ($\text{params} \times \text{bytes\_per\_param}$), and $I \approx 2 / \text{bytes\_per\_param} \approx 1$ FLOP/byte for FP16. At batch size 256, the input term becomes significant, and $I \approx 2 \times 256 / \text{bytes\_per\_param} \approx 256$ FLOP/byte, approaching the compute-bound regime. - -This means that at large batch sizes, the GPU transitions from memory-bound to compute-bound, and utilization increases dramatically. A single H100 achieving 2% utilization at batch size 1 may achieve 60% utilization at batch size 256. The economic implication is stark: the cost per token decreases by 30$\times$ as batch size increases from 1 to 256. - -The constraint is memory: each additional request in the batch requires its own KV cache, and the total KV cache across all requests must fit in GPU memory alongside the model weights. A 70B model with 80 GB of weights in FP16 leaves almost no room for KV cache on an 80 GB GPU. This is precisely why the precision engineering techniques covered later in this chapter matter: INT4 weight quantization frees 60 GB for KV cache, enabling batch sizes that transform the economics of serving. - -The continuous batching systems introduced in @sec-inference-scale manage batch size dynamically, adding and removing requests as they complete. Performance engineering's role is to maximize the effective batch size by minimizing the per-request memory footprint, primarily through KV cache compression and weight quantization. - -A critical enabler for large batch sizes is **paged KV cache management**[^fn-paged-attention-os], introduced by vLLM (Kwon et al., 2023). Traditional KV cache implementations pre-allocate contiguous memory for each request's maximum possible sequence length. - -[^fn-paged-attention-os]: **Paged Attention**: Named by direct analogy to OS virtual memory paging, where the OS maps non-contiguous physical pages to contiguous virtual addresses. The insight, presented at SOSP 2023, was that the same mechanism eliminates internal fragmentation in KV caches, recovering the 60--80% of GPU memory wasted by worst-case pre-allocation. This single abstraction transformed LLM serving economics by enabling 2--4$\times$ larger batch sizes without any change to model weights or precision. \index{Paged Attention!memory management} - -If the maximum is 4,096 tokens but the average is 500, approximately 88% of the allocated memory is wasted. Paged attention divides the KV cache into fixed-size blocks (pages), allocated on demand as the sequence grows. This eliminates memory fragmentation and enables near-100% utilization of the KV cache memory budget. The performance impact is indirect but substantial: by reducing memory waste, paged attention enables 2--4$\times$ larger effective batch sizes, which in turn improve throughput and GPU utilization through the batch size mechanism described above. - -The interaction between paged attention and KV cache quantization is multiplicative. Paged attention reduces memory *waste* (from fragmentation), while quantization reduces memory *usage* (from precision). Together, they can increase the effective batch size by 8--16$\times$ compared to a baseline system with pre-allocated FP16 KV caches, fundamentally changing the economics of LLM serving. - -### The Prefill-Decode Decomposition {#sec-performance-engineering-prefill-decode} - -Modern LLM serving systems decompose each request into two distinct phases with fundamentally different performance characteristics, a distinction that drives system architecture and optimization strategy. - -The **prefill phase** processes the entire input prompt in parallel. If the prompt contains $P$ tokens, the prefill phase executes a single forward pass over all $P$ tokens simultaneously. The GEMM operations have shape [$P$, $d$]$\times$ [$d$, $d$], making the batch dimension equal to $P$. For a prompt of 1024 tokens, this is arithmetically intensive: the arithmetic intensity is approximately $2 \times 1024 / 2 = 1024$ FLOP/byte for FP16 weights, well into the compute-bound regime. Prefill is therefore limited by Tensor Core throughput, not memory bandwidth. - -The **decode phase** generates output tokens one at a time, autoregressively. Each step has a batch dimension of 1 (for a single request) or the number of concurrent requests (for batched serving). At batch size 1, decode is deeply memory-bound as analyzed in @sec-performance-engineering-workload-placement. - -This decomposition has profound implications for system design. A system optimized for prefill (maximizing FLOPS utilization) would use large matrix sizes and high compute throughput. A system optimized for decode (maximizing bandwidth utilization) would use aggressive quantization and memory optimization. A real serving system must handle both phases, often simultaneously across different requests in a continuous batching framework. - -Disaggregated serving addresses this mismatch by running prefill and decode on separate hardware pools. Prefill servers are optimized for compute (fewer, higher-FLOPS GPUs), while decode servers are optimized for memory bandwidth and capacity (more memory per GPU, aggressive quantization). The KV cache computed during prefill is transferred to a decode server, which handles the subsequent autoregressive generation. This disaggregation allows each phase to use hardware and software configurations tuned for its specific bottleneck. - -The performance characteristics of each phase determine which optimization techniques apply. FlashAttention provides the largest speedup during prefill, where the quadratic attention computation dominates. KV cache quantization and speculative decoding apply exclusively to the decode phase. Precision engineering (FP8/INT4 weights) benefits both phases, but through different mechanisms: prefill benefits from doubled compute throughput (FP8 Tensor Cores), while decode benefits from doubled effective bandwidth (half the bytes per weight read). - -::: {.callout-notebook title="The Roofline Diagnostic"} - -**Problem**: You are deploying a 70B parameter LLM on 8$\times$ H100 GPUs with tensor parallelism. At batch size 1, each GPU holds approximately 17.5B parameters in FP16 (35 GB of weights). Each decode step reads all weights to produce one token. What is the achieved arithmetic intensity, and what is the theoretical maximum token generation rate? - -**The Math**: - -*Step 1: Arithmetic Intensity.* -Each decode step per GPU: FLOPs $= 2 \times 17.5 \times 10^9 = 35 \times 10^9$ FLOP. Bytes loaded $= 17.5 \times 10^9 \times 2 = 35 \times 10^9$ bytes $= 35$ GB. - -$$ -I = \frac{35 \times 10^9 \text{ FLOP}}{35 \times 10^9 \text{ bytes}} = 1.0 \text{ FLOP/byte} -$$ - -This is far below the H100 ridge point of ~295 FLOP/byte. The operation is deeply memory-bound. - -*Step 2: Token Rate.* -Since the operation is memory-bound, performance is limited by bandwidth, not compute: - -$$ -t_{\text{decode}} = \frac{35 \text{ GB}}{3.35 \text{ TB/s}} \approx 10.4 \text{ ms per token} -$$ - -This yields approximately 96 tokens/second per GPU, or about 96 tokens/second for the model (since tensor parallelism does not multiply throughput for memory-bound decode). In practice, overheads from KV cache reads and NVLink synchronization reduce this to 40--70 tokens/second. - -**Takeaway**: At batch size 1, fewer than 0.4% of the H100's FP16 FLOPS are in use. The only ways to improve are: (a) increase batch size to amortize weight reads, (b) reduce weight bytes via quantization, or (c) use speculative decoding to generate multiple tokens per weight read. - -::: - -The roofline model establishes the physics that constrains all subsequent optimization. The first and most impactful strategy for breaking through the memory wall is to keep data in SRAM instead of round-tripping through HBM. - -## Operator Fusion and Kernel Engineering {#sec-performance-engineering-fusion} - -Consider the simple sequence of operations $Y = \text{LayerNorm}(\text{GELU}(XW + b))$. In a naive implementation, the GPU writes the output of the matrix multiply back to main memory, reads it back for the GELU, writes it out again, and reads it one final time for the LayerNorm. This redundant data movement shatters performance. Operator fusion solves this by keeping intermediate results in ultra-fast registers, executing the entire sequence in a single trip to memory. - -### The Kernel Launch Problem {#sec-performance-engineering-kernel-launch} - -Each GPU kernel launch involves overhead: the CPU must prepare launch parameters, dispatch to the GPU command queue, and the GPU must schedule thread blocks across its streaming multiprocessors (SMs). For a small element-wise operation on a modern GPU, this overhead can be 5--20 $\mu$s, a time during which a memory-bound kernel might have already completed its useful work. When a transformer layer comprises dozens of small operations (add, multiply, normalize, activate), the cumulative launch overhead becomes significant. - -Each unfused kernel must also materialize its output in HBM. Consider a sequence of three operations: $Y = \text{LayerNorm}(\text{GELU}(XW + b))$. Without fusion, this requires: - -1. **GEMM kernel**: Read $X$ and $W$ from HBM, compute $XW + b$, write result $Z_1$ to HBM. -2. **GELU kernel**: Read $Z_1$ from HBM, compute GELU($Z_1$), write $Z_2$ to HBM. -3. **LayerNorm kernel**: Read $Z_2$ from HBM, compute LayerNorm($Z_2$), write $Y$ to HBM. - -Intermediate tensors $Z_1$ and $Z_2$ each occupy the same memory as the output $Y$. For a hidden dimension of 4096 and batch size of 2048 in FP16, each intermediate tensor is $4096 \times 2048 \times 2 = 16$ MB. The unfused execution reads and writes 32 MB of intermediate data that a fused kernel avoids entirely by holding $Z_1$ and $Z_2$ in registers or shared memory (SRAM) within the SM. - -@fig-fusion-before-after contrasts these two execution paths, making the HBM traffic savings visible. - -::: {#fig-fusion-before-after fig-env="figure" fig-pos="htb" fig-cap="**Operator Fusion: Before and After**. Left: three separate kernels each read from and write to HBM, materializing intermediate tensors $Z_1$ and $Z_2$ (5 HBM transfers total). Right: a single fused kernel reads $X$ and $W$ once, performs all three operations with intermediates held in SRAM, and writes only the final output $Y$ (2 HBM transfers). The fused path eliminates 32 MB of redundant HBM traffic per layer." fig-alt="Two side-by-side dataflow diagrams. Left unfused path shows 5 arrows between operations and HBM. Right fused path shows 2 arrows with intermediates staying in SRAM."} -```{.tikz} -\begin{tikzpicture}[font=\small\usefont{T1}{phv}{m}{n}, scale=0.85, transform shape] - \definecolor{BlueLine}{HTML}{006395} - \definecolor{BlueL}{HTML}{D1E6F3} - \definecolor{GreenLine}{HTML}{008F45} - \definecolor{GreenL}{HTML}{D4EFDF} - \definecolor{OrangeLine}{HTML}{E67817} - \definecolor{OrangeL}{HTML}{FCE4CC} - \definecolor{RedLine}{HTML}{CB202D} - \definecolor{RedL}{HTML}{F5D2D5} - - \tikzset{ - op/.style={draw=BlueLine, fill=BlueL, rounded corners=2pt, thick, - minimum width=2.0cm, minimum height=0.6cm, align=center, font=\scriptsize\bfseries}, - mem/.style={draw=RedLine, fill=RedL!40, rounded corners=2pt, thick, - minimum width=2.0cm, minimum height=0.5cm, align=center, font=\scriptsize}, - sram/.style={draw=GreenLine, fill=GreenL, rounded corners=2pt, thick, - minimum width=2.0cm, minimum height=0.5cm, align=center, font=\scriptsize}, - fusedbox/.style={draw=GreenLine, fill=GreenL!20, rounded corners=6pt, thick, - inner sep=6pt}, - hbmio/.style={-{Triangle[width=4pt,length=3pt]}, thick, RedLine}, - sramio/.style={-{Triangle[width=4pt,length=3pt]}, thick, GreenLine} - } - - % === LEFT: Unfused === - \node[font=\small\bfseries, text=RedLine] at (1.5, 5.5) {Unfused (3 Kernels)}; - - % HBM at top - \node[mem, minimum width=3.0cm] (hbm_l) at (1.5, 4.5) {HBM (Off-Chip)}; - - % Operations - \node[op] (gemm) at (1.5, 3.2) {GEMM}; - \node[op] (gelu) at (1.5, 1.8) {GELU}; - \node[op] (ln) at (1.5, 0.4) {LayerNorm}; - - % HBM read/write arrows (5 total round trips) - \draw[hbmio] (hbm_l.south) -- node[left, font=\tiny, text=RedLine] {read} (gemm.north); - \draw[hbmio] (gemm.east) -- ++(0.6,0) |- node[right, font=\tiny, text=RedLine, pos=0.2] {$Z_1$} (hbm_l.east); - \draw[hbmio] (hbm_l.south west) ++(0.3,0) -- ++(0,-2.1) -- (gelu.west); - \draw[hbmio] (gelu.east) -- ++(0.6,0) |- node[right, font=\tiny, text=RedLine, pos=0.2] {$Z_2$} ++(0,2.7); - \draw[hbmio] ([xshift=-5pt]hbm_l.south) -- ++(0,-3.5) -- (ln.west); - \draw[hbmio] (ln.south) -- ++(0,-0.4) node[below, font=\tiny] {$Y$ to HBM}; - - % Count annotation - \node[font=\scriptsize\bfseries, text=RedLine] at (1.5, -0.8) {5 HBM transfers}; - \node[font=\tiny, text=RedLine] at (1.5, -1.2) {32 MB intermediate traffic}; - - % === RIGHT: Fused === - \node[font=\small\bfseries, text=GreenLine] at (8.0, 5.5) {Fused (1 Kernel)}; - - % HBM at top - \node[mem, minimum width=3.0cm] (hbm_r) at (8.0, 4.5) {HBM (Off-Chip)}; - - % Fused kernel box - \node[fusedbox, minimum width=3.2cm, minimum height=3.2cm] (fbox) at (8.0, 1.8) {}; - - % Operations inside fused box - \node[op] (f_gemm) at (8.0, 3.0) {GEMM}; - \node[op] (f_gelu) at (8.0, 1.8) {GELU}; - \node[op] (f_ln) at (8.0, 0.6) {LayerNorm}; - - % SRAM label - \node[sram, minimum width=1.0cm] at (10.0, 1.8) {SRAM}; - - % Internal SRAM arrows - \draw[sramio] (f_gemm) -- node[right, font=\tiny, text=GreenLine] {$Z_1$} (f_gelu); - \draw[sramio] (f_gelu) -- node[right, font=\tiny, text=GreenLine] {$Z_2$} (f_ln); - - % Only 2 HBM transfers - \draw[hbmio] (hbm_r.south) -- node[left, font=\tiny, text=RedLine] {read $X$,$W$} (fbox.north); - \draw[hbmio] (fbox.south) -- ++(0,-0.4) node[below, font=\tiny] {$Y$ to HBM}; - - % Count annotation - \node[font=\scriptsize\bfseries, text=GreenLine] at (8.0, -0.8) {2 HBM transfers}; - \node[font=\tiny, text=GreenLine] at (8.0, -1.2) {0 MB intermediate traffic}; - - % Speedup annotation - \draw[OrangeLine, ultra thick, -{Triangle[width=6pt,length=5pt]}] (4.2, 1.8) -- (5.3, 1.8) - node[midway, above, font=\scriptsize\bfseries, text=OrangeLine] {2.5$\times$ faster}; - -\end{tikzpicture} -``` -::: - -In a naive implementation without operator fusion, executing one Transformer layer requires roughly 50 separate kernel launches. If each launch incurs a 10-microsecond overhead, the system spends 500 microseconds purely on dispatch latency. If the actual arithmetic execution of the layer takes only 2 milliseconds, the launch overhead consumes 20% of the total wall-clock time, leaving the GPU compute units idle for one-fifth of the inference cycle. This "launch-bound" regime limits the benefits of faster hardware; doubling the GPU's FLOPs does nothing to reduce the 500-microsecond fixed cost. Operator fusion addresses this by compiling these 50 discrete operations into a small handful of fused kernels—often reducing the count to 5–10 launches—thereby reclaiming the lost cycles and shifting the workload back towards a compute-bound profile. - -### Fusion Categories {#sec-performance-engineering-fusion-categories} - -GPU kernel fusion falls into three categories, each with different complexity and performance impact. - -Element-wise fusion is the simplest form: consecutive element-wise operations (add, multiply, activation functions) combine into a single kernel. Because each output element depends on exactly one input element, this fusion is always legal and straightforward to implement. Every modern deep learning framework performs element-wise fusion automatically. - -Reduction fusion combines an element-wise operation with a subsequent reduction (such as summing elements for a loss function, or computing mean and variance for layer normalization). This category is more complex because reductions require inter-thread communication within the kernel. Reductions need warp-level shuffle instructions or shared memory to aggregate partial results across threads. Despite this complexity, the memory savings are substantial: the intermediate tensor before the reduction never materializes in HBM. For layer normalization specifically, reduction fusion avoids writing the large pre-normalization tensor to HBM and reading it back for the mean/variance computation. - -Operator-specific fusion is the most impactful and the most difficult category. These are custom kernels designed for a specific sequence of operations, such as fused attention or fused GEMM-bias-activation. The kernel architect must reason about data flow, shared memory allocation, and thread scheduling simultaneously. The payoff is substantial: FlashAttention, which we examine next, reduces attention memory traffic from quadratic to linear in sequence length. - -To appreciate the quantitative impact, consider each category applied to a single transformer layer with hidden dimension 4096 and batch size 2048 in FP16. Element-wise fusion of a bias-GELU-dropout chain eliminates two intermediate tensors of 16 MB each, saving 64 MB of HBM traffic (two writes plus two reads) per layer. Across 80 layers, this reclaims 5.1 GB of memory bandwidth per forward pass. Reduction fusion of LayerNorm avoids materializing the pre-normalization tensor (16 MB) and the intermediate mean/variance statistics, saving an additional 48 MB per layer. Operator-specific attention fusion (FlashAttention) provides the largest single gain: for a sequence length of 8192, it eliminates the 128 MB per-head attention score matrix, saving over 4 GB per layer across 32 heads. The cumulative effect of all three fusion categories can reduce total HBM traffic by 60--80% for a transformer forward pass, translating directly into proportional wall-clock speedup for memory-bound workloads. - -### CUDA Graphs: Eliminating Launch Overhead {#sec-performance-engineering-cuda-graphs} - -An orthogonal technique for reducing the overhead term in the Iron Law is **CUDA Graphs**[^fn-cuda-graphs-constraint]. While operator fusion combines multiple operations into fewer kernels, CUDA Graphs eliminate the CPU overhead of launching those kernels. - -[^fn-cuda-graphs-constraint]: **CUDA Graphs**: Introduced in CUDA 10 (2018), originally for graphics rendering pipelines that replay identical command sequences every frame. The strict determinism requirement -- identical operations, shapes, and memory addresses per replay -- directly conflicts with the dynamic shapes and variable batch sizes of LLM serving, restricting their use primarily to the decode phase where the computation pattern repeats per token. \index{CUDA Graphs!overhead reduction} - -In standard PyTorch execution, each kernel launch requires the CPU to push a command to the GPU's command queue. For a transformer decoder layer with 30+ kernels, this CPU-to-GPU roundtrip (typically 5--10 $\mu$s per launch) accumulates to 150--300 $\mu$s per layer. For a 70-layer model, kernel launch overhead alone contributes 10--20 ms per forward pass, a significant fraction of the total time for memory-bound inference. - -CUDA Graphs address this by recording a sequence of GPU operations (kernel launches, memory copies) into a replayable graph. The recording happens once during a warmup phase. On subsequent iterations, replaying the graph requires only a single CPU-to-GPU command that dispatches the entire recorded sequence, reducing launch overhead to approximately 5--10 $\mu$s total regardless of the number of kernels. - -The benefit is substantial: for a model with 30+ kernels per layer and 70+ layers, the baseline kernel launch overhead can exceed 15 ms per forward pass. CUDA Graphs reduce this to under 0.1 ms, reclaiming 15 ms that translates directly to higher token generation rates. - -The constraint is that CUDA Graphs require deterministic execution: the sequence of operations, tensor shapes, and memory addresses must be identical across replays. This conflicts with dynamic inference patterns like variable-length sequences, changing batch sizes, and conditional computation (early exit, MoE routing). In practice, CUDA Graphs are most effective for the decode phase of LLM serving, where the computation pattern is repetitive (same operations per token), and less useful for the prefill phase, where input lengths vary. - -The combination of operator fusion (reducing the number of kernels) and CUDA Graphs (reducing the per-kernel overhead) can together eliminate nearly all non-compute overhead from the forward pass. When profiling reveals that kernel launch gaps constitute more than 10% of execution time, CUDA Graphs should be the first intervention considered. - -### FlashAttention: Tiled Attention as a System Primitive {#sec-performance-engineering-flashattention} - -Standard self-attention computes $\text{Softmax}(QK^T / \sqrt{d_k})V$, where $Q$, $K$, and $V$ are matrices of shape [sequence length$\times$ head dimension]. The na\"ive implementation materializes the full $N \times N$ attention matrix $S = QK^T$ in HBM, where $N$ is the sequence length. For $N = 8192$ and FP16 precision, this matrix alone consumes $8192 \times 8192 \times 2 = 128$ MB per attention head. At 32 heads, the materialized attention matrices require 4 GB per layer, dominating memory traffic for the entire forward pass. - -```{python} -#| label: flash-attention-savings-refactor -#| echo: false -# ┌───────────────────────────────────────────────────────────────────────────── -# │ FLASHATTENTION MEMORY SAVINGS (LEGO) -# ├───────────────────────────────────────────────────────────────────────────── -# │ Context: @sec-performance-engineering-operator-fusion FlashAttention prose -# │ -# │ Goal: Quantify HBM traffic reduction from tiling attention for a 32-head, -# │ seq_len=8192, head_dim=128, FP16 configuration. -# │ Show: ~65x reduction in HBM traffic by avoiding NxN materialization. -# │ How: Naive = 2*N^2 * h; Flash = 4*N*d * h. (bytes) -# │ -# │ Imports: mlsys.constants (BYTES_PER_FP16, MILLION, BILLION) -# │ Exports: naive_mb_str, flash_mb_str, savings_str, -# │ head_n_str, head_d_str, head_bytes_str, -# │ naive_head_traffic_mb_str, flash_head_traffic_mb_str, -# │ flash_savings_factor_str -# └───────────────────────────────────────────────────────────────────────────── -from mlsys.constants import BYTES_PER_FP16, MILLION, BILLION -from mlsys.formatting import fmt, check - -# ┌── LEGO ─────────────────────────────────────────────── -class FlashAttentionSavings: - """Memory traffic analysis for tiled attention vs naive.""" - - # ┌── 1. LOAD (Constants) ────────────────────────────────────────────── - seq_len = 8192 - head_dim = 128 - num_heads = 32 - bytes_per_elem = BYTES_PER_FP16 - - # ┌── 2. EXECUTE (The Compute) ──────────────────────────────────────── - # Naive attention HBM traffic per head: - # Write S = Q*K^T (N*N), read S for softmax, write P = softmax(S), read P for P*V - # Simplified: dominant term is materializing NxN matrix - naive_attn_bytes_per_head = 2 * seq_len * seq_len * bytes_per_elem # write + read S - naive_attn_bytes_total = naive_attn_bytes_per_head * num_heads - - # FlashAttention HBM traffic per head: - # Read Q, K, V once (3 * N * d * bytes), write O once (N * d * bytes) - # No NxN materialization - flash_attn_bytes_per_head = (3 + 1) * seq_len * head_dim * bytes_per_elem - flash_attn_bytes_total = flash_attn_bytes_per_head * num_heads - - # Savings ratio - savings_ratio = naive_attn_bytes_total / flash_attn_bytes_total - - # ┌── 3. GUARD (Invariants) ────────────────────────────────────────── - check(savings_ratio > 10, - f"FlashAttention should save >10x HBM traffic, got {savings_ratio:.1f}x") - - # ┌── 4. OUTPUT (Formatting) ───────────────────────────────────────────── - naive_mb_str = f"{naive_attn_bytes_total / MILLION:.0f}" - flash_mb_str = f"{flash_attn_bytes_total / MILLION:.0f}" - savings_str = f"{savings_ratio:.0f}" - - head_n_str = f"{seq_len}" - head_d_str = f"{head_dim}" - head_bytes_str = f"{head_dim * bytes_per_elem / MILLION * seq_len:.0f}" - naive_head_traffic_mb_str = f"{naive_attn_bytes_per_head / MILLION:.0f}" - flash_head_traffic_mb_str = f"{flash_attn_bytes_per_head / MILLION:.0f}" - flash_savings_factor_str = f"{savings_ratio:.0f}" - -# ┌── EXPORTS (Bridge to Text) ───────────────────────────────────────────────── -naive_mb_str = FlashAttentionSavings.naive_mb_str -flash_mb_str = FlashAttentionSavings.flash_mb_str -savings_str = FlashAttentionSavings.savings_str -head_n_str = FlashAttentionSavings.head_n_str -head_d_str = FlashAttentionSavings.head_d_str -head_bytes_str = FlashAttentionSavings.head_bytes_str -naive_head_traffic_mb_str = FlashAttentionSavings.naive_head_traffic_mb_str -flash_head_traffic_mb_str = FlashAttentionSavings.flash_head_traffic_mb_str -flash_savings_factor_str = FlashAttentionSavings.flash_savings_factor_str -``` - -FlashAttention, introduced by Dao et al. (2022), reformulates attention as a **tiled** computation. Instead of materializing the full $N \times N$ attention matrix, it processes $Q$, $K$, and $V$ in small blocks that fit in on-chip SRAM. The algorithm loads a block of $Q$ rows and iterates over blocks of $K$ and $V$ columns, computing partial attention scores and maintaining running statistics (online softmax) to produce the exact result without ever storing the full attention matrix in HBM. - -The HBM traffic reduction is dramatic. For a sequence length of `{python} head_n_str`, 32 heads, and head dimension `{python} head_d_str` in FP16, the na\"ive attention reads and writes approximately `{python} naive_mb_str` MB of attention matrices through HBM. FlashAttention reads $Q$, $K$, $V$ and writes $O$ once each, totaling approximately `{python} flash_mb_str` MB. This is a `{python} savings_str`$\times$ reduction in HBM traffic, translating directly into a proportional speedup for this memory-bound operation. - -The key insight behind FlashAttention is the **online softmax**[^fn-online-softmax-algo] trick, which makes tiling possible for an operation that appears to require global information. - -[^fn-online-softmax-algo]: **Online Softmax**: "Online" here is an algorithmic term meaning the computation processes data incrementally in a single pass without storing the full input -- the same sense as in "online learning" or "online algorithms." This property is what makes tiling possible: the algorithm never needs the complete $N \times N$ score matrix in memory simultaneously, reducing attention memory from $O(N^2)$ to $O(N)$ and making long-context inference feasible on fixed-size SRAM. \index{Online Softmax!FlashAttention} - -Standard softmax computes $\text{softmax}(s_i) = e^{s_i} / \sum_j e^{s_j}$, but for numerical stability it first subtracts the global maximum: $\text{softmax}(s_i) = e^{s_i - m} / \sum_j e^{s_j - m}$ where $m = \max_j s_j$. Finding this global maximum seems to require seeing all scores first, which would force materializing the full $N \times N$ matrix. - -The online algorithm avoids this by maintaining running statistics that are updated incrementally as each tile is processed. When processing tile $t$, the algorithm: - -1. Computes a local block of scores $S_t = Q_{\text{block}} K_t^T$. -2. Updates the running maximum: $m_{\text{new}} = \max(m_{\text{old}}, \max(S_t))$. -3. Rescales the previous running sum and output: multiply by $e^{m_{\text{old}} - m_{\text{new}}}$ to correct for the updated maximum. -4. Computes the local softmax contribution using $m_{\text{new}}$ and accumulates into the running output. - -After processing all tiles, the running output contains the exact same result as the standard algorithm. The rescaling step (step 3) is the critical innovation: it allows the algorithm to "fix up" previous partial results when a new tile reveals a larger maximum value. This correction is exact, not approximate, so FlashAttention produces bit-identical results to standard attention for a given numerical precision. - -The cost of this tiling is additional arithmetic: the rescaling operations in step 3 add FLOPs that the standard algorithm does not perform. Because the operation is profoundly memory-bound (the arithmetic intensity of standard attention is roughly 1--10 FLOP/byte for typical sequence lengths), the additional compute is "free" in the sense that the GPU's arithmetic units would otherwise be idle, waiting for HBM data transfers. Trading extra compute for fewer memory accesses is profitable whenever the operation is memory-bound, the central principle of this entire chapter. - -A concrete numerical example clarifies the memory savings. Consider one attention head with sequence length $N = 8192$ and head dimension $d = 128$. The $Q$, $K$, $V$ matrices are each $8192 \times 128$ in FP16, occupying $8192 \times 128 \times 2 = 2$ MB each (6 MB total for one head). The output matrix $O$ is the same size (2 MB). The total input/output data is therefore 8 MB per head. - -The na\"ive algorithm computes $S = QK^T$, a matrix of shape $8192 \times 8192$. In FP16, this score matrix occupies $8192 \times 8192 \times 2 = 128$ MB per head. After applying softmax, the result $P = \text{softmax}(S)$ is also $128$ MB. Computing $PV$ requires reading $P$ again. In total, the na\"ive algorithm reads $Q$, $K$, $V$ from HBM (6 MB), writes $S$ (128 MB), reads $S$ for softmax (128 MB), writes $P$ (128 MB), reads $P$ for the final multiply (128 MB), and writes $O$ (2 MB). The total HBM traffic is approximately 520 MB per head, dominated by the quadratic intermediates. - -FlashAttention processes the computation in tiles of size $B_r \times B_c$ (typically $128\times128$ on H100). For one tile, the algorithm loads a block of $Q$ ($128 \times 128 \times 2 = 32$ KB), a block of $K$ ($128 \times 128 \times 2 = 32$ KB), and a block of $V$ (32 KB), totaling 96 KB. This fits comfortably in the H100's 228 KB of shared memory per SM. The tile score $S_{\text{tile}} = Q_{\text{tile}} K_{\text{tile}}^T$ is computed and consumed entirely within SRAM; it is never written to HBM. The algorithm iterates over $8192 / 128 = 64$ column tiles for each of $64$ row tiles, but the total HBM traffic is just the cost of reading $Q$, $K$, $V$ once (6 MB) and writing $O$ once (2 MB), totaling 8 MB per head. This is 65$\times$ less HBM traffic than the na\"ive algorithm, and the ratio grows quadratically with sequence length. - -FlashAttention-2 (Dao, 2023) further optimizes the algorithm for modern GPU architectures by restructuring the parallelism pattern. The original FlashAttention parallelizes over batch and head dimensions, meaning each thread block handles one (batch, head) pair and iterates over the full sequence. FlashAttention-2 additionally parallelizes over the sequence dimension of the query matrix, distributing work across thread blocks more efficiently and achieving better occupancy on GPUs with many streaming multiprocessors. It also reduces the number of non-GEMM FLOPs by restructuring the rescaling operations and exploiting the asymmetry between the Q loop (outer) and K/V loop (inner). - -FlashAttention-3 (Dao et al., 2024) targets the H100's new hardware features: FP8 Tensor Cores and the Tensor Memory Accelerator (TMA). By computing attention in FP8 with selective FP16 accumulation, FlashAttention-3 achieves near-peak FP8 utilization for the attention operation, further closing the gap between achieved and theoretical performance. - -::: {.callout-war-story title="The FlashAttention Breakthrough"} -In 2022, Tri Dao challenged the prevailing wisdom that the attention mechanism's $O(N^2)$ complexity required better matrix multiplication kernels. His insight was that the bottleneck was not compute, but memory hierarchy. Standard attention materialized the massive $N \times N$ attention score matrix in high-latency HBM. FlashAttention restructured the algorithm using tiling to keep running statistics in on-chip SRAM, computing the softmax without ever writing the full matrix to global memory. This reduced memory complexity to linear $O(N)$ and wall-clock time by 2--4x. Within six months, it was integrated into PyTorch, TensorFlow, and JAX, becoming the default attention implementation for the industry. -::: - -@fig-flashattention-memory-savings quantifies the memory advantage across sequence lengths. Standard attention allocates the full $N \times N$ score matrix in HBM, while FlashAttention maintains only $O(N)$ running statistics. The gap widens quadratically: at a typical 8K context, FlashAttention uses 4,096$\times$ less attention memory; at 64K long-context, the savings reach 32,768$\times$. - -::: {#fig-flashattention-memory-savings fig-env="figure" fig-pos="htb" fig-cap="FlashAttention Memory Savings: O(N²) → O(N). Standard attention allocates the full N × N score matrix in HBM, while FlashAttention maintains only O(N) running statistics. The shaded region represents memory freed for KV cache, activations, or larger batch sizes. At long-context lengths (64K+), the savings exceed four orders of magnitude." fig-alt="Log-log plot of attention memory versus sequence length. Standard O(N²) curve rises steeply; FlashAttention O(N) stays flat. Shaded region shows savings, with annotations at 8K and 65K."} -```{python} -#| echo: false -# ┌───────────────────────────────────────────────────────────────────────────── -# │ FLASHATTENTION MEMORY SAVINGS (FIGURE) -# ├───────────────────────────────────────────────────────────────────────────── -# │ Context: @fig-flashattention-memory-savings — O(N²) vs O(N) memory -# │ -# │ Goal: Plot standard attention (N²) vs FlashAttention (N) memory vs seq len; -# │ show 4+ orders magnitude savings at 64K+. -# │ Show: Two curves on log scale; shaded savings region. -# │ How: N = logspace(512, 131072); bytes = 2*N^2 vs 2*N*...; matplotlib. -# │ -# │ Imports: numpy (np), matplotlib.pyplot (plt) -# │ Exports: (figure only, no prose variables) -# └───────────────────────────────────────────────────────────────────────────── -import numpy as np -import matplotlib.pyplot as plt - -plt.style.use('seaborn-v0_8-whitegrid') -plt.figure(figsize=(10, 6)) - -N = np.logspace(np.log2(512), np.log2(131072), num=200, base=2.0) - -bytes_per_element_fp16 = 2 -num_heads = 32 -batch_size = 1 - -mem_standard = (N**2 * bytes_per_element_fp16 * num_heads * batch_size) / 1e9 -mem_flash = (N * 4 * num_heads * batch_size) / 1e9 - -plt.plot(N, mem_standard, label='Standard Attention O(N²)', color='#CC5500', linewidth=2) -plt.plot(N, mem_flash, label='FlashAttention O(N)', color='#006395', linewidth=2) -plt.fill_between(N, mem_flash, mem_standard, color='#D1E6F3', alpha=0.6, label='Memory Savings') - -plt.xscale('log') -plt.yscale('log') - -plt.xlabel('Sequence Length (N)') -plt.ylabel('Attention Memory (GB)') -plt.title('FlashAttention Memory Savings: O(N²) → O(N)', fontsize=14) - -plt.axvline(x=8192, color='gray', linestyle='--', linewidth=1.5) -plt.text(8192, plt.ylim()[0]*1.5, ' Typical LLM context (8K)', rotation=90, - verticalalignment='bottom', color='black') - -plt.axvline(x=65536, color='gray', linestyle='--', linewidth=1.5) -plt.text(65536, plt.ylim()[0]*1.5, ' Long context (65K)', rotation=90, - verticalalignment='bottom', color='black') - -def annotate_ratio(n_val): - ratio = n_val / 2 - mem_val = (n_val**2 * bytes_per_element_fp16 * num_heads) / 1e9 - plt.annotate(f'{int(ratio):,}x Savings', - xy=(n_val, mem_val), - xytext=(n_val, mem_val * 0.25), - arrowprops=dict(facecolor='black', shrink=0.05, width=1, headwidth=6), - ha='center', va='top', fontsize=9, - bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="black", lw=0.5)) - -annotate_ratio(8192) -annotate_ratio(65536) - -plt.legend() -plt.grid(True, which="both", ls="-", alpha=0.7) -plt.tight_layout() -fig = plt.gcf() -``` -::: - -FlashAttention eliminates the memory wall within a single GPU by tiling across the SRAM-HBM boundary. For sequence lengths that exceed the memory capacity of a single GPU, the same tiling principle extends across multiple GPUs via **Ring Attention**\index{Ring Attention}. By distributing the KV cache across a ring of accelerators and overlapping communication with computation, Ring Attention enables million-token context windows that would be impossible on single-GPU configurations. We examine the distributed mechanics of Ring Attention in @sec-distributed-training-systems-systems-tensor-parallelism-d76e. - -## Precision Engineering {#sec-performance-engineering-precision} - -If moving a 2-byte FP16 weight from memory to compute takes 100 nanoseconds, how do we cut that time in half? We shrink the weight to a 1-byte FP8 value. Precision engineering recognizes that machine learning algorithms are surprisingly resilient to numerical noise, allowing us to violently compress our data types to double our effective memory bandwidth and throughput. While we examine 8-bit floating point (FP8) as a training-time primitive in @sec-distributed-training-systems, here we focus on the quantization techniques that enable efficient inference at scale. - -### Block-wise Quantization {#sec-performance-engineering-block-quant} - -Post-training quantization to INT8 or INT4 delivers even greater bandwidth savings for inference, but LLMs present a unique challenge: **outlier features**[^fn-outlier-features]. Dettmers et al. (2022) discovered that large language models develop a small number of hidden dimensions (typically fewer than 1% of all dimensions) with activation magnitudes 10--100$\times$ larger than the rest. - -[^fn-outlier-features]: **Outlier Features**: Large-scale transformers develop emergent "outlier" dimensions with activation magnitudes up to 100$\times$ larger than typical values. While these outliers constitute less than 0.1% of all features, clipping them during INT8 quantization destroys the model's reasoning capabilities. This physical property of large models is the reason Post-Training Quantization (PTQ) requires "outlier-aware" strategies like LLM.int8() or AWQ. \index{Outlier Features!quantization challenge} - Applying uniform per-tensor INT8 quantization clips these outliers, destroying the information they carry, or expands the quantization range to accommodate them, wasting precision on the majority of near-zero values. - -::: {.callout-definition title="Block-wise Quantization"} - -***Block-wise Quantization***\index{Block-wise Quantization!definition} is a precision reduction technique that partitions weight tensors into small, independent blocks (typically 32-128 elements) and applies unique quantization parameters to each. - -1. **Significance (Quantitative):** It significantly reduces the **Memory Bandwidth ($BW$)** demand by allowing for aggressive 4-bit or 8-bit compression while maintaining near-FP16 accuracy. By adapting to local value distributions, it minimizes the **Quantization Error** for the entire tensor. -2. **Distinction (Durable):** Unlike **Per-Tensor Quantization**, which uses a single scale factor, Block-wise Quantization preserves **Outlier Features**—large activation magnitudes that carry critical information but would be clipped or diluted by coarser methods. -3. **Common Pitfall:** A frequent misconception is that smaller blocks are always better. In reality, there is a **Metadata Overhead Trade-off**: each block requires its own scale and zero-point tensors, which consume memory and bandwidth ($D_{\text{vol}}$). If blocks are too small, the metadata can exceed the savings from the quantized weights. - -::: - -LLM.int8() solves this by decomposing each matrix multiplication into two parts: a small set of outlier dimensions processed in FP16, and the remaining dimensions processed in INT8. The system identifies outlier dimensions at runtime (those exceeding a magnitude threshold, typically 6.0), routes them to an FP16 GEMM, and routes the remaining dimensions to an INT8 GEMM. The results are combined to produce the final output. This achieves nearly lossless INT8 inference for models that would otherwise degrade substantially under uniform quantization. - -GPTQ (Frantar et al., 2023) takes a different approach: weight-only quantization using second-order information. Instead of quantizing each weight independently, GPTQ processes weights column by column, using the Hessian of the layer's loss surface to determine which quantization errors matter most and redistributing those errors across unquantized columns. This produces INT4 weight representations with minimal accuracy loss, even for models with severe outlier features. The key insight is that quantization error in one weight can be compensated by adjusting correlated weights. - -AWQ (Activation-Aware Weight Quantization, Lin et al., 2024) observes that not all weights are equally important: weights connected to high-activation channels contribute disproportionately to model output. AWQ identifies these salient weights by analyzing activation magnitudes across a calibration dataset, then applies per-channel scaling to protect them before uniform group quantization. This achieves INT4 weight quantization with quality comparable to GPTQ but with 10--100$\times$ faster quantization time (minutes instead of hours), since it avoids the expensive Hessian computation. - -SmoothQuant (Xiao et al., 2023) takes yet another approach to the outlier problem. Rather than handling outliers at runtime (LLM.int8()) or through weight optimization (GPTQ, AWQ), SmoothQuant smooths the activation distribution *before* quantization by migrating the quantization difficulty from activations to weights. The key observation is that activation outliers are channel-specific: certain hidden dimensions consistently produce large values across all tokens. SmoothQuant applies a per-channel scaling transformation that divides the activation by a smoothing factor and multiplies the corresponding weight by the same factor. This mathematically equivalent transformation reduces activation outlier magnitudes at the cost of slightly increasing weight magnitudes, making both tensors more amenable to uniform INT8 quantization. The result is efficient W8A8 (weight-8-bit, activation-8-bit) quantization that exploits INT8 Tensor Cores for both bandwidth and compute benefits. - -These four approaches, LLM.int8(), GPTQ, AWQ, and SmoothQuant, represent a progression in the sophistication of quantization techniques for LLMs. LLM.int8() handles outliers at runtime with mixed-precision decomposition but limits compression to INT8. GPTQ uses second-order information for aggressive INT4 weight compression but requires hours of calibration per model. AWQ achieves similar INT4 quality with minutes of calibration by focusing on activation-aware scaling. SmoothQuant enables W8A8 quantization by preprocessing the weight-activation pairs. In practice, AWQ has become the default choice for weight-only quantization in production LLM deployment, while SmoothQuant is preferred when both weight and activation quantization are needed for compute-bound workloads. - -The choice among these techniques also depends on the deployment target. For GPU inference with Tensor Core support, GPTQ and AWQ produce INT4 weight representations that are dequantized to FP16 during the GEMM computation, using the GPU's FP16 Tensor Cores. For CPU inference or edge deployment, INT8 representations (LLM.int8() or static per-channel INT8 quantization) can directly exploit integer arithmetic units without dequantization overhead. - -The storage cost for block-wise quantization is minimal. Storing one FP32 scale (32 bits) for every block of 128 INT8 weights (1024 bits) increases total model size by only 3%. This small overhead allows block-wise quantization to isolate the destructive impact of outliers, preserving the effective dynamic range for the 99% of normal weights, without the bandwidth penalty of higher-precision formats. - -::: {.callout-lighthouse title="Archetype C (Federated MobileNet): TinyML Survival"} -For **Archetype C (Federated MobileNet)**, quantization is not just an optimization; it is a prerequisite for survival. On a microcontroller with only 512 KB of SRAM, an FP16 model is physically impossible to load. **Binary Neural Networks (BNN)** and **1-bit Quantization** push this to the extreme, representing weights as single bits (+1/-1). While this trades significant accuracy, it reduces the memory footprint by 32$\times$ and the energy per operation by up to 100$\times$, enabling intelligence on devices that operate on harvested energy (microwatts). -::: - -### Post-Training vs Quantization-Aware Training {#sec-performance-engineering-ptq-qat} - -The trade-off between **Post-Training Quantization (PTQ)** and **Quantization-Aware Training (QAT)** centers on the balance between engineering agility and model fidelity. For a model like Llama-2-70B, PTQ is the default choice for immediate deployment. Techniques like GPTQ or AWQ process the model layer-by-layer using a small calibration dataset (typically 128--1024 samples) to minimize reconstruction error. This process is computationally cheap, requiring approximately 4--8 GPU-hours on a single H100 to quantize a 70B model to INT4. While PTQ preserves greater than 99% of accuracy at INT8, aggressive quantization to INT4 or INT3 often incurs a steep penalty: perplexity may degrade by 0.5--1.0 points, and reasoning performance on benchmarks like MMLU can drop from 69% to below 64%. - -When PTQ fails to meet quality thresholds, QAT provides the remedy by integrating quantization noise directly into the training loop. By simulating low-precision rounding during the forward pass and approximating gradients during the backward pass via the straight-through estimator[^fn-ste-quantization] (STE), the network learns to adjust its weights to be robust to quantization. - -[^fn-ste-quantization]: **Straight-Through Estimator (STE)**: Proposed by Bengio et al. (2013), the STE handles a fundamental calculus problem: the gradient of a rounding function is zero almost everywhere, making backpropagation through quantized layers impossible by standard rules. The STE simply passes the upstream gradient through the rounding step unchanged, pretending the rounding did not happen. This mathematically unjustified approximation works in practice because it preserves the gradient's direction even though it ignores the rounding noise, enabling QAT to converge. \index{Straight-Through Estimator!quantization} - -The cost is substantial: QAT is effectively a full fine-tuning run, often requiring hundreds of GPU-hours and a distributed training cluster. For a 70B model, this might mean a 3-day run on 8$\times$ H100s compared to the 4-hour single-GPU job for PTQ. Emerging techniques like **QLoRA** (Quantized Low-Rank Adaptation) bridge this gap by freezing the base model in 4-bit precision and fine-tuning only a small set of high-precision adapter weights. This hybrid approach offers the quality recovery of QAT with a memory footprint small enough to run on a single consumer GPU, effectively democratizing high-fidelity quantization. - -The practical workflow in most production environments follows a two-stage approach: deploy with PTQ first (because it is fast and requires no training infrastructure), then apply QAT or QLoRA only if the PTQ model fails to meet quality requirements at the target precision. This sequence minimizes engineering effort while preserving the option of higher quality when needed. - -### KV Cache Compression and Architectural Optimization {#sec-performance-engineering-kv-cache} - -The Key-Value (KV) cache is the primary memory bottleneck in large-scale LLM inference. Strategies to mitigate this pressure include numerical compression (quantization to INT8, FP8, or INT4) and architectural optimizations like Grouped Query Attention (GQA). Because these techniques are fundamental to scaling model serving, we provide a rigorous quantitative analysis of KV cache memory footprints and the performance impact of GQA in @sec-inference-scale-kv-cache-compression-943a. - -### Weight-Only vs Weight-Activation Quantization {#sec-performance-engineering-quant-strategies} -Weight-only quantization (GPTQ, AWQ) reduces weight precision to INT4 or INT3 while keeping activations in FP16. During a GEMM, the INT4 weights are dequantized to FP16 on-the-fly, and the computation proceeds using FP16 Tensor Cores. The benefit is reduced memory for weight storage and reduced HBM bandwidth for weight reads, but the GEMM itself still operates at FP16 precision. This approach is ideal for memory-bound inference (batch size 1 decode), where the bottleneck is reading weights from HBM. - -Weight-activation quantization (SmoothQuant, FP8 training) reduces both weights and activations to lower precision, enabling the GEMM to execute using lower-precision arithmetic (INT8 Tensor Cores, FP8 Tensor Cores). This provides both bandwidth and compute benefits but is more challenging to implement without quality degradation, because activation distributions are more dynamic and harder to quantize than weight distributions. - -The choice depends on the operational regime. For memory-bound inference (small batch sizes), weight-only INT4 quantization provides the largest speedup per unit of quality degradation. For compute-bound inference (large batch sizes) or training, weight-activation FP8 quantization provides throughput gains that weight-only quantization cannot match. Many production systems use different quantization strategies for different operating points: INT4 weight-only at low batch sizes (for latency) and FP8 weight-activation at high batch sizes (for throughput). - -```{python} -#| echo: false -#| label: kv-cache-size -# ┌───────────────────────────────────────────────────────────────────────────── -# │ KV CACHE MEMORY PER REQUEST (LEGO) -# ├───────────────────────────────────────────────────────────────────────────── -# │ Context: Precision Dividend callout — @sec-performance-engineering-kv-cache -# │ -# │ Goal: Compute KV cache memory per request for a 70B model (Llama-style GQA) -# │ at FP16, then derive max batch sizes before/after quantization. -# │ Show: kv_fp16_gb (~1.7 GB), max batch FP16 vs INT8, throughput ratio. -# │ How: 2 (K+V) * layers * kv_heads * head_dim * seq_len * bytes. -# │ -# │ Imports: mlsys.formatting (check) -# │ Exports: kv_fp16_gb_str, max_batch_fp16_str, max_batch_int8_str, -# │ batch_ratio_str -# └───────────────────────────────────────────────────────────────────────────── - -# ┌── LEGO ─────────────────────────────────────────────── -class KVCacheAnalysis: - """KV cache sizing and max batch capacity.""" - - # ┌── 1. LOAD (Constants) ────────────────────────────────────────────── - num_layers = 80 - num_kv_heads = 8 # GQA: 8 KV heads for 70B Llama - head_dim = 128 - seq_len = 4096 - bytes_fp16 = 2 - bytes_int8 = 1 - - # Capacity per GPU (H100 80 GB) - gpu_total_gb = 80 - weight_fp16_shard_gb = 35 # 140 / 4 GPUs - weight_int4_shard_gb = 8.75 - - # ┌── 2. EXECUTE (The Compute) ──────────────────────────────────────── - # Base size in GB per request - kv_cache_bytes = 2 * num_layers * num_kv_heads * head_dim * seq_len * bytes_fp16 - kv_fp16_gb_val = kv_cache_bytes / 1e9 - - # Max batch calculations - # FP16 case: 45 GB available for KV - avail_fp16_gb = gpu_total_gb - weight_fp16_shard_gb - max_batch_fp16_val = int(avail_fp16_gb / (kv_fp16_gb_val / 4)) # shared across 4 GPUs? - # Wait, the callout text says "available for KV cache: 80 - 35 = 45 GB/GPU" - # and "KV cache per request: 1.1 GB / 4 GPUs = 0.27 GB/GPU". - # Max batch = 45 / 0.27 = 166. - max_batch_fp16_val = int(avail_fp16_gb / (kv_fp16_gb_val / 4)) - - # INT8 case - avail_int8_gb = gpu_total_gb - weight_int4_shard_gb - kv_int8_per_gpu = (kv_fp16_gb_val / 2) / 4 - max_batch_int8_val = int(avail_int8_gb / kv_int8_per_gpu) - - batch_ratio = max_batch_int8_val / max_batch_fp16_val - - # ┌── 3. GUARD (Invariants) ────────────────────────────────────────── - check(kv_fp16_gb_val > 1.0, f"KV cache size too low, got {kv_fp16_gb_val:.2f}") - check(max_batch_int8_val > max_batch_fp16_val, "INT8 should increase batch size") - - # ┌── 4. OUTPUT (Formatting) ────────────────────────────────────────────── - kv_fp16_gb_str = f"{kv_fp16_gb_val:.1f}" - max_batch_fp16_str = f"{max_batch_fp16_val}" - max_batch_int8_str = f"{max_batch_int8_val}" - batch_ratio_str = f"{int(batch_ratio)}" - -# ┌── EXPORTS (Bridge to Text) ───────────────────────────────────────────────── -kv_fp16_gb_str = KVCacheAnalysis.kv_fp16_gb_str -max_batch_fp16_str = KVCacheAnalysis.max_batch_fp16_str -max_batch_int8_str = KVCacheAnalysis.max_batch_int8_str -batch_ratio_str = KVCacheAnalysis.batch_ratio_str -``` - -::: {.callout-notebook title="The Precision Dividend"} - -**Problem**: You serve a 70B parameter model on 4$\times$ H100 GPUs. The model weights in FP16 consume 140 GB (35 GB per GPU). KV cache at FP16 consumes `{python} kv_fp16_gb_str` GB per request. How does quantizing weights to INT4 and KV cache to INT8 change the maximum batch size? - -**Before optimization** (all FP16): - -- Weights: 35 GB/GPU -- Available for KV cache: $80 - 35 = 45$ GB/GPU -- KV cache per request: `{python} kv_fp16_gb_str` GB $\div$ 4 GPUs $\approx$ `{python} f"{kv_fp16_gb / 4:.1f}"` GB/GPU -- Maximum batch size: $\lfloor 45 /$ `{python} f"{kv_fp16_gb / 4:.1f}"` $\rfloor \approx$ `{python} max_batch_fp16_str` requests - -**After optimization** (INT4 weights, INT8 KV cache): - -- Weights: 35 GB$\times$ (4/16) = 8.75 GB/GPU (INT4) -- Available for KV cache: $80 - 8.75 = 71.25$ GB/GPU -- KV cache per request (INT8): `{python} f"{kv_fp16_gb / 2:.1f}"` GB $\div$ 4 $\approx$ `{python} f"{kv_fp16_gb / 2 / 4:.1f}"` GB/GPU -- Maximum batch size: $\lfloor 71.25 /$ `{python} f"{kv_fp16_gb / 2 / 4:.1f}"` $\rfloor \approx$ `{python} max_batch_int8_str` requests - -**Takeaway**: Precision engineering does not just make individual operations faster; it fundamentally changes serving economics by enabling larger batch sizes. Larger batches amortize the fixed cost of weight loading, shifting operations from memory-bound toward compute-bound. This single optimization can increase throughput by `{python} batch_ratio_str`$\times$ or more. - -::: - -Precision engineering reduces the bytes per memory transaction. Operator fusion reduces the number of transactions. Together, they attack the same fundamental bottleneck from complementary directions: if you must move data across a slow bus, move less of it (precision) and move it fewer times (fusion). The multiplicative interaction between these two techniques explains why modern serving systems deploy both simultaneously: FlashAttention reduces attention HBM traffic from quadratic to linear, and INT8 KV cache compression further halves the linear term. The combined effect exceeds what either technique achieves alone. - -Graph compilers automate both of these optimizations, applying them systematically across an entire model. - -## Graph Compilation {#sec-performance-engineering-compilation} - -Manually writing fused CUDA kernels for every possible combination of layers in a massive neural network is a Sisyphean task for human engineers. Instead of hand-crafting these optimizations one by one, what if we could write a program to analyze the model's structure and generate the optimal kernels automatically? Graph compilation does exactly this, transforming high-level PyTorch code into specialized, hardware-aware machine instructions. - -### The Compilation Pipeline {#sec-performance-engineering-compilation-pipeline} - -A graph compiler transforms a high-level model definition (Python code) into optimized hardware instructions through a multi-stage pipeline. To visualize this process, consider a standard transformer FFN block consisting of a projection, an activation, a second projection, and a layer normalization: `LayerNorm(Linear(GELU(Linear(x))))`. In standard PyTorch eager execution, this sequence triggers four separate kernel launches, each reading from and writing to HBM. - -In the **graph capture** stage, the compiler traces the model's execution to construct a computational graph, a directed acyclic graph where nodes represent operations and edges represent tensor dependencies. For the FFN block, this results in a graph with four primary nodes plus their associated parameter tensors. Dynamic Python control flow (loops, conditionals) must be handled by either tracing through a representative execution path or by using compiler-specific annotations to mark dynamic dimensions. - -During **graph-level optimization**, the compiler applies algebraic simplifications and operation rewriting. It identifies that the bias addition in the first `Linear` layer can be folded into the matrix multiplication kernel. It also recognizes that the `GELU` activation is an element-wise operation that depends only on the output of the first `Linear`. These standard compiler optimizations can reduce graph size by 10--30% before any hardware-specific work begins. - -The **operator fusion** pass is the most critical for performance. It identifies sequences of operations that can be combined into single kernels to reduce memory traffic. For the FFN block, the compiler fuses the `GELU` activation into the tail of the first `Linear` kernel (if supported as an epilogue) or fuses the `GELU` and the subsequent `LayerNorm` into a single kernel. Instead of writing the intermediate result of the first `Linear` to HBM and reading it back for `GELU`, the fused kernel keeps the data in the GPU's SRAM or registers. This typically reduces the number of HBM accesses by 30--50%, directly alleviating the memory bandwidth bottleneck. - -The memory planning pass determines when to allocate and free tensors. Without optimization, a 24-layer transformer might allocate separate buffers for every intermediate activation. The compiler analyzes tensor lifetimes, recognizing that the input to the first `Linear` is no longer needed after the second `Linear` computes its output, and reuses the same physical memory addresses. For a 70B model where activations can consume gigabytes per layer, this buffer reuse reduces peak memory requirements from $O(L)$ to $O(1)$, where $L$ is the number of layers. For a model where activations consume 10 GB per layer across 80 layers, this optimization reduces peak activation memory from 800 GB (impossible on any single GPU) to approximately 10--20 GB (comfortably within a single H100). Memory planning also interacts with operator fusion: fusing two operations eliminates the intermediate tensor between them, which both removes the HBM traffic and removes the memory allocation. The compiler must reason about both effects jointly to make profitable decisions. - -The kernel selection pass maps each fused operation to a specific machine code implementation. For the computationally heavy linear projections, the compiler selects a vendor-optimized cuBLAS or CUTLASS GEMM kernel. For the fused `GELU-LayerNorm` sequence, it generates a custom Triton kernel that keeps data in SRAM. The result for the FFN block is a reduction from 4 separate kernels to 2 highly optimized kernels, with a corresponding reduction in global memory traffic. - -### torch.compile {#sec-performance-engineering-torch-compile} - -PyTorch's `torch.compile` (introduced in PyTorch 2.0) brings graph compilation to the most widely used ML framework. It operates through three components: **TorchDynamo** for graph capture, **TorchInductor** for code generation, and **AOTAutograd** for ahead-of-time backward graph construction. - -TorchDynamo operates at the Python bytecode level[^fn-torchdynamo-bytecode], a design choice that distinguishes it from earlier tracing approaches. - -[^fn-torchdynamo-bytecode]: **TorchDynamo Bytecode Interception**: By hooking CPython's frame evaluation function (`PEP 523`, added in Python 3.6), TorchDynamo captures the computation graph at the lowest level of the Python interpreter, below any source-level abstractions. This is why it can trace through decorators, closures, and third-party libraries that defeated earlier tracing approaches. The trade-off is tight coupling to CPython internals: TorchDynamo must be updated for each new Python version, and it cannot run on alternative interpreters like PyPy. \index{TorchDynamo!bytecode tracing} - -Previous tracing methods (torch.jit.trace, torch.fx) operated at the Python source or AST level, requiring users to avoid unsupported Python constructs. TorchDynamo intercepts the bytecode interpreter itself, capturing a computational graph without requiring the user to modify their model code. When TorchDynamo encounters Python constructs it cannot trace (data-dependent control flow, unsupported operations), it inserts a "graph break" that splits the trace into multiple subgraphs, each compiled independently. The goal is to capture as large a subgraph as possible while gracefully handling dynamic Python behavior. - -TorchInductor generates optimized Triton kernels (for GPU) or C++/OpenMP code (for CPU) from the captured graph. Triton is a domain-specific language for writing GPU kernels in Python-like syntax, abstracting away thread block management and memory coalescing while still exposing tiling and fusion decisions. TorchInductor automatically fuses element-wise operations, reduces memory traffic by combining operations that share inputs, and selects tile sizes through autotuning. - -A minimal example illustrates the usage: - -```python -import torch - -def transformer_block(x, w1, w2, ln_weight, ln_bias): - """Unfused transformer FFN block.""" - h = x @ w1 # Linear projection - h = torch.nn.functional.gelu(h) # Activation - h = h @ w2 # Output projection - # Layer normalization - mean = h.mean(dim=-1, keepdim=True) - var = h.var(dim=-1, keepdim=True, unbiased=False) - h = (h - mean) / torch.sqrt(var + 1e-5) - h = h * ln_weight + ln_bias - return h - -# Compile the function — TorchDynamo traces, TorchInductor optimizes -compiled_block = torch.compile(transformer_block) - -# First call triggers compilation; subsequent calls use compiled code -output = compiled_block(x, w1, w2, ln_weight, ln_bias) -``` - -In this example, `torch.compile` will fuse the GELU activation with surrounding operations, combine the layer normalization mean/variance/normalize steps into a single kernel, and potentially fuse the bias addition with the preceding GEMM. The user writes standard PyTorch code; the compiler handles the optimization. - -### XLA and TPU Optimization {#sec-performance-engineering-xla} - -XLA (Accelerated Linear Algebra) is Google's graph compiler, used as the backend for JAX and TensorFlow. Unlike TorchInductor, which generates Triton code targeting NVIDIA GPUs, XLA generates HLO (High-Level Operations) intermediate representation that targets multiple backends, including Google TPUs, NVIDIA GPUs, and CPUs. Its architecture is fundamentally different from `torch.compile`: while PyTorch prioritizes flexibility by allowing graph breaks for unsupported Python features, XLA enforces **whole-program compilation**, tracing the entire computation as a single static graph and enabling global optimizations that span across layers and even across the forward and backward passes. - -This global view enables XLA's most distinctive capability: the **GSPMD (General Partitioner for SPMD)**. In distributed training, GSPMD automatically partitions the computation graph across thousands of TPU cores based on a few high-level user annotations. While a PyTorch user must manually wrap models with `DistributedDataParallel` or `FullyShardedDataParallel`, an XLA user defines the computation for a single device and allows the compiler to infer the necessary communication primitives (AllReduce, AllGather) and insert them into the graph. This allows for complex hybrid sharding strategies that are difficult to implement manually. - -For TPU hardware specifically, XLA performs layout optimizations unavailable on other platforms. It maps matrix multiplications onto the TPU's systolic array architecture, padding dimensions to align with the $128\times128$ hardware units and scheduling instructions to hide the latency of HBM fetches. The impact of these optimizations is visible in MFU metrics. On large-scale LLM training workloads, JAX/XLA on TPUv4 typically achieves 55--65% MFU, outperforming the 40--55% MFU typical of PyTorch/GPU setups without extensive hand-tuning. - -The trade-off for XLA's performance is compilation latency and rigidity. Because XLA must analyze the full static graph, initial compilation can take minutes for large models, compared to seconds for `torch.compile`. Any change in input shape triggers a full recompilation. This makes XLA excellent for steady-state production workloads where the graph is static and the model runs for days or weeks, but challenging for research environments involving dynamic shapes or rapid experimental iteration. The choice between `torch.compile` and XLA often depends on the deployment context: organizations using NVIDIA GPUs predominantly use PyTorch with `torch.compile`, while organizations using Google TPUs use JAX with XLA. - -### TensorRT: Inference Optimization {#sec-performance-engineering-tensorrt} - -NVIDIA TensorRT is a specialized inference compiler that treats the model not as a flexible program but as a rigid global optimization problem. Because inference requires no backward pass and no gradient storage, TensorRT applies aggressive transformations that would be mathematically invalid or impractically slow for training. It performs a calibration pass where it runs the model on representative data to determine the numerical range of every activation tensor. - -This calibration enables **mixed-precision quantization** at a granular level. For a 70B parameter LLM, blindly quantizing all layers to INT8 often degrades perplexity. TensorRT analyzes the sensitivity of each layer individually. It might determine that the attention layers in the first 3 blocks and the final 3 blocks are highly sensitive to precision loss, keeping them in FP16, while aggressively quantizing the middle 74 layers to INT8. This automated mixed-precision strategy recovers accuracy while capturing the throughput benefits of lower precision. TensorRT also eliminates training-only operations (dropout, batch normalization running statistics updates), optimizes for static shapes by generating kernels tuned for exact dimensions, and plans memory precisely since no gradient tensors are needed. - -TensorRT performs **kernel autotuning** far beyond simple heuristics. For every operation in the graph, it benchmarks dozens of candidate kernels, varying tile sizes, thread block configurations, and unrolling factors, on the actual target hardware. It selects the single fastest implementation for that specific GPU and input shape. The performance gap between TensorRT and general-purpose compilers is substantial: in head-to-head comparisons for serving a 70B LLM, TensorRT-LLM typically achieves 1.3--1.8$\times$ higher throughput than `torch.compile` with TorchInductor. - -The trade-off for TensorRT's aggressive optimization is reduced flexibility and high compilation cost. Compilation times are measured in minutes to hours (30--60 minutes for a 70B model), and the resulting engine is strictly tied to the specific GPU architecture and input shape range. Changing any of these requires recompilation. This makes TensorRT the standard for stable, high-volume production deployments, while `torch.compile` remains the preferred choice for development and lower-volume services where rapid iteration matters more than extracting the last percentage of throughput. - -### Compilation Overhead and Trade-offs {#sec-performance-engineering-compilation-overhead} - -Graph compilation is not free. The compilation process itself takes time, ranging from seconds for small models with `torch.compile` to minutes or hours for large models with TensorRT's full optimization pipeline. This overhead must be amortized over the number of times the compiled model executes. - -For training workloads that run for hours or days, compilation overhead is negligible. For inference workloads that serve millions of requests, the one-time compilation cost is similarly amortized. The problematic case is dynamic or infrequent workloads: a model that is compiled once but serves only a few hundred requests before being replaced by a new version may not recoup the compilation cost. - -Graph breaks are a related challenge specific to `torch.compile`. When TorchDynamo encounters Python code it cannot trace (data-dependent control flow, calls to uncompiled libraries, dynamic tensor shapes that change between iterations), it inserts a graph break. Each break produces a separate compiled subgraph with its own compilation overhead and potential optimization boundaries. A model with 50 graph breaks produces 50+ small compiled regions, each potentially too small for meaningful fusion. Reducing graph breaks requires refactoring the model code to be more "compiler-friendly," replacing Python control flow with tensor operations and ensuring static shapes where possible. - -Dynamic shapes present a fundamental tension between compilation and flexibility. A model compiled for input shape [batch=32, seq=512] will recompile when it encounters [batch=16, seq=1024]. TorchInductor supports "dynamic shapes" by generating kernels with symbolic dimensions, but this generality comes at the cost of reduced optimization compared to kernels specialized for exact shapes. TensorRT sidesteps this by requiring the user to specify a range of input shapes at compilation time, generating kernels that handle the specified range but nothing outside it. - -Despite these limitations, graph compilation represents the most accessible optimization technique: it requires no model modifications, no custom kernels, and minimal code changes. For most workloads, `torch.compile` with default settings provides 10--40% speedup with a single line of code, making it the natural first optimization to apply before considering more specialized techniques. - -The accessibility of graph compilation has changed the performance engineering workflow. Before `torch.compile`, extracting the last 30% of performance required weeks of manual kernel optimization. Now, a single line of code captures a significant fraction of that improvement, freeing the engineer to focus on the algorithmic and architectural optimizations (speculative decoding, MoE, precision engineering) that compilers cannot automate. The compiler handles the routine work; the engineer handles the creative work. This division of labor is likely to deepen as compilers improve, making the higher-level system design skills in this chapter increasingly valuable relative to low-level kernel engineering. - -### Compilation Modes and Backends {#sec-performance-engineering-compilation-modes} - -`torch.compile` supports multiple optimization levels that trade compilation time for runtime performance: - -The `default` mode applies standard optimizations: element-wise fusion, memory planning, and kernel selection from the pre-tuned library. Compilation is fast (seconds to minutes) and suitable for development iteration. - -The `reduce-overhead` mode additionally wraps the compiled graph in CUDA Graphs (discussed in @sec-performance-engineering-cuda-graphs), eliminating kernel launch overhead. This is particularly effective for small models where launch overhead is a significant fraction of execution time. - -The `max-autotune` mode triggers extensive kernel autotuning, benchmarking multiple kernel variants (different tile sizes, thread block configurations, memory access patterns) for each operation and selecting the fastest. This produces the highest-performance code but may take 10--30 minutes of compilation time, making it suitable for production deployment but impractical during development. - -The choice of backend also matters. TorchInductor (the default) generates Triton kernels for GPU and C++ for CPU. For deployment-specific optimization, the model can be exported through `torch.export` to an intermediate representation that can be consumed by TensorRT, ONNX Runtime, or other inference-specialized runtimes. Each backend applies its own optimization passes on top of the common graph-level transformations. - -### The Triton Language {#sec-performance-engineering-triton} - -Between hand-written CUDA and fully automated graph compilers sits **Triton**[^fn-triton-tiling], a Python-based language for writing GPU kernels. Triton occupies a middle ground: the programmer specifies the algorithm (tiling strategy, fusion pattern) while Triton handles low-level concerns (thread block scheduling, memory coalescing, shared memory management). - -[^fn-triton-tiling]: **Triton**: Created by Philippe Tillet at Harvard (2019), later developed at OpenAI. The key design decision was making the *tile* -- not the thread -- the fundamental programming abstraction. This choice directly mirrors the tiling strategy of FlashAttention: the programmer reasons about blocks of data that fit in SRAM, and the compiler maps those blocks to GPU threads and shared memory. This abstraction eliminates the most error-prone aspects of CUDA (warp divergence, bank conflicts, coalescing) while preserving control over the memory hierarchy decisions that determine ML kernel performance. \index{Triton!kernel programming} - -A Triton kernel for fused GELU activation illustrates the programming model: - -```python -import triton -import triton.language as tl - -@triton.jit -def fused_gelu_kernel( - input_ptr, output_ptr, - n_elements, - BLOCK_SIZE: tl.constexpr, -): - # Each program instance handles BLOCK_SIZE elements - pid = tl.program_id(0) - offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - - # Load input tile from HBM into registers - x = tl.load(input_ptr + offsets, mask=mask) - - # Fused GELU computation (tanh approximation) - # GELU(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * - # x^3))) - x_cubed = x * x * x - inner = 0.7978845608 * (x + 0.044715 * x_cubed) - gelu = 0.5 * x * (1.0 + tl.math.tanh(inner)) - - # Store result back to HBM - tl.store(output_ptr + offsets, gelu, mask=mask) -``` - -The programmer thinks in terms of tiles (`BLOCK_SIZE` elements), not individual threads. Triton compiles this to PTX/SASS instructions, handling thread-to-data mapping, memory coalescing, and register allocation automatically. This makes it feasible for ML engineers (rather than GPU specialists) to write custom fused kernels when the automatic compiler misses a fusion opportunity. - -The real power of Triton emerges when fusing multiple operations. A Triton kernel that implements `y = LayerNorm(GELU(x))` loads `x` once from HBM, computes both GELU and LayerNorm in registers/shared memory, and writes `y` once to HBM. Without fusion, this sequence requires three HBM round-trips: read `x`, write `GELU(x)`, read `GELU(x)`, write `norm_input`, read `norm_input`, write `y`. The fused kernel reduces HBM traffic by 3$\times$, and for memory-bound operations, this translates directly to a 3$\times$ speedup. - -Triton's adoption has accelerated rapidly since its integration into PyTorch's TorchInductor backend. When `torch.compile` identifies a fusion opportunity that requires a custom kernel, TorchInductor automatically generates Triton code for the fused operation. This means that many of the fusion benefits described in this section are available to users through a single `torch.compile` call, without writing any Triton code directly. For advanced use cases where the compiler's heuristics are insufficient, hand-written Triton kernels provide a middle ground between the accessibility of PyTorch and the performance of hand-tuned CUDA. - -::: {.callout-notebook title="The Compilation Dividend"} - -**Problem**: You deploy a 13B parameter model for inference. Without compilation, the PyTorch eager mode processes 120 tokens/second on a single H100. The Nsight Systems trace reveals that 35% of step time is spent in element-wise kernels (LayerNorm, GELU, residual additions) and 15% is kernel launch overhead. You apply `torch.compile` with the `max-autotune` backend. Estimate the new throughput. - -**The Math**: - -*Step 1: Identify the addressable overhead.* -Element-wise kernels (35%) and launch overhead (15%) total 50% of execution time. torch.compile fuses element-wise operations (reducing their time by approximately 70% due to eliminated HBM round-trips) and reduces kernel launches (eliminating most launch overhead). - -*Step 2: Estimate post-compilation time.* -Original time per token: $1/120 = 8.33$ ms. - -- GEMM time (unchanged): $8.33 \times 0.50 = 4.17$ ms -- Element-wise time (70% reduction): $8.33 \times 0.35 \times 0.30 = 0.87$ ms -- Launch overhead (80% reduction): $8.33 \times 0.15 \times 0.20 = 0.25$ ms - -New time per token: $4.17 + 0.87 + 0.25 = 5.29$ ms, yielding approximately 189 tokens/second. - -**Takeaway**: torch.compile delivers a 1.58$\times$ speedup by fusing element-wise operations and reducing launch overhead, without touching the GEMM kernels. The remaining bottleneck is now the GEMM itself (79% of step time), indicating that further improvement requires either precision reduction or batching. - -::: - -Graph compilation automates what manual kernel engineering achieves for individual operations, applying it systematically across the entire model graph. An orthogonal class of optimizations changes the fundamental algorithm itself: speculative decoding trades cheap compute for expensive latency, and mixture of experts decouples model capacity from per-token cost. - -## Speculative Decoding {#sec-performance-engineering-speculative} - -Speculative decoding is a latency optimization that breaks the sequential bottleneck of autoregressive generation by using a smaller draft model to predict multiple tokens, which are then verified in parallel by the target model. Because this technique is primarily employed to meet strict latency SLAs during model serving, we examine the core algorithm, speedup analysis, and hardware resource implications in detail in @sec-inference-scale-speculative-decoding-c438. - -```{python} -#| echo: false -#| label: moe-economics -# ┌───────────────────────────────────────────────────────────────────────────── -# │ MOE VS DENSE ECONOMICS -# ├───────────────────────────────────────────────────────────────────────────── -# │ Context: @sec-performance-engineering-moe MoE economics prose -# │ -# │ Goal: Quantify compute and bandwidth savings from MoE vs dense for a -# │ 1T-parameter dense model vs a 370B total / 37B active MoE model, -# │ showing MoE decouples capacity from inference cost. -# │ Show: ~2 TB vs ~74 GB memory bandwidth per decode step → >5x reduction -# │ Exports: dense_mem_str, moe_mem_str, compute_ratio_str, bw_ratio_str -# └───────────────────────────────────────────────────────────────────────────── - -from mlsys.formatting import check - -# ┌── LEGO ─────────────────────────────────────────────── -class MoEEconomics: - """MoE vs dense model memory and compute economics.""" - - # ┌── 1. LOAD (Constants) ────────────────────────────────────────────── - dense_params_b = 1000 # 1T dense model parameters (billions → bytes below) - moe_total_params_b = 370 # MoE total parameters (billions) - moe_active_params_b = 37 # MoE active parameters per token (10% sparsity) - bytes_fp16 = 2 # FP16: 2 bytes per parameter - experts_per_layer = 8 - active_experts = 1 # top-1 routing - - # ┌── 2. EXECUTE (The Compute) ──────────────────────────────────────── - dense_mem_fp16 = dense_params_b * bytes_fp16 # 2000 GB (2 TB) - moe_mem_fp16 = moe_total_params_b * bytes_fp16 # 740 GB total - dense_decode_bytes = dense_params_b * bytes_fp16 # ~2 TB bandwidth per decode - compute_ratio = dense_params_b / moe_active_params_b # ~27x compute reduction - moe_decode_bytes = moe_active_params_b * bytes_fp16 # 74 GB - bw_ratio = dense_decode_bytes / moe_decode_bytes - - # ┌── 3. GUARD (Invariants) ────────────────────────────────────────── - check(compute_ratio > 5, f"MoE should give >5x compute reduction, got {compute_ratio:.1f}x") - check(bw_ratio > 5, f"MoE should give >5x BW reduction, got {bw_ratio:.1f}x") - - # ┌── 4. OUTPUT (Formatting) ───────────────────────────────────────────── - dense_mem_str = f"{dense_mem_fp16:.0f}" - moe_mem_str = f"{moe_mem_fp16:.0f}" - compute_ratio_str = f"{compute_ratio:.0f}" - bw_ratio_str = f"{bw_ratio:.0f}" - -# ┌── EXPORTS (Bridge to Text) ───────────────────────────────────────────────── -dense_mem_str = MoEEconomics.dense_mem_str -moe_mem_str = MoEEconomics.moe_mem_str -compute_ratio_str = MoEEconomics.compute_ratio_str -bw_ratio_str = MoEEconomics.bw_ratio_str -``` - -## Mixture of Experts {#sec-performance-engineering-moe} - -If we want the reasoning capabilities of a 1-trillion parameter model, but only have the latency budget and compute budget to run a 100-billion parameter model, how do we bridge the gap? We train a massive model but only activate a tiny, relevant fraction of it for any given word. The Mixture of Experts (MoE) architecture routes inputs only to specialized sub-networks, breaking the iron link between model size and compute cost. - -### MoE Architecture {#sec-performance-engineering-moe-architecture} - -In a standard MoE transformer layer, the feed-forward network (FFN) is replaced by multiple parallel "expert" FFNs and a lightweight **router** (also called a gating network) that selects which experts process each token. Because the serving logic for these models—including Expert Parallelism (EP), AllToAll communication patterns, and load-balancing strategies—is fundamental to large-scale inference, we provide a rigorous quantitative treatment of MoE serving in @sec-inference-scale-expert-parallelism-moe-models-55f6. - -## Communication-Computation Overlap {#sec-performance-engineering-overlap} - -Consider a concrete example: a 70B model with tensor parallelism across 8 H100 GPUs. Each transformer layer requires two AllReduce operations (one after attention, one after FFN). Each AllReduce transfers approximately $2 \times d \times \text{batch} \times 2$ bytes at FP16 (where $d = 8192$ for a 70B model). At batch size 1, the data volume is $2 \times 8192 \times 1 \times 2 = 32$ KB per AllReduce. At 900 GB/s NVLink bandwidth, this transfer takes approximately 36 ns. However, the AllReduce launch overhead (approximately 5 $\mu$s) dominates the actual data transfer time by 100$\times$. At batch size 1, the AllReduce overhead is dominated by software launch latency, not bandwidth, and overlap provides limited benefit because there is insufficient compute to hide behind. - -At batch size 64, the data volume per AllReduce grows to $2 \times 8192 \times 64 \times 2 = 2$ MB, taking approximately 2.2 $\mu$s at NVLink bandwidth. The corresponding GEMM computation on each shard also takes approximately 20 $\mu$s (for a $64 \times 8192 \times 1024$ GEMM, where 1024 is the per-GPU hidden dimension). Here, the compute time exceeds the communication time by 9$\times$, and overlap becomes highly effective. This illustrates why batch size is the universal control knob: it simultaneously improves arithmetic intensity, GPU utilization, and communication overlap effectiveness. - -### Quantifying the Overlap Opportunity {#sec-performance-engineering-overlap-quantify} - -The potential benefit of communication-computation overlap depends on the relative magnitudes of communication and computation time, which vary dramatically across system configurations. - -```{python} -#| echo: false -#| label: overlap-calc -# ┌───────────────────────────────────────────────────────────────────────────── -# │ COMMUNICATION-COMPUTATION OVERLAP ANALYSIS -# ├───────────────────────────────────────────────────────────────────────────── -# │ Context: @sec-performance-engineering-comm-comp-overlap overlap analysis prose -# │ -# │ Goal: Quantify training step speedup from overlapping ring-AllReduce with -# │ backward pass for a 7B model on 8 H100 GPUs over NVLink (900 GB/s), -# │ at 40% MFU, to show when overlap fully hides communication overhead. -# │ Show: ~14 GB gradients; ~31 ms backward; ~31 ms AllReduce; step collapses from -# │ ~93 ms (no overlap) to ~62 ms (with overlap) → ~1.50x speedup — inline. -# │ How: AllReduce time = 2 * gradient_bytes / nvlink_bw; backward time = -# │ backward_flops / (h100_tflops * MFU); overlap removes min(back, AR). -# │ -# │ Imports: (none — uses inline constants), mlsys.formatting (check) -# │ Exports: gradient_gb_str, allreduce_ms_str, backward_ms_str, forward_ms_str, -# │ step_no_overlap_ms_str, step_overlap_ms_str, overlap_speedup_str, -# │ overlap_status -# └───────────────────────────────────────────────────────────────────────────── - -# ┌── LEGO ─────────────────────────────────────────────── -class OverlapCalc: - """Communication-computation overlap speedup for 7B model on 8 H100s.""" - - # ┌── 1. LOAD (Constants) ────────────────────────────────────────────── - model_params = 7e9 # 7B parameter model - bytes_per_param_fp16 = 2 # FP16 gradients - num_gpus = 8 - nvlink_bw = 900e9 # H100 NVLink: 900 GB/s bidirectional - ring_allreduce_factor = 2 # Ring AllReduce: 2*(N-1)/N ≈ 2 for large N - - # Backward pass compute: ~2 * forward FLOPs (gradient computation) - # Forward FLOPs ≈ 2 * params * seq_len * batch_per_gpu - # Simplified: backward ≈ 4 * params * tokens_per_gpu - tokens_per_gpu = 2048 # seq_len * micro_batch - forward_flops = 2 * model_params * tokens_per_gpu - backward_flops = 2 * forward_flops # backward is ~2x forward - h100_fp16_tflops = 989e12 # Tensor core TFLOPS - mfu = 0.40 # Realistic MFU - - # ┌── 2. EXECUTE (The Compute) ──────────────────────────────────────── - - # Gradient size - gradient_bytes = model_params * bytes_per_param_fp16 - - # AllReduce time (ring): 2*(N-1)/N * gradient_bytes / bandwidth - allreduce_time_s = ring_allreduce_factor * gradient_bytes / nvlink_bw - - # Backward pass compute time - backward_time_s = backward_flops / (h100_fp16_tflops * mfu) - - # Forward pass compute time (roughly half of backward) - forward_time_s = forward_flops / (h100_fp16_tflops * mfu) - - # Step time without overlap - step_no_overlap = forward_time_s + backward_time_s + allreduce_time_s - - # Step time with overlap - step_with_overlap = forward_time_s + max(backward_time_s, allreduce_time_s) - - # Speedup - overlap_speedup = step_no_overlap / step_with_overlap - - # Can we fully overlap? - can_overlap = backward_time_s > allreduce_time_s - - # ┌── 3. GUARD (Invariants) ────────────────────────────────────────── - check(overlap_speedup > 1.0, - f"Overlap should provide speedup, got {overlap_speedup:.2f}x") - - # ┌── 4. OUTPUT (Formatting) ───────────────────────────────────────────── - gradient_gb_str = f"{gradient_bytes / 1e9:.0f}" - allreduce_ms_str = f"{allreduce_time_s * 1000:.1f}" - backward_ms_str = f"{backward_time_s * 1000:.1f}" - forward_ms_str = f"{forward_time_s * 1000:.1f}" - step_no_overlap_ms_str = f"{step_no_overlap * 1000:.1f}" - step_overlap_ms_str = f"{step_with_overlap * 1000:.1f}" - overlap_speedup_str = f"{overlap_speedup:.2f}" - overlap_status = "fully hidden" if can_overlap else "partially hidden" - -# ┌── EXPORTS (Bridge to Text) ───────────────────────────────────────────────── -gradient_gb_str = OverlapCalc.gradient_gb_str -allreduce_ms_str = OverlapCalc.allreduce_ms_str -backward_ms_str = OverlapCalc.backward_ms_str -forward_ms_str = OverlapCalc.forward_ms_str -step_no_overlap_ms_str = OverlapCalc.step_no_overlap_ms_str -step_overlap_ms_str = OverlapCalc.step_overlap_ms_str -overlap_speedup_str = OverlapCalc.overlap_speedup_str -overlap_status = OverlapCalc.overlap_status -``` - -Consider a 7B parameter model trained on 8 H100 GPUs within a single node connected by NVLink at 900 GB/s. The gradient tensor contains `{python} gradient_gb_str` GB of FP16 values. A ring AllReduce across 8 GPUs transfers approximately $2 \times (N-1)/N \approx 2$ times this volume, taking approximately `{python} allreduce_ms_str` ms at NVLink bandwidth. The backward pass computation, at 40% Model FLOPs Utilization, takes approximately `{python} backward_ms_str` ms. - -Without overlap, the training step requires `{python} step_no_overlap_ms_str` ms (forward + backward + AllReduce). With overlap, the AllReduce is `{python} overlap_status` behind the backward pass, reducing the step time to `{python} step_overlap_ms_str` ms, a `{python} overlap_speedup_str`$\times$ improvement. This example illustrates a critical property: overlap is most effective when backward compute time exceeds AllReduce time. For smaller models or slower interconnects (e.g., PCIe at 64 GB/s instead of NVLink at 900 GB/s), the AllReduce would exceed the backward pass, and no amount of overlap can fully hide the communication. - -::: {#fig-overlap-budget fig-env="figure" fig-pos="htb" fig-cap="The Overlap Budget: As cluster scale increases (8 to 1024 GPUs), the communication overhead grows, eventually exceeding the computation budget. The overlap efficiency (black line) represents the percentage of communication hidden behind computation, which degrades from ~90% to ~40% due to inter-node latency and bandwidth constraints on a 70B parameter model." fig-alt="Stacked bar chart of step time contribution versus GPU count. Overlapped communication, pure computation, and exposed communication. Black line shows overlap efficiency degrading from 90% to 40%."} -```{python} -#| echo: false -# ┌───────────────────────────────────────────────────────────────────────────── -# │ OVERLAP BUDGET (FIGURE) -# ├───────────────────────────────────────────────────────────────────────────── -# │ Context: @fig-overlap-budget — communication-computation overlap -# │ -# │ Goal: Plot compute budget, comm overhead, overlap efficiency vs GPU count; -# │ show degradation from ~90% to ~40% at scale. -# │ Show: Bar/line chart; 8–1024 GPUs; overlap % curve. -# │ How: comm_scale_factor; overlap = compute/(compute+comm); matplotlib. -# │ -# │ Imports: matplotlib.pyplot (plt), numpy (np) -# │ Exports: (figure only, no prose variables) -# └───────────────────────────────────────────────────────────────────────────── -import matplotlib.pyplot as plt -import numpy as np - -plt.style.use('seaborn-v0_8-whitegrid') - -gpus = np.array([8, 16, 32, 64, 128, 256, 512, 1024]) -x = np.arange(len(gpus)) -compute_budget = 100.0 - -comm_scale_factor = np.logspace(0, 1.2, len(gpus)) -total_comm_cost = 35.0 * comm_scale_factor - -target_efficiency = np.linspace(0.92, 0.38, len(gpus)) - -hidden_comm = total_comm_cost * target_efficiency -hidden_comm = np.minimum(hidden_comm, compute_budget) - -realized_efficiency = hidden_comm / total_comm_cost -exposed_comm = total_comm_cost - hidden_comm - -comp_overlapped = hidden_comm -comp_pure = compute_budget - comp_overlapped -comm_exposed = exposed_comm - -fig, ax1 = plt.subplots(figsize=(10, 6)) - -c_pure = '#1f77b4' -c_overlap = '#2ca02c' -c_exposed = '#d62728' - -p1 = ax1.bar(x, comp_overlapped, label='Overlapped Communication', color=c_overlap, alpha=0.8, edgecolor='white') -p2 = ax1.bar(x, comp_pure, bottom=comp_overlapped, label='Pure Computation', color=c_pure, alpha=0.7, edgecolor='white') -p3 = ax1.bar(x, comm_exposed, bottom=comp_overlapped + comp_pure, label='Exposed Communication', color=c_exposed, alpha=0.8, hatch='//', edgecolor='white') - -ax2 = ax1.twinx() -line = ax2.plot(x, realized_efficiency * 100, color='black', marker='o', linewidth=2, linestyle='--', label='Overlap Efficiency') -ax2.set_ylabel('Overlap Efficiency (%)', color='black', fontsize=12, labelpad=10) -ax2.set_ylim(0, 110) -ax2.grid(False) - -ax1.set_xlabel('Number of GPUs', fontsize=12) -ax1.set_ylabel('Step Time Contribution (Normalized)', fontsize=12) -ax1.set_xticks(x) -ax1.set_xticklabels(gpus) -ax1.set_title('The Overlap Budget: Distributed Training Scaling', fontsize=14, pad=15) - -lines, labels = ax1.get_legend_handles_labels() -lines2, labels2 = ax2.get_legend_handles_labels() -ax1.legend(lines + lines2, labels + labels2, loc='upper left', frameon=True, fancybox=True, framealpha=0.9) - -ax1.text(0, compute_budget + 5, 'NVLink\nIntra-node', ha='center', fontsize=9, color='#555') -ax1.text(7, compute_budget + 130, 'InfiniBand\nInter-node', ha='center', fontsize=9, color='#555') - -plt.tight_layout() -fig = plt.gcf() -``` -::: - -### CUDA Streams and Asynchronous Execution {#sec-performance-engineering-cuda-streams} - -The mechanism enabling overlap on NVIDIA GPUs is **CUDA streams**. A CUDA stream is an ordered sequence of GPU operations (kernel launches, memory copies, NCCL collectives) that execute sequentially within the stream but can execute concurrently with operations in other streams. The application maintains two distinct streams: a compute stream for matrix multiplications and element-wise kernels, and a communication stream for NCCL operations. The workflow proceeds by launching a compute kernel on the first stream and immediately triggering an asynchronous communication call on the second: - -```python -compute_stream = torch.cuda.Stream() -comm_stream = torch.cuda.Stream() - -with torch.cuda.stream(compute_stream): - output = torch.matmul(A, B) # Non-blocking compute - -with torch.cuda.stream(comm_stream): - dist.all_reduce(gradients, async_op=True) # Non-blocking comm - -torch.cuda.synchronize() # Wait for both to complete -``` - -The GPU hardware scheduler interleaves execution units from both streams, running the GEMM on the SMs while the NVLink engine handles the AllReduce data transfer. However, while streams provide logical concurrency, they contend for physical resources. The SMs must manage the data movement instructions for the communication kernel. On an H100 with 132 SMs, a heavy NCCL operation might occupy 4--8 SMs solely for protocol processing and memory copying, leading to **SM partitioning**: the available compute throughput is reduced by 3--6% during communication. If the compute kernel is dense enough to saturate 100% of the SMs, enabling overlap can paradoxically slow down execution due to this resource contention, a phenomenon known as interference. In practice, the 3--6% compute throughput reduction is far smaller than the communication time that would otherwise be exposed, making the trade-off overwhelmingly favorable. - -In practice, achieving effective overlap requires attention to several details. The communication operation must be launched *before* the compute kernel it should overlap with, not after, because NCCL operations have their own launch overhead. Synchronization points (where one stream waits for another) must be minimized, as each synchronization serializes execution. - -PyTorch's Distributed Data Parallel (DDP) module implements gradient overlap by registering backward hooks on each parameter. When a parameter's gradient is computed during the backward pass, the hook triggers an asynchronous AllReduce on a separate NCCL stream. The optimizer step includes an implicit synchronization point that waits for all AllReduce operations to complete. This design overlaps gradient communication with gradient computation automatically, without requiring user intervention. - -The techniques covered so far -- fusion, precision, compilation, speculative decoding, MoE, and communication overlap -- address different aspects of the performance equation. Identifying which technique to apply in a given situation requires systematic measurement, and the profiling tools that make this diagnosis possible are the final piece of the optimization toolkit. - -## System Profiling {#sec-performance-engineering-profiling} - -An engineer spends two weeks rewriting a PyTorch module into a custom CUDA kernel to make it 5x faster, only to discover the overall model latency did not budge because the system was entirely I/O bound. Performance engineering without measurement is expensive guesswork. System profiling provides the surgical diagnostics required to identify exactly where the GPU is waiting, allowing us to apply our optimizations with precision. - -### The Profiling Hierarchy {#sec-performance-engineering-profiling-hierarchy} - -ML system profiling operates at four levels, each providing different granularity and answering different questions. - -Operation-level profiling measures the execution time and resource utilization of individual GPU kernels. Tools like NVIDIA Nsight Compute provide detailed metrics for a single kernel: achieved memory bandwidth, compute utilization, occupancy (fraction of available threads active), and instruction mix. This level answers questions like: "Is this GEMM achieving peak throughput?" and "Is this kernel memory-bound or compute-bound?" - -Trace-level profiling captures the timeline of all GPU kernels, CPU operations, and data transfers across the full execution of a training step or inference request. NVIDIA Nsight Systems and the PyTorch Profiler provide trace views showing exactly when each kernel executes, when data transfers occur, and where GPU idle gaps (bubbles) exist. This level reveals the overall structure of execution: sequential bottlenecks, launch gaps between kernels, and opportunities for overlap. - -Distributed profiling extends the trace across multiple GPUs and nodes, showing communication patterns alongside computation. This reveals whether communication is overlapped with compute, which collective operations dominate step time, and whether load imbalance exists across GPUs. - -Application-level profiling measures end-to-end metrics: tokens per second, time-to-first-token, P99 latency, and GPU utilization over time. This connects hardware-level observations to user-facing performance metrics and business KPIs. - -Diagnosing performance requires a drill-down approach, moving from global symptoms to local causes. When profiling a 70B model serving pipeline, the application level might reveal "end-to-end latency is 150 ms/token, 3$\times$ above the SLO." Descending to the distributed level, traces might show one GPU consistently lagging in AllReduce operations, pointing to a straggler or network congestion. Zooming into the trace level on that specific GPU reveals the timeline of kernel execution, exposing gaps where the SMs are idle due to scheduling overhead. Finally, kernel-level profiling (using Nsight Compute) inspects the specific instruction mix of a single matrix multiplication, revealing cache misses or register pressure. Skipping levels often leads to optimizing the wrong bottleneck: optimizing a kernel is futile if the GPU is spending 40% of its time waiting on the network. - -### Key Performance Metrics {#sec-performance-engineering-metrics} - -ML system performance is characterized by a family of metrics, each capturing a different aspect of system behavior. Understanding their relationships and trade-offs is essential for performance engineering. - -Model FLOPs Utilization (MFU) measures what fraction of the hardware's theoretical peak is being productively used by the model's computation. MFU captures the combined effect of all inefficiencies: memory bandwidth limitations, kernel launch overhead, communication wait time, and software stack overhead. For large-scale LLM training, MFU of 40--60% is typical; values above 60% indicate excellent optimization. - -Hardware FLOPs Utilization (HFU) is the simpler metric: total arithmetic operations executed (including overhead like activation recomputation) divided by peak FLOPS. HFU is always higher than MFU because it counts recomputed operations that do not directly advance the model's computation. The gap between HFU and MFU quantifies the "wasted" compute from techniques like activation checkpointing. - -Time-to-first-token (TTFT) measures the latency from when a request arrives to when the first output token is generated. TTFT includes queue waiting time, prefill computation, and any initialization overhead. For interactive applications, TTFT below 500 ms is generally acceptable; below 200 ms feels responsive. - -Inter-token latency (ITL) measures the time between consecutive output tokens during the decode phase. ITL directly determines the perceived generation speed. For a comfortable reading experience, ITL should be below 50 ms (20 tokens/second); for real-time speech synthesis, ITL below 25 ms may be required. - -Throughput (tokens/second for the system, or tokens/second/GPU for per-device efficiency) measures aggregate production capacity. Throughput is maximized by large batch sizes and high utilization, often at the expense of increased per-request latency. The fundamental trade-off between throughput and latency is captured by queuing theory, as explored in @sec-inference-scale. - -While MFU is the standard metric for training efficiency, it is often misleading for inference. Inference is predominantly memory-bound, meaning the Tensor Cores spend most cycles waiting for data. For these workloads, **Model Bandwidth Utilization (MBU)** is the more informative metric: the ratio of achieved memory bandwidth to the hardware's peak bandwidth. A decoding step on an H100 might show a dismal 2% MFU (using a fraction of the compute capability) but achieve 85% MBU (saturating the HBM3 memory bandwidth). In this context, 85% MBU indicates a highly optimized system; no amount of code optimization can squeeze more tokens per second without faster memory hardware. The choice of primary metric, MFU for training and MBU for inference, reflects the fundamental bottleneck shift between these two regimes. - -### Using the Roofline for Diagnosis {#sec-performance-engineering-roofline-diagnosis} - -The roofline model from @sec-performance-engineering-roofline becomes a diagnostic tool when combined with profiling data. The process is: - -1. **Measure** the achieved FLOPS and memory bandwidth for a kernel using Nsight Compute. -2. **Compute** the operational arithmetic intensity from the algorithm (FLOPs / bytes transferred). -3. **Plot** the measured performance against the roofline ceiling. - -A kernel that falls far below both the compute ceiling and the bandwidth ceiling has an implementation problem: launch overhead, poor memory access patterns, or low occupancy. A kernel that reaches the bandwidth ceiling but falls below the compute ceiling is memory-bound, and further optimization requires reducing memory traffic (fusion, precision) rather than improving compute efficiency. A kernel at the compute ceiling is compute-bound and can only be improved by algorithmic changes (reducing FLOPs) or faster hardware. - -Consider a concrete example: a LayerNorm kernel profiled on an H100 reports 15 TFLOPS of achieved compute and 2.8 TB/s of achieved memory bandwidth. Its arithmetic intensity is $15 \text{ TFLOPS} / 2.8 \text{ TB/s} \approx 5.4$ FLOP/byte. The H100's ridge point is approximately 295 FLOP/byte at FP16. Since 5.4 is far below 295, the kernel is strictly memory-bound. Its achieved 2.8 TB/s (84% of the H100's 3.35 TB/s peak bandwidth) confirms it is operating near the physical limit of the memory subsystem. The diagnosis is clear: further FLOP-level optimizations will yield zero gain; performance can only be improved by reducing data movement, either through fusion (eliminating the HBM round-trip) or precision reduction (halving the bytes per element). - -### PyTorch Profiler Workflow {#sec-performance-engineering-pytorch-profiler} - -The PyTorch Profiler integrates with the training loop to capture detailed traces with minimal code modification: - -```python -import torch -from torch.profiler import profile, schedule, tensorboard_trace_handler - -# Profile 2 warmup steps + 3 active steps -with profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - schedule=schedule(wait=1, warmup=2, active=3, repeat=1), - on_trace_ready=tensorboard_trace_handler("./profiler_logs"), - record_shapes=True, - profile_memory=True, - with_stack=True, -) as prof: - for step, batch in enumerate(dataloader): - if step >= 6: # 1 wait + 2 warmup + 3 active - break - output = model(batch) - loss = criterion(output, labels) - loss.backward() - optimizer.step() - optimizer.zero_grad() - prof.step() -``` - -The `schedule` parameter is important: it defines a warmup period where the profiler runs but does not record, allowing CUDA caches and JIT compilation to stabilize before measurement begins. Without warmup, the first few iterations include one-time costs (kernel compilation, memory allocation, CUDA context initialization) that inflate the measured times and misrepresent steady-state performance. - -The resulting trace, viewable in TensorBoard or Chrome's trace viewer, shows: - -- **Kernel timeline**: Which CUDA kernels execute and for how long. Look for gaps (GPU idle time) and unexpectedly long kernels. -- **Memory timeline**: GPU memory allocation and deallocation over time. Spikes indicate inefficient memory management; gradual growth suggests memory leaks. -- **CPU-GPU synchronization**: Points where the CPU waits for the GPU (or vice versa). Excessive synchronization often indicates that the CPU is not launching kernels fast enough to keep the GPU busy. -- **Communication events**: NCCL collective operations and their duration relative to compute kernels. Short compute between long communications indicates poor overlap. - -### Nsight Systems: Reading the Timeline {#sec-performance-engineering-nsight-systems} - -NVIDIA Nsight Systems is the primary tool for trace-level profiling. Understanding how to read its timeline view is an essential skill for performance engineers. The tool captures every GPU kernel launch, CUDA memory operation, NCCL communication, and CPU thread activity onto a unified timeline. - -A typical Nsight Systems workflow begins with capturing a trace of a few training or inference iterations: - -```bash -nsys profile --trace=cuda,nvtx,osrt,cudnn,cublas \ - --output=llm_profile \ - --force-overwrite=true \ - python3 inference_server.py --num-steps=5 -``` - -The `--trace` flags control which activities are recorded. The `cuda` flag captures kernel launches and memory operations. The `nvtx` flag captures user-annotated regions (PyTorch automatically annotates module boundaries with NVTX markers). The `cublas` and `cudnn` flags capture library-level operations, which helps identify whether a GEMM is using cuBLAS or a custom kernel. - -The resulting `.nsys-rep` file is opened in the Nsight Systems GUI, which presents a multi-row timeline. The most important rows are: - -The CUDA HW row shows the actual kernel execution on the GPU. Each colored bar represents a kernel, with width proportional to execution time. Gaps between bars indicate GPU idle time, which represents wasted potential. For a well-optimized inference pipeline, the CUDA HW row should show nearly continuous kernel execution with minimal gaps. - -The CUDA API row shows CPU-side CUDA API calls (kernel launches, memory allocations, synchronization). If CUDA API bars are significantly wider than the corresponding CUDA HW bars, the CPU is the bottleneck: it cannot launch kernels fast enough to keep the GPU busy. This is the kernel launch overhead problem that CUDA Graphs and `torch.compile` address. - -The NCCL row (for distributed workloads) shows collective communication operations. Comparing the NCCL row with the CUDA HW row reveals whether communication overlaps with computation. If NCCL bars appear during gaps in the CUDA HW row, communication is serialized. If NCCL bars overlap with CUDA HW bars, overlap is working correctly. - -The NVTX row shows user-annotated regions, which PyTorch maps to module names (`Linear`, `LayerNorm`, `Attention`). This connects low-level kernel names (often cryptic strings like `volta_fp16_s1688gemm_fp16_256x128_ldg8_f2f_nn`) to the model-level operations that produced them. - -The key patterns to look for in a Nsight Systems trace are: - -1. The ratio of kernel execution to idle time, which measures GPU utilization. -2. The distribution of kernel durations, where many short kernels suggest fusion opportunities. -3. The alignment of NCCL and CUDA HW rows, which reveals overlap effectiveness. -4. Memory allocation spikes, which may indicate inefficient memory management or unnecessary tensor materializations. -5. The proportion of time in GEMM versus non-GEMM kernels, which indicates how much of the execution is doing "useful" matrix arithmetic versus overhead. - -Experienced performance engineers develop pattern recognition for these traces, quickly identifying the dominant bottleneck from the visual structure of the timeline. A trace dominated by thin, closely packed kernel bars with minimal gaps indicates a well-optimized pipeline. A trace with large gaps between kernels, or with NCCL bars that do not overlap with CUDA HW bars, immediately reveals the primary optimization target. - -### Common Bottleneck Patterns {#sec-performance-engineering-bottleneck-patterns} - -Profiling reveals several recurring patterns in ML workloads that map directly to the optimization techniques in this chapter. - -Small kernel gaps manifest as many short GPU idle periods between small kernels, indicating that kernel launch overhead dominates. Graph compilation (`torch.compile`) fuses these small kernels to reclaim the lost cycles. - -Large intermediate tensors appear as memory timeline spikes corresponding to attention score matrices or other intermediate results. Unfused operations are materializing intermediates in HBM, and operator fusion (FlashAttention for attention, custom Triton kernels for other sequences) eliminates the unnecessary round-trips. - -Low FLOPS utilization combined with high memory bandwidth means the GPU achieves near-peak bandwidth but single-digit FLOPS utilization. Memory-bound operations dominate execution time, and the remedy is precision reduction (FP16 to FP8, INT4 KV cache) or algorithmic changes (batching, speculative decoding). - -Communication bubbles appear as GPU idle periods coinciding with NCCL AllReduce or AllToAll operations. Communication is not overlapped with compute, and the fix is to enable gradient overlap in DDP or restructure the pipeline schedule for zero-bubble execution. - -Load imbalance causes some GPUs to finish earlier and wait for stragglers. Uneven work distribution is the root cause. For MoE, adjusting auxiliary loss and capacity factors rebalances the load; for tensor parallelism, ensuring even dimension splitting eliminates the skew. - -Memory allocation thrashing shows repeated allocations and deallocations of large tensors, with peak memory far exceeding the steady-state usage. The framework is creating new tensors for each forward pass instead of reusing buffers. Enabling torch.compile's memory planning, using pre-allocated buffer pools, or applying activation checkpointing addresses this waste. - -CPU-bound data loading causes GPU utilization to drop periodically to zero while the CPU prepares the next batch. Data preprocessing or tokenization is not pipelined with GPU computation. Increasing the number of data loading workers, moving tokenization to a dedicated preprocessing service, or pre-tokenizing the dataset resolves this bottleneck. This pattern is more common in training than in inference but can occur in inference when input preprocessing (image resizing, text tokenization) is expensive. - -::: {.callout-notebook title="The Profiler Detective"} - -**Problem**: You run a 7B parameter LLM on a single H100 and observe 45 tokens/second during autoregressive generation at batch size 1. The theoretical bandwidth-limited rate is approximately 96 tokens/second. Where is the remaining 53% of performance hiding? - -**Investigation**: - -*Step 1: Kernel-level analysis.* Run Nsight Compute on the dominant GEMM kernel. Result: achieves 2.8 TB/s effective bandwidth out of the H100's 3.35 TB/s peak. Efficiency: 84%. This kernel is performing well. - -*Step 2: Trace-level analysis.* Run Nsight Systems on a full decode step. Result: 42% of step time is spent in GEMM kernels. The remaining 58% is split between: - -- Attention kernels (including KV cache reads): 28% -- Layer normalization and activation kernels: 12% -- Softmax and top-k sampling: 8% -- Kernel launch gaps: 10% - -*Step 3: Identify optimization targets.* - -- The KV cache attention kernel achieves only 1.9 TB/s bandwidth because of irregular memory access patterns. *Fix*: Implement KV cache quantization (INT8) with better memory layout. -- Kernel launch gaps (10% of time) come from 120+ individual kernel launches per layer. *Fix*: Apply `torch.compile` to fuse element-wise operations, reducing to ~30 kernels per layer. -- Layer normalization and activation kernels are unfused. *Fix*: Fused LayerNorm-GELU kernel via Triton. - -**Takeaway**: The dominant GEMM is near-optimal, but secondary operations and launch overhead consume over half the execution time. Systematic profiling reveals that the bottleneck is not where intuition suggests (the largest kernel) but in the accumulated overhead of many small operations. - -::: - -### The Profiling Feedback Loop {#sec-performance-engineering-profiling-loop} - -Effective performance engineering follows an iterative cycle: **profile, diagnose, optimize, verify**. The verification step is critical and often skipped. After applying an optimization, reprofile to confirm that the targeted bottleneck was addressed and that no new bottleneck emerged. Performance optimization is a waterbed problem: fixing one bottleneck often exposes the next. - -A common trap is optimizing based on microbenchmarks rather than end-to-end traces. A kernel that appears 2$\times$ faster in isolation may deliver only 5% improvement in end-to-end throughput if it was not the bottleneck, or if the surrounding code cannot take advantage of its speedup due to data dependencies. Always measure impact at the application level (tokens/second, step time, P99 latency) in addition to kernel-level metrics. - -Another subtlety is that profiling itself can perturb the system being measured. The PyTorch profiler adds approximately 10--20% overhead when recording full traces with memory profiling enabled. Nsight Systems adds less overhead but still affects scheduling. Profile warmup steps before active measurement, and discount the first few profiled iterations where JIT compilation or CUDA context initialization may dominate. - -### Profiling at Scale {#sec-performance-engineering-profiling-scale} - -Profiling a single GPU is straightforward; profiling a distributed system with hundreds or thousands of GPUs introduces unique challenges. The volume of trace data grows linearly with the number of GPUs: a 5-second Nsight Systems trace for one GPU is approximately 500 MB; the same trace for 1,000 GPUs would be 500 GB, impractical to store or analyze. - -Production systems address this through **hierarchical profiling**. At the top level, application-level metrics (MFU, throughput, step time) are collected continuously from every GPU with negligible overhead. These aggregate metrics detect when performance degrades. When a degradation is detected, **targeted profiling** is triggered on a representative subset of GPUs (typically one GPU per pipeline stage, per data-parallel group) for a short window (a few training steps). The resulting traces are analyzed to identify the specific bottleneck. - -Another approach is **statistical profiling**, where each GPU randomly samples a small fraction of its kernels for detailed timing. Over many training steps, the aggregated samples provide a statistically accurate picture of the kernel time distribution without the overhead of full tracing. This is analogous to the sampling profilers (like Linux `perf`) used in traditional systems engineering, adapted for the GPU context. - -The most challenging profiling scenario is intermittent stragglers: GPUs that are occasionally slow due to thermal throttling, memory errors, or network congestion, but fast most of the time. These stragglers may not appear in a short profiling window but can reduce training throughput by 10--20% over hours. Detecting them requires continuous per-GPU step-time monitoring with statistical anomaly detection, a form of profiling infrastructure that operates at the monitoring layer rather than the kernel layer. - -The profiling tools and techniques described in this section provide the measurement foundation for all optimization work. Without measurement, performance engineering degenerates into guesswork. With measurement, it becomes a systematic discipline guided by quantitative evidence. - -## Measurement at Scale {#sec-performance-engineering-measurement-scale} - -Optimizing a single node is a prerequisite, but the ultimate test of performance engineering is efficiency at fleet scale. When we move from 8 GPUs to 1,024 GPUs, new sources of overhead emerge that are invisible in local traces. Measurement at scale requires shifting from kernel-level micro-benchmarks to global efficiency metrics that capture the interaction of computation, communication, and hardware variability. - -### The Fleet Efficiency Metric - -While Hardware Utilization reports how often GPUs are busy, it fails to distinguish between useful work and wasted cycles (such as activation recomputation or communication bubbles). The gold standard for fleet measurement is **Model FLOPs Utilization (MFU)**. By focusing on the "useful" FLOPs required by the model architecture, MFU provides an invariant measure of system efficiency that remains comparable across different software stacks and parallelization strategies. - -::: {.callout-definition title="Model FLOPs Utilization"} - -***Model FLOPs Utilization (MFU)***\index{Model FLOPs Utilization!definition} is the ratio of useful model FLOPs performed to the hardware's theoretical peak throughput ($R_{\text{peak}}$). - -1. **Significance (Quantitative):** It is the most precise measure of **System Efficiency ($\eta$)** because it excludes "waste" FLOPs from recomputation, padding, or speculative execution. It serves as the primary diagnostic for whether hardware investment is translating into model progress. -2. **Distinction (Durable):** Unlike **Hardware Utilization** (which reports how often the GPU is "busy"), MFU reports how much of that busyness contributes to **Model Convergence** or prediction. -3. **Common Pitfall:** A frequent misconception is that high GPU utilization implies high efficiency. In reality, a system can show 90% hardware utilization while having a low MFU if it is wasting compute on inefficient kernel implementations or excessive communication overhead ($L_{\text{lat}}$). - -::: - -```{python} -#| label: fleet-efficiency-calc -#| echo: false -# ┌───────────────────────────────────────────────────────────────────────────── -# │ FLEET EFFICIENCY ANALYSIS (LEGO) -# ├───────────────────────────────────────────────────────────────────────────── -# │ Context: @sec-performance-engineering-measurement-scale -# │ -# │ Goal: Calculate MFU for a 70B model distributed across a 128-GPU cluster. -# │ Show: How the "Scaling Tax" reduces MFU from a 65% local peak to 48% fleet avg. -# │ How: MFU = (6 * P * D) / (N * R_peak * Time). -# │ -# │ Imports: mlsys.constants (BILLION, TRILLION, H100_FLOPS_FP16_TENSOR) -# │ Exports: fleet_params_b_str, fleet_nodes_str, fleet_t_step_ms_str, -# │ fleet_local_mfu_str, fleet_global_mfu_str, fleet_scaling_tax_str -# └───────────────────────────────────────────────────────────────────────────── -from mlsys.constants import BILLION, TRILLION, H100_FLOPS_FP16_TENSOR, second, TFLOPs -from mlsys.formatting import fmt, check - -class FleetEfficiencyCalc: - """Calculates Fleet MFU and identifies the Scaling Tax.""" - - # ┌── 1. LOAD (Constants) ────────────────────────────────────────────── - p_params = 70 * BILLION - n_gpus = 128 - r_peak_per_gpu = H100_FLOPS_FP16_TENSOR.m_as(TFLOPs / second) * TRILLION - tokens_per_step = 2048 * 32 # seq_len * global_batch - - # Observed times - t_local_step_ms = 180.0 # 8-GPU node baseline - t_fleet_step_ms = 245.0 # 128-GPU cluster observed - - # ┌── 2. EXECUTE (The Compute) ──────────────────────────────────────── - # Useful FLOPs per step = 6 * P * Tokens - flops_per_step = 6 * p_params * tokens_per_step - - # MFU = useful_flops / (n_gpus * r_peak * time) - local_mfu = flops_per_step / (8 * r_peak_per_gpu * (t_local_step_ms / 1000)) - global_mfu = flops_per_step / (n_gpus * r_peak_per_gpu * (t_fleet_step_ms / 1000)) - - scaling_tax = (1 - (global_mfu / local_mfu)) * 100 - - # ┌── 3. GUARD (Invariants) ────────────────────────────────────────── - check(global_mfu < local_mfu, "Fleet MFU must be lower than local node MFU due to scaling overhead.") - - # ┌── 4. OUTPUT (Formatting) ───────────────────────────────────────────── - fleet_params_b_str = "70B" - fleet_nodes_str = f"{n_gpus}" - fleet_t_step_ms_str = f"{t_fleet_step_ms:.0f}" - fleet_local_mfu_str = f"{local_mfu*100:.1f}" - fleet_global_mfu_str = f"{global_mfu*100:.1f}" - fleet_scaling_tax_str = f"{scaling_tax:.0f}" - -# ┌── EXPORTS (Bridge to Text) ───────────────────────────────────────────────── -fleet_params_b_str = FleetEfficiencyCalc.fleet_params_b_str -fleet_nodes_str = FleetEfficiencyCalc.fleet_nodes_str -fleet_t_step_ms_str = FleetEfficiencyCalc.fleet_t_step_ms_str -fleet_local_mfu_str = FleetEfficiencyCalc.fleet_local_mfu_str -fleet_global_mfu_str = FleetEfficiencyCalc.fleet_global_mfu_str -fleet_scaling_tax_str = FleetEfficiencyCalc.fleet_scaling_tax_str -``` - -::: {.callout-notebook title="The Scaling Tax"} -**Scenario**: Training a **`{python} fleet_params_b_str`** parameter model across a cluster of **`{python} fleet_nodes_str` H100 GPUs**. - -* **Local Node Baseline**: A single 8-GPU node achieves **`{python} fleet_local_mfu_str`% MFU**. -* **Fleet Performance**: At 128 GPUs, the step time increases to **`{python} fleet_t_step_ms_str` ms**, dropping MFU to **`{python} fleet_global_mfu_str`%**. - -**The Diagnosis**: The **`{python} fleet_scaling_tax_str`% Scaling Tax** represents the cost of inter-node communication (InfiniBand latency) and synchronization barriers. In a healthy fleet, this tax should remain stable; a sudden increase in the scaling tax signals a **Scaling Regression**—typically caused by a misaligned parallelization strategy or a "gray failure" in the network fabric. -::: - -### Detecting Scaling Regressions - -At scale, the system is non-linear. A code change that introduces a minor memory overhead on a single GPU can trigger a catastrophic performance collapse at 1,000 GPUs due to increased garbage collection pauses or exhausted InfiniBand credit buffers. Detecting these **Scaling Regressions** requires a tiered testing strategy: - -1. **Small-Scale Canaries**: Running the model on 8 and 64 GPUs to establish a scaling efficiency curve. -2. **Fleet Baseline Comparison**: Continuously comparing every production run's MFU against the "Gold Standard" baseline for that model architecture. -3. **Gray Failure Detection**: Monitoring the distribution of step times across the fleet. A single "straggler" node that is 10% slower due to thermal throttling can reduce the entire cluster's MFU by 10% in a synchronous data-parallel workload. - -::: {.callout-note title="Benchmark vs. Reality: The Hero Run Tax"} -Industry benchmarks like **MLPerf** are often "Hero Runs"—highly tuned configurations where logging is disabled, safety checks are bypassed, and the hardware is freshly rebooted. - -In production, your achieved MFU will typically sit **10–20% lower** than these hero numbers. This "Reality Tax" is consumed by essential operational overhead: -* **Observability**: Metrics collection and logging. -* **Reliability**: Checkpointing and health heartbeats. -* **Entropy**: Thermal throttling, memory fragmentation, and multi-tenant network noise. - -When planning capacity, engineers must budget for the Reality Tax. If a benchmark says you can train in 30 days, your production plan should assume 35–40 days. -::: - -With the measurement hierarchy established—from node-level traces to fleet-wide MFU—we now turn to the tactical execution. The optimization playbook translates these global measurements into a surgical sequence of interventions. - -## The Optimization Playbook: A 70B LLM Case Study {#sec-performance-engineering-playbook} - -You are handed a raw, unoptimized 70-billion parameter PyTorch model and told it must serve 1,000 tokens per second in production by next week. Where do you start? You cannot simply throw every technique at the wall. The optimization playbook requires a systematic, prioritized attack: first unblocking the memory wall, then fusing operators, and finally applying algorithmic techniques like speculative decoding in a specific, compounding sequence. - -### The Diagnostic Sequence {#sec-performance-engineering-diagnostic-sequence} - -When approaching a new ML workload for optimization, follow this sequence: - -1. **Baseline Measurement**: Measure end-to-end throughput (tokens/second for LLMs, samples/second for training) and collect a Nsight Systems trace. Compute the **Model FLOPs Utilization (MFU)** as established in @sec-performance-engineering-measurement-scale. An MFU below 30% indicates substantial optimization opportunity; above 50% is good; above 60% is excellent for large models. - -2. **Roofline Classification**: For the dominant kernels, compute the arithmetic intensity and plot on the roofline. This immediately identifies whether the workload is compute-bound or memory-bound, which determines the entire optimization strategy. - -3. **Selecting the Primary Bottleneck**: The primary bottleneck falls into one of three categories, each with a different optimization path: - -For **memory-bound workloads** (the common case for inference): - -1. Apply precision reduction (FP16 $\rightarrow$ FP8 or INT4) to increase effective bandwidth. -2. Apply operator fusion (FlashAttention, fused LayerNorm/GELU) to eliminate intermediate HBM traffic. -3. Apply graph compilation (torch.compile) to catch remaining fusion opportunities. -4. Consider algorithmic changes (speculative decoding, MoE) for further improvement. - -For **compute-bound workloads** (large-batch training): - -1. Ensure Tensor Cores are in use (FP16/BF16/FP8, not FP32). -2. Apply graph compilation for kernel selection and memory planning. -3. Consider reduced precision (FP8) for 2$\times$ compute throughput. -4. Optimize communication overlap to avoid compute idle time. - -For **communication-bound workloads** (distributed training at scale): - -1. Enable gradient communication overlap with backward pass. -2. Apply gradient compression or reduced precision communication. -3. Restructure pipeline schedules for zero-bubble execution. -4. Consider topology-aware placement to minimize cross-node traffic. - -The fourth step is to apply and verify. Implement the highest-impact optimization, reprofile, and verify improvement. Then iterate from the roofline classification step with the new profile, as the bottleneck may have shifted. - -### Combining Techniques {#sec-performance-engineering-combining} - -The most performant production systems combine multiple techniques simultaneously. A highly optimized LLM serving system might employ all of the following: - -- **FlashAttention-2** or **FlashAttention-3** for memory-efficient attention computation -- **INT4 weight quantization** (GPTQ or AWQ) with FP16 dequantization during GEMM -- **INT8 KV cache compression** with per-channel scaling -- **torch.compile** or **TensorRT** for element-wise fusion and kernel selection -- **Speculative decoding** for latency reduction at low batch sizes -- **Continuous batching** with dynamic sequence grouping for high throughput -- **Tensor parallelism** across GPUs with overlapped AllReduce - -These techniques are not additive in their speedup; they interact in complex ways. For instance, INT4 weight quantization reduces per-token HBM traffic by 4$\times$, which might shift the bottleneck from memory-bound to compute-bound. Once compute-bound, further bandwidth optimizations (KV cache compression) yield diminishing returns, and compute optimizations (FP8 Tensor Cores) become the priority. This is why the iterative profile-optimize-verify loop is essential: the optimal combination depends on the specific model, hardware, and workload characteristics. - -The interaction between optimizations creates a dependency graph that the performance engineer must navigate. Some combinations are synergistic: FlashAttention reduces attention memory traffic, and INT8 KV cache compression reduces KV cache memory, together freeing enough memory for larger batch sizes that transform the economics of serving. Other combinations are redundant: applying both CUDA Graphs and the `reduce-overhead` mode of `torch.compile` achieves the same result, since `reduce-overhead` internally uses CUDA Graphs. Still other combinations conflict: speculative decoding benefits most at small batch sizes (where decode is memory-bound), while many throughput optimizations work by increasing batch size. At large batch sizes, speculation adds overhead without proportional benefit. - -A practical heuristic for sequencing optimizations is to apply them in order of decreasing impact and increasing effort: - -1. **torch.compile** (1 line of code, 10--40% speedup): Always apply first. No downsides. -2. **Weight quantization** (INT4/FP8, hours of calibration, 2--4$\times$ throughput): Apply second. Frees memory for batching. -3. **FlashAttention** (library swap, 2--32$\times$ for attention): Apply third. Usually a configuration flag. -4. **KV cache compression** (INT8, library support, 2$\times$ cache reduction): Apply fourth. Enables larger batches. -5. **Speculative decoding** (requires draft model, engineering effort): Apply last, only if latency target not met. Most complex to deploy. - -This ordering reflects the principle that passive optimizations (compiler, library swaps) should precede active ones (algorithmic changes, new model components). Each step is validated by reprofiling before proceeding to the next. - -::: {.callout-checkpoint title="Optimization Strategy" collapse="false"} - -Test your ability to design an optimization plan: - -- [ ] Given an LLM serving workload at batch size 1 that achieves 25% of peak bandwidth, can you identify the three most impactful optimizations and their expected interaction? -- [ ] Can you explain why applying FP8 quantization to a workload that is already communication-bound provides no speedup? -- [ ] Can you describe the iterative profiling workflow and explain why verifying each optimization is as important as applying it? -- [ ] Can you identify a scenario where applying `torch.compile` would *degrade* performance (hint: consider compilation overhead and graph breaks)? - -::: - -### Case Study: Optimizing a 70B LLM Serving Pipeline {#sec-performance-engineering-case-study} - -To illustrate how the diagnostic sequence and combining principles work in practice, consider the task of optimizing a 70B parameter LLM for production serving. The target is a real-time chatbot application requiring time-to-first-token (TTFT) under 500 ms, inter-token latency (ITL) under 50 ms, and throughput of at least 1,000 tokens/second across the cluster. The model is deployed on a node of 8 H100 GPUs connected by NVLink. - -#### Baseline Measurement {#sec-performance-engineering-case-baseline} - -The initial deployment uses FP16 weights, standard PyTorch eager execution, and tensor parallelism across 8 GPUs. The 70B model in FP16 requires 140 GB of weight storage, distributed as approximately 17.5 GB per GPU. The baseline performance metrics are: - -- **TTFT**: 1,200 ms (well above the 500 ms target) -- **ITL**: 85 ms (above the 50 ms target) -- **Throughput**: 280 tokens/second (below the 1,000 token/second target) -- **Maximum batch size**: 4 (limited by KV cache memory) - -A Nsight Systems trace reveals the following time breakdown for a single decode step at batch size 1: - -- GEMM kernels: 38% of step time -- Attention (including KV cache reads): 24% of step time -- Element-wise operations (LayerNorm, GELU, residual): 14% of step time -- AllReduce communication (tensor parallelism): 12% of step time -- Kernel launch gaps and overhead: 12% of step time - -The roofline analysis confirms that decode is deeply memory-bound, with arithmetic intensity approximately 1 FLOP/byte at batch size 1. The GPU achieves 2.7 TB/s effective bandwidth (80% of peak), indicating reasonable kernel-level efficiency but a fundamental algorithmic limitation. - -#### Optimization Round 1: Precision Engineering {#sec-performance-engineering-case-round1} - -The first optimization targets the largest opportunity: reducing the bytes per weight read from HBM. Applying AWQ INT4 weight quantization reduces the weight storage from 140 GB to 35 GB (8.75 GB per GPU from 17.5 GB per GPU). The effective bandwidth for weight reads doubles because the same physical bandwidth now delivers twice as many weight values per second (each value is 4 bits instead of 16 bits, with on-the-fly dequantization to FP16 for the GEMM). - -Simultaneously, applying INT8 quantization to the KV cache reduces per-request cache size by 2$\times$. The combined effect on memory budget is dramatic: each GPU now has approximately 71 GB available for KV cache, up from 62 GB. At the reduced per-request KV cache size, the maximum batch size increases from 4 to approximately 32. - -Post-optimization metrics at batch size 1: - -- **ITL**: 48 ms (meets the 50 ms target) -- **Throughput at batch size 32**: 720 tokens/second (still below target) - -The Nsight Systems trace shows that GEMM time decreased by approximately 45% due to reduced weight reads, but attention and element-wise operations remain unchanged. The bottleneck has partially shifted. - -#### Optimization Round 2: Operator Fusion {#sec-performance-engineering-case-round2} - -The second round targets the 14% of step time consumed by element-wise operations and the 24% consumed by attention. Applying `torch.compile` with the `max-autotune` backend fuses element-wise operations (GELU, LayerNorm, residual additions), reducing their contribution from 14% to approximately 4% of step time. Simultaneously, enabling FlashAttention-2 replaces the standard attention implementation, reducing attention HBM traffic by approximately 16$\times$ for the prefill phase. - -For the decode phase, FlashAttention's impact is more modest because decode attention is dominated by KV cache reads rather than the $N \times N$ score matrix. However, the combination of INT8 KV cache compression and FlashAttention's efficient paged attention kernel reduces attention decode time by approximately 30%. - -The `reduce-overhead` mode in `torch.compile` wraps the decode step in a CUDA Graph, eliminating the 12% kernel launch overhead almost entirely. - -Post-optimization metrics: - -- **TTFT**: 380 ms (meets the 500 ms target) -- **ITL**: 32 ms at batch size 1 (well below the 50 ms target) -- **Throughput at batch size 32**: 1,050 tokens/second (meets the target) - -#### Optimization Round 3: Speculative Decoding {#sec-performance-engineering-case-round3} - -With the throughput target met, the team focuses on further reducing ITL for the best user experience. Speculative decoding with a 1.5B draft model (AWQ INT4 quantized to 0.75 GB) is deployed on the same GPUs. The draft model generates 5 candidate tokens in 4 ms (benefiting from the INT4 quantization applied in Round 1). Verification takes approximately 8 ms. - -At an average acceptance rate of 0.78, the expected tokens per round is approximately 3.5. The effective ITL becomes: - -$$ -\text{ITL}_{\text{effective}} = \frac{4 + 8}{3.5} \approx 3.4 \text{ ms per token} -$$ - -This is a 9.4$\times$ improvement over the Round 2 ITL of 32 ms, providing a remarkably responsive user experience. - -However, speculative decoding interacts with batching. At batch size 32, the verification step is no longer "free" because the GPU is closer to compute saturation. The system therefore applies speculation only when the current batch size is below 16, falling back to standard autoregressive decoding at higher loads. This adaptive policy maintains both the latency benefit at low load and the throughput benefit at high load. - -#### Lessons from the Case Study {#sec-performance-engineering-case-lessons} - -This optimization journey illustrates several principles: - -The order of optimizations matters. Precision engineering (Round 1) was applied first because it yields the largest single improvement and enables subsequent optimizations by freeing memory for larger batch sizes. Fusion (Round 2) addressed the new bottleneck exposed by precision engineering. Speculative decoding (Round 3) provided latency improvement once the throughput target was met. - -Each optimization changed the bottleneck. Before Round 1, the system was purely memory-bandwidth-bound. After INT4 quantization and batching, the system was partially compute-bound at large batch sizes. After fusion, kernel launch overhead was negligible, making the remaining bottleneck the fundamental memory-bandwidth limit for decode. Each optimization was validated by reprofiling to confirm the bottleneck shift. - -The final system combines five distinct techniques: INT4 weight quantization, INT8 KV cache compression, FlashAttention-2, torch.compile with CUDA Graphs, and adaptive speculative decoding. These techniques are not independent; they interact. INT4 quantization enables larger batch sizes, which changes whether speculative decoding is profitable. FlashAttention's benefit depends on sequence length, which grows during generation. The performance engineer must reason about these interactions holistically, guided by profiling data at each stage. - -This case study demonstrates how these disparate optimizations compound sequentially to transform an unusable prototype into a production-grade deployment. However, the path to these massive speedups is fraught with conventional wisdom that often proves disastrous at scale, leading us to examine the common fallacies and pitfalls of performance engineering. - -## Fallacies and Pitfalls {#sec-performance-engineering-fallacies} - -A team upgrades their inference cluster from A100s to H100s, expecting a massive 3x latency reduction based on the new spec sheet's teraFLOPS rating, only to find their generative model barely runs 15% faster. This leads us to one of the most pervasive traps in performance engineering: assuming that raw compute capacity dictates inference speed when the workload is entirely bound by memory bandwidth. - -Fallacy: *More FLOPS means faster inference.* - -The roofline model demonstrates that most inference operations are memory-bound, not compute-bound. A GPU with 2$\times$ the peak FLOPS but the same memory bandwidth will not generate tokens any faster for batch-1 LLM decode. The correct metric for memory-bound workloads is bandwidth, not FLOPS. This fallacy leads organizations to purchase the most expensive compute hardware when a mid-range GPU with equivalent HBM bandwidth would deliver identical inference throughput. - -Fallacy: *FP8 always halves training time compared to FP16.* - -FP8 doubles the peak TFLOPS and doubles the effective memory bandwidth, but these gains are realized only for operations that are bottlenecked by compute or bandwidth at FP16. Element-wise operations like activation functions are already limited by kernel launch overhead, not by precision. Communication-bound distributed training steps gain nothing from reduced arithmetic precision if the communication volume (gradient sizes) is not also reduced. The actual speedup depends on the fraction of execution time spent in precision-sensitive operations. - -Pitfall: *Optimizing the largest kernel while ignoring the long tail.* - -The profiling case study above illustrates this pitfall. Engineers naturally focus on the single largest kernel, which is often the GEMM in a transformer layer. When the GEMM is already near-optimal, however, the remaining performance budget is distributed across dozens of smaller operations: normalization, activation, attention scoring, KV cache management, and kernel launch overhead. Collectively, these "small" operations can consume more than half of total execution time. Graph compilation and systematic fusion address this long tail more effectively than further GEMM optimization. - -Pitfall: *Applying speculative decoding without considering batch dynamics.* - -Speculative decoding excels at batch size 1, where decode is deeply memory-bound and the verification step is essentially "free" (the GPU has ample spare compute). At large batch sizes, decode approaches the compute-bound regime, and the verification step adds meaningful compute cost. Furthermore, the variable number of accepted tokens per request complicates continuous batching schedulers. In high-throughput serving scenarios with large batches, the overhead of speculation may outweigh its latency benefits. - -Pitfall: *Treating MoE expert count as a free scaling knob.* - -Increasing the number of experts in an MoE model increases total parameters (capacity) without proportionally increasing per-token compute, which seems like a free lunch. Each additional expert, however, increases: (1) total memory requirements, requiring more GPUs; (2) AllToAll communication volume for expert routing; (3) load balancing difficulty, since the router must distribute tokens across more experts; and (4) training instability, as more experts compete for activation. Beyond approximately 64--256 experts, the system-level costs often outweigh the capacity benefits. - -Fallacy: *Graph compilers eliminate the need for manual kernel engineering.* - -Graph compilers have improved dramatically, but they remain limited by their cost models and fusion heuristics. FlashAttention required human insight to recognize that attention could be reformulated as a tiled algorithm with online softmax, an algorithmic insight beyond the scope of any current compiler's rewrite rules. Similarly, speculative decoding and MoE routing require algorithmic innovation that compilers cannot discover. Compilers automate known optimizations; human engineers discover new ones. - -Pitfall: *Quantizing everything to the lowest supported precision.* - -Aggressive quantization (INT4 weights, INT4 KV cache, FP8 activations) can degrade model quality in ways that are difficult to detect with standard benchmarks but visible to users. Perplexity on a held-out dataset may change by less than 1%, but the model may produce subtly worse responses for edge cases, rare languages, or complex reasoning tasks. The correct approach is targeted quantization: apply the most aggressive precision to the least sensitive components (KV cache, intermediate activations) and preserve higher precision for the most sensitive (first and last layers, attention logits). Calibration on a representative dataset, followed by evaluation on diverse quality benchmarks, is essential before deploying any quantized model to production. - -Pitfall: *Measuring throughput without measuring quality.* - -A model serving system that generates 200 tokens/second is not twice as good as one generating 100 tokens/second if the first system achieves that throughput by using INT4 quantization that degrades answer quality by 15%. Performance metrics must always be reported alongside quality metrics. The correct optimization target is the Pareto frontier of throughput versus quality, not throughput alone. - -Fallacy: *A single profiling run is sufficient to characterize performance.* - -ML system performance is non-stationary. GPU thermal throttling reduces clock speeds (and therefore FLOPS) after sustained workloads, sometimes by 10--15%. Memory fragmentation accumulates over hours of serving, gradually reducing effective batch size. Network congestion varies with cluster-wide traffic patterns. A profiling run during a cold start may show different bottleneck patterns than one after hours of production serving. Reliable performance characterization requires profiling under realistic, sustained conditions, ideally sampling multiple times across a production run. - -Pitfall: *Optimizing for average case while ignoring tail latency.* - -A serving system may achieve excellent average inter-token latency (30 ms) while exhibiting P99 latency of 500 ms due to garbage collection pauses in the Python runtime, CUDA memory allocation stalls, or occasional AllReduce delays from network congestion. For interactive applications, the user experience is dominated by the worst case, not the average. Performance engineering for production systems must profile and optimize tail latency specifically, often through techniques orthogonal to the throughput optimizations in this chapter: pre-allocated memory pools, CUDA graph replay (which eliminates allocation variance), and priority scheduling for latency-sensitive requests. - -Recognizing these pitfalls—from ignoring tail latency to misjudging the impact of raw FLOPS—saves teams from wasting months optimizing the wrong layer of the stack. We conclude this chapter by summarizing the core principles that guide a successful performance engineering lifecycle. - -## Summary {#sec-performance-engineering-summary} - -Performance engineering transforms a model that should be efficient into one that is. The techniques in this chapter address the fundamental bottleneck of modern ML systems: the memory wall. @fig-optimization-hierarchy summarizes how these techniques layer from hardware-level primitives to algorithmic innovations. - -::: {.callout-note title="Figure: The Optimization Hierarchy" collapse="false"} - -```{.tikz} -%| fig-cap: "**The Performance Engineering Hierarchy**. Optimization techniques organized by their level of abstraction, from hardware-level precision engineering at the base to algorithmic innovations at the top. Each layer builds on and benefits from the layers below it. The annotations show the primary mechanism and typical speedup range for each technique." -%| fig-alt: "Layered hierarchy diagram showing five optimization levels from bottom to top: Precision Engineering, Operator Fusion, Graph Compilation, Communication Overlap, and Algorithmic Innovation, with arrows showing how they interact." -%| label: fig-optimization-hierarchy - -\begin{tikzpicture}[ - layer/.style={draw, thick, rounded corners=3pt, minimum width=11cm, minimum height=1.2cm, font=\small}, - label/.style={font=\scriptsize, text width=4cm, align=left}, - >=stealth -] - % Layers from bottom to top - \node[layer, fill=blue!15] (hw) at (0,0) {\textbf{Precision Engineering} (FP8, INT4, KV Cache Compression)}; - \node[layer, fill=green!15] (fuse) at (0,1.6) {\textbf{Operator Fusion \& Tiling} (FlashAttention, Fused Kernels)}; - \node[layer, fill=orange!15] (comp) at (0,3.2) {\textbf{Graph Compilation} (torch.compile, XLA, TensorRT)}; - \node[layer, fill=purple!15] (comm) at (0,4.8) {\textbf{Communication Overlap} (Gradient Pipelining, Zero-Bubble)}; - \node[layer, fill=red!15] (algo) at (0,6.4) {\textbf{Algorithmic Innovation} (Speculative Decoding, MoE)}; - - % Right-side annotations - \node[label, right] at (6.2,0) {Mechanism: Reduce bytes/value\\Speedup: 2--4$\times$}; - \node[label, right] at (6.2,1.6) {Mechanism: Reduce HBM trips\\Speedup: 2--32$\times$}; - \node[label, right] at (6.2,3.2) {Mechanism: Automate fusion\\Speedup: 1.1--2$\times$}; - \node[label, right] at (6.2,4.8) {Mechanism: Hide latency\\Speedup: 1.1--1.5$\times$}; - \node[label, right] at (6.2,6.4) {Mechanism: Change algorithm\\Speedup: 1.5--10$\times$}; - - % Arrows between layers - \draw[->, thick, gray] (hw.north) -- (fuse.south); - \draw[->, thick, gray] (fuse.north) -- (comp.south); - \draw[->, thick, gray] (comp.north) -- (comm.south); - \draw[->, thick, gray] (comm.north) -- (algo.south); - - % Left-side label - \node[rotate=90, font=\small, anchor=south] at (-6.5,3.2) {Increasing Abstraction $\longrightarrow$}; -\end{tikzpicture} -``` - -::: - -The **Roofline Model** provides the diagnostic framework, classifying operations as compute-bound or memory-bound based on their arithmetic intensity relative to the hardware's ridge point. For the NVIDIA H100, this ridge point is approximately `{python} h100_fp16_ridge_str` FLOP/byte at FP16, meaning most transformer operations fall in the memory-bound regime. - -Operator fusion eliminates redundant HBM round-trips by combining sequences of operations into single kernels. FlashAttention is the canonical example, reducing attention HBM traffic by `{python} savings_str`$\times$ for long sequences through tiling and online softmax. Ring Attention extends this principle across GPUs for extreme sequence lengths. - -Precision engineering reduces bytes per transaction. FP8 formats (E4M3 for weights, E5M2 for gradients) double effective bandwidth on H100 hardware. Block-wise quantization (LLM.int8(), GPTQ, AWQ) handles the outlier features that defeat uniform quantization. KV cache compression to INT4 or INT8 directly increases serving batch size. - -Graph compilation automates fusion, kernel selection, and memory planning. torch.compile/TorchInductor generates optimized Triton kernels from standard PyTorch code. XLA provides whole-program optimization for JAX/TPU workloads. TensorRT delivers aggressive inference-specific optimization. - -Speculative decoding trades compute for latency by using a fast draft model to generate candidate tokens verified in parallel by the target model, with mathematical guarantees on output quality. Speedups of 1.5--3$\times$ are typical for diverse workloads. - -Mixture of Experts decouples model capacity from per-token inference cost through sparse activation. DeepSeek-V3 demonstrates the frontier: 671B total parameters with only 37B active per token. - -Communication-computation overlap hides network latency by executing communication and computation concurrently on separate hardware engines. Zero-bubble pipeline schedules approach the theoretical minimum idle time. - -System profiling provides the measurement infrastructure to identify which bottleneck dominates and which optimization to apply, closing the loop between diagnosis and treatment. - -Together, these techniques compose a coherent optimization stack. Precision engineering and operator fusion attack the **Iron Law** at the hardware level, reducing the memory and compute terms. Graph compilation automates these hardware-level optimizations across the full model. Communication overlap attacks the distributed overhead term. Speculative decoding and MoE change the algorithm itself, fundamentally restructuring the terms of the equation. - -The case study in @sec-performance-engineering-case-study demonstrated how these techniques compound in practice: INT4 quantization freed memory for larger batches, which changed the arithmetic intensity, which determined whether further bandwidth or compute optimizations were profitable. Each optimization shifted the bottleneck, requiring reprofiling and a new optimization decision. This iterative, measurement-driven process is the discipline of performance engineering. - -The unifying principle is that performance engineering is not about making hardware faster; it is about making software match the physics of the hardware it runs on. The memory wall is a physical constraint that grows wider with each hardware generation. The techniques in this chapter are the engineer's response: not fighting the physics but working within it, keeping data close to compute, reducing precision to the minimum that preserves quality, and restructuring algorithms to avoid unnecessary work. As models grow larger and hardware grows faster but not more bandwidth-rich, these skills become not just valuable but essential. - -::: {.callout-takeaways title="Match the Software to the Silicon"} - -* The **Memory Wall** is the defining constraint of modern ML performance: most transformer operations are memory-bound, with arithmetic intensity far below the hardware ridge point. Performance engineering is primarily about reducing bytes moved, not operations computed. -* **Operator Fusion** and **Tiling** (FlashAttention) eliminate redundant HBM traffic by keeping intermediate results in on-chip SRAM, reducing attention memory traffic from quadratic to linear in sequence length. -* **Precision Engineering** doubles effective bandwidth by reducing numerical representation (FP8, INT4), but requires careful handling of outlier features (LLM.int8(), GPTQ, AWQ) and dynamic scaling for training stability. -* **Graph Compilation** (torch.compile, XLA, TensorRT) automates known optimizations across the entire model graph, but cannot discover algorithmic innovations like FlashAttention or speculative decoding. -* **Speculative Decoding** is the only technique that reduces *latency* for memory-bound decode without changing precision or batch size, by trading abundant compute for scarce bandwidth through parallel verification of draft tokens. -* **Mixture of Experts** enables parameter scaling without proportional compute scaling, but introduces AllToAll communication patterns and load balancing challenges that require careful system engineering. -* **Always profile before optimizing**: the dominant bottleneck is often not where intuition suggests, and applying the wrong optimization wastes engineering effort while leaving the actual bottleneck untouched. - -::: - -The central lesson of performance engineering is that it is an iterative, measurement-driven discipline, not a one-shot transformation. Every optimization shifts the bottleneck: fusing attention kernels may move the constraint from HBM bandwidth to compute throughput, and reducing precision may free enough memory to increase batch size, which in turn changes the arithmetic intensity and demands a different optimization entirely. The practitioner who masters this profile-optimize-reprofile loop can extract 2 to 10 times more throughput from the same hardware than one who applies optimizations blindly. Profiling is not merely the first step; it is every step. - -This iterative mindset also determines which skills endure as hardware evolves. Individual techniques will change as new accelerator generations shift the ridge point of the Roofline Model and as new architectures alter the dominant computational patterns. What will not change is the fundamental discipline: measure the system, identify the binding constraint, apply the optimization that addresses that specific constraint, and then measure again. Engineers who internalize this cycle treat performance engineering as a continuous practice rather than a checklist, and that practice is what separates systems that merely run from systems that run efficiently at scale. - -::: {.callout-chapter-connection title="From Optimization to Serving"} - -In this chapter, we established a rigorous toolkit for performance engineering. Through operator fusion, precision reduction, graph compilation, and architectural optimizations like speculative decoding and Mixture-of-Experts, we now understand how to extract maximum computational efficiency from a single model executing on bare metal. - -Yet, optimizing an isolated forward pass is merely the precursor to deployment. In production environments, models do not operate in zero-contention vacuums. They must process unpredictable bursts of concurrent requests, multiplexing across fleets of distributed accelerators while strictly adhering to unforgiving latency and throughput budgets. - -In **Inference at Scale** (@sec-inference-scale), we transition from the mechanics of local execution to the architecture of robust serving systems. The low-level optimizations developed here function as the foundational primitives for high-availability infrastructure, tackling systemic challenges like continuous batching, distributed KV-cache management, and dynamic request routing. - -::: - -```{python} -#| echo: false -#| label: chapter-end -from mlsys.registry import end_chapter -end_chapter("vol2:performance_engineering") -``` diff --git a/book/quarto/contents/vol2/sustainable_ai/sustainable_ai.qmd b/book/quarto/contents/vol2/sustainable_ai/sustainable_ai.qmd index 5055c55b0..04e9c10e4 100644 --- a/book/quarto/contents/vol2/sustainable_ai/sustainable_ai.qmd +++ b/book/quarto/contents/vol2/sustainable_ai/sustainable_ai.qmd @@ -1357,6 +1357,43 @@ The geographic choice alone produces a `{python} emissions_ratio_str`-fold diffe #### Embodied Carbon Assessment {#sec-sustainable-ai-embodied-carbon-assessment-9de0} +::: {#fig-carbon-sankey fig-env="figure" fig-pos="htb" fig-cap="**The Total Carbon of Ownership (TCO)**. Sankey-style flow visualizing how carbon emissions accumulate across the AI lifecycle. For a typical datacenter deployment, operational energy (training and serving) dominates total emissions. However, as the grid shifts to renewables, the **Embodied Carbon** from semiconductor fabrication and datacenter construction becomes the binding sustainability constraint, making hardware longevity a critical engineering lever." fig-alt="Sankey diagram showing three input flows: Raw Materials, Semiconductor Fab, and Grid Energy. These merge into Training and Serving phases, ending in AI Model Value. Widths show relative carbon impact."} +```{.tikz} +\begin{tikzpicture}[font=\small\usefont{T1}{phv}{m}{n}, line join=round] + \definecolor{EmbodiedColor}{RGB}{204,85,0} % OrangeLine + \definecolor{OpsColor}{RGB}{0,99,149} % BlueLine + \definecolor{ValueColor}{RGB}{0,143,69} % GreenLine + + % Input flows + \fill[EmbodiedColor!30] (0, 3) -- (3, 2.5) -- (3, 1.5) -- (0, 2) -- cycle; + \node[left, align=right] at (0, 2.5) {\textbf{Raw Materials}\\Extraction \& Log.}; + + \fill[EmbodiedColor!50] (0, 1.5) -- (3, 1.5) -- (3, 0.5) -- (0, 1) -- cycle; + \node[left, align=right] at (0, 1.25) {\textbf{Chip Fabrication}\\(Embodied Carbon)}; + + \fill[OpsColor!40] (0, 0) -- (3, 0.5) -- (3, -2.5) -- (0, -2) -- cycle; + \node[left, align=right] at (0, -1) {\textbf{Grid Energy}\\(Coal/Gas/Renew.)}; + + % Lifecycle phases + \fill[OpsColor!60] (3, 2.5) -- (6, 2.5) -- (6, 0.5) -- (3, 0.5) -- cycle; + \node at (4.5, 1.5) {\textbf{Model Training}}; + + \fill[OpsColor!80] (3, 0.5) -- (6, 0.5) -- (6, -2.5) -- (3, -2.5) -- cycle; + \node at (4.5, -1) {\textbf{Inference / Serving}}; + + % Output flow + \fill[ValueColor!40] (6, 2.5) -- (9, 1.5) -- (9, -0.5) -- (6, -2.5) -- cycle; + \node[right, align=left] at (9, 0.5) {\textbf{AI Fleet Intelligence}\\(System Value)}; + + % Labels + \node[above, EmbodiedColor] at (1.5, 3) {\textit{Upstream}}; + \node[above, OpsColor] at (4.5, 2.5) {\textit{Operational}}; + \node[above, ValueColor] at (7.5, 2) {\textit{Downstream}}; + +\end{tikzpicture} +``` +::: + Embodied carbon encompasses emissions from raw material extraction, semiconductor fabrication, assembly, transportation, and end-of-life disposal. For AI hardware, manufacturing emissions are dominated by the energy-intensive nature of advanced semiconductor processes. A single NVIDIA H100 GPU embodies approximately 150 to 200 kg CO2eq from manufacturing, including wafer fabrication at advanced process nodes, high-bandwidth memory production, and packaging. @eq-embodied-daily amortizes this embodied carbon over the hardware lifetime to compute per-use emissions: