Files
cs249r_book/labs/vol1/lab_11_hw_accel.py
Vijay Janapa Reddi 6f5732558f feat: add complete first-draft labs for both volumes (33 Marimo labs)
Add all Vol1 (labs 01-16) and Vol2 (labs 01-17) interactive Marimo labs
as the first full first-pass implementation of the ML Systems curriculum labs.

Each lab follows the PROTOCOL 2-Act structure (35-40 min):
- Act I: Calibration with prediction lock → instruments → overlay
- Act II: Design challenge with failure states and reflection

Key pedagogical instruments introduced progressively:
- Vol1: D·A·M Triad, Iron Law, Memory Ledger, Roofline, Amdahl's Law,
  Little's Law, P99 Histogram, Compression Frontier, Chouldechova theorem
- Vol2: NVLink vs PCIe cliff, Bisection BW, Young-Daly T*, Parallelism Paradox,
  AllReduce ring vs tree, KV-cache model, Jevons Paradox, DP ε-δ tradeoff,
  SLO composition, Adversarial Pareto, two-volume synthesis capstone

All 35 staged files pass AST syntax verification (36/36 including lab_00).

Also includes:
- labs/LABS_SPEC.md: authoritative sub-agent brief for all lab conventions
- labs/core/style.py: expanded unified design system with semantic color tokens
2026-03-01 19:59:04 -05:00

1448 lines
61 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import marimo
__generated_with = "0.19.6"
app = marimo.App(width="full")
# ─────────────────────────────────────────────────────────────────────────────
# LAB 11: THE ROOFLINE
#
# Core Invariant:
# Every compute kernel is either compute-bound or memory-bandwidth-bound.
# The Roofline Model makes this concrete:
# attainable_perf = min(peak_flops, bandwidth x arithmetic_intensity)
# The ridge point = peak_flops / peak_bandwidth separates the two regimes.
# MFU (Model FLOP Utilization) = achieved_FLOPS / peak_FLOPS.
#
# Contexts: Cloud H100 (80 GB HBM3e, 3350 GB/s, 1979 TFLOPS FP16) vs
# Edge Jetson Orin NX (16 GB, 102 GB/s, ~12 TFLOPS FP16)
#
# New Instrument: Interactive Roofline Model (first introduction in curriculum)
#
# Act I — The Memory Wall (12-15 min)
# Scenario: GPU Kernel Engineer — GEMM achieves 15.7% MFU on H100. Why?
# Prediction: Why is MFU only 15.7%? (memory-bound, below ridge point)
# Instrument: Interactive Roofline — adjust M×N×K, precision, device
# Reflection: Most effective MFU improvement when memory-bound?
# Correct: Increase tile dimensions to raise arithmetic intensity
#
# Act II — The Design Challenge (20-25 min)
# Scenario: ML Infra Lead — 3 LLM kernel types, which are bottlenecks?
# Prediction: Which kernel types are memory-bound vs compute-bound?
# Instrument: Multi-operation Roofline + kernel fusion strategies
# Failure state: KV-cache + activations exceed device RAM
# Reflection: Why does kernel fusion improve memory-bound ops?
# Correct: Eliminates redundant reads/writes — fused ops load data once
#
# Key constants (all from NVIDIA spec sheets and hw_acceleration.qmd):
# H100_BW_GBS = 3350 # H100 SXM5 HBM3e bandwidth, NVIDIA spec
# H100_TFLOPS_FP16 = 1979 # H100 FP16 tensor core TFLOPS, NVIDIA spec
# H100_RAM_GB = 80 # H100 HBM3e capacity, NVIDIA spec
# H100_RIDGE_PT = 591 # FLOP/byte = 1979e12 / 3350e9, derived
# ORIN_BW_GBS = 102 # Jetson Orin NX 16GB, NVIDIA spec
# ORIN_TFLOPS_FP16 = 12 # Jetson Orin NX FP16 TFLOPS, estimated from GPU die
# ORIN_RAM_GB = 16 # Jetson Orin NX RAM, NVIDIA spec
# ORIN_RIDGE_PT = 10 # conservative FLOP/byte estimate per chapter text
# ─────────────────────────────────────────────────────────────────────────────
# ─── CELL 0: SETUP (hide_code=False — leave visible for instructors) ──────────
@app.cell
def _():
import marimo as mo
import sys
from pathlib import Path
import plotly.graph_objects as go
import numpy as np
_root = Path(__file__).resolve().parents[2]
if str(_root) not in sys.path:
sys.path.insert(0, str(_root))
from labs.core.state import DesignLedger
from labs.core.style import COLORS, LAB_CSS, apply_plotly_theme
ledger = DesignLedger()
# ── Hardware constants (all plain floats, no pint units) ──────────────────
# Cloud: NVIDIA H100 SXM5 (NVIDIA spec sheet, 2023)
H100_BW_GBS = 3350 # GB/s HBM3e memory bandwidth
H100_TFLOPS_FP16 = 1979 # TFLOPS FP16 tensor core peak
H100_TFLOPS_FP32 = 67 # TFLOPS FP32 (non-tensor)
H100_TFLOPS_INT8 = 3958 # TFLOPS INT8 tensor core peak (2x FP16)
H100_RAM_GB = 80 # GB HBM3e total capacity
H100_TDP_W = 700 # Watts TDP
H100_RIDGE_PT = 591 # FLOP/byte = 1979e12 / 3350e9 (derived)
# Edge: NVIDIA Jetson Orin NX 16GB (NVIDIA Jetson product brief 2023)
ORIN_BW_GBS = 102 # GB/s LPDDR5 memory bandwidth
ORIN_TFLOPS_FP16 = 12 # TFLOPS FP16 estimated from GPU die specs
ORIN_TFLOPS_INT8 = 100 # TOPS INT8 (advertised on product page)
ORIN_RAM_GB = 16 # GB LPDDR5
ORIN_TDP_W = 25 # Watts maximum TDP
ORIN_RIDGE_PT = 10 # FLOP/byte conservative (chapter text value)
# Note: exact ORIN FP16 ridge = 12e12/102e9 = 118; chapter uses ~10 as
# a conservative figure accounting for sustained vs peak bandwidth
# Bytes per element by precision
BYTES_FP32 = 4 # bytes per FP32 element
BYTES_FP16 = 2 # bytes per FP16/BF16 element
BYTES_INT8 = 1 # bytes per INT8 element
# Llama-3-8B architecture constants (public model card, Meta 2024)
LLAMA3_DMODEL = 4096 # embedding dimension
LLAMA3_LAYERS = 32 # transformer layers
LLAMA3_HEADS = 32 # attention heads
LLAMA3_PARAMS = 8e9 # total parameter count
return (
mo, ledger, go, np,
COLORS, LAB_CSS, apply_plotly_theme,
H100_BW_GBS, H100_TFLOPS_FP16, H100_TFLOPS_FP32, H100_TFLOPS_INT8,
H100_RAM_GB, H100_TDP_W, H100_RIDGE_PT,
ORIN_BW_GBS, ORIN_TFLOPS_FP16, ORIN_TFLOPS_INT8,
ORIN_RAM_GB, ORIN_TDP_W, ORIN_RIDGE_PT,
BYTES_FP32, BYTES_FP16, BYTES_INT8,
LLAMA3_DMODEL, LLAMA3_LAYERS, LLAMA3_HEADS, LLAMA3_PARAMS,
)
# ─── HEADER ──────────────────────────────────────────────────────────────────
@app.cell(hide_code=True)
def _(mo, LAB_CSS):
mo.vstack([
LAB_CSS,
mo.md("""
<div style="background: linear-gradient(135deg, #0f172a 0%, #1e293b 100%);
padding: 36px 44px; border-radius: 16px; color: white;
box-shadow: 0 8px 32px rgba(0,0,0,0.3);">
<div style="font-size: 0.72rem; font-weight: 700; letter-spacing: 0.18em;
color: #475569; text-transform: uppercase; margin-bottom: 10px;">
Machine Learning Systems · Volume I · Lab 11
</div>
<h1 style="margin: 0 0 10px 0; font-size: 2.4rem; font-weight: 900;
color: #f8fafc; line-height: 1.1; letter-spacing: -0.02em;">
The Roofline
</h1>
<p style="margin: 0 0 20px 0; font-size: 1.05rem; color: #94a3b8;
max-width: 680px; line-height: 1.65;">
A kernel that achieves 312 TFLOPS on a 1979 TFLOPS accelerator
is running at 15.7% efficiency. Is the algorithm broken?
No — the roof is somewhere else. This lab forces you to find it.
</p>
<div style="display: flex; gap: 12px; flex-wrap: wrap;">
<span style="background: rgba(99,102,241,0.15); color: #a5b4fc;
padding: 5px 14px; border-radius: 20px; font-size: 0.8rem;
font-weight: 600; border: 1px solid rgba(99,102,241,0.25);">
Cloud vs Edge
</span>
<span style="background: rgba(16,185,129,0.15); color: #6ee7b7;
padding: 5px 14px; border-radius: 20px; font-size: 0.8rem;
font-weight: 600; border: 1px solid rgba(16,185,129,0.25);">
3540 min
</span>
<span style="background: rgba(99,102,241,0.15); color: #a5b4fc;
padding: 5px 14px; border-radius: 20px; font-size: 0.8rem;
font-weight: 600; border: 1px solid rgba(99,102,241,0.25);">
Prerequisite: hw_acceleration.qmd
</span>
<span style="background: rgba(203,32,45,0.15); color: #fca5a5;
padding: 5px 14px; border-radius: 20px; font-size: 0.8rem;
font-weight: 600; border: 1px solid rgba(203,32,45,0.25);">
New instrument: Roofline Model
</span>
</div>
<div style="display: flex; gap: 12px; flex-wrap: wrap; margin-top: 14px;">
<span class="badge badge-ok">AI above ridge = Compute-Bound</span>
<span class="badge badge-warn">AI below ridge = Memory-Bound</span>
<span class="badge badge-fail">KV-cache exceeds device RAM = OOM</span>
</div>
</div>
"""),
])
return
# ─── RECOMMENDED READING ─────────────────────────────────────────────────────
@app.cell(hide_code=True)
def _(mo):
mo.callout(mo.md("""
**Recommended Reading** — Complete the following before this lab:
- **@sec-roofline-model** — The Roofline Model: memory wall, compute ceiling, ridge point
- **@sec-arithmetic-intensity** — Arithmetic intensity definition and GEMM derivation
- **@sec-mfu** — Model FLOP Utilization: what it measures and why 100% is never achievable
- **@sec-kernel-fusion** — Fusing elementwise ops: why it reduces memory traffic
The lab assumes you know what arithmetic intensity (FLOP/byte) means.
If the term *ridge point* is unfamiliar, re-read @sec-roofline-model first.
"""), kind="info")
return
# ─── CONTEXT TOGGLE ──────────────────────────────────────────────────────────
@app.cell(hide_code=True)
def _(mo):
context_toggle = mo.ui.radio(
options={
"Cloud (H100 SXM5, 80 GB HBM3e, 1979 TFLOPS FP16)": "cloud",
"Edge (Jetson Orin NX, 16 GB LPDDR5, ~12 TFLOPS FP16)": "edge",
},
value="Cloud (H100 SXM5, 80 GB HBM3e, 1979 TFLOPS FP16)",
label="Deployment context:",
inline=True,
)
mo.vstack([
mo.md("---"),
mo.md("## Select Your Deployment Context"),
mo.md(
"The ridge point separating memory-bound from compute-bound is a hardware "
"constant. On an H100 it is 591 FLOP/byte; on a Jetson Orin NX it is roughly "
"10 FLOP/byte. Select your context — the entire roofline changes with it."
),
context_toggle,
])
return (context_toggle,)
# =============================================================================
# ACT I — THE MEMORY WALL
# =============================================================================
@app.cell(hide_code=True)
def _(mo, context_toggle, COLORS):
_ctx = context_toggle.value
_color = COLORS["Cloud"] if _ctx == "cloud" else COLORS["Edge"]
_bg = "#EBF4FA" if _ctx == "cloud" else COLORS["RedLL"]
_device_name = "H100 SXM5" if _ctx == "cloud" else "Jetson Orin NX"
_peak_tflops = "1979 TFLOPS FP16" if _ctx == "cloud" else "~12 TFLOPS FP16"
_peak_bw = "3350 GB/s" if _ctx == "cloud" else "102 GB/s"
_ridge_str = "591 FLOP/byte" if _ctx == "cloud" else "~10 FLOP/byte"
mo.vstack([
mo.md("---"),
mo.Html(f"""
<div style="border-left:4px solid {_color}; background:{_bg};
border-radius:0 10px 10px 0; padding:16px 22px; margin:12px 0;">
<div style="font-size:0.72rem; font-weight:700; color:{_color};
text-transform:uppercase; letter-spacing:0.1em; margin-bottom:6px;">
Incoming Message · GPU Kernel Engineer
</div>
<div style="font-style:italic; font-size:1.0rem; color:#1e293b; line-height:1.65;">
"Our matrix multiply kernel achieves 312 TFLOPS on the {_device_name}.
Peak is {_peak_tflops}. MFU is about 15.7%.
We have spent three weeks tuning the arithmetic — register tiling,
loop unrolling, mixed precision. The math has not moved one TFLOP.
Our manager is asking why we keep hitting the same wall.
What ceiling are we running into?"
</div>
</div>
"""),
mo.md(f"""
## Act I — The Memory Wall
The engineer has been optimizing the wrong bottleneck for three weeks.
The {_device_name} has a peak compute of {_peak_tflops} but a memory
bandwidth of {_peak_bw}. The **ridge point** is the arithmetic intensity
at which a kernel transitions from memory-bound to compute-bound: **{_ridge_str}**.
A kernel running *below* the ridge point is bottlenecked by data movement,
not by floating-point throughput. No amount of arithmetic optimization
will improve a memory-bound kernel.
"""),
])
return
# ─── ACT I PREDICTION ────────────────────────────────────────────────────────
@app.cell(hide_code=True)
def _(mo):
act1_pred = mo.ui.radio(
options={
"A) The algorithm has a bug — matrix multiply should always be compute-bound": "A",
"B) The kernel is not memory-bound — increase the tile size": "B",
"C) The operation is memory-bandwidth-bound — arithmetic intensity is below the ridge point": "C",
"D) 15% MFU is already the hardware maximum for matrix multiply": "D",
},
label="Your prediction: why is MFU only 15.7% despite correct arithmetic?",
)
mo.vstack([
mo.md("### Your Prediction"),
mo.md(
"*Before touching the simulator, commit to your hypothesis. "
"The roofline is locked until you do.*"
),
act1_pred,
])
return (act1_pred,)
@app.cell(hide_code=True)
def _(mo, act1_pred):
mo.stop(
act1_pred.value is None,
mo.callout(
mo.md("Select your prediction above to unlock the Roofline simulator."),
kind="warn",
)
)
mo.callout(
mo.md(f"**Prediction locked:** Option {act1_pred.value}. Now explore the simulator below."),
kind="info",
)
return
# ─── ACT I ROOFLINE CONTROLS ─────────────────────────────────────────────────
@app.cell(hide_code=True)
def _(mo):
act1_matrix_n = mo.ui.slider(
start=128, stop=8192, value=512, step=128,
label="Matrix dimension N (square M=N=K)",
)
act1_precision = mo.ui.dropdown(
options={
"FP32 (4 bytes/element)": "fp32",
"FP16 / BF16 (2 bytes/element)": "fp16",
"INT8 (1 byte/element)": "int8",
},
value="FP16 / BF16 (2 bytes/element)",
label="Precision",
)
mo.vstack([
mo.md("---"),
mo.md("### The Roofline Simulator"),
mo.md(
"Adjust matrix dimension and precision. Watch the operation point move. "
"The vertical dashed line is the ridge point — the boundary between "
"memory-bound (left) and compute-bound (right)."
),
mo.hstack([act1_matrix_n, act1_precision], justify="start", gap="2rem"),
])
return (act1_matrix_n, act1_precision)
# ─── ACT I ROOFLINE PLOT ─────────────────────────────────────────────────────
@app.cell(hide_code=True)
def _(
mo, go, np,
context_toggle,
act1_matrix_n, act1_precision,
COLORS,
H100_BW_GBS, H100_TFLOPS_FP16, H100_TFLOPS_FP32, H100_TFLOPS_INT8, H100_RIDGE_PT,
ORIN_BW_GBS, ORIN_TFLOPS_FP16, ORIN_TFLOPS_INT8, ORIN_RIDGE_PT,
BYTES_FP32, BYTES_FP16, BYTES_INT8,
):
_ctx = context_toggle.value
_N = act1_matrix_n.value
_precision = act1_precision.value
# ── Device specs based on context and precision ───────────────────────────
if _ctx == "cloud":
_peak_bw_gbs = H100_BW_GBS
_ridge_pt = H100_RIDGE_PT
_device_label = "H100 SXM5"
_ctx_color = COLORS["Cloud"]
if _precision == "fp32":
_peak_flops_t = H100_TFLOPS_FP32
elif _precision == "int8":
_peak_flops_t = H100_TFLOPS_INT8
else:
_peak_flops_t = H100_TFLOPS_FP16
else:
_peak_bw_gbs = ORIN_BW_GBS
_ridge_pt = ORIN_RIDGE_PT
_device_label = "Jetson Orin NX"
_ctx_color = COLORS["Edge"]
if _precision == "int8":
_peak_flops_t = ORIN_TFLOPS_INT8
else:
_peak_flops_t = ORIN_TFLOPS_FP16
# ── Precision bytes ────────────────────────────────────────────────────────
if _precision == "fp32":
_bytes_elem = BYTES_FP32
_prec_label = "FP32"
elif _precision == "int8":
_bytes_elem = BYTES_INT8
_prec_label = "INT8"
else:
_bytes_elem = BYTES_FP16
_prec_label = "FP16"
# ── GEMM arithmetic intensity ─────────────────────────────────────────────
# For square M=N=K GEMM at B bytes/element:
# FLOPs = 2 x N^3 (N multiply-adds per dot product, N^2 outputs)
# Bytes = (MN + NK + MK) x B = 3 x N^2 x B (read A, B, write C)
# AI = 2N^3 / (3N^2 x B) = 2N / (3B)
_flops_gemm = 2.0 * _N * _N * _N
_bytes_gemm = 3.0 * _N * _N * _bytes_elem
_ai_gemm = _flops_gemm / _bytes_gemm
# ── Attainable performance ────────────────────────────────────────────────
_peak_flops_per_s = _peak_flops_t * 1e12
_peak_bw_per_s = _peak_bw_gbs * 1e9
_attain_t = min(_peak_flops_t, (_peak_bw_per_s * _ai_gemm) / 1e12)
_mfu_pct = (_attain_t / _peak_flops_t) * 100.0
_is_mem_bound = _ai_gemm < _ridge_pt
_regime_label = "Memory-Bound" if _is_mem_bound else "Compute-Bound"
_regime_color = COLORS["OrangeLine"] if _is_mem_bound else COLORS["GreenLine"]
# ── Build roofline curve ───────────────────────────────────────────────────
_ai_axis = np.logspace(-1, 4, 500)
_mem_slope = (_peak_bw_per_s * _ai_axis) / 1e12
_comp_ceil = np.full_like(_ai_axis, _peak_flops_t)
_roofline = np.minimum(_mem_slope, _comp_ceil)
# ── Build figure ──────────────────────────────────────────────────────────
_fig = go.Figure()
# Memory-bound segment (left of ridge)
_mask_m = _ai_axis <= _ridge_pt
_fig.add_trace(go.Scatter(
x=_ai_axis[_mask_m], y=_roofline[_mask_m],
mode="lines", name="Memory-bound roof",
line=dict(color=COLORS["OrangeLine"], width=3),
))
# Compute-bound segment (right of ridge)
_mask_c = _ai_axis >= _ridge_pt
_fig.add_trace(go.Scatter(
x=_ai_axis[_mask_c], y=_roofline[_mask_c],
mode="lines", name="Compute ceiling",
line=dict(color=COLORS["GreenLine"], width=3),
))
# Ridge point vertical dashed line
_fig.add_vline(
x=_ridge_pt,
line=dict(color="#64748b", width=2, dash="dash"),
annotation_text=f"Ridge: {_ridge_pt:.0f} FLOP/byte",
annotation_position="top right",
annotation_font=dict(size=11, color="#64748b"),
)
# Shaded zones
_fig.add_vrect(x0=0.1, x1=_ridge_pt,
fillcolor=COLORS["OrangeLine"], opacity=0.06, layer="below", line_width=0)
_fig.add_vrect(x0=_ridge_pt, x1=10000,
fillcolor=COLORS["GreenLine"], opacity=0.04, layer="below", line_width=0)
# Zone text labels
_fig.add_annotation(x=0.18, y=0.12, xref="paper", yref="paper",
text="Memory-Bound Zone", font=dict(size=10, color=COLORS["OrangeLine"]),
showarrow=False)
_fig.add_annotation(x=0.82, y=0.12, xref="paper", yref="paper",
text="Compute-Bound Zone", font=dict(size=10, color=COLORS["GreenLine"]),
showarrow=False)
# Vertical drop to x-axis from operation point
_fig.add_shape(
type="line",
x0=_ai_gemm, y0=1e-3,
x1=_ai_gemm, y1=_attain_t,
line=dict(color=_regime_color, width=1, dash="dot"),
layer="below",
)
# Operation point
_fig.add_trace(go.Scatter(
x=[_ai_gemm], y=[_attain_t],
mode="markers+text",
name=f"GEMM {_N}x{_N}x{_N} ({_prec_label})",
marker=dict(size=16, color=_regime_color,
line=dict(color="white", width=2), symbol="circle"),
text=[f" {_attain_t:.0f} TFLOPS ({_mfu_pct:.1f}% MFU)"],
textposition="middle right",
textfont=dict(size=11, color=_regime_color),
))
_fig.update_layout(
title=dict(
text=f"Roofline Model — {_device_label} ({_prec_label})",
font=dict(size=15, color=COLORS["Text"]), x=0.02,
),
xaxis=dict(type="log", title="Arithmetic Intensity (FLOP / byte)",
range=[-1, 4], gridcolor="#f1f5f9", linecolor=COLORS["Border"]),
yaxis=dict(type="log", title="Attainable Performance (TFLOPS)",
range=[-2, 4] if _ctx == "cloud" else [-2, 2.5],
gridcolor="#f1f5f9", linecolor=COLORS["Border"]),
legend=dict(x=0.02, y=0.98, bgcolor="rgba(255,255,255,0.85)"),
height=460,
plot_bgcolor="white", paper_bgcolor="white",
font_family="Inter, sans-serif", font_color=COLORS["Text"],
margin=dict(l=60, r=30, t=50, b=60),
)
# ── Metric cards ──────────────────────────────────────────────────────────
_mfu_color = (COLORS["GreenLine"] if _mfu_pct > 50
else COLORS["OrangeLine"] if _mfu_pct > 20
else COLORS["RedLine"])
_ai_color = COLORS["GreenLine"] if not _is_mem_bound else COLORS["OrangeLine"]
mo.vstack([
mo.as_html(_fig),
mo.Html(f"""
<div style="display:flex; gap:16px; flex-wrap:wrap; margin:16px 0 8px 0;">
<div style="padding:18px 24px; border:1px solid #e2e8f0; border-radius:10px;
min-width:165px; text-align:center; background:white;
box-shadow:0 2px 6px rgba(0,0,0,0.04);">
<div style="color:#64748b; font-size:0.78rem; font-weight:600;
text-transform:uppercase; letter-spacing:0.05em; margin-bottom:4px;">
Arithmetic Intensity
</div>
<div style="font-size:2rem; font-weight:800; color:{_ai_color};">{_ai_gemm:.0f}</div>
<div style="font-size:0.73rem; color:#94a3b8;">FLOP / byte</div>
</div>
<div style="padding:18px 24px; border:1px solid #e2e8f0; border-radius:10px;
min-width:165px; text-align:center; background:white;
box-shadow:0 2px 6px rgba(0,0,0,0.04);">
<div style="color:#64748b; font-size:0.78rem; font-weight:600;
text-transform:uppercase; letter-spacing:0.05em; margin-bottom:4px;">
Attainable Perf
</div>
<div style="font-size:2rem; font-weight:800; color:{_mfu_color};">{_attain_t:.0f}</div>
<div style="font-size:0.73rem; color:#94a3b8;">TFLOPS</div>
</div>
<div style="padding:18px 24px; border:1px solid #e2e8f0; border-radius:10px;
min-width:165px; text-align:center; background:white;
box-shadow:0 2px 6px rgba(0,0,0,0.04);">
<div style="color:#64748b; font-size:0.78rem; font-weight:600;
text-transform:uppercase; letter-spacing:0.05em; margin-bottom:4px;">
MFU
</div>
<div style="font-size:2rem; font-weight:800; color:{_mfu_color};">{_mfu_pct:.1f}%</div>
<div style="font-size:0.73rem; color:#94a3b8;">vs peak {_peak_flops_t:.0f} TFLOPS</div>
</div>
<div style="padding:18px 24px; border:1px solid #e2e8f0; border-radius:10px;
min-width:165px; text-align:center; background:white;
box-shadow:0 2px 6px rgba(0,0,0,0.04);">
<div style="color:#64748b; font-size:0.78rem; font-weight:600;
text-transform:uppercase; letter-spacing:0.05em; margin-bottom:4px;">
Regime
</div>
<div style="font-size:1.4rem; font-weight:800; color:{_regime_color}; margin:6px 0 2px 0;">
{_regime_label}
</div>
<div style="font-size:0.73rem; color:#94a3b8;">AI {'<' if _is_mem_bound else '>='} {_ridge_pt:.0f}</div>
</div>
</div>
"""),
mo.md(f"""
**Physics (visible):**
```
GEMM FLOPs = 2 x M x N x K = 2 x {_N}^3 = {_flops_gemm/1e9:.1f} GFLOPs
GEMM Bytes = (MN + NK + MK) x {_bytes_elem} = 3 x {_N}^2 x {_bytes_elem} = {_bytes_gemm/1e6:.1f} MB
Arith. Intens. = {_flops_gemm:.2e} / {_bytes_gemm:.2e} = {_ai_gemm:.1f} FLOP/byte
Ridge Point = peak_FLOPS / BW = {_peak_flops_t} TFLOPS / {_peak_bw_gbs} GB/s = {_ridge_pt:.0f} FLOP/byte
Attainable = min(peak_FLOPS, BW x AI)
= min({_peak_flops_t}, {_peak_bw_gbs}e9 x {_ai_gemm:.1f} / 1e12) TFLOPS
= min({_peak_flops_t:.0f}, {(_peak_bw_gbs*1e9*_ai_gemm/1e12):.0f}) TFLOPS
= {_attain_t:.1f} TFLOPS
MFU = {_attain_t:.1f} / {_peak_flops_t:.0f} = {_mfu_pct:.1f}%
```
"""),
])
return (_ai_gemm, _attain_t, _mfu_pct, _is_mem_bound, _peak_flops_t, _ridge_pt, _regime_label)
# ─── ACT I PREDICTION VS REALITY ─────────────────────────────────────────────
@app.cell(hide_code=True)
def _(mo, act1_pred, _ai_gemm, _ridge_pt, _is_mem_bound, _regime_label):
_pred = act1_pred.value
_feedback = {
"A": (
"**Not quite.** There is no bug. A square GEMM with N=512 has an arithmetic "
"intensity of only ~170 FLOP/byte — far below the H100 ridge point of 591 FLOP/byte. "
"Matrix multiply *can* be compute-bound, but only for large enough matrices. "
"The ratio of FLOPs to bytes grows linearly with N."
),
"B": (
"**Close, but the logic is inverted.** The kernel *is* memory-bound — "
"that is exactly the problem. Increasing tile size is the correct *solution*, "
"not evidence against memory-boundedness. Larger tiles increase arithmetic "
"intensity by reusing data in on-chip SRAM instead of re-fetching from HBM."
),
"C": (
"**Correct.** With N=512, arithmetic intensity is ~170 FLOP/byte — well below "
"the H100 ridge point of 591 FLOP/byte. The kernel is memory-bandwidth-bound. "
"Optimizing FP arithmetic does nothing when the bottleneck is data movement. "
"The kernel spends most cycles waiting for data, not computing."
),
"D": (
"**Incorrect.** 15% MFU is not a hardware maximum. It reflects a specific "
"operating point below the ridge. At N=8192 the same hardware achieves >70% MFU "
"because arithmetic intensity rises to ~2730 FLOP/byte, well past the ridge."
),
}
_correct = _pred == "C"
_actual = "memory-bound" if _is_mem_bound else "compute-bound"
mo.vstack([
mo.md("### Prediction vs Reality"),
mo.callout(
mo.md(
f"{_feedback.get(_pred, '')}\n\n"
f"**Actual result:** AI = {_ai_gemm:.0f} FLOP/byte, "
f"ridge = {_ridge_pt:.0f} FLOP/byte. "
f"Operation is **{_actual}** ({_regime_label})."
),
kind="success" if _correct else "warn",
),
])
return
# ─── ACT I REFLECTION ────────────────────────────────────────────────────────
@app.cell(hide_code=True)
def _(mo):
act1_reflect = mo.ui.radio(
options={
"A) Increase clock speed": "A",
"B) Use larger batch sizes or tile dimensions to raise arithmetic intensity": "B",
"C) Reduce floating-point precision": "C",
"D) Add more GPU memory capacity": "D",
},
label="Reflection: what is the most effective way to improve MFU for a memory-bound kernel?",
)
mo.vstack([
mo.md("---"),
mo.md("### Act I Reflection"),
act1_reflect,
])
return (act1_reflect,)
@app.cell(hide_code=True)
def _(mo, act1_reflect):
mo.stop(
act1_reflect.value is None,
mo.callout(mo.md("Select your answer to see the explanation."), kind="warn")
)
_reflect_feedback = {
"A": (
"**Incorrect.** Higher clock speed increases raw FLOP/s, but bandwidth is the "
"bottleneck. If data cannot move faster, the compute units remain starved "
"regardless of frequency."
),
"B": (
"**Correct.** Larger tiles allow the kernel to reuse data held in fast on-chip "
"SRAM across more arithmetic operations before evicting it back to HBM. "
"This increases FLOPs per byte — shifting the operation rightward on the "
"roofline toward and past the ridge point. This is exactly what cuBLAS tiles "
"and FlashAttention's block structure accomplish."
),
"C": (
"**Incorrect.** Reducing precision (FP32 to FP16) halves bytes-per-element and "
"doubles FLOP/s, but the *ratio* FLOPs/bytes stays roughly constant. "
"The arithmetic intensity of the GEMM is unchanged for a given matrix shape."
),
"D": (
"**Incorrect.** More capacity (larger HBM) is not the same as more bandwidth. "
"The bottleneck is the *rate* of data transfer. A larger memory pool that "
"arrives at the same rate still starves the compute units."
),
}
_correct = act1_reflect.value == "B"
mo.callout(
mo.md(_reflect_feedback[act1_reflect.value]),
kind="success" if _correct else "warn",
)
return
# ─── ACT I MATHPEEK ──────────────────────────────────────────────────────────
@app.cell(hide_code=True)
def _(mo):
mo.accordion({
"The governing equations": mo.md("""
**Roofline attainable performance:**
```
attainable_perf(I) = min(peak_flops, BW x I)
```
`I` = arithmetic intensity (FLOP / byte), `BW` = peak bandwidth (byte/s),
`peak_flops` = peak compute (FLOP/s).
**The ridge point** is where the two bounds are equal:
```
ridge_point = peak_flops / BW [FLOP / byte]
```
- H100 SXM5: ridge = 1979e12 / 3350e9 = 591 FLOP/byte
- Jetson Orin NX: ridge ~10 FLOP/byte (conservative, per chapter text)
**GEMM arithmetic intensity** for square M=N=K at B bytes/element:
```
FLOPs = 2 x N^3
Bytes = 3 x N^2 x B (read A, read B, write C)
I_gemm = 2N^3 / (3N^2 x B) = 2N / (3B)
```
At FP16 (B=2): I_gemm = N / 3. So I grows *linearly with N*.
| N | AI (FP16) | H100 regime |
|-------|-----------|--------------------|
| 128 | 43 | Memory-Bound |
| 512 | 170 | Memory-Bound |
| 1024 | 341 | Memory-Bound |
| 1773 | 591 | AT ridge point |
| 4096 | 1365 | Compute-Bound |
| 8192 | 2730 | Compute-Bound |
**MFU (Model FLOP Utilization):**
```
MFU = achieved_FLOPS / peak_FLOPS
```
MFU directly tells you how much of the advertised peak you are using.
A memory-bound kernel always has MFU < (BW x I) / peak_flops.
"""),
})
return
# =============================================================================
# ACT II — THE DESIGN CHALLENGE
# =============================================================================
@app.cell(hide_code=True)
def _(mo, context_toggle, COLORS):
_ctx = context_toggle.value
_color = COLORS["Cloud"] if _ctx == "cloud" else COLORS["Edge"]
_bg = "#EBF4FA" if _ctx == "cloud" else COLORS["RedLL"]
_device = "H100" if _ctx == "cloud" else "Jetson Orin NX"
_ram = "80 GB" if _ctx == "cloud" else "16 GB"
_ridge2 = "591 FLOP/byte" if _ctx == "cloud" else "~10 FLOP/byte"
mo.vstack([
mo.md("---"),
mo.Html(f"""
<div style="border-left:4px solid {_color}; background:{_bg};
border-radius:0 10px 10px 0; padding:16px 22px; margin:12px 0;">
<div style="font-size:0.72rem; font-weight:700; color:{_color};
text-transform:uppercase; letter-spacing:0.1em; margin-bottom:6px;">
Incoming Message · ML Infrastructure Lead
</div>
<div style="font-style:italic; font-size:1.0rem; color:#1e293b; line-height:1.65;">
"We are designing an LLM inference service on the {_device}
({_ram} RAM, ridge point {_ridge2}).
We have three kernel types: (1) large GEMM for the attention
projection layers, (2) layer normalization, (3) softmax over
attention scores. We have a budget to optimize exactly one kernel.
Half the team says GEMM. Half says fuse the elementwise ops.
Which operations are the actual bottlenecks and what do we do first?"
</div>
</div>
"""),
mo.md(f"""
## Act II — The Design Challenge
An LLM inference stack mixes kernels with wildly different arithmetic intensities.
Large GEMM operations can become compute-bound at large batch sizes.
Elementwise operations — layer norm, softmax — have arithmetic intensities
near 0.8 FLOP/byte and are *always* memory-bound regardless of batch size.
Kernel fusion addresses memory-bound ops by eliminating redundant HBM round-trips.
But push batch size too far and the KV-cache alone exhausts device RAM.
"""),
])
return
# ─── ACT II PREDICTION ───────────────────────────────────────────────────────
@app.cell(hide_code=True)
def _(mo):
act2_pred = mo.ui.radio(
options={
"A) Matrix multiply is always the bottleneck — optimize it first": "A",
"B) Softmax and layer norm are likely memory-bound; GEMM may be compute-bound at large batch sizes": "B",
"C) All three operations have similar arithmetic intensity": "C",
"D) Layer norm is compute-bound because it computes square roots": "D",
},
label="Your prediction: which kernel types are memory-bound vs compute-bound in LLM inference?",
)
mo.vstack([
mo.md("### Your Prediction"),
mo.md("*Commit before configuring the multi-operation design.*"),
act2_pred,
])
return (act2_pred,)
@app.cell(hide_code=True)
def _(mo, act2_pred):
mo.stop(
act2_pred.value is None,
mo.callout(
mo.md("Select your prediction above to unlock the design instruments."),
kind="warn",
)
)
mo.callout(
mo.md(f"**Prediction locked:** Option {act2_pred.value}. Configure the design below."),
kind="info",
)
return
# ─── ACT II DESIGN CONTROLS ──────────────────────────────────────────────────
@app.cell(hide_code=True)
def _(mo):
act2_batch = mo.ui.slider(
start=1, stop=256, value=8, step=1,
label="Batch size (concurrent sequences)",
)
act2_seqlen = mo.ui.slider(
start=256, stop=4096, value=512, step=256,
label="Sequence length (tokens)",
)
act2_fusion = mo.ui.dropdown(
options={
"Separate kernels (no fusion)": "none",
"Fuse softmax + layer norm": "partial",
"Fuse all elementwise ops (softmax+LN+bias)": "full",
},
value="Separate kernels (no fusion)",
label="Kernel fusion strategy",
)
mo.vstack([
mo.md("---"),
mo.md("### Design Instruments"),
mo.md(
"Adjust batch size, sequence length, and fusion strategy. "
"Watch the operation points shift across the roofline and monitor the OOM boundary."
),
mo.hstack([act2_batch, act2_seqlen], justify="start", gap="2rem"),
act2_fusion,
])
return (act2_batch, act2_seqlen, act2_fusion)
# ─── ACT II MULTI-OPERATION ROOFLINE ─────────────────────────────────────────
@app.cell(hide_code=True)
def _(
mo, go, np,
context_toggle,
act2_batch, act2_seqlen, act2_fusion,
COLORS,
H100_BW_GBS, H100_TFLOPS_FP16, H100_RIDGE_PT, H100_RAM_GB,
ORIN_BW_GBS, ORIN_TFLOPS_FP16, ORIN_RIDGE_PT, ORIN_RAM_GB,
LLAMA3_DMODEL, LLAMA3_LAYERS, LLAMA3_HEADS, LLAMA3_PARAMS,
BYTES_FP16,
):
_ctx = context_toggle.value
_B = act2_batch.value
_S = act2_seqlen.value
_fusion = act2_fusion.value
# ── Device ────────────────────────────────────────────────────────────────
if _ctx == "cloud":
_peak_bw_gbs = H100_BW_GBS
_peak_flops_t = H100_TFLOPS_FP16
_ridge_pt2 = H100_RIDGE_PT
_device_ram = H100_RAM_GB
_dev_label = "H100 SXM5"
_ctx_color = COLORS["Cloud"]
else:
_peak_bw_gbs = ORIN_BW_GBS
_peak_flops_t = ORIN_TFLOPS_FP16
_ridge_pt2 = ORIN_RIDGE_PT
_device_ram = ORIN_RAM_GB
_dev_label = "Jetson Orin NX"
_ctx_color = COLORS["Edge"]
_peak_flops_per_s = _peak_flops_t * 1e12
_peak_bw_per_s = _peak_bw_gbs * 1e9
# ── Op 1: GEMM (attention Q/K/V projection) ───────────────────────────────
# Shape: [B, S, D] x [D, D] -> [B, S, D]
# FLOPs = 2 x B x S x D x D (one projection; 4 total Q,K,V,out)
# Bytes = (B*S*D + D*D + B*S*D) x 2 (read input, weight, write output)
_D = LLAMA3_DMODEL
_gemm_flops = 2.0 * _B * _S * _D * _D
_gemm_bytes = (_B * _S * _D + _D * _D + _B * _S * _D) * BYTES_FP16
_gemm_ai2 = _gemm_flops / _gemm_bytes
# ── Op 2: Layer Norm [B, S, D] ────────────────────────────────────────────
# FLOPs per token: 5*D (mean, var, normalize, scale, shift)
# Bytes per token: 3*D*bytes (load, compute, store)
# AI = 5D / (3D * 2) = 5/6 ~ 0.83 FLOP/byte
_ln_flops_tok = 5.0 * _D
_ln_bytes_tok = 3.0 * _D * BYTES_FP16
_ln_ai_base = _ln_flops_tok / _ln_bytes_tok # ~0.83 FLOP/byte
# ── Op 3: Softmax [B, heads, S, S] ────────────────────────────────────────
# FLOPs per element: ~5 (max sub, exp, sum, div, log2)
# Bytes: 3 * bytes (load, intermediate, store)
# AI = 5 / (3*2) ~ 0.83 FLOP/byte
_sm_ai_base = 5.0 / (3.0 * BYTES_FP16) # ~0.83 FLOP/byte
# ── Fusion adjustment ─────────────────────────────────────────────────────
# Fusion eliminates intermediate HBM round-trips.
# Unfused: each op reads+writes full tensor — byte factor = 1.0
# partial (softmax+LN fused): one fewer round-trip each — factor ~0.50
# full (all elementwise fused): two fewer round-trips — factor ~0.35
# This raises effective AI by reducing denominator bytes.
_ff = {"none": 1.0, "partial": 0.5, "full": 0.35}[_fusion]
_ln_ai_eff = _ln_ai_base / _ff
_sm_ai_eff = _sm_ai_base / _ff
# ── Attainable performance ────────────────────────────────────────────────
def _attain2(ai_val):
return min(_peak_flops_t, (_peak_bw_per_s * ai_val) / 1e12)
_gemm_perf2 = _attain2(_gemm_ai2)
_ln_perf2 = _attain2(_ln_ai_eff)
_sm_perf2 = _attain2(_sm_ai_eff)
# ── Memory footprint ──────────────────────────────────────────────────────
# KV-cache: 2 tensors (K,V) each [B, S, heads, d_head, layers] at FP16
# d_head = D / heads = 4096/32 = 128
_d_head = _D // LLAMA3_HEADS # 128
_kv_gb = (2.0 * _B * _S * LLAMA3_HEADS * _d_head
* LLAMA3_LAYERS * BYTES_FP16) / 1e9
# Weights in FP16: params * 2 bytes / 1e9
_weights_gb = (LLAMA3_PARAMS * BYTES_FP16) / 1e9 # ~16 GB for 8B
# Activations (inference only, no gradients): rough 0.5x weights
_activ_gb = 0.5 * _weights_gb
_total_mem_gb = _weights_gb + _kv_gb + _activ_gb
_oom2 = _total_mem_gb > _device_ram
# ── Throughput estimate (tokens/sec) ─────────────────────────────────────
# Dominated by GEMM: FLOPs per token = 2 * D^2 * 4 projections * layers
_flops_per_tok = 2.0 * _D * _D * 4 * LLAMA3_LAYERS
_tokens_per_sec2 = (_gemm_perf2 * 1e12) / _flops_per_tok if not _oom2 else 0.0
# ── Regime helper ─────────────────────────────────────────────────────────
def _regime2(ai_val):
if ai_val < _ridge_pt2:
return "Memory-Bound", COLORS["OrangeLine"]
return "Compute-Bound", COLORS["GreenLine"]
_gemm_reg2, _gemm_rc2 = _regime2(_gemm_ai2)
_ln_reg2, _ln_rc2 = _regime2(_ln_ai_eff)
_sm_reg2, _sm_rc2 = _regime2(_sm_ai_eff)
# ── OOM banner ────────────────────────────────────────────────────────────
if _oom2:
_oom_banner = mo.callout(
mo.md(
f"**OOM — Infeasible Design.** "
f"Required: {_total_mem_gb:.1f} GB "
f"(weights {_weights_gb:.1f} GB + KV-cache {_kv_gb:.1f} GB "
f"+ activations {_activ_gb:.1f} GB) "
f"| Available: {_device_ram:.0f} GB ({_dev_label}). "
f"Reduce batch size or sequence length to stay within device RAM."
),
kind="danger",
)
else:
_head_gb = _device_ram - _total_mem_gb
_oom_banner = mo.callout(
mo.md(
f"**Memory budget OK.** "
f"Total: {_total_mem_gb:.1f} GB / {_device_ram:.0f} GB used. "
f"Headroom: {_head_gb:.1f} GB."
),
kind="success",
)
# ── Build multi-operation Roofline figure ─────────────────────────────────
_ai_axis2 = np.logspace(-2, 4, 500)
_mem_perf2 = (_peak_bw_per_s * _ai_axis2) / 1e12
_comp_c2 = np.full_like(_ai_axis2, _peak_flops_t)
_roof2 = np.minimum(_mem_perf2, _comp_c2)
_fig2 = go.Figure()
_mask_m2 = _ai_axis2 <= _ridge_pt2
_fig2.add_trace(go.Scatter(
x=_ai_axis2[_mask_m2], y=_roof2[_mask_m2],
mode="lines", name="Memory-bound roof",
line=dict(color=COLORS["OrangeLine"], width=3),
))
_mask_c2 = _ai_axis2 >= _ridge_pt2
_fig2.add_trace(go.Scatter(
x=_ai_axis2[_mask_c2], y=_roof2[_mask_c2],
mode="lines", name="Compute ceiling",
line=dict(color=COLORS["GreenLine"], width=3),
))
_fig2.add_vline(
x=_ridge_pt2,
line=dict(color="#64748b", width=2, dash="dash"),
annotation_text=f"Ridge: {_ridge_pt2:.0f}",
annotation_position="top right",
annotation_font=dict(size=10, color="#64748b"),
)
_fig2.add_vrect(x0=0.01, x1=_ridge_pt2,
fillcolor=COLORS["OrangeLine"], opacity=0.05, layer="below", line_width=0)
_fig2.add_vrect(x0=_ridge_pt2, x1=10000,
fillcolor=COLORS["GreenLine"], opacity=0.04, layer="below", line_width=0)
# Drop lines
for _ai_pt, _perf_pt, _col_pt in [
(_gemm_ai2, _gemm_perf2, "#6366f1"),
(_ln_ai_eff, _ln_perf2, COLORS["OrangeLine"]),
(_sm_ai_eff, _sm_perf2, COLORS["BlueLine"]),
]:
_fig2.add_shape(
type="line",
x0=_ai_pt, y0=1e-4, x1=_ai_pt, y1=_perf_pt,
line=dict(color=_col_pt, width=1, dash="dot"), layer="below",
)
# Operation scatter points
_ops_data = [
("GEMM (attn. projection)", _gemm_ai2, _gemm_perf2, "#6366f1", "circle"),
("Layer Norm", _ln_ai_eff, _ln_perf2, COLORS["OrangeLine"], "diamond"),
("Softmax", _sm_ai_eff, _sm_perf2, COLORS["BlueLine"], "square"),
]
for _op_nm, _op_ai, _op_perf, _op_color, _op_sym in _ops_data:
_fig2.add_trace(go.Scatter(
x=[_op_ai], y=[_op_perf],
mode="markers+text",
name=_op_nm,
marker=dict(size=15, color=_op_color, symbol=_op_sym,
line=dict(color="white", width=2)),
text=[f" {_op_ai:.1f} F/B"],
textposition="middle right",
textfont=dict(size=10, color=_op_color),
))
_fig2.update_layout(
title=dict(
text=f"Multi-Operation Roofline — {_dev_label} | B={_B}, S={_S}, fusion={_fusion}",
font=dict(size=14, color=COLORS["Text"]), x=0.02,
),
xaxis=dict(type="log", title="Arithmetic Intensity (FLOP / byte)",
range=[-2, 4], gridcolor="#f1f5f9"),
yaxis=dict(type="log", title="Attainable Performance (TFLOPS)",
range=[-3, 4] if _ctx == "cloud" else [-3, 2],
gridcolor="#f1f5f9"),
legend=dict(x=0.02, y=0.98, bgcolor="rgba(255,255,255,0.88)"),
height=490,
plot_bgcolor="white", paper_bgcolor="white",
font_family="Inter, sans-serif", font_color=COLORS["Text"],
margin=dict(l=60, r=30, t=50, b=60),
)
# ── Metric cards ──────────────────────────────────────────────────────────
def _card_html(label, value, unit, color):
return f"""
<div style="padding:14px 18px; border:1px solid #e2e8f0; border-radius:10px;
min-width:148px; text-align:center; background:white;
box-shadow:0 2px 6px rgba(0,0,0,0.04);">
<div style="color:#64748b; font-size:0.73rem; font-weight:600;
text-transform:uppercase; letter-spacing:0.05em; margin-bottom:4px;">
{label}
</div>
<div style="font-size:1.5rem; font-weight:800; color:{color}; margin:4px 0 2px 0;">
{value}
</div>
<div style="font-size:0.7rem; color:#94a3b8;">{unit}</div>
</div>"""
_tps_color = (COLORS["GreenLine"] if _tokens_per_sec2 > 1000
else COLORS["OrangeLine"] if _tokens_per_sec2 > 100
else COLORS["RedLine"])
_mem_color = (COLORS["RedLine"] if _oom2
else COLORS["OrangeLine"] if _total_mem_gb > _device_ram * 0.8
else COLORS["GreenLine"])
_cards_html = f"""
<div style="display:flex; gap:12px; flex-wrap:wrap; margin:14px 0 8px 0;">
{_card_html("GEMM", _gemm_reg2, f"AI={_gemm_ai2:.0f} F/B", _gemm_rc2)}
{_card_html("LayerNorm", _ln_reg2, f"AI={_ln_ai_eff:.1f} F/B", _ln_rc2)}
{_card_html("Softmax", _sm_reg2, f"AI={_sm_ai_eff:.1f} F/B", _sm_rc2)}
{_card_html("Throughput", f"{_tokens_per_sec2:,.0f}" if not _oom2 else "OOM",
"tokens / sec", _tps_color)}
{_card_html("Memory", f"{_total_mem_gb:.1f}",
f"/ {_device_ram:.0f} GB used", _mem_color)}
</div>"""
mo.vstack([
mo.as_html(_fig2),
mo.Html(_cards_html),
_oom_banner,
mo.md(f"""
**Physics (visible):**
```
Layer Norm AI (unfused) = 5D / (3D x 2) = {_ln_ai_base:.2f} FLOP/byte
Softmax AI (unfused) = 5 / (3 x 2) = {_sm_ai_base:.2f} FLOP/byte
Fusion factor ({_fusion}): {_ff:.2f}x bytes eliminated per round-trip
Layer Norm AI (fused) = {_ln_ai_eff:.2f} FLOP/byte
Softmax AI (fused) = {_sm_ai_eff:.2f} FLOP/byte
KV-cache = 2 x B x S x heads x d_head x layers x 2 bytes
= 2 x {_B} x {_S} x 32 x 128 x 32 x 2
= {_kv_gb:.2f} GB
Weights = {LLAMA3_PARAMS/1e9:.0f}B params x 2 bytes = {_weights_gb:.1f} GB
Total = {_total_mem_gb:.1f} GB | Limit: {_device_ram:.0f} GB
```
"""),
])
return (
_oom2, _total_mem_gb, _kv_gb,
_gemm_ai2, _gemm_reg2,
_ln_ai_eff, _ln_reg2,
_sm_ai_eff, _sm_reg2,
_tokens_per_sec2, _ridge_pt2,
_fusion,
)
# ─── ACT II PREDICTION VS REALITY ────────────────────────────────────────────
@app.cell(hide_code=True)
def _(
mo, act2_pred,
_gemm_ai2, _gemm_reg2,
_ln_reg2, _sm_reg2,
_ridge_pt2,
):
_pred2 = act2_pred.value
_feedback2 = {
"A": (
"**Incorrect.** GEMM can be the bottleneck at large batch sizes, but "
"it is *not always* so. At small batch sizes GEMM arithmetic intensity "
"is low enough to be memory-bound too. The elementwise ops (softmax, LN) "
"are *always* memory-bound at any batch size — their AI near 0.8 F/B "
"never approaches any accelerator's ridge point."
),
"B": (
"**Correct.** Softmax and layer norm have arithmetic intensities near "
"0.8 FLOP/byte — permanently below every accelerator's ridge point. "
"They are always memory-bound. GEMM's AI grows with batch × sequence "
"length, and can eventually cross the ridge to become compute-bound. "
"The correct strategy is to fuse the elementwise ops and push GEMM to "
"larger tiles/batches."
),
"C": (
"**Incorrect.** The spread is enormous: GEMM at large batches can reach "
"hundreds of FLOP/byte, while softmax stays at ~0.8 FLOP/byte regardless "
"of batch size. The difference is how much arithmetic reuse each kernel "
"can perform on each loaded byte."
),
"D": (
"**Incorrect.** The complexity of the arithmetic operations is irrelevant "
"to the bound regime. Layer norm reads each element, applies a few FLOPs, "
"and writes the result. The ratio of FLOPs to bytes is tiny regardless "
"of whether one of those FLOPs is a sqrt. Expensive single operations "
"contribute negligible throughput compared to the memory traffic."
),
}
_correct2 = _pred2 == "B"
mo.vstack([
mo.md("### Prediction vs Reality"),
mo.callout(
mo.md(
f"{_feedback2.get(_pred2, '')}\n\n"
f"**Actual:** GEMM AI = {_gemm_ai2:.0f} FLOP/byte ({_gemm_reg2}), "
f"LayerNorm = {_ln_reg2}, Softmax = {_sm_reg2}. "
f"Ridge = {_ridge_pt2:.0f} FLOP/byte."
),
kind="success" if _correct2 else "warn",
),
])
return
# ─── ACT II REFLECTION ────────────────────────────────────────────────────────
@app.cell(hide_code=True)
def _(mo):
act2_reflect = mo.ui.radio(
options={
"A) It reduces the number of FLOPs needed": "A",
"B) It eliminates redundant memory reads/writes — fused ops load data only once": "B",
"C) It increases clock frequency by reducing instruction overhead": "C",
"D) It allows using higher precision without accuracy loss": "D",
},
label="Reflection: why does kernel fusion improve performance for memory-bound operations?",
)
mo.vstack([
mo.md("---"),
mo.md("### Act II Reflection"),
act2_reflect,
])
return (act2_reflect,)
@app.cell(hide_code=True)
def _(mo, act2_reflect):
mo.stop(
act2_reflect.value is None,
mo.callout(mo.md("Select your answer to see the explanation."), kind="warn")
)
_reflect2_feedback = {
"A": (
"**Incorrect.** Fusion does not change the number of FLOPs executed. "
"The same arithmetic still runs. What changes is the *memory traffic* "
"between operations."
),
"B": (
"**Correct.** Without fusion, each elementwise kernel writes its output "
"to HBM and the next kernel reads it back — pure bandwidth round-trips "
"with almost no arithmetic reuse. A fused kernel keeps the tensor in "
"on-chip registers or SRAM across all operations, loading each element "
"once and storing once. This is why FlashAttention achieves 24x speedup "
"over naive attention: it fuses softmax + matrix multiply and avoids "
"materializing the full O(S^2) attention matrix in HBM."
),
"C": (
"**Incorrect.** Kernel fusion has no effect on clock frequency. "
"The hardware runs at the same frequency. The improvement is architectural "
"— eliminating unnecessary memory round-trips."
),
"D": (
"**Incorrect.** Precision is orthogonal to fusion. You can fuse FP32 or "
"FP16 or INT8 ops equally. The benefit of fusion is reducing memory traffic, "
"not changing the numeric format."
),
}
_correct2r = act2_reflect.value == "B"
mo.callout(
mo.md(_reflect2_feedback[act2_reflect.value]),
kind="success" if _correct2r else "warn",
)
return
# ─── ACT II MATHPEEK ─────────────────────────────────────────────────────────
@app.cell(hide_code=True)
def _(mo):
mo.accordion({
"The governing equations": mo.md("""
**Elementwise operation AI (general):**
```
AI_elementwise = FLOPs_per_element / bytes_per_element
~ 1 FLOP / 2 bytes (FP16)
= 0.5 FLOP/byte
```
Elementwise ops are **always** memory-bound on any modern accelerator.
The ridge point is never below ~10 FLOP/byte; 0.5 < 10.
**Layer normalization AI:**
```
FLOPs per token = 5 x D (mean, variance, normalize, scale, shift)
Bytes per token = 3 x D x 2 (load input, intermediate, write output)
AI_layernorm = 5D / (3D x 2) ~ 0.83 FLOP/byte
```
**Softmax AI (over attention scores [B, H, S, S]):**
```
FLOPs per element ~ 5 (subtract max, exp, accumulate sum, divide)
Bytes per element = 3 x 2 (load, intermediate, store)
AI_softmax = 5 / (3 x 2) ~ 0.83 FLOP/byte
```
**Kernel fusion benefit:**
Unfused N-op pipeline on tensor T:
```
Memory traffic = N x (read(T) + write(T)) = 2N x |T| bytes
```
Fully fused single-pass kernel:
```
Memory traffic = read(T) + write(T) = 2 x |T| bytes
```
Theoretical speedup (memory-bound) = N, limited by on-chip SRAM capacity.
**KV-cache memory footprint:**
```
KV_cache = 2 x B x S x n_heads x d_head x n_layers x bytes_per_elem
```
For Llama-3-8B at FP16 (n_heads=32, d_head=128, layers=32):
```
KV_cache = B x S x 524288 bytes = B x S x 0.5 MB
```
At B=64, S=2048: KV-cache alone = 64 GB — exceeds Orin NX entirely.
Kernel fusion buys performance only if the design fits in RAM first.
"""),
})
return
# =============================================================================
# DESIGN LEDGER SAVE + HUD
# =============================================================================
@app.cell(hide_code=True)
def _(
mo, ledger, COLORS,
context_toggle,
act1_pred, act2_pred,
_ai_gemm, _is_mem_bound, _mfu_pct,
_oom2, _tokens_per_sec2,
_fusion,
):
_ctx = context_toggle.value
_p1 = act1_pred.value or "none"
_p2 = act2_pred.value or "none"
_fus = _fusion if _fusion else "none"
_act1_correct = _p1 == "C"
_act2_correct = _p2 == "B"
ledger.save(chapter=11, design={
"context": _ctx,
"operation": "gemm",
"arithmetic_intensity": float(_ai_gemm),
"bound_type": "memory" if _is_mem_bound else "compute",
"act1_prediction": _p1,
"act1_correct": _act1_correct,
"act2_result": float(_mfu_pct),
"act2_decision": _fus,
"constraint_hit": bool(_oom2),
"mfu_percent": float(_mfu_pct),
})
_track = ledger.get_track() or _ctx
_color_map = {
"cloud": COLORS["Cloud"],
"edge": COLORS["Edge"],
"mobile": COLORS["Mobile"],
"tiny": COLORS["Tiny"],
"NONE": "#475569",
}
_hud_color = _color_map.get(_track, "#475569")
_p1_icon = "CORRECT" if _act1_correct else "WRONG"
_p2_icon = "CORRECT" if _act2_correct else "WRONG"
_oom_icon = "TRIGGERED" if _oom2 else "CLEAR"
_tps_str = f"{_tokens_per_sec2:,.0f}" if not _oom2 else "OOM"
mo.vstack([
mo.md("---"),
mo.Html(f"""
<div style="display:flex; gap:22px; align-items:center; padding:12px 24px;
background:#0f172a; border-radius:10px; margin-top:16px; flex-wrap:wrap;
font-family:'SF Mono','Fira Code',monospace; font-size:0.79rem;
border:1px solid #1e293b;">
<div style="color:#475569; font-weight:600; letter-spacing:0.08em;">
DESIGN LEDGER
</div>
<div>
<span style="color:#475569;">Context: </span>
<span style="color:{_hud_color}; font-weight:700;">{_ctx.upper()}</span>
</div>
<div>
<span style="color:#475569;">Chapter: </span>
<span style="color:#e2e8f0;">11</span>
</div>
<div>
<span style="color:#475569;">AI (GEMM): </span>
<span style="color:#e2e8f0;">{_ai_gemm:.0f} FLOP/byte</span>
</div>
<div>
<span style="color:#475569;">MFU: </span>
<span style="color:{'#4ade80' if _mfu_pct > 50 else '#fbbf24'};">
{_mfu_pct:.1f}%
</span>
</div>
<div>
<span style="color:#475569;">Act I: </span>
<span style="color:{'#4ade80' if _act1_correct else '#f87171'};">
{_p1_icon} [{_p1}]
</span>
</div>
<div>
<span style="color:#475569;">Act II: </span>
<span style="color:{'#4ade80' if _act2_correct else '#f87171'};">
{_p2_icon} [{_p2}]
</span>
</div>
<div>
<span style="color:#475569;">OOM: </span>
<span style="color:{'#f87171' if _oom2 else '#4ade80'};">
{_oom_icon}
</span>
</div>
<div>
<span style="color:#475569;">Fusion: </span>
<span style="color:#e2e8f0;">{_fus}</span>
</div>
<div>
<span style="color:#475569;">Tokens/s: </span>
<span style="color:#e2e8f0;">{_tps_str}</span>
</div>
</div>
"""),
])
return
# ─── KEY TAKEAWAYS ────────────────────────────────────────────────────────────
@app.cell(hide_code=True)
def _(mo):
mo.vstack([
mo.md("---"),
mo.md("## Key Takeaways"),
mo.callout(
mo.md("""
**1. The roofline is a constraint, not a goal.**
Every kernel lives in one of two regimes. If arithmetic intensity is below
the ridge point, you are memory-bound and no amount of arithmetic optimization
will move throughput. Identify your ridge point first; then measure your AI.
The fix for a memory-bound kernel is always to increase data reuse — more
FLOPs per loaded byte — not to add more compute units.
"""),
kind="info",
),
mo.callout(
mo.md("""
**2. Elementwise ops are always memory-bound — fuse them.**
Layer norm, softmax, ReLU, and all single-pass elementwise operations
have arithmetic intensities below 1 FLOP/byte. They will never exceed the
ridge point on any silicon built in this decade. The only way to improve
their performance is kernel fusion: eliminate the redundant HBM round-trips
between consecutive ops. This is why FlashAttention, fused layer norm,
and operator fusion graphs exist. The specific accelerator will change
in five years. This principle will not.
"""),
kind="info",
),
])
return
if __name__ == "__main__":
app.run()