mirror of
https://github.com/harvard-edge/cs249r_book.git
synced 2026-03-11 17:49:25 -05:00
280 lines
9.4 KiB
Python
280 lines
9.4 KiB
Python
"""
|
|
Shared testing infrastructure for TinyTorch modules.
|
|
|
|
This module provides a standardized testing framework that ensures consistent
|
|
output format and behavior across all TinyTorch modules.
|
|
"""
|
|
|
|
import sys
|
|
import traceback
|
|
import inspect
|
|
from typing import List, Callable, Tuple, Optional
|
|
from rich.console import Console
|
|
from rich.panel import Panel
|
|
from rich.table import Table
|
|
from rich.text import Text
|
|
|
|
class ModuleTestRunner:
|
|
"""
|
|
Standardized test runner for TinyTorch modules.
|
|
|
|
Provides consistent output formatting, error handling, and reporting
|
|
across all modules.
|
|
"""
|
|
|
|
def __init__(self, module_name: str):
|
|
"""Initialize the test runner for a specific module."""
|
|
self.module_name = module_name
|
|
self.tests: List[Tuple[str, Callable]] = []
|
|
self.console = Console()
|
|
self.results: List[Tuple[str, bool, str]] = []
|
|
|
|
def register_test(self, test_name: str, test_function: Callable) -> None:
|
|
"""Register a test function with a descriptive name."""
|
|
self.tests.append((test_name, test_function))
|
|
|
|
def auto_discover_tests(self, calling_module=None) -> None:
|
|
"""
|
|
Automatically discover and register test functions from the calling module.
|
|
|
|
Looks for functions that match specific patterns:
|
|
- Start with 'test_'
|
|
- End with '_comprehensive', '_integration', or specific activation names
|
|
- Are callable functions
|
|
|
|
Args:
|
|
calling_module: The module to search for tests (defaults to caller's module)
|
|
"""
|
|
if calling_module is None:
|
|
# Get the calling module from the stack
|
|
frame = inspect.currentframe()
|
|
try:
|
|
# Go up the stack to find the caller
|
|
if frame is not None:
|
|
# Go up multiple frames to find the actual calling module
|
|
# frame.f_back is run_module_tests_auto
|
|
# frame.f_back.f_back is the actual module
|
|
caller_frame = frame.f_back
|
|
if caller_frame is not None:
|
|
caller_frame = caller_frame.f_back
|
|
calling_module = inspect.getmodule(caller_frame)
|
|
finally:
|
|
del frame
|
|
|
|
if calling_module is None:
|
|
print("⚠️ Could not auto-discover tests - no calling module found")
|
|
return
|
|
|
|
# Get all members of the calling module
|
|
discovered_tests = []
|
|
|
|
for name, obj in inspect.getmembers(calling_module):
|
|
if self._is_valid_test_function(name, obj):
|
|
# Convert function name to readable test name
|
|
test_name = self._function_name_to_test_name(name)
|
|
discovered_tests.append((test_name, obj))
|
|
|
|
# Sort tests for logical sequence: unit tests first, then integration tests
|
|
def test_priority(test_tuple):
|
|
test_name, test_function = test_tuple
|
|
function_name = test_function.__name__
|
|
|
|
# Unit tests (test_unit_*) have highest priority
|
|
if function_name.startswith('test_unit_'):
|
|
return 0
|
|
# Integration tests (test_module_*) have medium priority
|
|
elif function_name.startswith('test_module_'):
|
|
return 1
|
|
# All other tests have lowest priority
|
|
else:
|
|
return 2
|
|
|
|
discovered_tests.sort(key=test_priority)
|
|
|
|
# Register discovered tests
|
|
for test_name, test_function in discovered_tests:
|
|
self.register_test(test_name, test_function)
|
|
|
|
print(f"🔍 Auto-discovered {len(discovered_tests)} test functions")
|
|
|
|
def _is_valid_test_function(self, name: str, obj: object) -> bool:
|
|
"""
|
|
Check if an object is a valid test function.
|
|
|
|
Args:
|
|
name: Name of the object
|
|
obj: The object to check
|
|
|
|
Returns:
|
|
bool: True if this is a valid test function
|
|
"""
|
|
# Must be callable
|
|
if not callable(obj):
|
|
return False
|
|
|
|
# Must be a function (not a class or other callable)
|
|
if not inspect.isfunction(obj):
|
|
return False
|
|
|
|
# Must start with 'test_'
|
|
if not name.startswith('test_'):
|
|
return False
|
|
|
|
# That's it! Any function starting with 'test_' is a valid test
|
|
return True
|
|
|
|
def _function_name_to_test_name(self, function_name: str) -> str:
|
|
"""
|
|
Convert a function name to a readable test name.
|
|
|
|
Args:
|
|
function_name: The function name (e.g., 'test_tensor_creation_comprehensive')
|
|
|
|
Returns:
|
|
str: Human-readable test name (e.g., 'Tensor Creation')
|
|
"""
|
|
# Remove 'test_' prefix
|
|
name = function_name.replace('test_', '')
|
|
|
|
# Handle specific cases
|
|
name_mappings = {
|
|
'tensor_creation_comprehensive': 'Tensor Creation',
|
|
'tensor_properties_comprehensive': 'Tensor Properties',
|
|
'tensor_arithmetic_comprehensive': 'Tensor Arithmetic',
|
|
'tensor_integration': 'Tensor Integration',
|
|
'relu_activation': 'ReLU Activation',
|
|
'sigmoid_activation': 'Sigmoid Activation',
|
|
'tanh_activation': 'Tanh Activation',
|
|
'softmax_activation': 'Softmax Activation',
|
|
'activations_integration': 'Activations Integration'
|
|
}
|
|
|
|
if name in name_mappings:
|
|
return name_mappings[name]
|
|
|
|
# Generic conversion: replace underscores with spaces and title case
|
|
return name.replace('_', ' ').title()
|
|
|
|
def run_all_tests(self) -> bool:
|
|
"""
|
|
Run all registered tests and return overall success.
|
|
|
|
Returns:
|
|
bool: True if all tests passed, False otherwise
|
|
"""
|
|
if not self.tests:
|
|
print(f"⚠️ No tests registered for {self.module_name}")
|
|
return False
|
|
|
|
print(f"\n🧪 Running {self.module_name} Module Tests...")
|
|
print("=" * 50)
|
|
|
|
all_passed = True
|
|
|
|
for test_name, test_function in self.tests:
|
|
success, output = self._run_single_test(test_name, test_function)
|
|
self.results.append((test_name, success, output))
|
|
|
|
# Get the actual function name
|
|
function_name = test_function.__name__
|
|
|
|
if success:
|
|
print(f"✅ {test_name} ({function_name}): PASSED")
|
|
else:
|
|
print(f"❌ {test_name} ({function_name}): FAILED")
|
|
if output:
|
|
print(f" Error: {output}")
|
|
all_passed = False
|
|
|
|
print("=" * 50)
|
|
|
|
# Final summary
|
|
total_tests = len(self.tests)
|
|
passed_tests = sum(1 for _, success, _ in self.results if success)
|
|
|
|
if all_passed:
|
|
print(f"🎉 All tests passed! ({passed_tests}/{total_tests})")
|
|
print(f"✅ {self.module_name} module is working correctly!")
|
|
else:
|
|
print(f"❌ {passed_tests}/{total_tests} tests passed")
|
|
print(f"🔧 {self.module_name} module needs fixes")
|
|
|
|
return all_passed
|
|
|
|
def _run_single_test(self, test_name: str, test_function: Callable) -> Tuple[bool, str]:
|
|
"""
|
|
Run a single test function and capture its result.
|
|
|
|
Args:
|
|
test_name: Name of the test for reporting
|
|
test_function: The test function to execute
|
|
|
|
Returns:
|
|
Tuple of (success, error_message)
|
|
"""
|
|
try:
|
|
# Capture any output from the test function
|
|
import io
|
|
import contextlib
|
|
|
|
with contextlib.redirect_stdout(io.StringIO()) as captured_stdout:
|
|
with contextlib.redirect_stderr(io.StringIO()) as captured_stderr:
|
|
# Execute the test function
|
|
test_function()
|
|
|
|
# If we get here, the test passed
|
|
return True, ""
|
|
|
|
except AssertionError as e:
|
|
# Test failed with assertion
|
|
return False, str(e)
|
|
except Exception as e:
|
|
# Test failed with other exception
|
|
error_msg = f"{type(e).__name__}: {str(e)}"
|
|
return False, error_msg
|
|
|
|
def create_test_runner(module_name: str) -> ModuleTestRunner:
|
|
"""
|
|
Factory function to create a test runner for a module.
|
|
|
|
Args:
|
|
module_name: Name of the module being tested
|
|
|
|
Returns:
|
|
ModuleTestRunner: Configured test runner instance
|
|
"""
|
|
return ModuleTestRunner(module_name)
|
|
|
|
def run_module_tests_auto(module_name: str) -> bool:
|
|
"""
|
|
Automatically discover and run all tests in the calling module.
|
|
|
|
Args:
|
|
module_name: Name of the module being tested
|
|
|
|
Returns:
|
|
bool: True if all tests passed, False otherwise
|
|
"""
|
|
test_runner = create_test_runner(module_name)
|
|
test_runner.auto_discover_tests()
|
|
return test_runner.run_all_tests()
|
|
|
|
# Legacy compatibility function
|
|
def run_module_tests(module_name: str, test_functions: List[Tuple[str, Callable]]) -> bool:
|
|
"""
|
|
Legacy function for backward compatibility.
|
|
|
|
Args:
|
|
module_name: Name of the module being tested
|
|
test_functions: List of (test_name, test_function) tuples
|
|
|
|
Returns:
|
|
bool: True if all tests passed, False otherwise
|
|
"""
|
|
test_runner = create_test_runner(module_name)
|
|
|
|
for test_name, test_function in test_functions:
|
|
test_runner.register_test(test_name, test_function)
|
|
|
|
return test_runner.run_all_tests()
|