mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-04-28 21:12:46 -05:00
Complete comprehensive testing for API simplification
Added full test suite following TinyTorch testing conventions: ✅ UNIT TESTS (test_api_simplification.py): - 23 comprehensive tests covering all API components - Tests Parameter function, Module base class, Linear/Conv2d layers - Tests functional interface (F.relu, F.flatten, F.max_pool2d) - Tests optimizer integration and backward compatibility - Tests complete model workflows (MLP, CNN) ✅ INTEGRATION TESTS (test_api_simplification_integration.py): - Cross-component integration testing - Complete workflow validation (model → optimizer → training setup) - PyTorch compatibility verification - Nested module parameter collection testing ✅ EXAMPLE FIXES: - Fixed optimizer parameter names (lr → learning_rate) - Examples demonstrate real-world usage patterns - Show dramatic code simplification vs old API 🎯 TEST RESULTS: - Unit Tests: 23/23 PASS ✅ - Integration Tests: 8/8 PASS ✅ - API simplification validated with comprehensive coverage The testing validates that the API simplification maintains educational value while providing clean PyTorch-compatible interfaces.
This commit is contained in:
388
tests/integration/test_api_simplification_integration.py
Normal file
388
tests/integration/test_api_simplification_integration.py
Normal file
@@ -0,0 +1,388 @@
|
||||
"""
|
||||
Integration test for API Simplification
|
||||
|
||||
Validates that the new PyTorch-compatible API integrates correctly across all components:
|
||||
- nn module with Module, Linear, Conv2d
|
||||
- nn.functional with relu, flatten, max_pool2d
|
||||
- optim module with Adam, SGD
|
||||
- Complete workflow integration (model creation → optimizer → training setup)
|
||||
|
||||
This follows TinyTorch testing conventions:
|
||||
1. Unit tests in test_api_simplification.py
|
||||
2. Integration tests here (cross-component)
|
||||
3. Examples as ultimate integration validation
|
||||
"""
|
||||
|
||||
import sys
|
||||
import warnings
|
||||
import numpy as np
|
||||
|
||||
|
||||
def test_modern_api_integration():
|
||||
"""Test complete modern API integration across all components."""
|
||||
|
||||
# Suppress warnings for cleaner test output
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
results = {
|
||||
"module_name": "api_simplification",
|
||||
"integration_type": "modern_api_validation",
|
||||
"tests": [],
|
||||
"success": True,
|
||||
"errors": []
|
||||
}
|
||||
|
||||
try:
|
||||
# Test 1: Modern imports work
|
||||
try:
|
||||
import tinytorch.nn as nn
|
||||
import tinytorch.nn.functional as F
|
||||
import tinytorch.optim as optim
|
||||
from tinytorch.core.tensor import Tensor, Parameter
|
||||
|
||||
results["tests"].append({
|
||||
"name": "modern_imports",
|
||||
"status": "✅ PASS",
|
||||
"description": "Modern PyTorch-like imports work"
|
||||
})
|
||||
except ImportError as e:
|
||||
results["tests"].append({
|
||||
"name": "modern_imports",
|
||||
"status": "❌ FAIL",
|
||||
"description": f"Modern imports failed: {e}"
|
||||
})
|
||||
results["success"] = False
|
||||
results["errors"].append(f"Import error: {e}")
|
||||
return results
|
||||
|
||||
# Test 2: Complete MLP workflow integration
|
||||
try:
|
||||
class SimpleMLP(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(4, 8)
|
||||
self.fc2 = nn.Linear(8, 2)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.fc1(x))
|
||||
return self.fc2(x)
|
||||
|
||||
# Create model and optimizer
|
||||
model = SimpleMLP()
|
||||
optimizer = optim.Adam(model.parameters(), learning_rate=0.001)
|
||||
|
||||
# Test forward pass
|
||||
x = Tensor([[1.0, 2.0, 3.0, 4.0]])
|
||||
output = model(x)
|
||||
|
||||
# Verify integration
|
||||
assert len(list(model.parameters())) == 4, "Should have 4 parameters"
|
||||
assert output.shape == (1, 2), f"Expected (1, 2), got {output.shape}"
|
||||
assert len(optimizer.parameters) == 4, "Optimizer should have 4 parameters"
|
||||
|
||||
results["tests"].append({
|
||||
"name": "mlp_workflow_integration",
|
||||
"status": "✅ PASS",
|
||||
"description": "Complete MLP workflow integrates correctly"
|
||||
})
|
||||
except Exception as e:
|
||||
results["tests"].append({
|
||||
"name": "mlp_workflow_integration",
|
||||
"status": "❌ FAIL",
|
||||
"description": f"MLP workflow failed: {e}"
|
||||
})
|
||||
results["success"] = False
|
||||
results["errors"].append(f"MLP workflow error: {e}")
|
||||
|
||||
# Test 3: Complete CNN workflow integration
|
||||
try:
|
||||
class SimpleCNN(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(3, 8, (3, 3))
|
||||
self.fc1 = nn.Linear(200, 10) # Simplified calculation
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.conv1(x))
|
||||
x = F.flatten(x)
|
||||
return self.fc1(x)
|
||||
|
||||
# Create model and optimizer
|
||||
model = SimpleCNN()
|
||||
optimizer = optim.SGD(model.parameters(), learning_rate=0.01)
|
||||
|
||||
# Verify CNN integration
|
||||
params = list(model.parameters())
|
||||
assert len(params) == 4, f"Expected 4 parameters, got {len(params)}"
|
||||
assert len(optimizer.parameters) == 4, "Optimizer should have 4 parameters"
|
||||
|
||||
# Test parameter shapes
|
||||
conv_weight = model.conv1.weight
|
||||
conv_bias = model.conv1.bias
|
||||
fc_weight = model.fc1.weights
|
||||
fc_bias = model.fc1.bias
|
||||
|
||||
assert conv_weight.shape == (8, 3, 3, 3), f"Conv weight shape: {conv_weight.shape}"
|
||||
assert conv_bias.shape == (8,), f"Conv bias shape: {conv_bias.shape}"
|
||||
assert fc_weight.shape == (200, 10), f"FC weight shape: {fc_weight.shape}"
|
||||
assert fc_bias.shape == (10,), f"FC bias shape: {fc_bias.shape}"
|
||||
|
||||
results["tests"].append({
|
||||
"name": "cnn_workflow_integration",
|
||||
"status": "✅ PASS",
|
||||
"description": "Complete CNN workflow integrates correctly"
|
||||
})
|
||||
except Exception as e:
|
||||
results["tests"].append({
|
||||
"name": "cnn_workflow_integration",
|
||||
"status": "❌ FAIL",
|
||||
"description": f"CNN workflow failed: {e}"
|
||||
})
|
||||
results["success"] = False
|
||||
results["errors"].append(f"CNN workflow error: {e}")
|
||||
|
||||
# Test 4: Functional interface integration
|
||||
try:
|
||||
x = Tensor([[-2.0, -1.0, 0.0, 1.0, 2.0]])
|
||||
|
||||
# Test relu
|
||||
relu_out = F.relu(x)
|
||||
expected_relu = np.array([[0.0, 0.0, 0.0, 1.0, 2.0]])
|
||||
np.testing.assert_array_equal(relu_out.data, expected_relu)
|
||||
|
||||
# Test flatten
|
||||
x2 = Tensor([[[[1, 2], [3, 4]]]]) # (1, 1, 2, 2)
|
||||
flat_out = F.flatten(x2)
|
||||
assert flat_out.shape == (1, 4), f"Flatten shape: {flat_out.shape}"
|
||||
|
||||
results["tests"].append({
|
||||
"name": "functional_interface_integration",
|
||||
"status": "✅ PASS",
|
||||
"description": "Functional interface integrates correctly"
|
||||
})
|
||||
except Exception as e:
|
||||
results["tests"].append({
|
||||
"name": "functional_interface_integration",
|
||||
"status": "❌ FAIL",
|
||||
"description": f"Functional interface failed: {e}"
|
||||
})
|
||||
results["success"] = False
|
||||
results["errors"].append(f"Functional interface error: {e}")
|
||||
|
||||
# Test 5: Backward compatibility integration
|
||||
try:
|
||||
# Test old names still work
|
||||
from tinytorch.core.layers import Dense
|
||||
from tinytorch.core.spatial import MultiChannelConv2D
|
||||
|
||||
dense = Dense(5, 3)
|
||||
conv = MultiChannelConv2D(3, 8, (3, 3))
|
||||
|
||||
# Should be the same classes as new names
|
||||
assert type(dense).__name__ == 'Linear', f"Dense should be Linear, got {type(dense).__name__}"
|
||||
assert type(conv).__name__ == 'Conv2d', f"MultiChannelConv2D should be Conv2d, got {type(conv).__name__}"
|
||||
|
||||
results["tests"].append({
|
||||
"name": "backward_compatibility_integration",
|
||||
"status": "✅ PASS",
|
||||
"description": "Backward compatibility maintained"
|
||||
})
|
||||
except Exception as e:
|
||||
results["tests"].append({
|
||||
"name": "backward_compatibility_integration",
|
||||
"status": "❌ FAIL",
|
||||
"description": f"Backward compatibility failed: {e}"
|
||||
})
|
||||
results["success"] = False
|
||||
results["errors"].append(f"Backward compatibility error: {e}")
|
||||
|
||||
# Test 6: Cross-module parameter integration
|
||||
try:
|
||||
# Test that parameters flow correctly across modules
|
||||
class ComplexModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv_block = nn.Module()
|
||||
self.conv_block.conv1 = nn.Conv2d(3, 16, (3, 3))
|
||||
self.conv_block.conv2 = nn.Conv2d(16, 32, (3, 3))
|
||||
|
||||
self.classifier = nn.Module()
|
||||
self.classifier.fc1 = nn.Linear(512, 128)
|
||||
self.classifier.fc2 = nn.Linear(128, 10)
|
||||
|
||||
def forward(self, x):
|
||||
return x # Stub
|
||||
|
||||
model = ComplexModel()
|
||||
params = list(model.parameters())
|
||||
|
||||
# Should collect from nested modules
|
||||
assert len(params) == 8, f"Expected 8 parameters from nested modules, got {len(params)}"
|
||||
|
||||
# Test optimizer works with nested parameters
|
||||
optimizer = optim.Adam(model.parameters(), learning_rate=0.001)
|
||||
assert len(optimizer.parameters) == 8, "Optimizer should get nested parameters"
|
||||
|
||||
results["tests"].append({
|
||||
"name": "cross_module_parameter_integration",
|
||||
"status": "✅ PASS",
|
||||
"description": "Cross-module parameter collection works"
|
||||
})
|
||||
except Exception as e:
|
||||
results["tests"].append({
|
||||
"name": "cross_module_parameter_integration",
|
||||
"status": "❌ FAIL",
|
||||
"description": f"Cross-module parameters failed: {e}"
|
||||
})
|
||||
results["success"] = False
|
||||
results["errors"].append(f"Cross-module error: {e}")
|
||||
|
||||
except Exception as e:
|
||||
results["success"] = False
|
||||
results["errors"].append(f"Unexpected error: {e}")
|
||||
results["tests"].append({
|
||||
"name": "unexpected_error",
|
||||
"status": "❌ FAIL",
|
||||
"description": f"Unexpected error: {e}"
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def test_pytorch_api_compatibility():
|
||||
"""Test that the API closely matches PyTorch patterns."""
|
||||
|
||||
results = {
|
||||
"module_name": "api_simplification",
|
||||
"integration_type": "pytorch_compatibility",
|
||||
"tests": [],
|
||||
"success": True,
|
||||
"errors": []
|
||||
}
|
||||
|
||||
try:
|
||||
# Test PyTorch-like import patterns
|
||||
import tinytorch.nn as nn
|
||||
import tinytorch.nn.functional as F
|
||||
import tinytorch.optim as optim
|
||||
|
||||
# Test 1: PyTorch-like model definition
|
||||
try:
|
||||
class PyTorchLikeModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.features = nn.Module()
|
||||
self.features.conv1 = nn.Conv2d(3, 64, (3, 3))
|
||||
self.features.conv2 = nn.Conv2d(64, 128, (3, 3))
|
||||
|
||||
self.classifier = nn.Module()
|
||||
self.classifier.fc1 = nn.Linear(2048, 512)
|
||||
self.classifier.fc2 = nn.Linear(512, 10)
|
||||
|
||||
def forward(self, x):
|
||||
# Conv features
|
||||
x = F.relu(self.features.conv1(x))
|
||||
x = F.max_pool2d(x, (2, 2))
|
||||
x = F.relu(self.features.conv2(x))
|
||||
x = F.max_pool2d(x, (2, 2))
|
||||
|
||||
# Classifier
|
||||
x = F.flatten(x)
|
||||
x = F.relu(self.classifier.fc1(x))
|
||||
x = self.classifier.fc2(x)
|
||||
return x
|
||||
|
||||
model = PyTorchLikeModel()
|
||||
optimizer = optim.Adam(model.parameters(), learning_rate=0.001)
|
||||
|
||||
# Should work exactly like PyTorch
|
||||
assert callable(model), "Model should be callable"
|
||||
assert len(list(model.parameters())) > 0, "Should have parameters"
|
||||
assert hasattr(optimizer, 'parameters'), "Optimizer should have parameters"
|
||||
|
||||
results["tests"].append({
|
||||
"name": "pytorch_like_model_definition",
|
||||
"status": "✅ PASS",
|
||||
"description": "PyTorch-like model definition works"
|
||||
})
|
||||
except Exception as e:
|
||||
results["tests"].append({
|
||||
"name": "pytorch_like_model_definition",
|
||||
"status": "❌ FAIL",
|
||||
"description": f"PyTorch-like definition failed: {e}"
|
||||
})
|
||||
results["success"] = False
|
||||
results["errors"].append(f"PyTorch compatibility error: {e}")
|
||||
|
||||
# Test 2: PyTorch-like training setup pattern
|
||||
try:
|
||||
# This should look exactly like PyTorch code
|
||||
model = nn.Linear(784, 10)
|
||||
optimizer = optim.SGD(model.parameters(), learning_rate=0.01)
|
||||
|
||||
# Test that syntax matches PyTorch
|
||||
params = model.parameters()
|
||||
assert hasattr(params, '__iter__'), "parameters() should be iterable"
|
||||
|
||||
param_list = list(model.parameters())
|
||||
assert len(param_list) == 2, "Linear should have weight + bias"
|
||||
|
||||
results["tests"].append({
|
||||
"name": "pytorch_training_setup",
|
||||
"status": "✅ PASS",
|
||||
"description": "PyTorch-like training setup works"
|
||||
})
|
||||
except Exception as e:
|
||||
results["tests"].append({
|
||||
"name": "pytorch_training_setup",
|
||||
"status": "❌ FAIL",
|
||||
"description": f"Training setup failed: {e}"
|
||||
})
|
||||
results["success"] = False
|
||||
results["errors"].append(f"Training setup error: {e}")
|
||||
|
||||
except Exception as e:
|
||||
results["success"] = False
|
||||
results["errors"].append(f"PyTorch compatibility test error: {e}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("🧪 TinyTorch API Simplification Integration Tests")
|
||||
print("=" * 60)
|
||||
|
||||
# Run integration tests
|
||||
api_results = test_modern_api_integration()
|
||||
pytorch_results = test_pytorch_api_compatibility()
|
||||
|
||||
# Print results
|
||||
all_results = [api_results, pytorch_results]
|
||||
total_tests = sum(len(r["tests"]) for r in all_results)
|
||||
total_passed = sum(len([t for t in r["tests"] if t["status"] == "✅ PASS"]) for r in all_results)
|
||||
total_failed = total_tests - total_passed
|
||||
|
||||
print(f"\n📊 Integration Test Summary:")
|
||||
print(f" Total tests: {total_tests}")
|
||||
print(f" ✅ Passed: {total_passed}")
|
||||
print(f" ❌ Failed: {total_failed}")
|
||||
|
||||
# Detailed results
|
||||
for results in all_results:
|
||||
print(f"\n🔍 {results['integration_type']}:")
|
||||
for test in results["tests"]:
|
||||
print(f" {test['status']} {test['name']}: {test['description']}")
|
||||
|
||||
if results["errors"]:
|
||||
print(f" Errors: {results['errors']}")
|
||||
|
||||
# Overall success
|
||||
overall_success = all(r["success"] for r in all_results)
|
||||
if overall_success:
|
||||
print("\n✅ All integration tests passed!")
|
||||
print("🎉 API simplification integration is successful!")
|
||||
else:
|
||||
print("\n❌ Some integration tests failed!")
|
||||
print("🔧 Fix integration issues before deployment.")
|
||||
|
||||
sys.exit(0 if overall_success else 1)
|
||||
443
tests/test_api_simplification.py
Normal file
443
tests/test_api_simplification.py
Normal file
@@ -0,0 +1,443 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Suite for TinyTorch API Simplification
|
||||
|
||||
This test suite validates all the new PyTorch-compatible API features
|
||||
introduced in the API simplification project.
|
||||
|
||||
Test Hierarchy:
|
||||
1. Unit Tests: Individual components (Parameter, Module, Linear, Conv2d)
|
||||
2. Integration Tests: Components working together
|
||||
3. End-to-End Tests: Complete workflows (model creation, training setup)
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import sys
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
class TestParameterFunction(unittest.TestCase):
|
||||
"""Unit tests for the Parameter() helper function."""
|
||||
|
||||
def setUp(self):
|
||||
from tinytorch.core.tensor import Parameter, Tensor
|
||||
self.Parameter = Parameter
|
||||
self.Tensor = Tensor
|
||||
|
||||
def test_parameter_creation(self):
|
||||
"""Test basic Parameter creation."""
|
||||
param = self.Parameter([[1.0, 2.0], [3.0, 4.0]])
|
||||
|
||||
self.assertTrue(param.requires_grad, "Parameter should have requires_grad=True")
|
||||
self.assertEqual(param.shape, (2, 2), "Parameter should preserve shape")
|
||||
self.assertEqual(param.dtype, np.float32, "Parameter should default to float32")
|
||||
|
||||
def test_parameter_vs_tensor(self):
|
||||
"""Test Parameter vs Tensor differences."""
|
||||
data = [[1.0, 2.0]]
|
||||
|
||||
tensor = self.Tensor(data)
|
||||
param = self.Parameter(data)
|
||||
|
||||
self.assertFalse(tensor.requires_grad, "Tensor should default requires_grad=False")
|
||||
self.assertTrue(param.requires_grad, "Parameter should have requires_grad=True")
|
||||
|
||||
def test_parameter_with_dtype(self):
|
||||
"""Test Parameter with explicit dtype."""
|
||||
param = self.Parameter([1, 2, 3], dtype='int32')
|
||||
self.assertEqual(param.dtype, np.int32, "Parameter should respect dtype")
|
||||
self.assertTrue(param.requires_grad, "Parameter should still have requires_grad=True")
|
||||
|
||||
|
||||
class TestModuleBaseClass(unittest.TestCase):
|
||||
"""Unit tests for the Module base class."""
|
||||
|
||||
def setUp(self):
|
||||
import tinytorch.nn as nn
|
||||
from tinytorch.core.tensor import Parameter
|
||||
self.nn = nn
|
||||
self.Parameter = Parameter
|
||||
|
||||
def test_module_creation(self):
|
||||
"""Test basic Module creation."""
|
||||
module = self.nn.Module()
|
||||
|
||||
self.assertEqual(len(module._parameters), 0, "New module should have no parameters")
|
||||
self.assertEqual(len(module._modules), 0, "New module should have no submodules")
|
||||
self.assertEqual(len(list(module.parameters())), 0, "parameters() should return empty list")
|
||||
|
||||
def test_parameter_registration(self):
|
||||
"""Test automatic parameter registration."""
|
||||
Parameter = self.Parameter
|
||||
|
||||
class TestModule(self.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.weight = Parameter([[1.0, 2.0]])
|
||||
self.bias = Parameter([0.5])
|
||||
self.non_param = "not a parameter"
|
||||
|
||||
module = TestModule()
|
||||
params = list(module.parameters())
|
||||
|
||||
self.assertEqual(len(params), 2, "Should register 2 parameters")
|
||||
self.assertTrue(all(p.requires_grad for p in params), "All parameters should require gradients")
|
||||
|
||||
def test_submodule_registration(self):
|
||||
"""Test automatic submodule registration."""
|
||||
Parameter = self.Parameter
|
||||
|
||||
class SubModule(self.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.weight = Parameter([[1.0]])
|
||||
|
||||
class MainModule(self.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sub1 = SubModule()
|
||||
self.sub2 = SubModule()
|
||||
self.weight = Parameter([[2.0]])
|
||||
|
||||
module = MainModule()
|
||||
params = list(module.parameters())
|
||||
|
||||
self.assertEqual(len(params), 3, "Should collect parameters from submodules: 2 + 1 = 3")
|
||||
self.assertEqual(len(module._modules), 2, "Should register 2 submodules")
|
||||
|
||||
def test_callable_interface(self):
|
||||
"""Test that modules are callable via __call__."""
|
||||
class TestModule(self.nn.Module):
|
||||
def forward(self, x):
|
||||
return x + 1
|
||||
|
||||
module = TestModule()
|
||||
result = module(5) # Should call forward(5)
|
||||
|
||||
self.assertEqual(result, 6, "Module should be callable and delegate to forward()")
|
||||
|
||||
|
||||
class TestLinearLayer(unittest.TestCase):
|
||||
"""Unit tests for the Linear layer (renamed from Dense)."""
|
||||
|
||||
def setUp(self):
|
||||
import tinytorch.nn as nn
|
||||
from tinytorch.core.tensor import Tensor
|
||||
self.nn = nn
|
||||
self.Tensor = Tensor
|
||||
|
||||
def test_linear_creation(self):
|
||||
"""Test Linear layer creation."""
|
||||
linear = self.nn.Linear(5, 3)
|
||||
|
||||
self.assertEqual(linear.input_size, 5, "Should store input size")
|
||||
self.assertEqual(linear.output_size, 3, "Should store output size")
|
||||
self.assertEqual(linear.weights.shape, (5, 3), "Weights should have correct shape")
|
||||
self.assertEqual(linear.bias.shape, (3,), "Bias should have correct shape")
|
||||
self.assertTrue(linear.weights.requires_grad, "Weights should require gradients")
|
||||
self.assertTrue(linear.bias.requires_grad, "Bias should require gradients")
|
||||
|
||||
def test_linear_no_bias(self):
|
||||
"""Test Linear layer without bias."""
|
||||
linear = self.nn.Linear(5, 3, use_bias=False)
|
||||
|
||||
self.assertIsNone(linear.bias, "Should have no bias when use_bias=False")
|
||||
self.assertEqual(len(list(linear.parameters())), 1, "Should only have weight parameter")
|
||||
|
||||
def test_linear_forward(self):
|
||||
"""Test Linear layer forward pass."""
|
||||
linear = self.nn.Linear(3, 2)
|
||||
x = self.Tensor([[1.0, 2.0, 3.0]])
|
||||
|
||||
output = linear(x)
|
||||
|
||||
self.assertEqual(output.shape, (1, 2), "Output should have correct shape")
|
||||
self.assertIsInstance(output, self.Tensor, "Output should be a Tensor")
|
||||
|
||||
def test_linear_parameter_collection(self):
|
||||
"""Test that Linear parameters are properly collected."""
|
||||
linear = self.nn.Linear(4, 2)
|
||||
params = list(linear.parameters())
|
||||
|
||||
self.assertEqual(len(params), 2, "Should have 2 parameters (weight + bias)")
|
||||
shapes = [p.shape for p in params]
|
||||
self.assertIn((4, 2), shapes, "Should include weight shape")
|
||||
self.assertIn((2,), shapes, "Should include bias shape")
|
||||
|
||||
|
||||
class TestConv2dLayer(unittest.TestCase):
|
||||
"""Unit tests for the Conv2d layer (renamed from MultiChannelConv2D)."""
|
||||
|
||||
def setUp(self):
|
||||
import tinytorch.nn as nn
|
||||
from tinytorch.core.tensor import Tensor
|
||||
self.nn = nn
|
||||
self.Tensor = Tensor
|
||||
|
||||
def test_conv2d_creation(self):
|
||||
"""Test Conv2d layer creation."""
|
||||
conv = self.nn.Conv2d(3, 16, (3, 3))
|
||||
|
||||
self.assertEqual(conv.in_channels, 3, "Should store input channels")
|
||||
self.assertEqual(conv.out_channels, 16, "Should store output channels")
|
||||
self.assertEqual(conv.kernel_size, (3, 3), "Should store kernel size")
|
||||
self.assertEqual(conv.weight.shape, (16, 3, 3, 3), "Weight should have correct shape")
|
||||
self.assertEqual(conv.bias.shape, (16,), "Bias should have correct shape")
|
||||
|
||||
def test_conv2d_parameter_collection(self):
|
||||
"""Test that Conv2d parameters are properly collected."""
|
||||
conv = self.nn.Conv2d(3, 8, (3, 3))
|
||||
params = list(conv.parameters())
|
||||
|
||||
self.assertEqual(len(params), 2, "Should have 2 parameters (weight + bias)")
|
||||
weight_params = [p for p in params if len(p.shape) == 4]
|
||||
bias_params = [p for p in params if len(p.shape) == 1]
|
||||
|
||||
self.assertEqual(len(weight_params), 1, "Should have 1 weight parameter")
|
||||
self.assertEqual(len(bias_params), 1, "Should have 1 bias parameter")
|
||||
|
||||
|
||||
class TestFunctionalInterface(unittest.TestCase):
|
||||
"""Unit tests for the functional interface (F.relu, F.flatten, etc.)."""
|
||||
|
||||
def setUp(self):
|
||||
import tinytorch.nn.functional as F
|
||||
from tinytorch.core.tensor import Tensor
|
||||
self.F = F
|
||||
self.Tensor = Tensor
|
||||
|
||||
def test_relu_function(self):
|
||||
"""Test F.relu function."""
|
||||
x = self.Tensor([[-2.0, -1.0, 0.0, 1.0, 2.0]])
|
||||
output = self.F.relu(x)
|
||||
|
||||
expected = np.array([[0.0, 0.0, 0.0, 1.0, 2.0]])
|
||||
np.testing.assert_array_equal(output.data, expected, "ReLU should zero negative values")
|
||||
|
||||
def test_flatten_function(self):
|
||||
"""Test F.flatten function."""
|
||||
x = self.Tensor([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]]) # Shape: (1, 2, 2, 2)
|
||||
output = self.F.flatten(x)
|
||||
|
||||
self.assertEqual(output.shape, (1, 8), "Flatten should preserve batch dimension")
|
||||
expected = np.array([[1, 2, 3, 4, 5, 6, 7, 8]])
|
||||
np.testing.assert_array_equal(output.data, expected, "Flatten should preserve order")
|
||||
|
||||
def test_flatten_start_dim(self):
|
||||
"""Test F.flatten with custom start_dim."""
|
||||
x = self.Tensor([[[1, 2], [3, 4]]]) # Shape: (1, 2, 2) = 4 elements
|
||||
output = self.F.flatten(x, start_dim=0) # Flatten everything
|
||||
|
||||
self.assertEqual(output.shape, (4,), "Should flatten from dimension 0")
|
||||
expected = np.array([1, 2, 3, 4])
|
||||
np.testing.assert_array_equal(output.data, expected, "Should flatten all dimensions")
|
||||
|
||||
|
||||
class TestOptimizerIntegration(unittest.TestCase):
|
||||
"""Integration tests for optimizers with the new API."""
|
||||
|
||||
def setUp(self):
|
||||
import tinytorch.nn as nn
|
||||
import tinytorch.optim as optim
|
||||
self.nn = nn
|
||||
self.optim = optim
|
||||
|
||||
def test_adam_with_model_parameters(self):
|
||||
"""Test Adam optimizer with model.parameters()."""
|
||||
model = self.nn.Linear(5, 3)
|
||||
optimizer = self.optim.Adam(model.parameters(), learning_rate=0.001)
|
||||
|
||||
# Check that optimizer received the parameters
|
||||
self.assertEqual(len(optimizer.parameters), 2, "Adam should receive 2 parameters")
|
||||
|
||||
# Check parameter types
|
||||
for param in optimizer.parameters:
|
||||
self.assertTrue(param.requires_grad, "All optimizer parameters should require gradients")
|
||||
|
||||
def test_sgd_with_model_parameters(self):
|
||||
"""Test SGD optimizer with model.parameters()."""
|
||||
nn = self.nn
|
||||
|
||||
class SimpleNet(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(4, 8)
|
||||
self.fc2 = nn.Linear(8, 2)
|
||||
|
||||
model = SimpleNet()
|
||||
optimizer = self.optim.SGD(model.parameters(), learning_rate=0.01)
|
||||
|
||||
# Should collect parameters from both layers
|
||||
self.assertEqual(len(optimizer.parameters), 4, "SGD should receive 4 parameters (2 layers × 2 params)")
|
||||
|
||||
|
||||
class TestBackwardCompatibility(unittest.TestCase):
|
||||
"""Integration tests for backward compatibility."""
|
||||
|
||||
def test_dense_alias_works(self):
|
||||
"""Test that Dense alias still works."""
|
||||
from tinytorch.core.layers import Dense
|
||||
|
||||
dense = Dense(5, 3)
|
||||
self.assertEqual(dense.input_size, 5, "Dense alias should work")
|
||||
self.assertEqual(dense.output_size, 3, "Dense alias should work")
|
||||
|
||||
def test_multichannel_conv2d_alias_works(self):
|
||||
"""Test that MultiChannelConv2D alias still works."""
|
||||
from tinytorch.core.spatial import MultiChannelConv2D
|
||||
|
||||
conv = MultiChannelConv2D(3, 16, (3, 3))
|
||||
self.assertEqual(conv.in_channels, 3, "MultiChannelConv2D alias should work")
|
||||
self.assertEqual(conv.out_channels, 16, "MultiChannelConv2D alias should work")
|
||||
|
||||
|
||||
class TestCompleteModelWorkflow(unittest.TestCase):
|
||||
"""End-to-end integration tests for complete model workflows."""
|
||||
|
||||
def setUp(self):
|
||||
import tinytorch.nn as nn
|
||||
import tinytorch.nn.functional as F
|
||||
import tinytorch.optim as optim
|
||||
from tinytorch.core.tensor import Tensor
|
||||
|
||||
self.nn = nn
|
||||
self.F = F
|
||||
self.optim = optim
|
||||
self.Tensor = Tensor
|
||||
|
||||
def test_complete_mlp_workflow(self):
|
||||
"""Test complete MLP creation and setup."""
|
||||
nn = self.nn
|
||||
F = self.F
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(4, 8)
|
||||
self.fc2 = nn.Linear(8, 2)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.fc1(x))
|
||||
return self.fc2(x)
|
||||
|
||||
# Create model
|
||||
model = MLP()
|
||||
|
||||
# Test parameter collection
|
||||
params = list(model.parameters())
|
||||
self.assertEqual(len(params), 4, "Should collect all parameters")
|
||||
|
||||
# Test optimizer creation
|
||||
optimizer = self.optim.Adam(model.parameters(), learning_rate=0.001)
|
||||
self.assertIsNotNone(optimizer, "Should create optimizer")
|
||||
|
||||
# Test forward pass
|
||||
x = self.Tensor([[1.0, 2.0, 3.0, 4.0]])
|
||||
output = model(x)
|
||||
self.assertEqual(output.shape, (1, 2), "Should produce correct output shape")
|
||||
|
||||
def test_complete_cnn_workflow(self):
|
||||
"""Test complete CNN creation and setup."""
|
||||
nn = self.nn
|
||||
F = self.F
|
||||
|
||||
class CNN(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(3, 8, (3, 3))
|
||||
self.fc1 = nn.Linear(128, 10) # Simplified size
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.conv1(x))
|
||||
x = F.flatten(x)
|
||||
return self.fc1(x)
|
||||
|
||||
# Create model
|
||||
model = CNN()
|
||||
|
||||
# Test parameter collection
|
||||
params = list(model.parameters())
|
||||
self.assertEqual(len(params), 4, "Should collect conv + fc parameters")
|
||||
|
||||
# Test optimizer creation
|
||||
optimizer = self.optim.Adam(model.parameters(), learning_rate=0.001)
|
||||
self.assertIsNotNone(optimizer, "Should create optimizer")
|
||||
|
||||
def test_pytorch_like_syntax(self):
|
||||
"""Test that the syntax matches PyTorch patterns."""
|
||||
# This test verifies the API feels like PyTorch
|
||||
import tinytorch.nn as nn
|
||||
import tinytorch.nn.functional as F
|
||||
import tinytorch.optim as optim
|
||||
|
||||
# Should be able to write PyTorch-like code
|
||||
model = nn.Linear(10, 1)
|
||||
optimizer = optim.Adam(model.parameters(), learning_rate=0.001)
|
||||
|
||||
# Test that this workflow doesn't crash
|
||||
x = self.Tensor([[1.0] * 10])
|
||||
output = model(x)
|
||||
|
||||
# Should be able to chain operations
|
||||
output = F.relu(output)
|
||||
|
||||
self.assertIsNotNone(output, "PyTorch-like workflow should work")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Run the test suite
|
||||
print("🧪 TinyTorch API Simplification Test Suite")
|
||||
print("=" * 60)
|
||||
|
||||
# Create test suite
|
||||
suite = unittest.TestSuite()
|
||||
|
||||
# Add all test classes
|
||||
test_classes = [
|
||||
TestParameterFunction,
|
||||
TestModuleBaseClass,
|
||||
TestLinearLayer,
|
||||
TestConv2dLayer,
|
||||
TestFunctionalInterface,
|
||||
TestOptimizerIntegration,
|
||||
TestBackwardCompatibility,
|
||||
TestCompleteModelWorkflow
|
||||
]
|
||||
|
||||
for test_class in test_classes:
|
||||
tests = unittest.TestLoader().loadTestsFromTestCase(test_class)
|
||||
suite.addTests(tests)
|
||||
|
||||
# Run tests
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(suite)
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 60)
|
||||
print(f"🎯 Test Summary:")
|
||||
print(f" Tests run: {result.testsRun}")
|
||||
print(f" Failures: {len(result.failures)}")
|
||||
print(f" Errors: {len(result.errors)}")
|
||||
|
||||
if result.failures:
|
||||
print("\n❌ Failures:")
|
||||
for test, traceback in result.failures:
|
||||
print(f" {test}: {traceback.split('AssertionError:')[-1].strip()}")
|
||||
|
||||
if result.errors:
|
||||
print("\n💥 Errors:")
|
||||
for test, traceback in result.errors:
|
||||
print(f" {test}: {traceback.split('Exception:')[-1].strip()}")
|
||||
|
||||
if result.wasSuccessful():
|
||||
print("\n✅ All API simplification tests passed!")
|
||||
print("🎉 PyTorch-compatible API is working correctly!")
|
||||
else:
|
||||
print("\n❌ Some tests failed!")
|
||||
print("🔧 API needs fixes before deployment.")
|
||||
|
||||
# Exit with proper code
|
||||
sys.exit(0 if result.wasSuccessful() else 1)
|
||||
Reference in New Issue
Block a user