From 7d82ece2ff36288bb98bbe6af2a63ba9c848578f Mon Sep 17 00:00:00 2001 From: Vijay Janapa Reddi Date: Sun, 28 Sep 2025 21:29:16 -0400 Subject: [PATCH] Fix CIFAR CNN parameter names - Phase 1 Complete MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit All examples now learning successfully: ✅ Perceptron - 100% accuracy ✅ XOR - Training with validation ✅ MNIST - Deep learning working ✅ CIFAR - Fixed Conv2d weight vs weights issue ✅ TinyGPT - Transformer training Ready for Phase 2: Optimization testing --- examples/cifar_cnn_modern/train_cnn.py | 4 +- test_all_examples.py | 128 +++++++++++++++++++++++++ 2 files changed, 130 insertions(+), 2 deletions(-) create mode 100644 test_all_examples.py diff --git a/examples/cifar_cnn_modern/train_cnn.py b/examples/cifar_cnn_modern/train_cnn.py index 5dc6f8fc..e497c512 100644 --- a/examples/cifar_cnn_modern/train_cnn.py +++ b/examples/cifar_cnn_modern/train_cnn.py @@ -180,8 +180,8 @@ class CIFARCNN: def parameters(self): """Get all trainable parameters from YOUR layers.""" return [ - self.conv1.weights, self.conv1.bias, - self.conv2.weights, self.conv2.bias, + self.conv1.weight, self.conv1.bias, + self.conv2.weight, self.conv2.bias, self.fc1.weights, self.fc1.bias, self.fc2.weights, self.fc2.bias ] diff --git a/test_all_examples.py b/test_all_examples.py new file mode 100644 index 00000000..d1897ca7 --- /dev/null +++ b/test_all_examples.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 +""" +Phase 1: Test all TinyTorch examples to ensure they learn. +Tests each example and logs results. +""" + +import subprocess +import sys +import time +from datetime import datetime + +def log(message): + """Log with timestamp.""" + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + print(f"[{timestamp}] {message}") + sys.stdout.flush() + +def test_example(name, command, success_criteria): + """Test an example and return success status.""" + log(f"Testing {name}...") + log(f"Command: {command}") + + try: + result = subprocess.run( + command, + shell=True, + capture_output=True, + text=True, + timeout=120 + ) + + output = result.stdout + result.stderr + + # Check success criteria + success = all(criterion in output for criterion in success_criteria) + + if success: + log(f"✅ {name} PASSED - All criteria met") + else: + log(f"❌ {name} FAILED - Missing criteria") + log(f"Output preview: {output[-500:]}") + + return success, output + + except subprocess.TimeoutExpired: + log(f"⏱️ {name} TIMEOUT - Took too long") + return False, "TIMEOUT" + except Exception as e: + log(f"❌ {name} ERROR - {str(e)}") + return False, str(e) + +def main(): + """Test all examples in order of complexity.""" + + log("="*60) + log("PHASE 1: TESTING ALL EXAMPLES FOR LEARNING") + log("="*60) + + results = [] + + # 1. Perceptron (simplest) + log("\n1. PERCEPTRON (1957)") + success, output = test_example( + "Perceptron", + "python examples/perceptron_1957/rosenblatt_perceptron.py --epochs 100", + ["SUCCESS", "100.0%", "Loss"] + ) + results.append(("Perceptron", success)) + + # 2. XOR (multi-layer) + log("\n2. XOR (1969)") + success, output = test_example( + "XOR", + "python examples/xor_1969/minsky_xor_problem.py --epochs 200", + ["SUCCESS", "Training Complete", "Val"] + ) + results.append(("XOR", success)) + + # 3. MNIST MLP (deep network) + log("\n3. MNIST MLP (1986)") + success, output = test_example( + "MNIST", + "python examples/mnist_mlp_1986/train_mlp.py --epochs 2 --batch-size 32", + ["SUCCESS", "Training", "Test"] + ) + results.append(("MNIST", success)) + + # 4. CIFAR CNN (convolutional) + log("\n4. CIFAR CNN (Modern)") + success, output = test_example( + "CIFAR", + "python examples/cifar_cnn_modern/train_cnn.py --quick-test --epochs 2", + ["SUCCESS", "Forward pass", "CNN"] + ) + results.append(("CIFAR", success)) + + # 5. TinyGPT (transformer) + log("\n5. TINYGPT (2018)") + success, output = test_example( + "TinyGPT", + "python examples/gpt_2018/train_gpt.py", + ["Success", "transformer", "Loss"] + ) + results.append(("TinyGPT", success)) + + # Summary + log("\n" + "="*60) + log("PHASE 1 SUMMARY") + log("="*60) + + for name, success in results: + status = "✅ PASS" if success else "❌ FAIL" + log(f"{name:15} {status}") + + all_passed = all(success for _, success in results) + + if all_passed: + log("\n🎉 ALL EXAMPLES LEARNING SUCCESSFULLY!") + log("Ready for Phase 2: Optimization Testing") + else: + log("\n⚠️ Some examples need fixing before optimization") + log("Fix failing examples first") + + return all_passed + +if __name__ == "__main__": + success = main() + sys.exit(0 if success else 1)