mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-03-11 22:33:36 -05:00
Refactor Module 16: Extract verify_pruning_works() function
- Create standalone verify_pruning_works() function (Section 8.5) - Clean separation: verification logic in reusable function - test_module() now calls verify_pruning_works() - much cleaner - Students can call this function on their own pruned models - Returns dict with verification results (sparsity, zeros, verified) - Includes example usage in __main__ block - HONEST messaging: Memory saved = 0 MB (dense storage) - Educational: Explains compute vs memory savings Benefits: - Not tacked on - first-class verification function - Reusable across different pruning strategies - Clear educational value about dense vs sparse storage - Each function has one clear job
This commit is contained in:
@@ -1330,6 +1330,77 @@ def test_unit_compress_model():
|
||||
if __name__ == "__main__":
|
||||
test_unit_compress_model()
|
||||
|
||||
# %% [markdown]
|
||||
"""
|
||||
## 8.5 Verification - Proving Pruning Works
|
||||
|
||||
Before analyzing compression in production, let's verify that our pruning actually achieves sparsity using real measurements.
|
||||
"""
|
||||
|
||||
# %% nbgrader={"grade": false, "grade_id": "verify_pruning", "solution": false}
|
||||
def verify_pruning_works(model, target_sparsity=0.8):
|
||||
"""
|
||||
Verify pruning actually creates zeros using real zero-counting.
|
||||
|
||||
This is NOT a theoretical calculation - we count actual zero values
|
||||
in parameter arrays and honestly report memory footprint (unchanged with dense storage).
|
||||
|
||||
Args:
|
||||
model: Model with pruned parameters
|
||||
target_sparsity: Expected sparsity ratio (default 0.8 = 80%)
|
||||
|
||||
Returns:
|
||||
dict: Verification results with sparsity, zeros, total, verified
|
||||
|
||||
Example:
|
||||
>>> model = SimpleModel(Linear(100, 50))
|
||||
>>> magnitude_prune(model, sparsity=0.8)
|
||||
>>> results = verify_pruning_works(model, target_sparsity=0.8)
|
||||
>>> assert results['verified'] # Pruning actually works!
|
||||
"""
|
||||
print("🔬 Verifying pruning sparsity with actual zero-counting...")
|
||||
|
||||
# Count actual zeros in model parameters
|
||||
zeros = sum(np.sum(p.data == 0) for p in model.parameters())
|
||||
total = sum(p.data.size for p in model.parameters())
|
||||
sparsity = zeros / total
|
||||
memory_bytes = sum(p.data.nbytes for p in model.parameters())
|
||||
|
||||
# Display results
|
||||
print(f" Total parameters: {total:,}")
|
||||
print(f" Zero parameters: {zeros:,}")
|
||||
print(f" Active parameters: {total - zeros:,}")
|
||||
print(f" Sparsity achieved: {sparsity*100:.1f}%")
|
||||
print(f" Memory footprint: {memory_bytes / MB_TO_BYTES:.2f} MB (unchanged - dense storage)")
|
||||
|
||||
# Verify target met (allow 15% tolerance for structured pruning variations)
|
||||
verified = abs(sparsity - target_sparsity) < 0.15
|
||||
status = '✓' if verified else '✗'
|
||||
print(f" {status} Meets {target_sparsity*100:.0f}% sparsity target")
|
||||
|
||||
assert verified, f"Sparsity target not met: {sparsity:.2f} vs {target_sparsity:.2f}"
|
||||
|
||||
print(f"\n✅ VERIFIED: {sparsity*100:.1f}% sparsity achieved")
|
||||
print(f"⚠️ Memory saved: 0 MB (dense numpy arrays)")
|
||||
print(f"💡 LEARNING: Compute savings ~{sparsity*100:.1f}% (skip zero multiplications)")
|
||||
print(f" In production: Use sparse formats (scipy.sparse.csr_matrix) for memory savings")
|
||||
|
||||
return {
|
||||
'sparsity': sparsity,
|
||||
'zeros': zeros,
|
||||
'total': total,
|
||||
'active': total - zeros,
|
||||
'memory_mb': memory_bytes / MB_TO_BYTES,
|
||||
'verified': verified
|
||||
}
|
||||
|
||||
# Run verification example when developing
|
||||
if __name__ == "__main__":
|
||||
# Create and prune test model
|
||||
test_model = SimpleModel(Linear(100, 50), Linear(50, 25))
|
||||
magnitude_prune(test_model, sparsity=0.8)
|
||||
verify_pruning_works(test_model, target_sparsity=0.8)
|
||||
|
||||
# %% [markdown]
|
||||
"""
|
||||
## 8.6 Systems Analysis - Compression Techniques
|
||||
@@ -1629,35 +1700,16 @@ def test_module():
|
||||
|
||||
print(f"✅ Low-rank: {compression_ratio:.2f}x compression, {error:.3f} error")
|
||||
|
||||
# ✨ VERIFICATION: Actual Optimization Effects
|
||||
print("\n🔬 VERIFICATION: Actual Optimization Effects...")
|
||||
print("=" * 50)
|
||||
|
||||
print("\n✓ Verifying pruning sparsity...")
|
||||
# Count actual zeros in pruned model
|
||||
zeros = sum(np.sum(p.data == 0) for p in model.parameters())
|
||||
total = sum(p.data.size for p in model.parameters())
|
||||
sparsity = zeros / total
|
||||
memory_bytes = sum(p.data.nbytes for p in model.parameters())
|
||||
|
||||
print(f" Total parameters: {total:,}")
|
||||
print(f" Zero parameters: {zeros:,}")
|
||||
print(f" Sparsity achieved: {sparsity*100:.1f}%")
|
||||
print(f" Memory footprint: {memory_bytes / MB_TO_BYTES:.2f} MB (unchanged - dense storage)")
|
||||
|
||||
# Verify pruning actually works
|
||||
print()
|
||||
target_sparsity = compression_config['magnitude_prune']
|
||||
assert abs(sparsity - target_sparsity) < 0.15, f"Sparsity target not met: {sparsity:.2f} vs {target_sparsity:.2f}"
|
||||
|
||||
print(f"\n✅ VERIFIED: {sparsity*100:.1f}% sparsity achieved")
|
||||
print(f"⚠️ Memory saved: 0 MB (dense numpy arrays)")
|
||||
print(f"💡 LEARNING: Compute savings ~{sparsity*100:.1f}% (skip zero multiplications)")
|
||||
print(f" In production: Use sparse formats (scipy.sparse.csr_matrix) for memory savings")
|
||||
verification_results = verify_pruning_works(model, target_sparsity=target_sparsity)
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("🎉 ALL TESTS PASSED! Module ready for export.")
|
||||
print("📈 Compression system provides:")
|
||||
print(f" • {sparsity*100:.1f}% sparsity")
|
||||
print(f" • ✓ VERIFIED with actual zero-counting")
|
||||
print(f" • {verification_results['sparsity']*100:.1f}% sparsity")
|
||||
print(f" • ✓ VERIFIED: {verification_results['zeros']:,} actual zeros counted")
|
||||
print(f" • Honest: Dense storage = no memory savings (educational limitation)")
|
||||
print("Run: tito module complete 16")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user