mirror of
https://github.com/harvard-edge/cs249r_book.git
synced 2026-05-22 22:33:28 -05:00
80 lines
3.4 KiB
Python
80 lines
3.4 KiB
Python
"""
|
|
NanoGPT model definition for MLPerf EDU.
|
|
|
|
The canonical training entry point is `scripts/verify_training.py`, which
|
|
loads `NanoGPTWhiteBox` and trains it with the dataset_factory's
|
|
character-level TinyShakespeare loader. This file exports only the model
|
|
class and its tokenizer contract; do not add a `run_benchmark` here.
|
|
|
|
Vocab note: dataset_factory uses an ASCII char tokenizer over
|
|
TinyShakespeare with 65 unique chars. We use vocab_size=128 to give a
|
|
small safety margin while keeping the embedding matrix small (the GPT-2
|
|
vocab of 50,257 is BPE and does not apply to char-level data).
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import functional as F
|
|
|
|
from .gpt2_infer import GPTBlock
|
|
|
|
CHAR_VOCAB_SIZE = 128 # ASCII range; TinyShakespeare uses 65 of these.
|
|
|
|
# iter-6 (Han): GPT-2-Small canonical geometry as the "real LLM" stand-in.
|
|
# 124M params at d_model=768, n_head=12, n_layer=12. Trains in ~30 min on
|
|
# M5 Max; weights ~248 MB at fp16. This is the configuration that fills
|
|
# the (dram_bound, bandwidth_bound, device_saturated) cell when run with
|
|
# batch_size=32 ctx=2048 per Han's iter-6 sizing math.
|
|
GPT2_SMALL_CONFIG = dict(n_embd=768, n_head=12, n_layer=12, max_seq_len=4096)
|
|
|
|
|
|
class NanoGPTWhiteBox(nn.Module):
|
|
"""Char-level decoder-only transformer used for the NanoGPT workload.
|
|
|
|
Default config (~11M params) is the iter-1/iter-3 small variant.
|
|
Pass NanoGPTWhiteBox(**GPT2_SMALL_CONFIG) for the iter-6 124M
|
|
variant used by the LLM serving workloads.
|
|
"""
|
|
|
|
def __init__(self, vocab_size=CHAR_VOCAB_SIZE, n_embd=384, n_head=6, n_layer=6, max_seq_len=2048):
|
|
super().__init__()
|
|
self.config = {
|
|
"vocab_size": vocab_size, "n_embd": n_embd,
|
|
"n_head": n_head, "n_layer": n_layer, "max_seq_len": max_seq_len,
|
|
}
|
|
self.wte = nn.Embedding(vocab_size, n_embd)
|
|
self.wpe = nn.Embedding(max_seq_len, n_embd)
|
|
self.blocks = nn.ModuleList([
|
|
GPTBlock(n_embd, n_head, max_seq_len=max_seq_len) for _ in range(n_layer)
|
|
])
|
|
self.ln_f = nn.LayerNorm(n_embd)
|
|
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
|
|
self.wte.weight = self.lm_head.weight
|
|
|
|
def forward(self, idx, targets=None, use_kv_cache=False, past_key_values=None):
|
|
"""Unified forward path for training and inference (incl. KV-cache decode).
|
|
|
|
Returns:
|
|
(logits, loss) if targets is given (training).
|
|
(logits, present_kvs) if use_kv_cache=True (decode).
|
|
(logits, None) otherwise (single-shot inference / prefill).
|
|
"""
|
|
b, t = idx.size()
|
|
past_length = past_key_values[0][0].size(-2) if past_key_values is not None else 0
|
|
pos = torch.arange(past_length, past_length + t, dtype=torch.long, device=idx.device)
|
|
x = self.wte(idx) + self.wpe(pos)
|
|
|
|
present_kvs = [] if use_kv_cache else None
|
|
for i, block in enumerate(self.blocks):
|
|
past_kv = past_key_values[i] if past_key_values is not None else None
|
|
x, present_kv = block(x, use_kv_cache=use_kv_cache, past_key_value=past_kv)
|
|
if use_kv_cache:
|
|
present_kvs.append(present_kv)
|
|
x = self.ln_f(x)
|
|
logits = self.lm_head(x)
|
|
|
|
if targets is not None:
|
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
|
return logits, loss
|
|
return logits, present_kvs
|