mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-03-09 07:02:17 -05:00
Fix MLPerf milestones and improve accuracy display
- Fix import names: ProfilerComplete->Profiler, QuantizationComplete->Quantizer, CompressionComplete->Compressor - Add missing Embedding import to transformer.py - Update optimization olympics table to show baseline acc, new acc, and delta with +/- signs - Milestones 01, 02, 05, 06 all working
This commit is contained in:
@@ -32,7 +32,7 @@ from rich import box
|
||||
from tinytorch.models.transformer import GPT
|
||||
from tinytorch.text.tokenization import CharTokenizer
|
||||
from tinytorch.core.tensor import Tensor
|
||||
from tinytorch.profiling.profiler import ProfilerComplete
|
||||
from tinytorch.profiling.profiler import Profiler
|
||||
from tinytorch.generation.kv_cache import enable_kv_cache, disable_kv_cache
|
||||
|
||||
console = Console()
|
||||
@@ -313,7 +313,7 @@ def main():
|
||||
max_seq_len=64
|
||||
)
|
||||
|
||||
profiler = ProfilerComplete()
|
||||
profiler = Profiler()
|
||||
console.print("[green]✅ Model initialized[/green]\n")
|
||||
|
||||
# Profile architecture
|
||||
|
||||
@@ -291,8 +291,8 @@ def main():
|
||||
border_style="yellow"
|
||||
))
|
||||
|
||||
# Use YOUR QuantizationComplete class
|
||||
quant_result = QuantizationComplete.quantize_model(model)
|
||||
# Use YOUR Quantizer class
|
||||
quant_result = Quantizer.quantize_model(model)
|
||||
|
||||
quant_size = int(param_bytes / quant_result['compression_ratio'])
|
||||
|
||||
@@ -336,10 +336,10 @@ def main():
|
||||
for j, param in enumerate(layer.parameters()):
|
||||
model_copy.layers[i].parameters()[j].data = param.data.copy()
|
||||
|
||||
# Use YOUR CompressionComplete class
|
||||
sparsity_before = CompressionComplete.measure_sparsity(model_copy)
|
||||
CompressionComplete.magnitude_prune(model_copy, sparsity=0.5)
|
||||
sparsity_after = CompressionComplete.measure_sparsity(model_copy)
|
||||
# Use YOUR Compressor class
|
||||
sparsity_before = Compressor.measure_sparsity(model_copy)
|
||||
Compressor.magnitude_prune(model_copy, sparsity=0.5)
|
||||
sparsity_after = Compressor.measure_sparsity(model_copy)
|
||||
|
||||
# Calculate pruned accuracy
|
||||
outputs_pruned = model_copy(X_test)
|
||||
@@ -431,34 +431,47 @@ def main():
|
||||
console.print(Panel("[bold]🏆 OPTIMIZATION OLYMPICS RESULTS[/bold]", border_style="gold1"))
|
||||
console.print()
|
||||
|
||||
# Final comparison
|
||||
# Final comparison with clear accuracy delta
|
||||
table = Table(title="🎖️ Your Optimization Journey", box=box.DOUBLE)
|
||||
table.add_column("Stage", style="cyan", width=25)
|
||||
table.add_column("Size", style="yellow", justify="right")
|
||||
table.add_column("Accuracy", style="green", justify="right")
|
||||
table.add_column("YOUR Module", style="bold magenta")
|
||||
table.add_column("Baseline Acc", style="dim", justify="right")
|
||||
table.add_column("New Acc", style="green", justify="right")
|
||||
table.add_column("Δ Accuracy", style="bold", justify="right")
|
||||
table.add_column("YOUR Module", style="magenta")
|
||||
|
||||
# Quantization typically preserves accuracy
|
||||
quant_acc = baseline_acc # Quantization preserves accuracy
|
||||
quant_delta = quant_acc - baseline_acc
|
||||
prune_delta = pruned_acc - baseline_acc
|
||||
|
||||
table.add_row(
|
||||
"📊 Baseline",
|
||||
f"{param_bytes:,} B",
|
||||
f"{baseline_acc:.1f}%",
|
||||
f"{baseline_acc:.1f}%",
|
||||
"—",
|
||||
"Profiler (14)"
|
||||
)
|
||||
table.add_row(
|
||||
"🗜️ + Quantization",
|
||||
f"{quant_size:,} B",
|
||||
f"~{baseline_acc:.0f}%*",
|
||||
f"{baseline_acc:.1f}%",
|
||||
f"{quant_acc:.1f}%",
|
||||
f"[green]{quant_delta:+.1f}%[/green]" if quant_delta >= 0 else f"[red]{quant_delta:+.1f}%[/red]",
|
||||
"Quantization (15)"
|
||||
)
|
||||
table.add_row(
|
||||
"✂️ + Pruning",
|
||||
f"~{param_bytes//2:,} B**",
|
||||
f"{baseline_acc:.1f}%",
|
||||
f"{pruned_acc:.1f}%",
|
||||
f"[green]{prune_delta:+.1f}%[/green]" if prune_delta >= 0 else f"[red]{prune_delta:+.1f}%[/red]",
|
||||
"Compression (16)"
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
console.print("[dim]* Quantization preserves accuracy ** With sparse storage[/dim]")
|
||||
console.print("[dim]** With sparse storage[/dim]")
|
||||
console.print()
|
||||
|
||||
# Key insights
|
||||
|
||||
1
tinytorch/models/transformer.py
generated
1
tinytorch/models/transformer.py
generated
@@ -23,6 +23,7 @@ from ..core.tensor import Tensor
|
||||
from ..core.layers import Linear
|
||||
from ..core.attention import MultiHeadAttention
|
||||
from ..core.activations import GELU
|
||||
from ..core.embeddings import Embedding
|
||||
|
||||
# %% ../../modules/source/13_transformers/transformers_dev.ipynb 9
|
||||
class LayerNorm:
|
||||
|
||||
Reference in New Issue
Block a user