mirror of
https://github.com/harvard-edge/cs249r_book.git
synced 2026-05-06 01:28:35 -05:00
Snapshots the autonomous-iteration work happening in the standalone
/Users/VJ/GitHub/mlperf-edu/ repo. Two iterations folded in:
iter-1: code-defect cleanup (Patterson + Dean sign-off)
- Remove dead simulated_loss + load_real_wikitext_data from
nanogpt_train.py; align NanoGPTWhiteBox vocab to char-level
(50,257 -> 128, dropping 19.3M unused embedding params).
- Fix two broken examples.{edge,mobile} imports in inference paths.
- Reconcile README benchmark table with workloads.yaml (was wrong
on 7 of 16 workloads).
iter-2: DLRM DRAM-resident variant (Emer sign-off)
- New MicroDLRMDRAM with 2M-row hash-mapped virtual EmbeddingBag,
sized so per-batch byte transfer (8 MB at B=8192, m_spa=256)
exceeds PyTorch's ~50 us dispatch floor and exhibits the
bandwidth-bound regime production DLRM lives in.
- Smoke test asserts pure-lookup gap >= 3x; current host shows
4.29x end-to-end and 3.49x lookup-only.
Branch is parked; not for merge to dev. Iteration log lives in the
standalone repo under .iteration_log/ (gitignored locally).
364 lines
12 KiB
Python
364 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
MLPerf EDU: Training Verification Script
|
|
|
|
Runs each classical workload for a controlled number of epochs on real data,
|
|
captures train/val loss curves, and checks for overfitting.
|
|
|
|
This produces the baseline results that go into the paper.
|
|
|
|
Usage:
|
|
python scripts/verify_training.py [--model MODEL_NAME] [--epochs N]
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
import json
|
|
import time
|
|
import argparse
|
|
|
|
# Ensure repo root is on the path
|
|
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|
sys.path.insert(0, REPO_ROOT)
|
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torch.optim as optim
|
|
|
|
from reference.dataset_factory import get_dataloaders
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Model configs: (module_path, class_name, kwargs, lr, batch_size, epochs)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
CLASSICAL_MODELS = {
|
|
"nanogpt-12m": {
|
|
"module": "reference.cloud.nanogpt_train",
|
|
"class": "NanoGPTWhiteBox",
|
|
"kwargs": {},
|
|
"lr": 3e-4,
|
|
"batch_size": 16,
|
|
"epochs": 20,
|
|
"batches_per_epoch": 100,
|
|
},
|
|
"nano-moe-12m": {
|
|
"module": "reference.cloud.nano_moe",
|
|
"class": "NanoMoEWhiteBox",
|
|
"kwargs": {},
|
|
"lr": 3e-4,
|
|
"batch_size": 16,
|
|
"epochs": 20,
|
|
"batches_per_epoch": 100,
|
|
},
|
|
"resnet18": {
|
|
"module": "reference.edge.resnet_train",
|
|
"class": "ResNet18WhiteBox",
|
|
"kwargs": {"num_classes": 100},
|
|
"lr": 1e-3,
|
|
"batch_size": 64,
|
|
"epochs": 15,
|
|
"batches_per_epoch": 100,
|
|
},
|
|
"micro-dlrm-1m": {
|
|
"module": "reference.cloud.micro_dlrm",
|
|
"class": "MicroDLRMWhiteBox",
|
|
"kwargs": {},
|
|
"lr": 1e-3,
|
|
"batch_size": 256,
|
|
"epochs": 20,
|
|
"batches_per_epoch": 50,
|
|
},
|
|
"micro-dlrm-dram-1m": {
|
|
"module": "reference.cloud.micro_dlrm_dram",
|
|
"class": "MicroDLRMDRAM",
|
|
"kwargs": {},
|
|
"lr": 1e-3,
|
|
"batch_size": 1024,
|
|
"epochs": 20,
|
|
"batches_per_epoch": 50,
|
|
},
|
|
"micro-diffusion-32px": {
|
|
"module": "reference.cloud.micro_diffusion",
|
|
"class": "MicroDiffusionUNet",
|
|
"kwargs": {},
|
|
"lr": 1e-3,
|
|
"batch_size": 64,
|
|
"epochs": 15,
|
|
"batches_per_epoch": 100,
|
|
},
|
|
"dscnn-kws": {
|
|
"module": "reference.tiny.dscnn_kws",
|
|
"class": "DSCNN",
|
|
"kwargs": {"num_classes": 12},
|
|
"lr": 1e-3,
|
|
"batch_size": 64,
|
|
"epochs": 15,
|
|
"batches_per_epoch": 100,
|
|
},
|
|
"anomaly-ae": {
|
|
"module": "reference.tiny.anomaly_detection_ae",
|
|
"class": "AnomalyDetectionAE",
|
|
"kwargs": {"input_dim": 784, "bottleneck_dim": 8},
|
|
"lr": 1e-3,
|
|
"batch_size": 64,
|
|
"epochs": 20,
|
|
"batches_per_epoch": 50,
|
|
},
|
|
}
|
|
|
|
|
|
def load_model(config, device):
|
|
"""Dynamically import and instantiate a model."""
|
|
import importlib
|
|
mod = importlib.import_module(config["module"])
|
|
cls = getattr(mod, config["class"])
|
|
model = cls(**config["kwargs"]).to(device)
|
|
n_params = sum(p.numel() for p in model.parameters())
|
|
return model, n_params
|
|
|
|
|
|
def train_one_epoch(model, model_name, train_loader, optimizer, device, max_batches):
|
|
"""Train for one epoch, return average loss."""
|
|
model.train()
|
|
losses = []
|
|
|
|
for batch_idx, batch in enumerate(train_loader):
|
|
if batch_idx >= max_batches:
|
|
break
|
|
|
|
optimizer.zero_grad()
|
|
|
|
if "dlrm" in model_name:
|
|
dense, sparse_idx, sparse_off, labels = batch
|
|
dense = dense.to(device)
|
|
sparse_idx = [s.to(device) for s in sparse_idx]
|
|
sparse_off = [s.to(device) for s in sparse_off]
|
|
labels = labels.to(device)
|
|
outputs = model(dense, sparse_idx, sparse_off)
|
|
loss = F.binary_cross_entropy(outputs, labels)
|
|
elif "resnet" in model_name:
|
|
data_batch, target_batch = batch
|
|
outputs = model(data_batch.to(device))
|
|
loss = F.cross_entropy(outputs, target_batch.to(device))
|
|
elif "diffusion" in model_name:
|
|
data_batch, _ = batch
|
|
data_batch = data_batch.to(device)
|
|
outputs = model(data_batch)
|
|
loss = F.mse_loss(outputs, data_batch)
|
|
elif "dscnn" in model_name or "kws" in model_name:
|
|
data_batch, target_batch = batch
|
|
_, loss = model(data_batch.to(device), targets=target_batch.to(device))
|
|
elif "anomaly" in model_name:
|
|
data_batch, _ = batch # labels not used for AE training
|
|
_, loss = model(data_batch.to(device))
|
|
else:
|
|
# Language models
|
|
data_batch, target_batch = batch
|
|
_, loss = model(data_batch.to(device), targets=target_batch.to(device))
|
|
|
|
loss.backward()
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
|
optimizer.step()
|
|
losses.append(loss.item())
|
|
|
|
return sum(losses) / len(losses) if losses else float("inf")
|
|
|
|
|
|
@torch.no_grad()
|
|
def validate(model, model_name, val_loader, device, max_batches=20):
|
|
"""Evaluate on validation set, return average loss and accuracy (if applicable)."""
|
|
model.eval()
|
|
losses = []
|
|
correct = 0
|
|
total = 0
|
|
|
|
for batch_idx, batch in enumerate(val_loader):
|
|
if batch_idx >= max_batches:
|
|
break
|
|
|
|
if "dlrm" in model_name:
|
|
dense, sparse_idx, sparse_off, labels = batch
|
|
dense = dense.to(device)
|
|
sparse_idx = [s.to(device) for s in sparse_idx]
|
|
sparse_off = [s.to(device) for s in sparse_off]
|
|
labels = labels.to(device)
|
|
outputs = model(dense, sparse_idx, sparse_off)
|
|
loss = F.binary_cross_entropy(outputs, labels)
|
|
preds = (outputs > 0.5).float()
|
|
correct += (preds == labels).sum().item()
|
|
total += labels.numel()
|
|
elif "resnet" in model_name:
|
|
data_batch, target_batch = batch
|
|
data_batch = data_batch.to(device)
|
|
target_batch = target_batch.to(device)
|
|
outputs = model(data_batch)
|
|
loss = F.cross_entropy(outputs, target_batch)
|
|
preds = outputs.argmax(dim=1)
|
|
correct += (preds == target_batch).sum().item()
|
|
total += target_batch.numel()
|
|
elif "diffusion" in model_name:
|
|
data_batch, _ = batch
|
|
data_batch = data_batch.to(device)
|
|
outputs = model(data_batch)
|
|
loss = F.mse_loss(outputs, data_batch)
|
|
elif "dscnn" in model_name or "kws" in model_name:
|
|
data_batch, target_batch = batch
|
|
data_batch = data_batch.to(device)
|
|
target_batch = target_batch.to(device)
|
|
logits, loss = model(data_batch, targets=target_batch)
|
|
preds = logits.argmax(dim=1)
|
|
correct += (preds == target_batch).sum().item()
|
|
total += target_batch.numel()
|
|
elif "anomaly" in model_name:
|
|
data_batch, _ = batch
|
|
_, loss = model(data_batch.to(device))
|
|
else:
|
|
data_batch, target_batch = batch
|
|
_, loss = model(data_batch.to(device), targets=target_batch.to(device))
|
|
|
|
losses.append(loss.item())
|
|
|
|
avg_loss = sum(losses) / len(losses) if losses else float("inf")
|
|
accuracy = correct / total if total > 0 else None
|
|
return avg_loss, accuracy
|
|
|
|
|
|
def train_model(model_name, config, device):
|
|
"""Full training loop for one model."""
|
|
print(f"\n{'='*60}")
|
|
print(f" Training: {model_name}")
|
|
print(f"{'='*60}")
|
|
|
|
# Load model
|
|
model, n_params = load_model(config, device)
|
|
print(f" Parameters: {n_params/1e6:.2f}M")
|
|
print(f" Device: {device}")
|
|
|
|
# Load data
|
|
train_loader, val_loader = get_dataloaders(model_name, batch_size=config["batch_size"])
|
|
print(f" Train samples: {len(train_loader.dataset)}")
|
|
print(f" Val samples: {len(val_loader.dataset)}")
|
|
|
|
# Optimizer with weight decay for regularization
|
|
optimizer = optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=1e-2)
|
|
|
|
# LR scheduler for smooth convergence
|
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
|
optimizer, T_max=config["epochs"], eta_min=config["lr"] * 0.1
|
|
)
|
|
|
|
# Training loop
|
|
results = {
|
|
"model_name": model_name,
|
|
"n_params": n_params,
|
|
"device": str(device),
|
|
"train_losses": [],
|
|
"val_losses": [],
|
|
"val_accuracies": [],
|
|
"epoch_times": [],
|
|
}
|
|
|
|
start_time = time.time()
|
|
|
|
for epoch in range(config["epochs"]):
|
|
t0 = time.perf_counter()
|
|
|
|
train_loss = train_one_epoch(
|
|
model, model_name, train_loader, optimizer, device,
|
|
max_batches=config["batches_per_epoch"]
|
|
)
|
|
val_loss, val_acc = validate(model, model_name, val_loader, device)
|
|
scheduler.step()
|
|
|
|
epoch_time = time.perf_counter() - t0
|
|
|
|
results["train_losses"].append(train_loss)
|
|
results["val_losses"].append(val_loss)
|
|
results["val_accuracies"].append(val_acc)
|
|
results["epoch_times"].append(epoch_time)
|
|
|
|
# Print progress
|
|
acc_str = f" | Acc: {val_acc:.3f}" if val_acc is not None else ""
|
|
gap = val_loss - train_loss
|
|
overfit = " ⚠️ OVERFIT" if gap > 0.5 * train_loss else ""
|
|
print(f" Epoch {epoch:3d} | Train: {train_loss:.4f} | Val: {val_loss:.4f}{acc_str} | "
|
|
f"Gap: {gap:.4f}{overfit} | {epoch_time:.1f}s")
|
|
|
|
total_time = time.time() - start_time
|
|
results["total_time_s"] = total_time
|
|
results["final_train_loss"] = results["train_losses"][-1]
|
|
results["final_val_loss"] = results["val_losses"][-1]
|
|
|
|
# Overfitting check
|
|
gap = results["final_val_loss"] - results["final_train_loss"]
|
|
if gap > 0.5 * results["final_train_loss"]:
|
|
print(f" ⚠️ OVERFITTING: val-train gap = {gap:.4f}")
|
|
else:
|
|
print(f" ✅ Healthy: val-train gap = {gap:.4f}")
|
|
|
|
print(f" Total time: {total_time:.1f}s")
|
|
|
|
# Save checkpoint
|
|
ckpt_dir = os.path.join(REPO_ROOT, "checkpoints", model_name)
|
|
os.makedirs(ckpt_dir, exist_ok=True)
|
|
torch.save(model.state_dict(), os.path.join(ckpt_dir, "instructor_baseline.pt"))
|
|
|
|
# Save results
|
|
results_path = os.path.join(ckpt_dir, "results.json")
|
|
with open(results_path, "w") as f:
|
|
json.dump(results, f, indent=2)
|
|
|
|
return results
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="MLPerf EDU Training Verification")
|
|
parser.add_argument("--model", type=str, default=None,
|
|
help="Specific model to train (or 'all')")
|
|
parser.add_argument("--epochs", type=int, default=None,
|
|
help="Override epoch count")
|
|
args = parser.parse_args()
|
|
|
|
device = ("cuda" if torch.cuda.is_available()
|
|
else "mps" if torch.backends.mps.is_available()
|
|
else "cpu")
|
|
print(f"🚀 MLPerf EDU Training Verification")
|
|
print(f" Device: {device}")
|
|
|
|
models_to_train = {}
|
|
if args.model and args.model != "all":
|
|
if args.model not in CLASSICAL_MODELS:
|
|
print(f"❌ Unknown model: {args.model}")
|
|
print(f" Available: {list(CLASSICAL_MODELS.keys())}")
|
|
sys.exit(1)
|
|
models_to_train[args.model] = CLASSICAL_MODELS[args.model]
|
|
else:
|
|
models_to_train = CLASSICAL_MODELS
|
|
|
|
if args.epochs:
|
|
for cfg in models_to_train.values():
|
|
cfg["epochs"] = args.epochs
|
|
|
|
all_results = {}
|
|
for model_name, config in models_to_train.items():
|
|
results = train_model(model_name, config, device)
|
|
all_results[model_name] = results
|
|
|
|
# Summary table
|
|
print(f"\n{'='*70}")
|
|
print(f" SUMMARY")
|
|
print(f"{'='*70}")
|
|
print(f" {'Model':<25s} {'Params':>8s} {'Train':>8s} {'Val':>8s} {'Gap':>8s} {'Time':>8s}")
|
|
print(f" {'-'*25} {'-'*8} {'-'*8} {'-'*8} {'-'*8} {'-'*8}")
|
|
for name, r in all_results.items():
|
|
gap = r["final_val_loss"] - r["final_train_loss"]
|
|
print(f" {name:<25s} {r['n_params']/1e6:>7.1f}M "
|
|
f"{r['final_train_loss']:>8.4f} {r['final_val_loss']:>8.4f} "
|
|
f"{gap:>8.4f} {r['total_time_s']:>7.1f}s")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|