Add educational test output with Rich CLI

- Create pytest_tinytorch.py plugin for educational test output
- Update test_tensor_core.py with WHAT/WHY/STUDENT LEARNING docstrings
- Show test purpose on pass, detailed context on failure
- Use --tinytorch flag to enable educational mode

Students can now understand what each test checks and why it matters.
This commit is contained in:
Vijay Janapa Reddi
2025-12-02 22:37:25 -08:00
parent a622e2c200
commit 36dd05ef62
3 changed files with 823 additions and 81 deletions

View File

@@ -1,9 +1,27 @@
"""
Module 01: Tensor - Core Functionality Tests
Tests fundamental tensor operations and memory management
=============================================
These tests verify that Tensor, the fundamental data structure of TinyTorch, works correctly.
WHY TENSORS MATTER:
------------------
Tensors are the foundation of ALL deep learning:
- Every input (images, text, audio) becomes a tensor
- Every weight and bias in a neural network is a tensor
- Every gradient computed during training is a tensor
If Tensor doesn't work, nothing else will. This is Module 01 for a reason.
WHAT STUDENTS LEARN:
-------------------
1. How data is represented in deep learning frameworks
2. Why NumPy is the backbone of Python ML
3. How operations like broadcasting save memory and compute
"""
import numpy as np
import pytest
import sys
from pathlib import Path
@@ -12,28 +30,59 @@ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
class TestTensorCreation:
"""Test tensor creation and initialization."""
"""
Test tensor creation and initialization.
CONCEPT: A Tensor wraps a NumPy array and adds deep learning capabilities
(like gradient tracking). Creating tensors is the first step in any ML pipeline.
"""
def test_tensor_from_list(self):
"""Test creating tensor from Python list."""
"""
WHAT: Create tensors from Python lists.
WHY: Students often start with raw Python data (lists of numbers,
nested lists for matrices). TinyTorch must accept this natural input
and convert it to the internal NumPy representation.
STUDENT LEARNING: Data can enter the framework in different forms,
but internally it's always a NumPy array.
"""
try:
from tinytorch.core.tensor import Tensor
# 1D tensor
# 1D tensor (vector) - like a single data sample's features
t1 = Tensor([1, 2, 3])
assert t1.shape == (3,)
assert t1.shape == (3,), (
f"1D tensor has wrong shape.\n"
f" Input: [1, 2, 3] (3 elements)\n"
f" Expected shape: (3,)\n"
f" Got: {t1.shape}"
)
assert np.array_equal(t1.data, [1, 2, 3])
# 2D tensor
# 2D tensor (matrix) - like a batch of samples or weight matrix
t2 = Tensor([[1, 2], [3, 4]])
assert t2.shape == (2, 2)
assert np.array_equal(t2.data, [[1, 2], [3, 4]])
assert t2.shape == (2, 2), (
f"2D tensor has wrong shape.\n"
f" Input: [[1,2], [3,4]] (2 rows, 2 cols)\n"
f" Expected shape: (2, 2)\n"
f" Got: {t2.shape}"
)
except ImportError:
assert True, "Tensor not implemented yet"
pytest.skip("Tensor not implemented yet")
def test_tensor_from_numpy(self):
"""Test creating tensor from numpy array."""
"""
WHAT: Create tensors from NumPy arrays.
WHY: Real ML data comes from NumPy (pandas, scikit-learn, image loaders).
TinyTorch must seamlessly accept NumPy arrays.
STUDENT LEARNING: TinyTorch uses float32 by default (like PyTorch)
because it's faster and uses half the memory of float64.
"""
try:
from tinytorch.core.tensor import Tensor
@@ -41,111 +90,211 @@ class TestTensorCreation:
t = Tensor(arr)
assert t.shape == (2, 2)
# TinyTorch uses float32 for efficiency
assert t.dtype == np.float32
assert t.dtype == np.float32, (
f"Tensor should use float32 for efficiency.\n"
f" Expected dtype: np.float32\n"
f" Got: {t.dtype}\n"
"float32 is half the memory of float64 and faster on GPUs."
)
assert np.allclose(t.data, arr)
except ImportError:
assert True, "Tensor not implemented yet"
pytest.skip("Tensor not implemented yet")
def test_tensor_shapes(self):
"""Test tensor shape handling."""
"""
WHAT: Handle tensors of various dimensions.
WHY: Deep learning uses many tensor shapes:
- 1D: feature vectors, biases
- 2D: weight matrices, batch of 1D samples
- 3D: sequences (batch, seq_len, features)
- 4D: images (batch, height, width, channels)
STUDENT LEARNING: Shape is critical. Most bugs are shape mismatches.
"""
try:
from tinytorch.core.tensor import Tensor
# Test different shapes
shapes = [(5,), (3, 4), (2, 3, 4), (1, 28, 28, 3)]
test_cases = [
((5,), "1D: feature vector"),
((3, 4), "2D: weight matrix"),
((2, 3, 4), "3D: sequence data"),
((1, 28, 28, 3), "4D: single RGB image"),
]
for shape in shapes:
for shape, description in test_cases:
data = np.random.randn(*shape)
t = Tensor(data)
assert t.shape == shape
assert t.shape == shape, (
f"Shape mismatch for {description}.\n"
f" Expected: {shape}\n"
f" Got: {t.shape}"
)
except ImportError:
assert True, "Tensor not implemented yet"
pytest.skip("Tensor not implemented yet")
class TestTensorOperations:
"""Test tensor arithmetic and operations."""
"""
Test tensor arithmetic and operations.
CONCEPT: Neural networks are just sequences of mathematical operations
on tensors. If these operations don't work, training is impossible.
"""
def test_tensor_addition(self):
"""Test tensor addition."""
"""
WHAT: Element-wise tensor addition.
WHY: Addition is used everywhere in neural networks:
- Adding bias to layer output: y = Wx + b
- Residual connections: output = layer(x) + x
- Gradient accumulation
STUDENT LEARNING: Operations return new Tensors (functional style).
"""
try:
from tinytorch.core.tensor import Tensor
t1 = Tensor([1, 2, 3])
t2 = Tensor([4, 5, 6])
# Element-wise addition
result = t1 + t2
expected = np.array([5, 7, 9])
assert isinstance(result, Tensor)
assert np.array_equal(result.data, expected)
assert isinstance(result, Tensor), (
"Addition should return a Tensor, not numpy array.\n"
"This maintains the computation graph for backpropagation."
)
assert np.array_equal(result.data, expected), (
f"Element-wise addition failed.\n"
f" {t1.data} + {t2.data}\n"
f" Expected: {expected}\n"
f" Got: {result.data}"
)
except (ImportError, TypeError):
assert True, "Tensor addition not implemented yet"
pytest.skip("Tensor addition not implemented yet")
def test_tensor_multiplication(self):
"""Test tensor multiplication."""
"""
WHAT: Element-wise tensor multiplication.
WHY: Element-wise multiplication (Hadamard product) is used for:
- Applying masks (setting values to zero)
- Gating mechanisms (LSTM, attention)
- Dropout during training
STUDENT LEARNING: This is NOT matrix multiplication. It's element-by-element.
"""
try:
from tinytorch.core.tensor import Tensor
t1 = Tensor([1, 2, 3])
t2 = Tensor([2, 3, 4])
# Element-wise multiplication
result = t1 * t2
expected = np.array([2, 6, 12])
assert isinstance(result, Tensor)
assert np.array_equal(result.data, expected)
assert np.array_equal(result.data, expected), (
f"Element-wise multiplication failed.\n"
f" {t1.data} * {t2.data} (element-wise)\n"
f" Expected: {expected}\n"
f" Got: {result.data}\n"
"Remember: * is element-wise, @ is matrix multiplication."
)
except (ImportError, TypeError):
assert True, "Tensor multiplication not implemented yet"
pytest.skip("Tensor multiplication not implemented yet")
def test_matrix_multiplication(self):
"""Test matrix multiplication."""
"""
WHAT: Matrix multiplication (the @ operator).
WHY: Matrix multiplication is THE core operation of neural networks:
- Linear layers: y = x @ W
- Attention: scores = Q @ K^T
- Every fully-connected layer uses it
STUDENT LEARNING: Matrix dimensions must be compatible.
(m×n) @ (n×p) = (m×p) - inner dimensions must match.
"""
try:
from tinytorch.core.tensor import Tensor
t1 = Tensor([[1, 2], [3, 4]])
t2 = Tensor([[5, 6], [7, 8]])
t1 = Tensor([[1, 2], [3, 4]]) # 2×2
t2 = Tensor([[5, 6], [7, 8]]) # 2×2
# Matrix multiplication
# Matrix multiplication using @ operator
if hasattr(t1, '__matmul__'):
result = t1 @ t2
else:
# Fallback to manual matmul
result = Tensor(t1.data @ t2.data)
# Manual calculation:
# [1*5+2*7, 1*6+2*8] = [19, 22]
# [3*5+4*7, 3*6+4*8] = [43, 50]
expected = np.array([[19, 22], [43, 50]])
assert np.array_equal(result.data, expected)
assert np.array_equal(result.data, expected), (
f"Matrix multiplication failed.\n"
f" {t1.data}\n @\n {t2.data}\n"
f" Expected:\n {expected}\n"
f" Got:\n {result.data}"
)
except (ImportError, TypeError):
assert True, "Matrix multiplication not implemented yet"
pytest.skip("Matrix multiplication not implemented yet")
class TestTensorMemory:
"""Test tensor memory management."""
"""
Test tensor memory management.
CONCEPT: Efficient memory use is critical for deep learning.
Large models can use 10s of GB. Understanding memory helps debug OOM errors.
"""
def test_tensor_data_access(self):
"""Test accessing tensor data."""
"""
WHAT: Access the underlying NumPy array.
WHY: Sometimes you need the raw data for:
- Visualization (matplotlib expects NumPy)
- Debugging (print values)
- Integration with other libraries
STUDENT LEARNING: .data gives you the NumPy array inside the Tensor.
"""
try:
from tinytorch.core.tensor import Tensor
data = np.array([1, 2, 3, 4])
t = Tensor(data)
# Should be able to access underlying data
assert hasattr(t, 'data')
assert hasattr(t, 'data'), (
"Tensor must have a .data attribute.\n"
"This gives access to the underlying NumPy array."
)
assert np.array_equal(t.data, data)
except ImportError:
assert True, "Tensor not implemented yet"
pytest.skip("Tensor not implemented yet")
def test_tensor_copy_semantics(self):
"""Test tensor copying behavior."""
"""
WHAT: Verify tensors don't share memory unexpectedly.
WHY: Shared memory can cause subtle bugs:
- Modifying one tensor accidentally changes another
- Gradient corruption during backprop
- Non-reproducible results
STUDENT LEARNING: TinyTorch should copy data by default for safety.
"""
try:
from tinytorch.core.tensor import Tensor
@@ -159,127 +308,225 @@ class TestTensorMemory:
# Modifying original shouldn't affect t2
original_data[0] = 999
if not np.shares_memory(t2.data, original_data):
assert t2.data[0] == 1 # Should be unchanged
assert t2.data[0] == 1, (
"Tensor should not share memory with input!\n"
"Modifying the original array changed the tensor.\n"
"This can cause hard-to-debug issues."
)
except ImportError:
assert True, "Tensor not implemented yet"
pytest.skip("Tensor not implemented yet")
def test_tensor_memory_efficiency(self):
"""Test tensor memory usage is reasonable."""
"""
WHAT: Handle large tensors efficiently.
WHY: Real models have millions of parameters:
- ResNet-50: 25 million parameters
- GPT-2: 1.5 billion parameters
- LLaMA: 7-65 billion parameters
STUDENT LEARNING: Memory efficiency matters at scale.
"""
try:
from tinytorch.core.tensor import Tensor
# Large tensor test
# Create a 1000×1000 tensor (1 million elements)
data = np.random.randn(1000, 1000)
t = Tensor(data)
# Should not create unnecessary copies
assert t.shape == (1000, 1000)
assert t.data.size == 1000000
assert t.data.size == 1000000, (
f"Tensor should have 1M elements.\n"
f" Got: {t.data.size} elements"
)
except ImportError:
assert True, "Tensor not implemented yet"
pytest.skip("Tensor not implemented yet")
class TestTensorReshaping:
"""Test tensor reshaping and view operations."""
"""
Test tensor reshaping and view operations.
CONCEPT: Reshaping changes how we interpret the same data.
The underlying values don't change, just their arrangement.
"""
def test_tensor_reshape(self):
"""Test tensor reshaping."""
"""
WHAT: Reshape tensor to different dimensions.
WHY: Reshaping is constantly needed:
- Flattening images for dense layers
- Rearranging for batch processing
- Preparing data for specific layer types
STUDENT LEARNING: Total elements must stay the same.
[12 elements] can become (3,4) or (2,6) or (2,2,3), but not (5,3).
"""
try:
from tinytorch.core.tensor import Tensor
t = Tensor(np.arange(12)) # [0, 1, 2, ..., 11]
# Test reshape
if hasattr(t, 'reshape'):
reshaped = t.reshape(3, 4)
assert reshaped.shape == (3, 4)
assert reshaped.shape == (3, 4), (
f"Reshape failed.\n"
f" Original: {t.shape} (12 elements)\n"
f" Requested: (3, 4) (12 elements)\n"
f" Got: {reshaped.shape}"
)
assert reshaped.data.size == 12
else:
# Manual reshape test
reshaped_data = t.data.reshape(3, 4)
assert reshaped_data.shape == (3, 4)
except ImportError:
assert True, "Tensor reshape not implemented yet"
pytest.skip("Tensor reshape not implemented yet")
def test_tensor_flatten(self):
"""Test tensor flattening."""
"""
WHAT: Flatten tensor to 1D.
WHY: Flattening is required to connect:
- Conv layers (4D) to Dense layers (2D)
- Image data to classification heads
STUDENT LEARNING: flatten() is shorthand for reshape(-1)
"""
try:
from tinytorch.core.tensor import Tensor
t = Tensor(np.random.randn(2, 3, 4))
t = Tensor(np.random.randn(2, 3, 4)) # 2×3×4 = 24 elements
if hasattr(t, 'flatten'):
flat = t.flatten()
assert flat.shape == (24,)
assert flat.shape == (24,), (
f"Flatten failed.\n"
f" Original: {t.shape} = {2*3*4} elements\n"
f" Expected: (24,)\n"
f" Got: {flat.shape}"
)
else:
# Manual flatten test
flat_data = t.data.flatten()
assert flat_data.shape == (24,)
except ImportError:
assert True, "Tensor flatten not implemented yet"
pytest.skip("Tensor flatten not implemented yet")
def test_tensor_transpose(self):
"""Test tensor transpose."""
"""
WHAT: Transpose tensor (swap dimensions).
WHY: Transpose is used for:
- Matrix multiplication compatibility
- Attention: K^T in Q @ K^T
- Rearranging data layouts
STUDENT LEARNING: Transpose swaps rows and columns.
(m×n) becomes (n×m).
"""
try:
from tinytorch.core.tensor import Tensor
t = Tensor([[1, 2, 3], [4, 5, 6]]) # 2x3
t = Tensor([[1, 2, 3], [4, 5, 6]]) # 2×3
if hasattr(t, 'T') or hasattr(t, 'transpose'):
if hasattr(t, 'T'):
transposed = t.T
else:
transposed = t.transpose()
assert transposed.shape == (3, 2)
transposed = t.T if hasattr(t, 'T') else t.transpose()
assert transposed.shape == (3, 2), (
f"Transpose failed.\n"
f" Original: {t.shape}\n"
f" Expected: (3, 2)\n"
f" Got: {transposed.shape}"
)
expected = np.array([[1, 4], [2, 5], [3, 6]])
assert np.array_equal(transposed.data, expected)
else:
# Manual transpose test
transposed_data = t.data.T
assert transposed_data.shape == (3, 2)
except ImportError:
assert True, "Tensor transpose not implemented yet"
pytest.skip("Tensor transpose not implemented yet")
class TestTensorBroadcasting:
"""Test tensor broadcasting operations."""
"""
Test tensor broadcasting operations.
CONCEPT: Broadcasting lets you operate on tensors of different shapes
by automatically expanding the smaller one. This saves memory and code.
"""
def test_scalar_broadcasting(self):
"""Test broadcasting with scalars."""
"""
WHAT: Add a scalar to every element.
WHY: Scalar operations are common:
- Adding bias: output + bias
- Normalization: (x - mean) / std
- Scaling: x * 0.1
STUDENT LEARNING: Scalars broadcast to match any shape.
"""
try:
from tinytorch.core.tensor import Tensor
t = Tensor([1, 2, 3])
# Test scalar addition
if hasattr(t, '__add__'):
result = t + 5
expected = np.array([6, 7, 8])
assert np.array_equal(result.data, expected)
assert np.array_equal(result.data, expected), (
f"Scalar broadcasting failed.\n"
f" {t.data} + 5\n"
f" Expected: {expected}\n"
f" Got: {result.data}\n"
"The scalar 5 should be added to every element."
)
except (ImportError, TypeError):
assert True, "Scalar broadcasting not implemented yet"
pytest.skip("Scalar broadcasting not implemented yet")
def test_vector_broadcasting(self):
"""Test broadcasting between different shapes."""
"""
WHAT: Broadcast a vector across a matrix.
WHY: Vector broadcasting is used for:
- Adding bias to batch output: (batch, features) + (features,)
- Normalizing channels: (batch, H, W, C) / (C,)
STUDENT LEARNING: Broadcasting aligns from the RIGHT.
(2,3) + (3,) works because 3 aligns with 3.
(2,3) + (2,) fails because 2 doesn't align with 3.
"""
try:
from tinytorch.core.tensor import Tensor
t1 = Tensor([[1, 2, 3], [4, 5, 6]]) # 2x3
t1 = Tensor([[1, 2, 3], [4, 5, 6]]) # 2×3
t2 = Tensor([10, 20, 30]) # 3,
# Should broadcast to same shape
if hasattr(t1, '__add__'):
result = t1 + t2
assert result.shape == (2, 3)
assert result.shape == (2, 3), (
f"Broadcasting produced wrong shape.\n"
f" (2,3) + (3,) should give (2,3)\n"
f" Got: {result.shape}"
)
expected = np.array([[11, 22, 33], [14, 25, 36]])
assert np.array_equal(result.data, expected)
assert np.array_equal(result.data, expected), (
f"Vector broadcasting failed.\n"
f" [[1,2,3], [4,5,6]] + [10,20,30]\n"
f" Expected: {expected}\n"
f" Got: {result.data}\n"
"Each row should have [10,20,30] added to it."
)
except (ImportError, TypeError):
assert True, "Vector broadcasting not implemented yet"
pytest.skip("Vector broadcasting not implemented yet")
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -2,11 +2,17 @@
Pytest configuration for TinyTorch tests.
This file is automatically loaded by pytest and sets up the test environment.
It also provides a Rich-based educational test output that helps students
understand what each test does and why it matters.
"""
import sys
import os
import re
from pathlib import Path
from typing import Optional
import pytest
# Add tests directory to Python path so test_utils can be imported
tests_dir = Path(__file__).parent
@@ -27,3 +33,226 @@ try:
except ImportError:
pass # test_utils not yet created or has issues
# Register the TinyTorch educational test plugin
pytest_plugins = ['tests.pytest_tinytorch']
# =============================================================================
# Educational Test Output Plugin
# =============================================================================
def extract_test_purpose(docstring: Optional[str]) -> dict:
"""
Extract WHAT/WHY/HOW from test docstrings.
Returns dict with keys: 'what', 'why', 'learning', 'raw'
"""
if not docstring:
return {'what': None, 'why': None, 'learning': None, 'raw': None}
result = {'raw': docstring.strip()}
# Extract WHAT section
what_match = re.search(r'WHAT:\s*(.+?)(?=\n\s*\n|WHY:|$)', docstring, re.DOTALL | re.IGNORECASE)
if what_match:
result['what'] = what_match.group(1).strip()
# Extract WHY section
why_match = re.search(r'WHY:\s*(.+?)(?=\n\s*\n|STUDENT|HOW:|$)', docstring, re.DOTALL | re.IGNORECASE)
if why_match:
result['why'] = why_match.group(1).strip()
# Extract STUDENT LEARNING section
learning_match = re.search(r'STUDENT LEARNING:\s*(.+?)(?=\n\s*\n|$)', docstring, re.DOTALL | re.IGNORECASE)
if learning_match:
result['learning'] = learning_match.group(1).strip()
return result
def get_module_from_path(path: str) -> Optional[str]:
"""Extract module number from test file path."""
match = re.search(r'/(\d{2})_(\w+)/', str(path))
if match:
return f"Module {match.group(1)}: {match.group(2).title()}"
return None
class TinyTorchTestReporter:
"""Rich-based test reporter for educational output."""
def __init__(self):
self.current_module = None
self.passed = 0
self.failed = 0
self.skipped = 0
self.use_rich = False
try:
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
self.console = Console()
self.use_rich = True
except ImportError:
self.console = None
def print_test_start(self, nodeid: str, docstring: Optional[str]):
"""Print when a test starts (only in verbose mode)."""
if not self.use_rich:
return
# Extract test name
parts = nodeid.split("::")
test_name = parts[-1] if parts else nodeid
# Get module info
module = get_module_from_path(nodeid)
if module and module != self.current_module:
self.current_module = module
self.console.print(f"\n[bold blue]━━━ {module} ━━━[/bold blue]")
# Get purpose from docstring
purpose = extract_test_purpose(docstring)
what = purpose.get('what')
if what:
# Truncate to first line/sentence
what_short = what.split('\n')[0][:60]
self.console.print(f" [dim]⏳[/dim] {test_name}: {what_short}...")
else:
self.console.print(f" [dim]⏳[/dim] {test_name}...")
def print_test_result(self, nodeid: str, outcome: str, docstring: Optional[str] = None,
longrepr=None):
"""Print test result with educational context."""
if not self.use_rich:
return
parts = nodeid.split("::")
test_name = parts[-1] if parts else nodeid
if outcome == "passed":
self.passed += 1
self.console.print(f" [green]✓[/green] {test_name}")
elif outcome == "skipped":
self.skipped += 1
self.console.print(f" [yellow]⊘[/yellow] {test_name} [dim](skipped)[/dim]")
elif outcome == "failed":
self.failed += 1
self.console.print(f" [red]✗[/red] {test_name}")
# Show educational context on failure
purpose = extract_test_purpose(docstring)
if purpose.get('what') or purpose.get('why'):
from rich.panel import Panel
from rich.text import Text
content = Text()
if purpose.get('what'):
content.append("WHAT: ", style="bold cyan")
content.append(purpose['what'][:200] + "\n\n")
if purpose.get('why'):
content.append("WHY THIS MATTERS: ", style="bold yellow")
content.append(purpose['why'][:300])
self.console.print(Panel(content, title="[red]Test Failed[/red]",
border_style="red", padding=(0, 1)))
def print_summary(self):
"""Print final summary."""
if not self.use_rich:
return
total = self.passed + self.failed + self.skipped
self.console.print("\n" + "" * 50)
status = "[green]ALL PASSED[/green]" if self.failed == 0 else f"[red]{self.failed} FAILED[/red]"
self.console.print(f"[bold]{status}[/bold] | {self.passed} passed, {self.skipped} skipped, {total} total")
# Global reporter instance
_reporter = TinyTorchTestReporter()
# =============================================================================
# Pytest Hooks
# =============================================================================
def pytest_configure(config):
"""Configure pytest with TinyTorch-specific settings."""
# Register custom markers
config.addinivalue_line(
"markers", "module(name): mark test as belonging to a specific module"
)
config.addinivalue_line(
"markers", "slow: mark test as slow running"
)
config.addinivalue_line(
"markers", "integration: mark test as integration test"
)
def pytest_collection_modifyitems(session, config, items):
"""Modify test collection to add educational metadata."""
for item in items:
# Auto-detect module from path
module = get_module_from_path(str(item.fspath))
if module:
# Store module info for later use
item._tinytorch_module = module
@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_makereport(item, call):
"""Hook to capture test results for educational output."""
outcome = yield
report = outcome.get_result()
# Only process the "call" phase (not setup/teardown)
if report.when == "call":
# Get docstring from test function
docstring = item.function.__doc__ if hasattr(item, 'function') else None
# Store for later use if needed
report._tinytorch_docstring = docstring
def pytest_terminal_summary(terminalreporter, exitstatus, config):
"""Add educational summary at the end of test run."""
# Check if we should show educational summary
if hasattr(config, '_tinytorch_show_summary') and config._tinytorch_show_summary:
_reporter.print_summary()
# =============================================================================
# Custom Test Runner Command (for tito test)
# =============================================================================
def run_tests_with_rich_output(test_path: str = None, verbose: bool = True):
"""
Run tests with Rich educational output.
This can be called from tito CLI to provide a better student experience.
"""
from rich.console import Console
from rich.panel import Panel
console = Console()
# Header
console.print(Panel(
"[bold]🧪 TinyTorch Test Runner[/bold]\n"
"Running tests with educational context...",
border_style="blue"
))
# Build pytest args
args = ["-v", "--tb=short"]
if test_path:
args.append(test_path)
# Run pytest
exit_code = pytest.main(args)
return exit_code

266
tests/pytest_tinytorch.py Normal file
View File

@@ -0,0 +1,266 @@
"""
TinyTorch Educational Test Plugin for Pytest
=============================================
This plugin provides Rich-formatted output that helps students understand
what tests are checking and why they matter.
USAGE:
pytest --tinytorch # Enable educational output
pytest --tinytorch -v # Verbose educational output
Or run through tito:
tito test --edu # Educational mode
"""
import re
from typing import Optional, Dict, Any
import pytest
def pytest_addoption(parser):
"""Add TinyTorch-specific command line options."""
group = parser.getgroup('tinytorch', 'TinyTorch educational testing')
group.addoption(
'--tinytorch',
action='store_true',
dest='tinytorch_edu',
default=False,
help='Enable TinyTorch educational test output'
)
def pytest_configure(config):
"""Configure the plugin."""
if config.getoption('tinytorch_edu', False):
config.pluginmanager.register(TinyTorchReporter(config), 'tinytorch_reporter')
class TinyTorchReporter:
"""
Rich-based reporter that shows educational context for tests.
Features:
- Module grouping with descriptions
- WHAT/WHY extraction from docstrings
- Clear pass/fail indicators
- Educational failure messages
"""
def __init__(self, config):
self.config = config
self.current_module = None
self.stats = {'passed': 0, 'failed': 0, 'skipped': 0, 'error': 0}
self.failures = []
try:
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
self.console = Console()
self.rich_available = True
except ImportError:
self.rich_available = False
def _extract_purpose(self, docstring: Optional[str]) -> Dict[str, Optional[str]]:
"""Extract WHAT/WHY/LEARNING from docstring."""
if not docstring:
return {'what': None, 'why': None, 'learning': None}
result = {}
# Extract WHAT
what_match = re.search(r'WHAT:\s*(.+?)(?=\n\s*\n|WHY:|$)', docstring, re.DOTALL | re.IGNORECASE)
result['what'] = what_match.group(1).strip() if what_match else None
# Extract WHY
why_match = re.search(r'WHY:\s*(.+?)(?=\n\s*\n|STUDENT|HOW:|$)', docstring, re.DOTALL | re.IGNORECASE)
result['why'] = why_match.group(1).strip() if why_match else None
# Extract STUDENT LEARNING
learning_match = re.search(r'STUDENT LEARNING:\s*(.+?)(?=\n\s*\n|$)', docstring, re.DOTALL)
result['learning'] = learning_match.group(1).strip() if learning_match else None
return result
def _get_module_info(self, nodeid: str) -> Optional[str]:
"""Extract module name from test path."""
match = re.search(r'/(\d{2})_(\w+)/', nodeid)
if match:
num, name = match.groups()
return f"Module {num}: {name.replace('_', ' ').title()}"
# Check for other test categories
if '/integration/' in nodeid:
return "Integration Tests"
if '/regression/' in nodeid:
return "Regression Tests"
if '/e2e/' in nodeid:
return "End-to-End Tests"
return None
@pytest.hookimpl(hookwrapper=True)
def pytest_collection_finish(self, session):
"""Called after collection, show what we're testing."""
yield
if not self.rich_available:
return
from rich.panel import Panel
from rich.table import Table
# Group tests by module
modules = {}
for item in session.items:
module = self._get_module_info(item.nodeid) or "Other Tests"
if module not in modules:
modules[module] = []
modules[module].append(item.name)
# Create summary table
table = Table(show_header=True, header_style="bold blue")
table.add_column("Module", style="cyan")
table.add_column("Tests", justify="right")
table.add_column("Sample Tests", style="dim")
for module, tests in sorted(modules.items()):
sample = ", ".join(tests[:2])
if len(tests) > 2:
sample += f", ... (+{len(tests)-2} more)"
table.add_row(module, str(len(tests)), sample)
self.console.print(Panel(
table,
title="[bold]🧪 TinyTorch Test Suite[/bold]",
subtitle=f"[dim]{len(session.items)} tests to run[/dim]",
border_style="blue"
))
self.console.print()
@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_protocol(self, item):
"""Called for each test."""
# Check if we're entering a new module
module = self._get_module_info(item.nodeid)
if self.rich_available and module and module != self.current_module:
self.current_module = module
self.console.print(f"\n[bold blue]━━━ {module} ━━━[/bold blue]")
yield
@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_makereport(self, item, call):
"""Process test results."""
outcome = yield
report = outcome.get_result()
if report.when != "call":
return
if not self.rich_available:
return
# Get test info
test_name = item.name
docstring = item.function.__doc__ if hasattr(item, 'function') else None
purpose = self._extract_purpose(docstring)
# Format output based on result
if report.passed:
self.stats['passed'] += 1
what = purpose.get('what', '')
if what:
what_short = what.split('\n')[0][:50]
self.console.print(f" [green]✓[/green] {test_name} [dim]- {what_short}[/dim]")
else:
self.console.print(f" [green]✓[/green] {test_name}")
elif report.skipped:
self.stats['skipped'] += 1
self.console.print(f" [yellow]⊘[/yellow] {test_name} [dim](skipped)[/dim]")
elif report.failed:
self.stats['failed'] += 1
self.console.print(f" [red]✗[/red] {test_name}")
# Store failure info for detailed output
self.failures.append({
'name': test_name,
'nodeid': item.nodeid,
'purpose': purpose,
'longrepr': report.longreprtext
})
def pytest_sessionfinish(self, session, exitstatus):
"""Called at the end of the session."""
if not self.rich_available:
return
from rich.panel import Panel
from rich.text import Text
self.console.print()
# Show failure details with educational context
if self.failures:
self.console.print("[bold red]━━━ Failed Tests ━━━[/bold red]\n")
for failure in self.failures:
# Create educational failure panel
content = Text()
purpose = failure['purpose']
if purpose.get('what'):
content.append("📋 WHAT: ", style="bold cyan")
content.append(purpose['what'][:200] + "\n\n", style="white")
if purpose.get('why'):
content.append("❓ WHY: ", style="bold yellow")
content.append(purpose['why'][:300] + "\n\n", style="white")
if purpose.get('learning'):
content.append("💡 TIP: ", style="bold green")
content.append(purpose['learning'][:200] + "\n\n", style="white")
# Add error excerpt
error_lines = failure['longrepr'].split('\n')
error_excerpt = '\n'.join(error_lines[-10:]) # Last 10 lines
content.append("🔍 Error:\n", style="bold red")
content.append(error_excerpt[:500], style="dim")
self.console.print(Panel(
content,
title=f"[red]✗ {failure['name']}[/red]",
border_style="red",
padding=(1, 2)
))
self.console.print()
# Summary
total = sum(self.stats.values())
passed = self.stats['passed']
failed = self.stats['failed']
skipped = self.stats['skipped']
if failed == 0:
status_style = "green"
status_text = "ALL TESTS PASSED"
emoji = "🎉"
else:
status_style = "red"
status_text = f"{failed} TESTS FAILED"
emoji = ""
summary = Text()
summary.append(f"\n{emoji} ", style="bold")
summary.append(status_text, style=f"bold {status_style}")
summary.append(f"\n\n Passed: {passed}", style="green")
summary.append(f" Failed: {failed}", style="red")
summary.append(f" Skipped: {skipped}", style="yellow")
summary.append(f" Total: {total}", style="dim")
self.console.print(Panel(summary, border_style=status_style))