mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-03-12 03:03:37 -05:00
MAJOR FEATURE: Multi-channel convolutions for real CNN architectures Key additions: - MultiChannelConv2D class with in_channels/out_channels support - Handles RGB images (3 channels) and arbitrary channel counts - He initialization for stable training - Optional bias parameters - Batch processing support Testing & Validation: - Comprehensive unit tests for single/multi-channel - Integration tests for complete CNN pipelines - Memory profiling and parameter scaling analysis - QA approved: All mandatory tests passing CIFAR-10 CNN Example: - Updated train_cnn.py to use MultiChannelConv2D - Architecture: Conv(3→32) → Pool → Conv(32→64) → Pool → Dense - Demonstrates why convolutions matter for vision - Shows parameter reduction vs MLPs (18KB vs 12MB) Systems Analysis: - Parameter scaling: O(in_channels × out_channels × kernel²) - Memory profiling shows efficient scaling - Performance characteristics documented - Production context with PyTorch comparisons This enables proper CNN training on CIFAR-10 with ~60% accuracy target.
45 lines
1.5 KiB
Python
45 lines
1.5 KiB
Python
#!/usr/bin/env python3
|
|
"""Test the final 15-module structure."""
|
|
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
def test_module(module_path):
|
|
"""Test a single module."""
|
|
py_files = list(module_path.glob("*_dev.py"))
|
|
if not py_files:
|
|
return None
|
|
result = subprocess.run([sys.executable, str(py_files[0])],
|
|
capture_output=True, timeout=10, cwd=Path.cwd())
|
|
return result.returncode == 0
|
|
|
|
print("="*60)
|
|
print("TinyTorch 15-Module Structure Test")
|
|
print("="*60)
|
|
|
|
modules_dir = Path("modules/source")
|
|
parts = [
|
|
("Part I: MLPs (XORNet)", ["01_setup", "02_tensor", "03_activations", "04_layers", "05_networks"]),
|
|
("Part II: CNNs (CIFAR-10)", ["06_spatial", "07_dataloader", "08_autograd", "09_optimizers", "10_training"]),
|
|
("Part III: Transformers (TinyGPT)", ["11_embeddings", "12_attention", "13_normalization", "14_transformers", "15_generation"])
|
|
]
|
|
|
|
for part_name, modules in parts:
|
|
print(f"\n{part_name}")
|
|
print("-"*40)
|
|
for module in modules:
|
|
path = modules_dir / module
|
|
if not path.exists():
|
|
print(f" ⚠️ {module:20} Missing")
|
|
elif test_module(path):
|
|
print(f" ✅ {module:20} Passes")
|
|
elif test_module(path) is None:
|
|
print(f" ⚠️ {module:20} No implementation")
|
|
else:
|
|
print(f" ❌ {module:20} Failed")
|
|
|
|
print("\n" + "="*60)
|
|
print("✨ Clean 15-module structure ready!")
|
|
print("Each part: 5 modules, 1 innovation, 1 capstone")
|