mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2025-12-05 19:17:52 -06:00
Add educational test output with Rich CLI
- Create pytest_tinytorch.py plugin for educational test output - Update test_tensor_core.py with WHAT/WHY/STUDENT LEARNING docstrings - Show test purpose on pass, detailed context on failure - Use --tinytorch flag to enable educational mode Students can now understand what each test checks and why it matters.
This commit is contained in:
@@ -1,9 +1,27 @@
|
||||
"""
|
||||
Module 01: Tensor - Core Functionality Tests
|
||||
Tests fundamental tensor operations and memory management
|
||||
=============================================
|
||||
|
||||
These tests verify that Tensor, the fundamental data structure of TinyTorch, works correctly.
|
||||
|
||||
WHY TENSORS MATTER:
|
||||
------------------
|
||||
Tensors are the foundation of ALL deep learning:
|
||||
- Every input (images, text, audio) becomes a tensor
|
||||
- Every weight and bias in a neural network is a tensor
|
||||
- Every gradient computed during training is a tensor
|
||||
|
||||
If Tensor doesn't work, nothing else will. This is Module 01 for a reason.
|
||||
|
||||
WHAT STUDENTS LEARN:
|
||||
-------------------
|
||||
1. How data is represented in deep learning frameworks
|
||||
2. Why NumPy is the backbone of Python ML
|
||||
3. How operations like broadcasting save memory and compute
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
@@ -12,28 +30,59 @@ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
|
||||
class TestTensorCreation:
|
||||
"""Test tensor creation and initialization."""
|
||||
"""
|
||||
Test tensor creation and initialization.
|
||||
|
||||
CONCEPT: A Tensor wraps a NumPy array and adds deep learning capabilities
|
||||
(like gradient tracking). Creating tensors is the first step in any ML pipeline.
|
||||
"""
|
||||
|
||||
def test_tensor_from_list(self):
|
||||
"""Test creating tensor from Python list."""
|
||||
"""
|
||||
WHAT: Create tensors from Python lists.
|
||||
|
||||
WHY: Students often start with raw Python data (lists of numbers,
|
||||
nested lists for matrices). TinyTorch must accept this natural input
|
||||
and convert it to the internal NumPy representation.
|
||||
|
||||
STUDENT LEARNING: Data can enter the framework in different forms,
|
||||
but internally it's always a NumPy array.
|
||||
"""
|
||||
try:
|
||||
from tinytorch.core.tensor import Tensor
|
||||
|
||||
# 1D tensor
|
||||
# 1D tensor (vector) - like a single data sample's features
|
||||
t1 = Tensor([1, 2, 3])
|
||||
assert t1.shape == (3,)
|
||||
assert t1.shape == (3,), (
|
||||
f"1D tensor has wrong shape.\n"
|
||||
f" Input: [1, 2, 3] (3 elements)\n"
|
||||
f" Expected shape: (3,)\n"
|
||||
f" Got: {t1.shape}"
|
||||
)
|
||||
assert np.array_equal(t1.data, [1, 2, 3])
|
||||
|
||||
# 2D tensor
|
||||
# 2D tensor (matrix) - like a batch of samples or weight matrix
|
||||
t2 = Tensor([[1, 2], [3, 4]])
|
||||
assert t2.shape == (2, 2)
|
||||
assert np.array_equal(t2.data, [[1, 2], [3, 4]])
|
||||
assert t2.shape == (2, 2), (
|
||||
f"2D tensor has wrong shape.\n"
|
||||
f" Input: [[1,2], [3,4]] (2 rows, 2 cols)\n"
|
||||
f" Expected shape: (2, 2)\n"
|
||||
f" Got: {t2.shape}"
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
assert True, "Tensor not implemented yet"
|
||||
pytest.skip("Tensor not implemented yet")
|
||||
|
||||
def test_tensor_from_numpy(self):
|
||||
"""Test creating tensor from numpy array."""
|
||||
"""
|
||||
WHAT: Create tensors from NumPy arrays.
|
||||
|
||||
WHY: Real ML data comes from NumPy (pandas, scikit-learn, image loaders).
|
||||
TinyTorch must seamlessly accept NumPy arrays.
|
||||
|
||||
STUDENT LEARNING: TinyTorch uses float32 by default (like PyTorch)
|
||||
because it's faster and uses half the memory of float64.
|
||||
"""
|
||||
try:
|
||||
from tinytorch.core.tensor import Tensor
|
||||
|
||||
@@ -41,111 +90,211 @@ class TestTensorCreation:
|
||||
t = Tensor(arr)
|
||||
|
||||
assert t.shape == (2, 2)
|
||||
# TinyTorch uses float32 for efficiency
|
||||
assert t.dtype == np.float32
|
||||
assert t.dtype == np.float32, (
|
||||
f"Tensor should use float32 for efficiency.\n"
|
||||
f" Expected dtype: np.float32\n"
|
||||
f" Got: {t.dtype}\n"
|
||||
"float32 is half the memory of float64 and faster on GPUs."
|
||||
)
|
||||
assert np.allclose(t.data, arr)
|
||||
|
||||
except ImportError:
|
||||
assert True, "Tensor not implemented yet"
|
||||
pytest.skip("Tensor not implemented yet")
|
||||
|
||||
def test_tensor_shapes(self):
|
||||
"""Test tensor shape handling."""
|
||||
"""
|
||||
WHAT: Handle tensors of various dimensions.
|
||||
|
||||
WHY: Deep learning uses many tensor shapes:
|
||||
- 1D: feature vectors, biases
|
||||
- 2D: weight matrices, batch of 1D samples
|
||||
- 3D: sequences (batch, seq_len, features)
|
||||
- 4D: images (batch, height, width, channels)
|
||||
|
||||
STUDENT LEARNING: Shape is critical. Most bugs are shape mismatches.
|
||||
"""
|
||||
try:
|
||||
from tinytorch.core.tensor import Tensor
|
||||
|
||||
# Test different shapes
|
||||
shapes = [(5,), (3, 4), (2, 3, 4), (1, 28, 28, 3)]
|
||||
test_cases = [
|
||||
((5,), "1D: feature vector"),
|
||||
((3, 4), "2D: weight matrix"),
|
||||
((2, 3, 4), "3D: sequence data"),
|
||||
((1, 28, 28, 3), "4D: single RGB image"),
|
||||
]
|
||||
|
||||
for shape in shapes:
|
||||
for shape, description in test_cases:
|
||||
data = np.random.randn(*shape)
|
||||
t = Tensor(data)
|
||||
assert t.shape == shape
|
||||
assert t.shape == shape, (
|
||||
f"Shape mismatch for {description}.\n"
|
||||
f" Expected: {shape}\n"
|
||||
f" Got: {t.shape}"
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
assert True, "Tensor not implemented yet"
|
||||
pytest.skip("Tensor not implemented yet")
|
||||
|
||||
|
||||
class TestTensorOperations:
|
||||
"""Test tensor arithmetic and operations."""
|
||||
"""
|
||||
Test tensor arithmetic and operations.
|
||||
|
||||
CONCEPT: Neural networks are just sequences of mathematical operations
|
||||
on tensors. If these operations don't work, training is impossible.
|
||||
"""
|
||||
|
||||
def test_tensor_addition(self):
|
||||
"""Test tensor addition."""
|
||||
"""
|
||||
WHAT: Element-wise tensor addition.
|
||||
|
||||
WHY: Addition is used everywhere in neural networks:
|
||||
- Adding bias to layer output: y = Wx + b
|
||||
- Residual connections: output = layer(x) + x
|
||||
- Gradient accumulation
|
||||
|
||||
STUDENT LEARNING: Operations return new Tensors (functional style).
|
||||
"""
|
||||
try:
|
||||
from tinytorch.core.tensor import Tensor
|
||||
|
||||
t1 = Tensor([1, 2, 3])
|
||||
t2 = Tensor([4, 5, 6])
|
||||
|
||||
# Element-wise addition
|
||||
result = t1 + t2
|
||||
expected = np.array([5, 7, 9])
|
||||
|
||||
assert isinstance(result, Tensor)
|
||||
assert np.array_equal(result.data, expected)
|
||||
assert isinstance(result, Tensor), (
|
||||
"Addition should return a Tensor, not numpy array.\n"
|
||||
"This maintains the computation graph for backpropagation."
|
||||
)
|
||||
assert np.array_equal(result.data, expected), (
|
||||
f"Element-wise addition failed.\n"
|
||||
f" {t1.data} + {t2.data}\n"
|
||||
f" Expected: {expected}\n"
|
||||
f" Got: {result.data}"
|
||||
)
|
||||
|
||||
except (ImportError, TypeError):
|
||||
assert True, "Tensor addition not implemented yet"
|
||||
pytest.skip("Tensor addition not implemented yet")
|
||||
|
||||
def test_tensor_multiplication(self):
|
||||
"""Test tensor multiplication."""
|
||||
"""
|
||||
WHAT: Element-wise tensor multiplication.
|
||||
|
||||
WHY: Element-wise multiplication (Hadamard product) is used for:
|
||||
- Applying masks (setting values to zero)
|
||||
- Gating mechanisms (LSTM, attention)
|
||||
- Dropout during training
|
||||
|
||||
STUDENT LEARNING: This is NOT matrix multiplication. It's element-by-element.
|
||||
"""
|
||||
try:
|
||||
from tinytorch.core.tensor import Tensor
|
||||
|
||||
t1 = Tensor([1, 2, 3])
|
||||
t2 = Tensor([2, 3, 4])
|
||||
|
||||
# Element-wise multiplication
|
||||
result = t1 * t2
|
||||
expected = np.array([2, 6, 12])
|
||||
|
||||
assert isinstance(result, Tensor)
|
||||
assert np.array_equal(result.data, expected)
|
||||
assert np.array_equal(result.data, expected), (
|
||||
f"Element-wise multiplication failed.\n"
|
||||
f" {t1.data} * {t2.data} (element-wise)\n"
|
||||
f" Expected: {expected}\n"
|
||||
f" Got: {result.data}\n"
|
||||
"Remember: * is element-wise, @ is matrix multiplication."
|
||||
)
|
||||
|
||||
except (ImportError, TypeError):
|
||||
assert True, "Tensor multiplication not implemented yet"
|
||||
pytest.skip("Tensor multiplication not implemented yet")
|
||||
|
||||
def test_matrix_multiplication(self):
|
||||
"""Test matrix multiplication."""
|
||||
"""
|
||||
WHAT: Matrix multiplication (the @ operator).
|
||||
|
||||
WHY: Matrix multiplication is THE core operation of neural networks:
|
||||
- Linear layers: y = x @ W
|
||||
- Attention: scores = Q @ K^T
|
||||
- Every fully-connected layer uses it
|
||||
|
||||
STUDENT LEARNING: Matrix dimensions must be compatible.
|
||||
(m×n) @ (n×p) = (m×p) - inner dimensions must match.
|
||||
"""
|
||||
try:
|
||||
from tinytorch.core.tensor import Tensor
|
||||
|
||||
t1 = Tensor([[1, 2], [3, 4]])
|
||||
t2 = Tensor([[5, 6], [7, 8]])
|
||||
t1 = Tensor([[1, 2], [3, 4]]) # 2×2
|
||||
t2 = Tensor([[5, 6], [7, 8]]) # 2×2
|
||||
|
||||
# Matrix multiplication
|
||||
# Matrix multiplication using @ operator
|
||||
if hasattr(t1, '__matmul__'):
|
||||
result = t1 @ t2
|
||||
else:
|
||||
# Fallback to manual matmul
|
||||
result = Tensor(t1.data @ t2.data)
|
||||
|
||||
# Manual calculation:
|
||||
# [1*5+2*7, 1*6+2*8] = [19, 22]
|
||||
# [3*5+4*7, 3*6+4*8] = [43, 50]
|
||||
expected = np.array([[19, 22], [43, 50]])
|
||||
assert np.array_equal(result.data, expected)
|
||||
|
||||
assert np.array_equal(result.data, expected), (
|
||||
f"Matrix multiplication failed.\n"
|
||||
f" {t1.data}\n @\n {t2.data}\n"
|
||||
f" Expected:\n {expected}\n"
|
||||
f" Got:\n {result.data}"
|
||||
)
|
||||
|
||||
except (ImportError, TypeError):
|
||||
assert True, "Matrix multiplication not implemented yet"
|
||||
pytest.skip("Matrix multiplication not implemented yet")
|
||||
|
||||
|
||||
class TestTensorMemory:
|
||||
"""Test tensor memory management."""
|
||||
"""
|
||||
Test tensor memory management.
|
||||
|
||||
CONCEPT: Efficient memory use is critical for deep learning.
|
||||
Large models can use 10s of GB. Understanding memory helps debug OOM errors.
|
||||
"""
|
||||
|
||||
def test_tensor_data_access(self):
|
||||
"""Test accessing tensor data."""
|
||||
"""
|
||||
WHAT: Access the underlying NumPy array.
|
||||
|
||||
WHY: Sometimes you need the raw data for:
|
||||
- Visualization (matplotlib expects NumPy)
|
||||
- Debugging (print values)
|
||||
- Integration with other libraries
|
||||
|
||||
STUDENT LEARNING: .data gives you the NumPy array inside the Tensor.
|
||||
"""
|
||||
try:
|
||||
from tinytorch.core.tensor import Tensor
|
||||
|
||||
data = np.array([1, 2, 3, 4])
|
||||
t = Tensor(data)
|
||||
|
||||
# Should be able to access underlying data
|
||||
assert hasattr(t, 'data')
|
||||
assert hasattr(t, 'data'), (
|
||||
"Tensor must have a .data attribute.\n"
|
||||
"This gives access to the underlying NumPy array."
|
||||
)
|
||||
assert np.array_equal(t.data, data)
|
||||
|
||||
except ImportError:
|
||||
assert True, "Tensor not implemented yet"
|
||||
pytest.skip("Tensor not implemented yet")
|
||||
|
||||
def test_tensor_copy_semantics(self):
|
||||
"""Test tensor copying behavior."""
|
||||
"""
|
||||
WHAT: Verify tensors don't share memory unexpectedly.
|
||||
|
||||
WHY: Shared memory can cause subtle bugs:
|
||||
- Modifying one tensor accidentally changes another
|
||||
- Gradient corruption during backprop
|
||||
- Non-reproducible results
|
||||
|
||||
STUDENT LEARNING: TinyTorch should copy data by default for safety.
|
||||
"""
|
||||
try:
|
||||
from tinytorch.core.tensor import Tensor
|
||||
|
||||
@@ -159,127 +308,225 @@ class TestTensorMemory:
|
||||
# Modifying original shouldn't affect t2
|
||||
original_data[0] = 999
|
||||
if not np.shares_memory(t2.data, original_data):
|
||||
assert t2.data[0] == 1 # Should be unchanged
|
||||
assert t2.data[0] == 1, (
|
||||
"Tensor should not share memory with input!\n"
|
||||
"Modifying the original array changed the tensor.\n"
|
||||
"This can cause hard-to-debug issues."
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
assert True, "Tensor not implemented yet"
|
||||
pytest.skip("Tensor not implemented yet")
|
||||
|
||||
def test_tensor_memory_efficiency(self):
|
||||
"""Test tensor memory usage is reasonable."""
|
||||
"""
|
||||
WHAT: Handle large tensors efficiently.
|
||||
|
||||
WHY: Real models have millions of parameters:
|
||||
- ResNet-50: 25 million parameters
|
||||
- GPT-2: 1.5 billion parameters
|
||||
- LLaMA: 7-65 billion parameters
|
||||
|
||||
STUDENT LEARNING: Memory efficiency matters at scale.
|
||||
"""
|
||||
try:
|
||||
from tinytorch.core.tensor import Tensor
|
||||
|
||||
# Large tensor test
|
||||
# Create a 1000×1000 tensor (1 million elements)
|
||||
data = np.random.randn(1000, 1000)
|
||||
t = Tensor(data)
|
||||
|
||||
# Should not create unnecessary copies
|
||||
assert t.shape == (1000, 1000)
|
||||
assert t.data.size == 1000000
|
||||
assert t.data.size == 1000000, (
|
||||
f"Tensor should have 1M elements.\n"
|
||||
f" Got: {t.data.size} elements"
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
assert True, "Tensor not implemented yet"
|
||||
pytest.skip("Tensor not implemented yet")
|
||||
|
||||
|
||||
class TestTensorReshaping:
|
||||
"""Test tensor reshaping and view operations."""
|
||||
"""
|
||||
Test tensor reshaping and view operations.
|
||||
|
||||
CONCEPT: Reshaping changes how we interpret the same data.
|
||||
The underlying values don't change, just their arrangement.
|
||||
"""
|
||||
|
||||
def test_tensor_reshape(self):
|
||||
"""Test tensor reshaping."""
|
||||
"""
|
||||
WHAT: Reshape tensor to different dimensions.
|
||||
|
||||
WHY: Reshaping is constantly needed:
|
||||
- Flattening images for dense layers
|
||||
- Rearranging for batch processing
|
||||
- Preparing data for specific layer types
|
||||
|
||||
STUDENT LEARNING: Total elements must stay the same.
|
||||
[12 elements] can become (3,4) or (2,6) or (2,2,3), but not (5,3).
|
||||
"""
|
||||
try:
|
||||
from tinytorch.core.tensor import Tensor
|
||||
|
||||
t = Tensor(np.arange(12)) # [0, 1, 2, ..., 11]
|
||||
|
||||
# Test reshape
|
||||
if hasattr(t, 'reshape'):
|
||||
reshaped = t.reshape(3, 4)
|
||||
assert reshaped.shape == (3, 4)
|
||||
assert reshaped.shape == (3, 4), (
|
||||
f"Reshape failed.\n"
|
||||
f" Original: {t.shape} (12 elements)\n"
|
||||
f" Requested: (3, 4) (12 elements)\n"
|
||||
f" Got: {reshaped.shape}"
|
||||
)
|
||||
assert reshaped.data.size == 12
|
||||
else:
|
||||
# Manual reshape test
|
||||
reshaped_data = t.data.reshape(3, 4)
|
||||
assert reshaped_data.shape == (3, 4)
|
||||
|
||||
except ImportError:
|
||||
assert True, "Tensor reshape not implemented yet"
|
||||
pytest.skip("Tensor reshape not implemented yet")
|
||||
|
||||
def test_tensor_flatten(self):
|
||||
"""Test tensor flattening."""
|
||||
"""
|
||||
WHAT: Flatten tensor to 1D.
|
||||
|
||||
WHY: Flattening is required to connect:
|
||||
- Conv layers (4D) to Dense layers (2D)
|
||||
- Image data to classification heads
|
||||
|
||||
STUDENT LEARNING: flatten() is shorthand for reshape(-1)
|
||||
"""
|
||||
try:
|
||||
from tinytorch.core.tensor import Tensor
|
||||
|
||||
t = Tensor(np.random.randn(2, 3, 4))
|
||||
t = Tensor(np.random.randn(2, 3, 4)) # 2×3×4 = 24 elements
|
||||
|
||||
if hasattr(t, 'flatten'):
|
||||
flat = t.flatten()
|
||||
assert flat.shape == (24,)
|
||||
assert flat.shape == (24,), (
|
||||
f"Flatten failed.\n"
|
||||
f" Original: {t.shape} = {2*3*4} elements\n"
|
||||
f" Expected: (24,)\n"
|
||||
f" Got: {flat.shape}"
|
||||
)
|
||||
else:
|
||||
# Manual flatten test
|
||||
flat_data = t.data.flatten()
|
||||
assert flat_data.shape == (24,)
|
||||
|
||||
except ImportError:
|
||||
assert True, "Tensor flatten not implemented yet"
|
||||
pytest.skip("Tensor flatten not implemented yet")
|
||||
|
||||
def test_tensor_transpose(self):
|
||||
"""Test tensor transpose."""
|
||||
"""
|
||||
WHAT: Transpose tensor (swap dimensions).
|
||||
|
||||
WHY: Transpose is used for:
|
||||
- Matrix multiplication compatibility
|
||||
- Attention: K^T in Q @ K^T
|
||||
- Rearranging data layouts
|
||||
|
||||
STUDENT LEARNING: Transpose swaps rows and columns.
|
||||
(m×n) becomes (n×m).
|
||||
"""
|
||||
try:
|
||||
from tinytorch.core.tensor import Tensor
|
||||
|
||||
t = Tensor([[1, 2, 3], [4, 5, 6]]) # 2x3
|
||||
t = Tensor([[1, 2, 3], [4, 5, 6]]) # 2×3
|
||||
|
||||
if hasattr(t, 'T') or hasattr(t, 'transpose'):
|
||||
if hasattr(t, 'T'):
|
||||
transposed = t.T
|
||||
else:
|
||||
transposed = t.transpose()
|
||||
|
||||
assert transposed.shape == (3, 2)
|
||||
transposed = t.T if hasattr(t, 'T') else t.transpose()
|
||||
|
||||
assert transposed.shape == (3, 2), (
|
||||
f"Transpose failed.\n"
|
||||
f" Original: {t.shape}\n"
|
||||
f" Expected: (3, 2)\n"
|
||||
f" Got: {transposed.shape}"
|
||||
)
|
||||
expected = np.array([[1, 4], [2, 5], [3, 6]])
|
||||
assert np.array_equal(transposed.data, expected)
|
||||
else:
|
||||
# Manual transpose test
|
||||
transposed_data = t.data.T
|
||||
assert transposed_data.shape == (3, 2)
|
||||
|
||||
except ImportError:
|
||||
assert True, "Tensor transpose not implemented yet"
|
||||
pytest.skip("Tensor transpose not implemented yet")
|
||||
|
||||
|
||||
class TestTensorBroadcasting:
|
||||
"""Test tensor broadcasting operations."""
|
||||
"""
|
||||
Test tensor broadcasting operations.
|
||||
|
||||
CONCEPT: Broadcasting lets you operate on tensors of different shapes
|
||||
by automatically expanding the smaller one. This saves memory and code.
|
||||
"""
|
||||
|
||||
def test_scalar_broadcasting(self):
|
||||
"""Test broadcasting with scalars."""
|
||||
"""
|
||||
WHAT: Add a scalar to every element.
|
||||
|
||||
WHY: Scalar operations are common:
|
||||
- Adding bias: output + bias
|
||||
- Normalization: (x - mean) / std
|
||||
- Scaling: x * 0.1
|
||||
|
||||
STUDENT LEARNING: Scalars broadcast to match any shape.
|
||||
"""
|
||||
try:
|
||||
from tinytorch.core.tensor import Tensor
|
||||
|
||||
t = Tensor([1, 2, 3])
|
||||
|
||||
# Test scalar addition
|
||||
if hasattr(t, '__add__'):
|
||||
result = t + 5
|
||||
expected = np.array([6, 7, 8])
|
||||
assert np.array_equal(result.data, expected)
|
||||
assert np.array_equal(result.data, expected), (
|
||||
f"Scalar broadcasting failed.\n"
|
||||
f" {t.data} + 5\n"
|
||||
f" Expected: {expected}\n"
|
||||
f" Got: {result.data}\n"
|
||||
"The scalar 5 should be added to every element."
|
||||
)
|
||||
|
||||
except (ImportError, TypeError):
|
||||
assert True, "Scalar broadcasting not implemented yet"
|
||||
pytest.skip("Scalar broadcasting not implemented yet")
|
||||
|
||||
def test_vector_broadcasting(self):
|
||||
"""Test broadcasting between different shapes."""
|
||||
"""
|
||||
WHAT: Broadcast a vector across a matrix.
|
||||
|
||||
WHY: Vector broadcasting is used for:
|
||||
- Adding bias to batch output: (batch, features) + (features,)
|
||||
- Normalizing channels: (batch, H, W, C) / (C,)
|
||||
|
||||
STUDENT LEARNING: Broadcasting aligns from the RIGHT.
|
||||
(2,3) + (3,) works because 3 aligns with 3.
|
||||
(2,3) + (2,) fails because 2 doesn't align with 3.
|
||||
"""
|
||||
try:
|
||||
from tinytorch.core.tensor import Tensor
|
||||
|
||||
t1 = Tensor([[1, 2, 3], [4, 5, 6]]) # 2x3
|
||||
t1 = Tensor([[1, 2, 3], [4, 5, 6]]) # 2×3
|
||||
t2 = Tensor([10, 20, 30]) # 3,
|
||||
|
||||
# Should broadcast to same shape
|
||||
if hasattr(t1, '__add__'):
|
||||
result = t1 + t2
|
||||
assert result.shape == (2, 3)
|
||||
assert result.shape == (2, 3), (
|
||||
f"Broadcasting produced wrong shape.\n"
|
||||
f" (2,3) + (3,) should give (2,3)\n"
|
||||
f" Got: {result.shape}"
|
||||
)
|
||||
expected = np.array([[11, 22, 33], [14, 25, 36]])
|
||||
assert np.array_equal(result.data, expected)
|
||||
assert np.array_equal(result.data, expected), (
|
||||
f"Vector broadcasting failed.\n"
|
||||
f" [[1,2,3], [4,5,6]] + [10,20,30]\n"
|
||||
f" Expected: {expected}\n"
|
||||
f" Got: {result.data}\n"
|
||||
"Each row should have [10,20,30] added to it."
|
||||
)
|
||||
|
||||
except (ImportError, TypeError):
|
||||
assert True, "Vector broadcasting not implemented yet"
|
||||
pytest.skip("Vector broadcasting not implemented yet")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
@@ -2,11 +2,17 @@
|
||||
Pytest configuration for TinyTorch tests.
|
||||
|
||||
This file is automatically loaded by pytest and sets up the test environment.
|
||||
It also provides a Rich-based educational test output that helps students
|
||||
understand what each test does and why it matters.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
# Add tests directory to Python path so test_utils can be imported
|
||||
tests_dir = Path(__file__).parent
|
||||
@@ -27,3 +33,226 @@ try:
|
||||
except ImportError:
|
||||
pass # test_utils not yet created or has issues
|
||||
|
||||
# Register the TinyTorch educational test plugin
|
||||
pytest_plugins = ['tests.pytest_tinytorch']
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Educational Test Output Plugin
|
||||
# =============================================================================
|
||||
|
||||
def extract_test_purpose(docstring: Optional[str]) -> dict:
|
||||
"""
|
||||
Extract WHAT/WHY/HOW from test docstrings.
|
||||
|
||||
Returns dict with keys: 'what', 'why', 'learning', 'raw'
|
||||
"""
|
||||
if not docstring:
|
||||
return {'what': None, 'why': None, 'learning': None, 'raw': None}
|
||||
|
||||
result = {'raw': docstring.strip()}
|
||||
|
||||
# Extract WHAT section
|
||||
what_match = re.search(r'WHAT:\s*(.+?)(?=\n\s*\n|WHY:|$)', docstring, re.DOTALL | re.IGNORECASE)
|
||||
if what_match:
|
||||
result['what'] = what_match.group(1).strip()
|
||||
|
||||
# Extract WHY section
|
||||
why_match = re.search(r'WHY:\s*(.+?)(?=\n\s*\n|STUDENT|HOW:|$)', docstring, re.DOTALL | re.IGNORECASE)
|
||||
if why_match:
|
||||
result['why'] = why_match.group(1).strip()
|
||||
|
||||
# Extract STUDENT LEARNING section
|
||||
learning_match = re.search(r'STUDENT LEARNING:\s*(.+?)(?=\n\s*\n|$)', docstring, re.DOTALL | re.IGNORECASE)
|
||||
if learning_match:
|
||||
result['learning'] = learning_match.group(1).strip()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_module_from_path(path: str) -> Optional[str]:
|
||||
"""Extract module number from test file path."""
|
||||
match = re.search(r'/(\d{2})_(\w+)/', str(path))
|
||||
if match:
|
||||
return f"Module {match.group(1)}: {match.group(2).title()}"
|
||||
return None
|
||||
|
||||
|
||||
class TinyTorchTestReporter:
|
||||
"""Rich-based test reporter for educational output."""
|
||||
|
||||
def __init__(self):
|
||||
self.current_module = None
|
||||
self.passed = 0
|
||||
self.failed = 0
|
||||
self.skipped = 0
|
||||
self.use_rich = False
|
||||
|
||||
try:
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
self.console = Console()
|
||||
self.use_rich = True
|
||||
except ImportError:
|
||||
self.console = None
|
||||
|
||||
def print_test_start(self, nodeid: str, docstring: Optional[str]):
|
||||
"""Print when a test starts (only in verbose mode)."""
|
||||
if not self.use_rich:
|
||||
return
|
||||
|
||||
# Extract test name
|
||||
parts = nodeid.split("::")
|
||||
test_name = parts[-1] if parts else nodeid
|
||||
|
||||
# Get module info
|
||||
module = get_module_from_path(nodeid)
|
||||
if module and module != self.current_module:
|
||||
self.current_module = module
|
||||
self.console.print(f"\n[bold blue]━━━ {module} ━━━[/bold blue]")
|
||||
|
||||
# Get purpose from docstring
|
||||
purpose = extract_test_purpose(docstring)
|
||||
what = purpose.get('what')
|
||||
|
||||
if what:
|
||||
# Truncate to first line/sentence
|
||||
what_short = what.split('\n')[0][:60]
|
||||
self.console.print(f" [dim]⏳[/dim] {test_name}: {what_short}...")
|
||||
else:
|
||||
self.console.print(f" [dim]⏳[/dim] {test_name}...")
|
||||
|
||||
def print_test_result(self, nodeid: str, outcome: str, docstring: Optional[str] = None,
|
||||
longrepr=None):
|
||||
"""Print test result with educational context."""
|
||||
if not self.use_rich:
|
||||
return
|
||||
|
||||
parts = nodeid.split("::")
|
||||
test_name = parts[-1] if parts else nodeid
|
||||
|
||||
if outcome == "passed":
|
||||
self.passed += 1
|
||||
self.console.print(f" [green]✓[/green] {test_name}")
|
||||
elif outcome == "skipped":
|
||||
self.skipped += 1
|
||||
self.console.print(f" [yellow]⊘[/yellow] {test_name} [dim](skipped)[/dim]")
|
||||
elif outcome == "failed":
|
||||
self.failed += 1
|
||||
self.console.print(f" [red]✗[/red] {test_name}")
|
||||
|
||||
# Show educational context on failure
|
||||
purpose = extract_test_purpose(docstring)
|
||||
if purpose.get('what') or purpose.get('why'):
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
content = Text()
|
||||
if purpose.get('what'):
|
||||
content.append("WHAT: ", style="bold cyan")
|
||||
content.append(purpose['what'][:200] + "\n\n")
|
||||
if purpose.get('why'):
|
||||
content.append("WHY THIS MATTERS: ", style="bold yellow")
|
||||
content.append(purpose['why'][:300])
|
||||
|
||||
self.console.print(Panel(content, title="[red]Test Failed[/red]",
|
||||
border_style="red", padding=(0, 1)))
|
||||
|
||||
def print_summary(self):
|
||||
"""Print final summary."""
|
||||
if not self.use_rich:
|
||||
return
|
||||
|
||||
total = self.passed + self.failed + self.skipped
|
||||
|
||||
self.console.print("\n" + "━" * 50)
|
||||
status = "[green]ALL PASSED[/green]" if self.failed == 0 else f"[red]{self.failed} FAILED[/red]"
|
||||
self.console.print(f"[bold]{status}[/bold] | {self.passed} passed, {self.skipped} skipped, {total} total")
|
||||
|
||||
|
||||
# Global reporter instance
|
||||
_reporter = TinyTorchTestReporter()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pytest Hooks
|
||||
# =============================================================================
|
||||
|
||||
def pytest_configure(config):
|
||||
"""Configure pytest with TinyTorch-specific settings."""
|
||||
# Register custom markers
|
||||
config.addinivalue_line(
|
||||
"markers", "module(name): mark test as belonging to a specific module"
|
||||
)
|
||||
config.addinivalue_line(
|
||||
"markers", "slow: mark test as slow running"
|
||||
)
|
||||
config.addinivalue_line(
|
||||
"markers", "integration: mark test as integration test"
|
||||
)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(session, config, items):
|
||||
"""Modify test collection to add educational metadata."""
|
||||
for item in items:
|
||||
# Auto-detect module from path
|
||||
module = get_module_from_path(str(item.fspath))
|
||||
if module:
|
||||
# Store module info for later use
|
||||
item._tinytorch_module = module
|
||||
|
||||
|
||||
@pytest.hookimpl(hookwrapper=True)
|
||||
def pytest_runtest_makereport(item, call):
|
||||
"""Hook to capture test results for educational output."""
|
||||
outcome = yield
|
||||
report = outcome.get_result()
|
||||
|
||||
# Only process the "call" phase (not setup/teardown)
|
||||
if report.when == "call":
|
||||
# Get docstring from test function
|
||||
docstring = item.function.__doc__ if hasattr(item, 'function') else None
|
||||
|
||||
# Store for later use if needed
|
||||
report._tinytorch_docstring = docstring
|
||||
|
||||
|
||||
def pytest_terminal_summary(terminalreporter, exitstatus, config):
|
||||
"""Add educational summary at the end of test run."""
|
||||
# Check if we should show educational summary
|
||||
if hasattr(config, '_tinytorch_show_summary') and config._tinytorch_show_summary:
|
||||
_reporter.print_summary()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Custom Test Runner Command (for tito test)
|
||||
# =============================================================================
|
||||
|
||||
def run_tests_with_rich_output(test_path: str = None, verbose: bool = True):
|
||||
"""
|
||||
Run tests with Rich educational output.
|
||||
|
||||
This can be called from tito CLI to provide a better student experience.
|
||||
"""
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
||||
console = Console()
|
||||
|
||||
# Header
|
||||
console.print(Panel(
|
||||
"[bold]🧪 TinyTorch Test Runner[/bold]\n"
|
||||
"Running tests with educational context...",
|
||||
border_style="blue"
|
||||
))
|
||||
|
||||
# Build pytest args
|
||||
args = ["-v", "--tb=short"]
|
||||
if test_path:
|
||||
args.append(test_path)
|
||||
|
||||
# Run pytest
|
||||
exit_code = pytest.main(args)
|
||||
|
||||
return exit_code
|
||||
|
||||
266
tests/pytest_tinytorch.py
Normal file
266
tests/pytest_tinytorch.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""
|
||||
TinyTorch Educational Test Plugin for Pytest
|
||||
=============================================
|
||||
|
||||
This plugin provides Rich-formatted output that helps students understand
|
||||
what tests are checking and why they matter.
|
||||
|
||||
USAGE:
|
||||
pytest --tinytorch # Enable educational output
|
||||
pytest --tinytorch -v # Verbose educational output
|
||||
|
||||
Or run through tito:
|
||||
tito test --edu # Educational mode
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Optional, Dict, Any
|
||||
import pytest
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
"""Add TinyTorch-specific command line options."""
|
||||
group = parser.getgroup('tinytorch', 'TinyTorch educational testing')
|
||||
group.addoption(
|
||||
'--tinytorch',
|
||||
action='store_true',
|
||||
dest='tinytorch_edu',
|
||||
default=False,
|
||||
help='Enable TinyTorch educational test output'
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
"""Configure the plugin."""
|
||||
if config.getoption('tinytorch_edu', False):
|
||||
config.pluginmanager.register(TinyTorchReporter(config), 'tinytorch_reporter')
|
||||
|
||||
|
||||
class TinyTorchReporter:
|
||||
"""
|
||||
Rich-based reporter that shows educational context for tests.
|
||||
|
||||
Features:
|
||||
- Module grouping with descriptions
|
||||
- WHAT/WHY extraction from docstrings
|
||||
- Clear pass/fail indicators
|
||||
- Educational failure messages
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.current_module = None
|
||||
self.stats = {'passed': 0, 'failed': 0, 'skipped': 0, 'error': 0}
|
||||
self.failures = []
|
||||
|
||||
try:
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
self.console = Console()
|
||||
self.rich_available = True
|
||||
except ImportError:
|
||||
self.rich_available = False
|
||||
|
||||
def _extract_purpose(self, docstring: Optional[str]) -> Dict[str, Optional[str]]:
|
||||
"""Extract WHAT/WHY/LEARNING from docstring."""
|
||||
if not docstring:
|
||||
return {'what': None, 'why': None, 'learning': None}
|
||||
|
||||
result = {}
|
||||
|
||||
# Extract WHAT
|
||||
what_match = re.search(r'WHAT:\s*(.+?)(?=\n\s*\n|WHY:|$)', docstring, re.DOTALL | re.IGNORECASE)
|
||||
result['what'] = what_match.group(1).strip() if what_match else None
|
||||
|
||||
# Extract WHY
|
||||
why_match = re.search(r'WHY:\s*(.+?)(?=\n\s*\n|STUDENT|HOW:|$)', docstring, re.DOTALL | re.IGNORECASE)
|
||||
result['why'] = why_match.group(1).strip() if why_match else None
|
||||
|
||||
# Extract STUDENT LEARNING
|
||||
learning_match = re.search(r'STUDENT LEARNING:\s*(.+?)(?=\n\s*\n|$)', docstring, re.DOTALL)
|
||||
result['learning'] = learning_match.group(1).strip() if learning_match else None
|
||||
|
||||
return result
|
||||
|
||||
def _get_module_info(self, nodeid: str) -> Optional[str]:
|
||||
"""Extract module name from test path."""
|
||||
match = re.search(r'/(\d{2})_(\w+)/', nodeid)
|
||||
if match:
|
||||
num, name = match.groups()
|
||||
return f"Module {num}: {name.replace('_', ' ').title()}"
|
||||
|
||||
# Check for other test categories
|
||||
if '/integration/' in nodeid:
|
||||
return "Integration Tests"
|
||||
if '/regression/' in nodeid:
|
||||
return "Regression Tests"
|
||||
if '/e2e/' in nodeid:
|
||||
return "End-to-End Tests"
|
||||
|
||||
return None
|
||||
|
||||
@pytest.hookimpl(hookwrapper=True)
|
||||
def pytest_collection_finish(self, session):
|
||||
"""Called after collection, show what we're testing."""
|
||||
yield
|
||||
|
||||
if not self.rich_available:
|
||||
return
|
||||
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
|
||||
# Group tests by module
|
||||
modules = {}
|
||||
for item in session.items:
|
||||
module = self._get_module_info(item.nodeid) or "Other Tests"
|
||||
if module not in modules:
|
||||
modules[module] = []
|
||||
modules[module].append(item.name)
|
||||
|
||||
# Create summary table
|
||||
table = Table(show_header=True, header_style="bold blue")
|
||||
table.add_column("Module", style="cyan")
|
||||
table.add_column("Tests", justify="right")
|
||||
table.add_column("Sample Tests", style="dim")
|
||||
|
||||
for module, tests in sorted(modules.items()):
|
||||
sample = ", ".join(tests[:2])
|
||||
if len(tests) > 2:
|
||||
sample += f", ... (+{len(tests)-2} more)"
|
||||
table.add_row(module, str(len(tests)), sample)
|
||||
|
||||
self.console.print(Panel(
|
||||
table,
|
||||
title="[bold]🧪 TinyTorch Test Suite[/bold]",
|
||||
subtitle=f"[dim]{len(session.items)} tests to run[/dim]",
|
||||
border_style="blue"
|
||||
))
|
||||
self.console.print()
|
||||
|
||||
@pytest.hookimpl(hookwrapper=True)
|
||||
def pytest_runtest_protocol(self, item):
|
||||
"""Called for each test."""
|
||||
# Check if we're entering a new module
|
||||
module = self._get_module_info(item.nodeid)
|
||||
if self.rich_available and module and module != self.current_module:
|
||||
self.current_module = module
|
||||
self.console.print(f"\n[bold blue]━━━ {module} ━━━[/bold blue]")
|
||||
|
||||
yield
|
||||
|
||||
@pytest.hookimpl(hookwrapper=True)
|
||||
def pytest_runtest_makereport(self, item, call):
|
||||
"""Process test results."""
|
||||
outcome = yield
|
||||
report = outcome.get_result()
|
||||
|
||||
if report.when != "call":
|
||||
return
|
||||
|
||||
if not self.rich_available:
|
||||
return
|
||||
|
||||
# Get test info
|
||||
test_name = item.name
|
||||
docstring = item.function.__doc__ if hasattr(item, 'function') else None
|
||||
purpose = self._extract_purpose(docstring)
|
||||
|
||||
# Format output based on result
|
||||
if report.passed:
|
||||
self.stats['passed'] += 1
|
||||
what = purpose.get('what', '')
|
||||
if what:
|
||||
what_short = what.split('\n')[0][:50]
|
||||
self.console.print(f" [green]✓[/green] {test_name} [dim]- {what_short}[/dim]")
|
||||
else:
|
||||
self.console.print(f" [green]✓[/green] {test_name}")
|
||||
|
||||
elif report.skipped:
|
||||
self.stats['skipped'] += 1
|
||||
self.console.print(f" [yellow]⊘[/yellow] {test_name} [dim](skipped)[/dim]")
|
||||
|
||||
elif report.failed:
|
||||
self.stats['failed'] += 1
|
||||
self.console.print(f" [red]✗[/red] {test_name}")
|
||||
|
||||
# Store failure info for detailed output
|
||||
self.failures.append({
|
||||
'name': test_name,
|
||||
'nodeid': item.nodeid,
|
||||
'purpose': purpose,
|
||||
'longrepr': report.longreprtext
|
||||
})
|
||||
|
||||
def pytest_sessionfinish(self, session, exitstatus):
|
||||
"""Called at the end of the session."""
|
||||
if not self.rich_available:
|
||||
return
|
||||
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
self.console.print()
|
||||
|
||||
# Show failure details with educational context
|
||||
if self.failures:
|
||||
self.console.print("[bold red]━━━ Failed Tests ━━━[/bold red]\n")
|
||||
|
||||
for failure in self.failures:
|
||||
# Create educational failure panel
|
||||
content = Text()
|
||||
|
||||
purpose = failure['purpose']
|
||||
if purpose.get('what'):
|
||||
content.append("📋 WHAT: ", style="bold cyan")
|
||||
content.append(purpose['what'][:200] + "\n\n", style="white")
|
||||
|
||||
if purpose.get('why'):
|
||||
content.append("❓ WHY: ", style="bold yellow")
|
||||
content.append(purpose['why'][:300] + "\n\n", style="white")
|
||||
|
||||
if purpose.get('learning'):
|
||||
content.append("💡 TIP: ", style="bold green")
|
||||
content.append(purpose['learning'][:200] + "\n\n", style="white")
|
||||
|
||||
# Add error excerpt
|
||||
error_lines = failure['longrepr'].split('\n')
|
||||
error_excerpt = '\n'.join(error_lines[-10:]) # Last 10 lines
|
||||
content.append("🔍 Error:\n", style="bold red")
|
||||
content.append(error_excerpt[:500], style="dim")
|
||||
|
||||
self.console.print(Panel(
|
||||
content,
|
||||
title=f"[red]✗ {failure['name']}[/red]",
|
||||
border_style="red",
|
||||
padding=(1, 2)
|
||||
))
|
||||
self.console.print()
|
||||
|
||||
# Summary
|
||||
total = sum(self.stats.values())
|
||||
passed = self.stats['passed']
|
||||
failed = self.stats['failed']
|
||||
skipped = self.stats['skipped']
|
||||
|
||||
if failed == 0:
|
||||
status_style = "green"
|
||||
status_text = "ALL TESTS PASSED"
|
||||
emoji = "🎉"
|
||||
else:
|
||||
status_style = "red"
|
||||
status_text = f"{failed} TESTS FAILED"
|
||||
emoji = "❌"
|
||||
|
||||
summary = Text()
|
||||
summary.append(f"\n{emoji} ", style="bold")
|
||||
summary.append(status_text, style=f"bold {status_style}")
|
||||
summary.append(f"\n\n Passed: {passed}", style="green")
|
||||
summary.append(f" Failed: {failed}", style="red")
|
||||
summary.append(f" Skipped: {skipped}", style="yellow")
|
||||
summary.append(f" Total: {total}", style="dim")
|
||||
|
||||
self.console.print(Panel(summary, border_style=status_style))
|
||||
|
||||
Reference in New Issue
Block a user