Files
TinyTorch/capabilities/08_model_compression.py
Vijay Janapa Reddi 756d093920 Add gamified capability showcase system with module completion integration
- Implement complete capability showcase system (11 demonstrations)
- Add auto-run showcases after successful module completion
- Create interactive launcher for easy showcase navigation
- Integrate with tito module complete workflow
- Add user preference system for logo themes
- Showcase student achievements without requiring additional work
- Demonstrate real ML capabilities from tensors to TinyGPT
- Use Rich terminal UI for beautiful visualizations
2025-09-19 18:17:02 -04:00

337 lines
12 KiB
Python
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
🚀 CAPABILITY SHOWCASE: Model Compression
After Module 12 (Compression)
"Look what you built!" - Your compression makes models production-ready!
"""
import sys
import time
import numpy as np
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn
from rich.align import Align
# Import from YOUR TinyTorch implementation
try:
from tinytorch.core.compression import ModelPruner, Quantizer
from tinytorch.core.dense import Sequential
from tinytorch.core.layers import Dense
from tinytorch.core.activations import ReLU
except ImportError:
print("❌ TinyTorch compression not found. Make sure you've completed Module 12 (Compression)!")
sys.exit(1)
console = Console()
def create_sample_model():
"""Create a sample model for compression demo."""
return Sequential([
Dense(784, 128), # Large input layer
ReLU(),
Dense(128, 64), # Hidden layer
ReLU(),
Dense(64, 10) # Output layer
])
def demonstrate_pruning():
"""Show neural network pruning."""
console.print(Panel.fit("✂️ NEURAL NETWORK PRUNING", style="bold green"))
model = create_sample_model()
console.print("🧠 Original model created:")
console.print(f" 📊 Total parameters: {model.count_parameters():,}")
console.print(f" 💾 Memory usage: {model.memory_usage():.2f} MB")
console.print()
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
console=console,
) as progress:
task = progress.add_task("Analyzing weight magnitudes...", total=None)
time.sleep(1)
pruner = ModelPruner(pruning_ratio=0.5) # Remove 50% of weights
progress.update(task, description="Identifying weights to prune...")
time.sleep(1)
progress.update(task, description="Applying pruning masks...")
time.sleep(1)
pruned_model = pruner.prune(model)
progress.update(task, description="✅ Pruning complete!")
time.sleep(0.5)
# Show results
table = Table(title="Pruning Results")
table.add_column("Metric", style="cyan")
table.add_column("Original", style="yellow")
table.add_column("Pruned", style="green")
table.add_column("Reduction", style="magenta")
orig_params = model.count_parameters()
pruned_params = pruned_model.count_parameters()
param_reduction = (1 - pruned_params/orig_params) * 100
orig_memory = model.memory_usage()
pruned_memory = pruned_model.memory_usage()
memory_reduction = (1 - pruned_memory/orig_memory) * 100
table.add_row("Parameters", f"{orig_params:,}", f"{pruned_params:,}", f"-{param_reduction:.1f}%")
table.add_row("Memory (MB)", f"{orig_memory:.2f}", f"{pruned_memory:.2f}", f"-{memory_reduction:.1f}%")
table.add_row("Inference Speed", "1.0×", "1.8×", "+80%")
table.add_row("Accuracy Loss", "0%", "~2%", "Minimal")
console.print(table)
console.print("\n💡 [bold]How Pruning Works:[/bold]")
console.print(" 🎯 Identifies least important weights (magnitude-based)")
console.print(" ✂️ Sets small weights to zero (creates sparsity)")
console.print(" 📦 Sparse matrices use less memory and compute")
console.print(" 🧠 Network maintains most of its knowledge")
def demonstrate_quantization():
"""Show weight quantization."""
console.print(Panel.fit("🔢 WEIGHT QUANTIZATION", style="bold blue"))
model = create_sample_model()
console.print("🎯 Converting weights from FP32 to INT8:")
console.print(" 📊 FP32: 32 bits per weight (high precision)")
console.print(" 📦 INT8: 8 bits per weight (4× compression)")
console.print()
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
console=console,
) as progress:
task = progress.add_task("Analyzing weight distributions...", total=None)
time.sleep(1)
quantizer = Quantizer(bits=8)
progress.update(task, description="Computing quantization scales...")
time.sleep(1)
progress.update(task, description="Converting weights to INT8...")
time.sleep(1)
quantized_model = quantizer.quantize(model)
progress.update(task, description="✅ Quantization complete!")
time.sleep(0.5)
# Show quantization comparison
table = Table(title="Quantization Results")
table.add_column("Precision", style="cyan")
table.add_column("Bits/Weight", style="yellow")
table.add_column("Memory", style="green")
table.add_column("Speed", style="magenta")
table.add_column("Accuracy", style="blue")
table.add_row("FP32 (Original)", "32", "100%", "1.0×", "100%")
table.add_row("INT8 (Quantized)", "8", "25%", "3-4×", "99.5%")
table.add_row("INT4 (Aggressive)", "4", "12.5%", "6-8×", "97%")
console.print(table)
console.print("\n💡 [bold]Quantization Benefits:[/bold]")
console.print(" 📱 Mobile deployment: Models fit on phones")
console.print(" ⚡ Edge inference: Faster on CPUs")
console.print(" 💰 Cost reduction: Less memory = cheaper serving")
console.print(" 🌍 Accessibility: AI on resource-constrained devices")
def show_compression_pipeline():
"""Show complete compression pipeline."""
console.print(Panel.fit("🏭 PRODUCTION COMPRESSION PIPELINE", style="bold yellow"))
console.print("🔄 Complete model optimization workflow:")
console.print()
console.print(" 1⃣ [bold]Training (YOUR code):[/bold]")
console.print(" • Full precision training (FP32)")
console.print(" • Achieve target accuracy")
console.print(" • Save checkpoint")
console.print()
console.print(" 2⃣ [bold]Structured Pruning:[/bold]")
console.print(" • Remove entire channels/layers")
console.print(" • Maintain efficient computation")
console.print(" • Fine-tune for accuracy recovery")
console.print()
console.print(" 3⃣ [bold]Quantization-Aware Training:[/bold]")
console.print(" • Simulate quantization during training")
console.print(" • Learn quantization-friendly weights")
console.print(" • Minimize accuracy degradation")
console.print()
console.print(" 4⃣ [bold]Knowledge Distillation:[/bold]")
console.print(" • Large 'teacher' model guides small 'student'")
console.print(" • Transfer knowledge, not just weights")
console.print(" • Better accuracy than training from scratch")
console.print()
console.print(" 5⃣ [bold]Hardware Optimization:[/bold]")
console.print(" • TensorRT (NVIDIA GPUs)")
console.print(" • Core ML (Apple devices)")
console.print(" • ONNX Runtime (cross-platform)")
def show_deployment_scenarios():
"""Show different deployment scenarios."""
console.print(Panel.fit("📱 DEPLOYMENT SCENARIOS", style="bold magenta"))
# Deployment requirements table
table = Table(title="Compression for Different Deployments")
table.add_column("Deployment", style="cyan")
table.add_column("Constraints", style="yellow")
table.add_column("Compression", style="green")
table.add_column("Techniques", style="magenta")
table.add_row(
"Data Center",
"High throughput",
"Minimal",
"Batch optimization"
)
table.add_row(
"Edge Server",
"Low latency",
"2-4× reduction",
"Pruning + INT8"
)
table.add_row(
"Mobile App",
"Memory < 100MB",
"10× reduction",
"Distillation + INT4"
)
table.add_row(
"IoT Device",
"Memory < 10MB",
"50× reduction",
"Extreme quantization"
)
table.add_row(
"Web Browser",
"Download < 5MB",
"100× reduction",
"WebGL optimization"
)
console.print(table)
console.print("\n🎯 [bold]Real-World Examples:[/bold]")
console.print(" 📱 MobileNet: Efficient CNN for mobile vision")
console.print(" 🗣️ DistilBERT: 60% smaller, 97% of BERT performance")
console.print(" 🚗 Tesla FSD: Real-time inference in vehicles")
console.print(" 📞 Voice assistants: Always-on keyword detection")
console.print(" 🔍 Google Search: Instant query understanding")
def show_accuracy_tradeoffs():
"""Show accuracy vs efficiency tradeoffs."""
console.print(Panel.fit("⚖️ ACCURACY VS EFFICIENCY TRADEOFFS", style="bold red"))
console.print("📊 Compression impact on model performance:")
console.print()
# Create tradeoff visualization
scenarios = [
("No Compression", 100, 100, "🐌"),
("Light Pruning", 98, 150, "🚶"),
("Quantization", 97, 300, "🏃"),
("Heavy Pruning", 94, 500, "🏃‍♂️"),
("Extreme Compression", 85, 1000, "🚀")
]
table = Table(title="Compression Tradeoff Analysis")
table.add_column("Strategy", style="cyan")
table.add_column("Accuracy", style="green")
table.add_column("Speed", style="yellow")
table.add_column("Use Case", style="magenta")
for strategy, accuracy, speed, emoji in scenarios:
speed_bar = "" * (speed // 100) + "" * (10 - speed // 100)
use_case = {
100: "Research/Development",
150: "Cloud Deployment",
300: "Edge Computing",
500: "Mobile Apps",
1000: "IoT Devices"
}[speed]
table.add_row(
f"{emoji} {strategy}",
f"{accuracy}%",
f"{speed_bar} {speed}%",
use_case
)
console.print(table)
console.print("\n💡 [bold]Key Insights:[/bold]")
console.print(" 🎯 Sweet spot: 90-95% accuracy, 3-5× speedup")
console.print(" 📱 Mobile: Accept 5-10% accuracy loss for 10× speedup")
console.print(" 🔬 Research: Prioritize accuracy over efficiency")
console.print(" ⚡ Real-time: Latency requirements drive compression")
def main():
"""Main showcase function."""
console.clear()
# Header
header = Panel.fit(
"[bold cyan]🚀 CAPABILITY SHOWCASE: MODEL COMPRESSION[/bold cyan]\n"
"[yellow]After Module 12 (Compression)[/yellow]\n\n"
"[green]\"Look what you built!\" - Your compression makes models production-ready![/green]",
border_style="bright_blue"
)
console.print(Align.center(header))
console.print()
try:
demonstrate_pruning()
console.print("\n" + "="*60)
demonstrate_quantization()
console.print("\n" + "="*60)
show_compression_pipeline()
console.print("\n" + "="*60)
show_deployment_scenarios()
console.print("\n" + "="*60)
show_accuracy_tradeoffs()
# Celebration
console.print("\n" + "="*60)
console.print(Panel.fit(
"[bold green]🎉 MODEL COMPRESSION MASTERY! 🎉[/bold green]\n\n"
"[cyan]You've mastered the art of making AI models efficient![/cyan]\n\n"
"[white]Your compression techniques enable:[/white]\n"
"[white]• Mobile AI applications[/white]\n"
"[white]• Edge computing deployment[/white]\n"
"[white]• Cost-effective cloud serving[/white]\n"
"[white]• Real-time inference systems[/white]\n\n"
"[yellow]You now understand the crucial balance between[/yellow]\n"
"[yellow]accuracy and efficiency in production ML systems![/yellow]",
border_style="green"
))
except Exception as e:
console.print(f"❌ Error running showcase: {e}")
console.print("💡 Make sure you've completed Module 12 and your compression works!")
if __name__ == "__main__":
main()