mirror of
https://github.com/harvard-edge/cs249r_book.git
synced 2026-05-22 05:53:13 -05:00
20 of 20 workloads now schema-valid; 9 of 11 measurable workloads have evidence-bound regime values backed by sidecars in roofline/. The linter passes --verify-against-sidecars across the suite. 13 prior guess-classifications were corrected by measurement; the surprises (DLRM compute-bound, ResNet bandwidth-bound, Diffusion bandwidth-bound) will inform paper prose. Branch parked.
150 lines
6.2 KiB
Python
150 lines
6.2 KiB
Python
"""
|
|
MLPerf EDU: NanoGPT-Decode workload (Cloud Division)
|
|
|
|
Autoregressive decode with a real KV cache. Each step appends one
|
|
token's K and V, and attention re-reads the entire cached K, V from
|
|
DRAM -- the canonical bandwidth-bound regime that dominates LLM
|
|
serving cost in production.
|
|
|
|
Pair with nanogpt-prefill (same checkpoint) to observe the
|
|
prefill-vs-decode bottleneck split.
|
|
"""
|
|
import statistics
|
|
import time
|
|
import torch
|
|
|
|
from .nanogpt_train import NanoGPTWhiteBox
|
|
|
|
|
|
def _sync():
|
|
if torch.backends.mps.is_available():
|
|
torch.mps.synchronize()
|
|
elif torch.cuda.is_available():
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
def kv_cache_bytes(past_key_values, dtype_bytes: int = 4) -> int:
|
|
"""Total bytes held in the KV cache across all layers."""
|
|
total = 0
|
|
for k, v in past_key_values:
|
|
total += k.numel() * dtype_bytes + v.numel() * dtype_bytes
|
|
return total
|
|
|
|
|
|
class NanoGPTDecode:
|
|
"""Warms the KV cache to `prefill_ctx`, then times `decode_steps` single-token steps.
|
|
|
|
Reports time-to-first-token (TTFT), median + p99 inter-token latency
|
|
(ITL), final KV-cache bytes, and an achieved-bandwidth estimate
|
|
derived from streaming the cached K,V each step.
|
|
"""
|
|
|
|
def __init__(self, model: NanoGPTWhiteBox,
|
|
prefill_ctx: int = 1792, decode_steps: int = 64, batch_size: int = 1):
|
|
max_ctx = prefill_ctx + decode_steps
|
|
if max_ctx > model.config["max_seq_len"]:
|
|
raise ValueError(
|
|
f"prefill_ctx + decode_steps = {max_ctx} exceeds model "
|
|
f"max_seq_len={model.config['max_seq_len']}; bump it."
|
|
)
|
|
self.model = model.eval()
|
|
self.prefill_ctx = prefill_ctx
|
|
self.decode_steps = decode_steps
|
|
self.batch = batch_size
|
|
self.vocab = model.config["vocab_size"]
|
|
|
|
def _sample(self, logits):
|
|
# Argmax keeps the test deterministic; replace with multinomial
|
|
# if students need temperature/top-p exploration.
|
|
return logits.argmax(dim=-1, keepdim=True)
|
|
|
|
def run(self, emit_sidecar: bool = True) -> dict:
|
|
device = next(self.model.parameters()).device
|
|
prompt = torch.randint(0, self.vocab, (self.batch, self.prefill_ctx), device=device)
|
|
n_params = sum(p.numel() for p in self.model.parameters())
|
|
cfg = self.model.config
|
|
head_dim = cfg["n_embd"] // cfg["n_head"]
|
|
# Per-step bytes during decode: full weight reread + full KV stream.
|
|
kv_bytes_per_step = 2 * cfg["n_layer"] * cfg["n_head"] * head_dim * self.prefill_ctx * 4
|
|
bytes_per_step = n_params * 4 + kv_bytes_per_step
|
|
# Per-step FLOPs: one new token through all weights + attention over ctx.
|
|
flops_per_step = 2 * n_params + 4 * cfg["n_layer"] * cfg["n_head"] * head_dim * self.prefill_ctx
|
|
|
|
with torch.no_grad():
|
|
# Warm the cache and get the last-step logits.
|
|
_sync()
|
|
t_prefill_start = time.perf_counter()
|
|
logits, kv = self.model(prompt, use_kv_cache=True)
|
|
_sync()
|
|
prefill_time = time.perf_counter() - t_prefill_start
|
|
|
|
# First decode step (TTFT measured here; the prefill is the
|
|
# "prompt processing" phase, not part of TTFT in serving SLOs).
|
|
_sync()
|
|
t0 = time.perf_counter()
|
|
next_tok = self._sample(logits[:, -1, :])
|
|
logits, kv = self.model(next_tok, use_kv_cache=True, past_key_values=kv)
|
|
_sync()
|
|
ttft = time.perf_counter() - t0
|
|
|
|
per_step = []
|
|
n_loop = self.decode_steps - 1
|
|
if emit_sidecar and n_loop > 0:
|
|
from mlperf.roofline import measure_roofline
|
|
with measure_roofline(
|
|
"nanogpt-decode",
|
|
analytic_flops=lambda: flops_per_step * n_loop,
|
|
analytic_bytes=lambda: bytes_per_step * n_loop,
|
|
n_iter=n_loop,
|
|
):
|
|
for _ in range(n_loop):
|
|
next_tok = self._sample(logits[:, -1, :])
|
|
_sync()
|
|
t = time.perf_counter()
|
|
logits, kv = self.model(next_tok, use_kv_cache=True, past_key_values=kv)
|
|
_sync()
|
|
per_step.append(time.perf_counter() - t)
|
|
else:
|
|
for _ in range(n_loop):
|
|
next_tok = self._sample(logits[:, -1, :])
|
|
_sync()
|
|
t = time.perf_counter()
|
|
logits, kv = self.model(next_tok, use_kv_cache=True, past_key_values=kv)
|
|
_sync()
|
|
per_step.append(time.perf_counter() - t)
|
|
|
|
kv_bytes = kv_cache_bytes(kv)
|
|
median_itl = statistics.median(per_step) if per_step else float("nan")
|
|
p99_itl = sorted(per_step)[int(len(per_step) * 0.99) - 1] if per_step else float("nan")
|
|
# Achieved bandwidth: each decode step re-reads the full KV cache
|
|
# (the model also re-reads weights, but those usually live in LLC
|
|
# after warmup). KV stream is the *additive* per-step cost.
|
|
achieved_bw_gbps = kv_bytes / median_itl / 1e9 if per_step else 0.0
|
|
|
|
return {
|
|
"phase": "decode",
|
|
"prefill_ctx": self.prefill_ctx,
|
|
"decode_steps": self.decode_steps,
|
|
"batch_size": self.batch,
|
|
"prefill_warm_s": prefill_time,
|
|
"ttft_s": ttft,
|
|
"itl_median_s": median_itl,
|
|
"itl_p99_s": p99_itl,
|
|
"kv_cache_bytes": kv_bytes,
|
|
"achieved_bw_gbps": achieved_bw_gbps,
|
|
"output_tokens_per_sec": 1.0 / median_itl if per_step else 0.0,
|
|
}
|
|
|
|
|
|
def run_benchmark(checkpoint_path: str = None, scenario: str = "Offline",
|
|
prefill_ctx: int = 1792, decode_steps: int = 64,
|
|
batch_size: int = 1) -> dict:
|
|
device = ("cuda" if torch.cuda.is_available()
|
|
else "mps" if torch.backends.mps.is_available() else "cpu")
|
|
model = NanoGPTWhiteBox().to(device)
|
|
if checkpoint_path:
|
|
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
|
|
return NanoGPTDecode(
|
|
model, prefill_ctx=prefill_ctx, decode_steps=decode_steps, batch_size=batch_size
|
|
).run()
|