Files
cs249r_book/mlperf-edu/examples/open_resnet_quantized.py
Vijay Janapa Reddi a9878ad6bd feat: import mlperf-edu pedagogical benchmark suite
Snapshot of the standalone /Users/VJ/GitHub/mlperf-edu/ repo as of
2026-04-16, brought into MLSysBook as a parked feature branch for
backup and iteration. Not for merge to dev.

Contents (88 files, ~2.3 MB):
- 16 reference workloads (cloud / edge / tiny / agent divisions)
- LoadGen proxy harness + SUT plugin protocol
- Compliance checker, autograder, hardware fingerprint
- Paper draft (paper.tex) with TikZ/SVG figure sources
- Three lab examples + practitioner workflow configs
- Workload + dataset YAML registries (single source of truth)

Excluded (per mlperf-edu/.gitignore + size constraints):
- Datasets (6.6 GB), checkpoints (260 MB), gpt2 weights (523 MB)
- Generated PDFs, .venv, build artifacts
2026-04-16 14:15:05 -04:00

45 lines
1.8 KiB
Python

import torch
import torch.nn as nn
from mlperf.sut import SUT_Interface
from reference.edge.resnet_train import ResNet18WhiteBox
class OpenResNetQuantized(SUT_Interface):
"""
OPEN DIVISION: Extreme Architectural Transformations allowed!
Student Optimization Notes:
In the Open division, you do not have to hit 99% Accuracy! You can aggressively
quantize (INT8), prune entire layers, or inject sparse convolutions to dramatically
increase your Speed (QPS) and minimize Power drainage!
"""
def __init__(self, config: dict):
super().__init__(config)
# Note: PyTorch native dynamic quantization runs predominantly on CPU currently
self.device = torch.device('cpu')
print("[Submitter:Open] 🧠 Loading Native ResNet18WhiteBox Parameters...")
self.model = ResNet18WhiteBox(num_classes=100)
self.model.eval()
# 1. STUDENT OPTIMIZATION: Extreme INT8 Quantization!
print("[Submitter:Open] ⚡ Smashing Math precision dynamically to INT8!")
# We target specific linear and convolutional bottlenecks mathematically.
self.optimized_model = torch.quantization.quantize_dynamic(
self.model, {nn.Linear, nn.Conv2d}, dtype=torch.qint8
)
async def process_queries(self, samples: list):
batch_size = len(samples)
input_tensor = torch.randn((batch_size, 3, 32, 32))
# 2. Execute Quantized execution structurally
with torch.inference_mode():
logits = self.optimized_model(input_tensor)
# Due to 8-bit dynamic rounding, precision will crash heavily missing
# the 99% threshold. But throughput will spike immensely!
return {
"predictions": logits,
"targets": torch.randint(0, 100, (batch_size,))
}