mirror of
https://github.com/harvard-edge/cs249r_book.git
synced 2026-05-03 16:18:49 -05:00
style: apply consistent whitespace and formatting across codebase
This commit is contained in:
@@ -13,29 +13,29 @@ from pathlib import Path
|
||||
def setup_integration_test():
|
||||
"""
|
||||
Set up the environment for integration testing.
|
||||
|
||||
|
||||
This function ensures:
|
||||
1. The TinyTorch package is importable
|
||||
2. NumPy random seed is set for reproducibility
|
||||
3. Warning filters are set appropriately
|
||||
|
||||
|
||||
Call this at the top of integration test files before importing TinyTorch.
|
||||
"""
|
||||
import warnings
|
||||
import numpy as np
|
||||
|
||||
|
||||
# Ensure tinytorch is on the path (from project root)
|
||||
project_root = Path(__file__).parent.parent
|
||||
if str(project_root) not in sys.path:
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
|
||||
# Set random seed for reproducibility
|
||||
np.random.seed(42)
|
||||
|
||||
|
||||
# Suppress certain warnings during tests
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||
|
||||
|
||||
# Set quiet mode for tinytorch imports during tests
|
||||
os.environ['TINYTORCH_QUIET'] = '1'
|
||||
|
||||
@@ -53,21 +53,21 @@ def get_test_data_path() -> Path:
|
||||
def create_test_tensor(shape, requires_grad=True, seed=None):
|
||||
"""
|
||||
Create a test tensor with random data.
|
||||
|
||||
|
||||
Args:
|
||||
shape: Tuple specifying tensor shape
|
||||
requires_grad: Whether tensor should track gradients
|
||||
seed: Optional random seed for reproducibility
|
||||
|
||||
|
||||
Returns:
|
||||
Tensor with random data
|
||||
"""
|
||||
import numpy as np
|
||||
from tinytorch.core.tensor import Tensor
|
||||
|
||||
|
||||
if seed is not None:
|
||||
np.random.seed(seed)
|
||||
|
||||
|
||||
data = np.random.randn(*shape).astype(np.float32)
|
||||
return Tensor(data, requires_grad=requires_grad)
|
||||
|
||||
@@ -75,7 +75,7 @@ def create_test_tensor(shape, requires_grad=True, seed=None):
|
||||
def assert_tensors_close(t1, t2, rtol=1e-5, atol=1e-8, msg=""):
|
||||
"""
|
||||
Assert that two tensors are element-wise close.
|
||||
|
||||
|
||||
Args:
|
||||
t1: First tensor
|
||||
t2: Second tensor
|
||||
@@ -84,11 +84,11 @@ def assert_tensors_close(t1, t2, rtol=1e-5, atol=1e-8, msg=""):
|
||||
msg: Optional message for assertion error
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
|
||||
# Extract data from tensors if needed
|
||||
data1 = t1.data if hasattr(t1, 'data') else t1
|
||||
data2 = t2.data if hasattr(t2, 'data') else t2
|
||||
|
||||
|
||||
if not np.allclose(data1, data2, rtol=rtol, atol=atol):
|
||||
diff = np.abs(data1 - data2)
|
||||
max_diff = np.max(diff)
|
||||
@@ -111,4 +111,3 @@ def skip_if_no_tinytorch():
|
||||
return pytest.mark.skipif(False, reason="TinyTorch available")
|
||||
except ImportError:
|
||||
return pytest.mark.skip(reason="TinyTorch not installed")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user