mirror of
https://github.com/harvard-edge/cs249r_book.git
synced 2026-05-21 21:43:04 -05:00
123 lines
4.7 KiB
Python
123 lines
4.7 KiB
Python
"""
|
|
MLPerf EDU iter-7: LoRA adapter and injection helper.
|
|
|
|
Pure PyTorch, no `peft` library, no HF dependency. Targets the fused
|
|
`c_attn` (QKV) projection in `CausalSelfAttention`.
|
|
|
|
Per Han's iter-7 spec:
|
|
- rank=8, alpha=16 (scale = alpha/rank = 2.0)
|
|
- target_modules: c_attn only
|
|
- Trainable params for GPT-2-Small (12 layers, d_model=768):
|
|
12 * (768*8 + 8*3*768) = 12 * (6144 + 18432) = 294,912 params
|
|
= 0.33% of the 88M base.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class LoRAAdapter(nn.Module):
|
|
"""Low-rank adapter A @ B injected as a residual on a Linear layer.
|
|
|
|
Forward: y = base(x) + (alpha/rank) * (x @ A) @ B
|
|
|
|
Where:
|
|
base : the frozen Linear (d_in -> d_out) we are adapting.
|
|
A : Linear(d_in -> rank), Kaiming-initialized, trainable.
|
|
B : Linear(rank -> d_out), zero-initialized, trainable.
|
|
"""
|
|
|
|
def __init__(self, base: nn.Linear, rank: int = 8, alpha: int = 16):
|
|
super().__init__()
|
|
if not isinstance(base, nn.Linear):
|
|
raise TypeError(f"LoRAAdapter expects nn.Linear, got {type(base)}")
|
|
self.base = base
|
|
self.rank = rank
|
|
self.alpha = alpha
|
|
self.scale = alpha / rank
|
|
|
|
# Freeze base.
|
|
for p in self.base.parameters():
|
|
p.requires_grad = False
|
|
|
|
# Init A with Kaiming, B with zero so initial delta = 0 (training
|
|
# starts at the base model's behavior — standard LoRA convention).
|
|
# Build the adapters on the same device + dtype as the base layer
|
|
# so injection works regardless of where the base lives (MPS/CUDA/CPU).
|
|
device = base.weight.device
|
|
dtype = base.weight.dtype
|
|
self.lora_A = nn.Linear(base.in_features, rank, bias=False).to(device=device, dtype=dtype)
|
|
self.lora_B = nn.Linear(rank, base.out_features, bias=False).to(device=device, dtype=dtype)
|
|
nn.init.kaiming_uniform_(self.lora_A.weight, a=5 ** 0.5)
|
|
nn.init.zeros_(self.lora_B.weight)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.base(x) + self.scale * self.lora_B(self.lora_A(x))
|
|
|
|
def merge(self) -> nn.Linear:
|
|
"""Return a new nn.Linear with W' = W_base + scale * B @ A baked in.
|
|
|
|
Used for deployment-time inference: drop the LoRA branch entirely.
|
|
Critical for the merged-vs-unmerged latency comparison.
|
|
"""
|
|
merged = nn.Linear(self.base.in_features, self.base.out_features,
|
|
bias=self.base.bias is not None)
|
|
with torch.no_grad():
|
|
delta = self.scale * (self.lora_B.weight @ self.lora_A.weight)
|
|
merged.weight.copy_(self.base.weight + delta)
|
|
if self.base.bias is not None:
|
|
merged.bias.copy_(self.base.bias)
|
|
return merged.to(self.base.weight.device).to(self.base.weight.dtype)
|
|
|
|
|
|
def inject_lora(model: nn.Module, rank: int = 8, alpha: int = 16,
|
|
target: str = "c_attn") -> tuple[int, int]:
|
|
"""Walk the model, replace every `target`-named Linear with a LoRAAdapter.
|
|
|
|
Returns (n_adapters_injected, total_trainable_lora_params).
|
|
|
|
After injection, freeze every non-LoRA parameter. The model's `forward`
|
|
path is unchanged because LoRAAdapter shares the base layer's call shape.
|
|
"""
|
|
n_injected = 0
|
|
n_trainable = 0
|
|
for name, module in list(model.named_modules()):
|
|
for child_name, child in list(module.named_children()):
|
|
if child_name == target and isinstance(child, nn.Linear):
|
|
adapter = LoRAAdapter(child, rank=rank, alpha=alpha)
|
|
setattr(module, child_name, adapter)
|
|
n_injected += 1
|
|
n_trainable += sum(p.numel() for p in adapter.lora_A.parameters())
|
|
n_trainable += sum(p.numel() for p in adapter.lora_B.parameters())
|
|
|
|
# Freeze every parameter that is NOT inside a LoRAAdapter.
|
|
for name, p in model.named_parameters():
|
|
if "lora_A" in name or "lora_B" in name:
|
|
p.requires_grad = True
|
|
else:
|
|
p.requires_grad = False
|
|
return n_injected, n_trainable
|
|
|
|
|
|
def base_grad_norm(model: nn.Module) -> float:
|
|
"""Sum of squared gradient norms over base (non-LoRA) parameters."""
|
|
total = 0.0
|
|
for name, p in model.named_parameters():
|
|
if "lora_" in name:
|
|
continue
|
|
if p.grad is not None:
|
|
total += p.grad.detach().pow(2).sum().item()
|
|
return total ** 0.5
|
|
|
|
|
|
def lora_grad_norm(model: nn.Module) -> float:
|
|
"""Sum of squared gradient norms over LoRA-only parameters."""
|
|
total = 0.0
|
|
for name, p in model.named_parameters():
|
|
if "lora_" not in name:
|
|
continue
|
|
if p.grad is not None:
|
|
total += p.grad.detach().pow(2).sum().item()
|
|
return total ** 0.5
|