Files
cs249r_book/mlperf-edu/scripts/verify_training.py
Vijay Janapa Reddi efaa075ba8 mlperf-edu: sync iter-1 and iter-2 from standalone repo
Snapshots the autonomous-iteration work happening in the standalone
/Users/VJ/GitHub/mlperf-edu/ repo. Two iterations folded in:

  iter-1: code-defect cleanup (Patterson + Dean sign-off)
    - Remove dead simulated_loss + load_real_wikitext_data from
      nanogpt_train.py; align NanoGPTWhiteBox vocab to char-level
      (50,257 -> 128, dropping 19.3M unused embedding params).
    - Fix two broken examples.{edge,mobile} imports in inference paths.
    - Reconcile README benchmark table with workloads.yaml (was wrong
      on 7 of 16 workloads).

  iter-2: DLRM DRAM-resident variant (Emer sign-off)
    - New MicroDLRMDRAM with 2M-row hash-mapped virtual EmbeddingBag,
      sized so per-batch byte transfer (8 MB at B=8192, m_spa=256)
      exceeds PyTorch's ~50 us dispatch floor and exhibits the
      bandwidth-bound regime production DLRM lives in.
    - Smoke test asserts pure-lookup gap >= 3x; current host shows
      4.29x end-to-end and 3.49x lookup-only.

Branch is parked; not for merge to dev. Iteration log lives in the
standalone repo under .iteration_log/ (gitignored locally).
2026-04-16 14:59:42 -04:00

364 lines
12 KiB
Python

#!/usr/bin/env python3
"""
MLPerf EDU: Training Verification Script
Runs each classical workload for a controlled number of epochs on real data,
captures train/val loss curves, and checks for overfitting.
This produces the baseline results that go into the paper.
Usage:
python scripts/verify_training.py [--model MODEL_NAME] [--epochs N]
"""
import sys
import os
import json
import time
import argparse
# Ensure repo root is on the path
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, REPO_ROOT)
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import torch
import torch.nn.functional as F
import torch.optim as optim
from reference.dataset_factory import get_dataloaders
# ---------------------------------------------------------------------------
# Model configs: (module_path, class_name, kwargs, lr, batch_size, epochs)
# ---------------------------------------------------------------------------
CLASSICAL_MODELS = {
"nanogpt-12m": {
"module": "reference.cloud.nanogpt_train",
"class": "NanoGPTWhiteBox",
"kwargs": {},
"lr": 3e-4,
"batch_size": 16,
"epochs": 20,
"batches_per_epoch": 100,
},
"nano-moe-12m": {
"module": "reference.cloud.nano_moe",
"class": "NanoMoEWhiteBox",
"kwargs": {},
"lr": 3e-4,
"batch_size": 16,
"epochs": 20,
"batches_per_epoch": 100,
},
"resnet18": {
"module": "reference.edge.resnet_train",
"class": "ResNet18WhiteBox",
"kwargs": {"num_classes": 100},
"lr": 1e-3,
"batch_size": 64,
"epochs": 15,
"batches_per_epoch": 100,
},
"micro-dlrm-1m": {
"module": "reference.cloud.micro_dlrm",
"class": "MicroDLRMWhiteBox",
"kwargs": {},
"lr": 1e-3,
"batch_size": 256,
"epochs": 20,
"batches_per_epoch": 50,
},
"micro-dlrm-dram-1m": {
"module": "reference.cloud.micro_dlrm_dram",
"class": "MicroDLRMDRAM",
"kwargs": {},
"lr": 1e-3,
"batch_size": 1024,
"epochs": 20,
"batches_per_epoch": 50,
},
"micro-diffusion-32px": {
"module": "reference.cloud.micro_diffusion",
"class": "MicroDiffusionUNet",
"kwargs": {},
"lr": 1e-3,
"batch_size": 64,
"epochs": 15,
"batches_per_epoch": 100,
},
"dscnn-kws": {
"module": "reference.tiny.dscnn_kws",
"class": "DSCNN",
"kwargs": {"num_classes": 12},
"lr": 1e-3,
"batch_size": 64,
"epochs": 15,
"batches_per_epoch": 100,
},
"anomaly-ae": {
"module": "reference.tiny.anomaly_detection_ae",
"class": "AnomalyDetectionAE",
"kwargs": {"input_dim": 784, "bottleneck_dim": 8},
"lr": 1e-3,
"batch_size": 64,
"epochs": 20,
"batches_per_epoch": 50,
},
}
def load_model(config, device):
"""Dynamically import and instantiate a model."""
import importlib
mod = importlib.import_module(config["module"])
cls = getattr(mod, config["class"])
model = cls(**config["kwargs"]).to(device)
n_params = sum(p.numel() for p in model.parameters())
return model, n_params
def train_one_epoch(model, model_name, train_loader, optimizer, device, max_batches):
"""Train for one epoch, return average loss."""
model.train()
losses = []
for batch_idx, batch in enumerate(train_loader):
if batch_idx >= max_batches:
break
optimizer.zero_grad()
if "dlrm" in model_name:
dense, sparse_idx, sparse_off, labels = batch
dense = dense.to(device)
sparse_idx = [s.to(device) for s in sparse_idx]
sparse_off = [s.to(device) for s in sparse_off]
labels = labels.to(device)
outputs = model(dense, sparse_idx, sparse_off)
loss = F.binary_cross_entropy(outputs, labels)
elif "resnet" in model_name:
data_batch, target_batch = batch
outputs = model(data_batch.to(device))
loss = F.cross_entropy(outputs, target_batch.to(device))
elif "diffusion" in model_name:
data_batch, _ = batch
data_batch = data_batch.to(device)
outputs = model(data_batch)
loss = F.mse_loss(outputs, data_batch)
elif "dscnn" in model_name or "kws" in model_name:
data_batch, target_batch = batch
_, loss = model(data_batch.to(device), targets=target_batch.to(device))
elif "anomaly" in model_name:
data_batch, _ = batch # labels not used for AE training
_, loss = model(data_batch.to(device))
else:
# Language models
data_batch, target_batch = batch
_, loss = model(data_batch.to(device), targets=target_batch.to(device))
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
losses.append(loss.item())
return sum(losses) / len(losses) if losses else float("inf")
@torch.no_grad()
def validate(model, model_name, val_loader, device, max_batches=20):
"""Evaluate on validation set, return average loss and accuracy (if applicable)."""
model.eval()
losses = []
correct = 0
total = 0
for batch_idx, batch in enumerate(val_loader):
if batch_idx >= max_batches:
break
if "dlrm" in model_name:
dense, sparse_idx, sparse_off, labels = batch
dense = dense.to(device)
sparse_idx = [s.to(device) for s in sparse_idx]
sparse_off = [s.to(device) for s in sparse_off]
labels = labels.to(device)
outputs = model(dense, sparse_idx, sparse_off)
loss = F.binary_cross_entropy(outputs, labels)
preds = (outputs > 0.5).float()
correct += (preds == labels).sum().item()
total += labels.numel()
elif "resnet" in model_name:
data_batch, target_batch = batch
data_batch = data_batch.to(device)
target_batch = target_batch.to(device)
outputs = model(data_batch)
loss = F.cross_entropy(outputs, target_batch)
preds = outputs.argmax(dim=1)
correct += (preds == target_batch).sum().item()
total += target_batch.numel()
elif "diffusion" in model_name:
data_batch, _ = batch
data_batch = data_batch.to(device)
outputs = model(data_batch)
loss = F.mse_loss(outputs, data_batch)
elif "dscnn" in model_name or "kws" in model_name:
data_batch, target_batch = batch
data_batch = data_batch.to(device)
target_batch = target_batch.to(device)
logits, loss = model(data_batch, targets=target_batch)
preds = logits.argmax(dim=1)
correct += (preds == target_batch).sum().item()
total += target_batch.numel()
elif "anomaly" in model_name:
data_batch, _ = batch
_, loss = model(data_batch.to(device))
else:
data_batch, target_batch = batch
_, loss = model(data_batch.to(device), targets=target_batch.to(device))
losses.append(loss.item())
avg_loss = sum(losses) / len(losses) if losses else float("inf")
accuracy = correct / total if total > 0 else None
return avg_loss, accuracy
def train_model(model_name, config, device):
"""Full training loop for one model."""
print(f"\n{'='*60}")
print(f" Training: {model_name}")
print(f"{'='*60}")
# Load model
model, n_params = load_model(config, device)
print(f" Parameters: {n_params/1e6:.2f}M")
print(f" Device: {device}")
# Load data
train_loader, val_loader = get_dataloaders(model_name, batch_size=config["batch_size"])
print(f" Train samples: {len(train_loader.dataset)}")
print(f" Val samples: {len(val_loader.dataset)}")
# Optimizer with weight decay for regularization
optimizer = optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=1e-2)
# LR scheduler for smooth convergence
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=config["epochs"], eta_min=config["lr"] * 0.1
)
# Training loop
results = {
"model_name": model_name,
"n_params": n_params,
"device": str(device),
"train_losses": [],
"val_losses": [],
"val_accuracies": [],
"epoch_times": [],
}
start_time = time.time()
for epoch in range(config["epochs"]):
t0 = time.perf_counter()
train_loss = train_one_epoch(
model, model_name, train_loader, optimizer, device,
max_batches=config["batches_per_epoch"]
)
val_loss, val_acc = validate(model, model_name, val_loader, device)
scheduler.step()
epoch_time = time.perf_counter() - t0
results["train_losses"].append(train_loss)
results["val_losses"].append(val_loss)
results["val_accuracies"].append(val_acc)
results["epoch_times"].append(epoch_time)
# Print progress
acc_str = f" | Acc: {val_acc:.3f}" if val_acc is not None else ""
gap = val_loss - train_loss
overfit = " ⚠️ OVERFIT" if gap > 0.5 * train_loss else ""
print(f" Epoch {epoch:3d} | Train: {train_loss:.4f} | Val: {val_loss:.4f}{acc_str} | "
f"Gap: {gap:.4f}{overfit} | {epoch_time:.1f}s")
total_time = time.time() - start_time
results["total_time_s"] = total_time
results["final_train_loss"] = results["train_losses"][-1]
results["final_val_loss"] = results["val_losses"][-1]
# Overfitting check
gap = results["final_val_loss"] - results["final_train_loss"]
if gap > 0.5 * results["final_train_loss"]:
print(f" ⚠️ OVERFITTING: val-train gap = {gap:.4f}")
else:
print(f" ✅ Healthy: val-train gap = {gap:.4f}")
print(f" Total time: {total_time:.1f}s")
# Save checkpoint
ckpt_dir = os.path.join(REPO_ROOT, "checkpoints", model_name)
os.makedirs(ckpt_dir, exist_ok=True)
torch.save(model.state_dict(), os.path.join(ckpt_dir, "instructor_baseline.pt"))
# Save results
results_path = os.path.join(ckpt_dir, "results.json")
with open(results_path, "w") as f:
json.dump(results, f, indent=2)
return results
def main():
parser = argparse.ArgumentParser(description="MLPerf EDU Training Verification")
parser.add_argument("--model", type=str, default=None,
help="Specific model to train (or 'all')")
parser.add_argument("--epochs", type=int, default=None,
help="Override epoch count")
args = parser.parse_args()
device = ("cuda" if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available()
else "cpu")
print(f"🚀 MLPerf EDU Training Verification")
print(f" Device: {device}")
models_to_train = {}
if args.model and args.model != "all":
if args.model not in CLASSICAL_MODELS:
print(f"❌ Unknown model: {args.model}")
print(f" Available: {list(CLASSICAL_MODELS.keys())}")
sys.exit(1)
models_to_train[args.model] = CLASSICAL_MODELS[args.model]
else:
models_to_train = CLASSICAL_MODELS
if args.epochs:
for cfg in models_to_train.values():
cfg["epochs"] = args.epochs
all_results = {}
for model_name, config in models_to_train.items():
results = train_model(model_name, config, device)
all_results[model_name] = results
# Summary table
print(f"\n{'='*70}")
print(f" SUMMARY")
print(f"{'='*70}")
print(f" {'Model':<25s} {'Params':>8s} {'Train':>8s} {'Val':>8s} {'Gap':>8s} {'Time':>8s}")
print(f" {'-'*25} {'-'*8} {'-'*8} {'-'*8} {'-'*8} {'-'*8}")
for name, r in all_results.items():
gap = r["final_val_loss"] - r["final_train_loss"]
print(f" {name:<25s} {r['n_params']/1e6:>7.1f}M "
f"{r['final_train_loss']:>8.4f} {r['final_val_loss']:>8.4f} "
f"{gap:>8.4f} {r['total_time_s']:>7.1f}s")
if __name__ == "__main__":
main()