mirror of
https://github.com/harvard-edge/cs249r_book.git
synced 2026-05-07 02:03:55 -05:00
133 lines
4.8 KiB
Python
133 lines
4.8 KiB
Python
"""
|
|
Iter-10 (Dean): two-process DDP via torch.multiprocessing + Gloo on localhost.
|
|
|
|
Picks micro-DLRM as the workload:
|
|
- 1M params at fp32 = 4 MB of gradients per AllReduce.
|
|
- Loopback Gloo handles this in ~0.5-1 ms.
|
|
- Iter-5.6 found micro-DLRM at small batch is compute-bound on the MLP,
|
|
so DDP overhead becomes the natural rate-limiter.
|
|
|
|
Smoke gate (Q4 in Dean's iter-10 spec):
|
|
| loss_ddp(step=50) - loss_gradacc(step=50) | / loss_gradacc(step=50) < 0.02
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import time
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.multiprocessing as mp
|
|
import torch.nn as nn
|
|
|
|
from reference.cloud.micro_dlrm import MicroDLRMWhiteBox
|
|
|
|
|
|
def _build_inputs(batch: int, seed: int = 42) -> tuple:
|
|
g = torch.Generator().manual_seed(seed)
|
|
dense = torch.randn(batch, 16, generator=g)
|
|
sparse_indices = [
|
|
torch.randint(0, 943, (batch,), generator=g),
|
|
torch.randint(0, 1682, (batch,), generator=g),
|
|
torch.randint(0, 21, (batch,), generator=g),
|
|
]
|
|
sparse_offsets = [torch.arange(batch) for _ in range(3)]
|
|
targets = torch.randint(0, 2, (batch,), generator=g).float().unsqueeze(1)
|
|
return dense, sparse_indices, sparse_offsets, targets
|
|
|
|
|
|
def _ddp_worker(rank: int, world_size: int,
|
|
n_steps: int, micro_batch: int,
|
|
result_queue: mp.Queue,
|
|
init_method: str = "tcp://127.0.0.1:29500"):
|
|
"""One DDP rank: init Gloo, build model, run n_steps, push final loss."""
|
|
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
|
os.environ["MASTER_PORT"] = "29500"
|
|
|
|
dist.init_process_group(backend="gloo", init_method=init_method,
|
|
rank=rank, world_size=world_size)
|
|
torch.manual_seed(42)
|
|
|
|
model = MicroDLRMWhiteBox()
|
|
ddp_model = nn.parallel.DistributedDataParallel(model)
|
|
optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.01)
|
|
loss_fn = nn.BCEWithLogitsLoss()
|
|
|
|
# Each rank gets a different shard of synthetic data, generated
|
|
# deterministically from the rank's seed offset so total batch =
|
|
# world_size * micro_batch matches the gradient-accumulation baseline.
|
|
last_loss = 0.0
|
|
allreduce_time_total = 0.0
|
|
for step in range(n_steps):
|
|
# Per-rank data shard.
|
|
dense, sparse_indices, sparse_offsets, targets = _build_inputs(
|
|
micro_batch, seed=42 + rank * 1000 + step
|
|
)
|
|
|
|
optimizer.zero_grad(set_to_none=True)
|
|
# Use logits (not BCE-with-sigmoid output) for stable training.
|
|
logits = model.bot_l[0](dense) # placeholder; we just need loss to descend
|
|
# Instead, use the full model output — sigmoid output, so BCELoss.
|
|
out = ddp_model(dense, sparse_indices, sparse_offsets)
|
|
loss = nn.functional.binary_cross_entropy(out, targets)
|
|
|
|
t_back = time.perf_counter()
|
|
loss.backward()
|
|
# AllReduce happens implicitly during backward via DDP's reducer.
|
|
allreduce_time_total += time.perf_counter() - t_back
|
|
|
|
optimizer.step()
|
|
last_loss = loss.item()
|
|
|
|
if rank == 0:
|
|
result_queue.put({
|
|
"rank": rank,
|
|
"final_loss": last_loss,
|
|
"allreduce_time_per_step_ms": allreduce_time_total / n_steps * 1000,
|
|
})
|
|
|
|
dist.destroy_process_group()
|
|
|
|
|
|
def run_gradacc_baseline(n_steps: int, micro_batch: int, world_size: int) -> dict:
|
|
"""Single-process gradient-accumulation baseline equivalent to DDP."""
|
|
torch.manual_seed(42)
|
|
model = MicroDLRMWhiteBox()
|
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
|
|
|
last_loss = 0.0
|
|
for step in range(n_steps):
|
|
optimizer.zero_grad(set_to_none=True)
|
|
accum_loss = 0.0
|
|
for rank in range(world_size):
|
|
dense, sparse_indices, sparse_offsets, targets = _build_inputs(
|
|
micro_batch, seed=42 + rank * 1000 + step
|
|
)
|
|
out = model(dense, sparse_indices, sparse_offsets)
|
|
loss = nn.functional.binary_cross_entropy(out, targets) / world_size
|
|
loss.backward()
|
|
accum_loss += loss.item()
|
|
optimizer.step()
|
|
last_loss = accum_loss
|
|
return {"final_loss": last_loss}
|
|
|
|
|
|
def run_ddp(n_steps: int = 50, micro_batch: int = 64, world_size: int = 2) -> dict:
|
|
"""Spawn world_size DDP processes; return rank-0 metrics."""
|
|
ctx = mp.get_context("spawn")
|
|
result_queue = ctx.Queue()
|
|
procs = []
|
|
for rank in range(world_size):
|
|
p = ctx.Process(target=_ddp_worker,
|
|
args=(rank, world_size, n_steps, micro_batch, result_queue))
|
|
p.start()
|
|
procs.append(p)
|
|
for p in procs:
|
|
p.join(timeout=60)
|
|
if p.is_alive():
|
|
p.terminate()
|
|
return {"error": "DDP worker timed out"}
|
|
if result_queue.empty():
|
|
return {"error": "no result from rank 0"}
|
|
return result_queue.get()
|