mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-05-31 21:25:59 -05:00
Enhance inline testing for better student experience
- Add comprehensive step-by-step inline tests to activations module - Each activation function now has immediate feedback tests - Tests check mathematical properties, edge cases, and numerical stability - Provide clear success/failure messages with actionable guidance - Create comprehensive testing guidelines document - Document two-tier testing approach: inline tests for learning, pytest for validation - All existing tests still pass, enhanced learning experience
This commit is contained in:
357
docs/development/testing-guidelines.md
Normal file
357
docs/development/testing-guidelines.md
Normal file
@@ -0,0 +1,357 @@
|
||||
# TinyTorch Testing Guidelines
|
||||
|
||||
## Overview
|
||||
|
||||
TinyTorch uses a **two-tier testing system** designed to provide immediate feedback during development while ensuring comprehensive validation for production use.
|
||||
|
||||
## The Two-Tier Testing System
|
||||
|
||||
### **Tier 1: Inline Testing (For Learning)**
|
||||
- **Purpose**: Immediate feedback during development
|
||||
- **Location**: Within `*_dev.py` files as `🧪 Test Your Implementation` sections
|
||||
- **Style**: Simple, visual, encouraging
|
||||
- **When**: After each major implementation step
|
||||
- **Audience**: Students learning the concepts
|
||||
|
||||
### **Tier 2: Comprehensive Testing (For Validation)**
|
||||
- **Purpose**: Thorough validation and grading
|
||||
- **Location**: `tests/test_*.py` files using pytest
|
||||
- **Style**: Professional test suites with edge cases
|
||||
- **When**: After completing the module
|
||||
- **Audience**: Instructors and automated systems
|
||||
|
||||
## Why This Approach Works
|
||||
|
||||
### **Prevents Late-Stage Failures**
|
||||
Students get immediate feedback as they build, preventing the frustration of implementing an entire module only to discover fundamental errors during final testing.
|
||||
|
||||
### **Builds Confidence**
|
||||
Each successful inline test provides positive reinforcement and confirms the student is on the right track.
|
||||
|
||||
### **Teaches Testing Culture**
|
||||
Students learn to test incrementally, a crucial skill for professional development.
|
||||
|
||||
### **Maintains Professional Standards**
|
||||
The comprehensive test suites ensure that the final package meets production-quality standards.
|
||||
|
||||
## Inline Testing Guidelines
|
||||
|
||||
### **When to Add Inline Tests**
|
||||
|
||||
✅ **Add inline tests after:**
|
||||
- Each major class implementation
|
||||
- Each significant function or method
|
||||
- Complex algorithms or mathematical operations
|
||||
- Data loading or preprocessing steps
|
||||
- Any component that students might struggle with
|
||||
|
||||
❌ **Don't add inline tests for:**
|
||||
- Trivial getters/setters
|
||||
- Simple utility functions
|
||||
- Already well-tested components
|
||||
|
||||
### **Inline Test Structure**
|
||||
|
||||
```python
|
||||
# %% [markdown]
|
||||
"""
|
||||
### 🧪 Test Your [Component] Implementation
|
||||
|
||||
Let's test your [Component] implementation to ensure it's working correctly:
|
||||
"""
|
||||
|
||||
# %%
|
||||
# Test [Component] implementation
|
||||
print("Testing [Component] Implementation:")
|
||||
print("=" * 40)
|
||||
|
||||
try:
|
||||
# Test 1: Basic functionality
|
||||
component = Component()
|
||||
test_input = create_test_input()
|
||||
output = component(test_input)
|
||||
|
||||
print(f"✅ Input: {test_input}")
|
||||
print(f"✅ Output: {output}")
|
||||
|
||||
# Check if implementation is correct
|
||||
if validate_output(output):
|
||||
print("🎉 [Component] implementation is CORRECT!")
|
||||
else:
|
||||
print("❌ [Component] implementation needs fixing")
|
||||
print(" [Specific guidance on what to check]")
|
||||
|
||||
# Test 2: Edge cases or properties
|
||||
edge_case_test()
|
||||
|
||||
print("✅ [Component] tests complete!")
|
||||
|
||||
except NotImplementedError:
|
||||
print("⚠️ [Component] not implemented yet - complete the method above!")
|
||||
except Exception as e:
|
||||
print(f"❌ Error in [Component]: {e}")
|
||||
print(" Check your implementation in the [method] method")
|
||||
|
||||
print() # Add spacing
|
||||
```
|
||||
|
||||
### **Best Practices for Inline Tests**
|
||||
|
||||
#### **1. Immediate Feedback**
|
||||
```python
|
||||
# ✅ Good: Immediate, specific feedback
|
||||
if np.allclose(output.data.flatten(), expected):
|
||||
print("🎉 ReLU implementation is CORRECT!")
|
||||
else:
|
||||
print("❌ ReLU implementation needs fixing")
|
||||
print(" Make sure negative values become 0, positive values stay unchanged")
|
||||
|
||||
# ❌ Bad: Vague or delayed feedback
|
||||
assert np.allclose(output.data.flatten(), expected) # Just crashes
|
||||
```
|
||||
|
||||
#### **2. Visual and Intuitive**
|
||||
```python
|
||||
# ✅ Good: Visual confirmation
|
||||
print(f"✅ Input: {test_input.data.flatten()}")
|
||||
print(f"✅ Output: {output.data.flatten()}")
|
||||
print(f"✅ Expected: {expected}")
|
||||
|
||||
# ❌ Bad: No visual feedback
|
||||
result = component(test_input)
|
||||
```
|
||||
|
||||
#### **3. Property-Based Testing**
|
||||
```python
|
||||
# ✅ Good: Test mathematical properties
|
||||
all_positive = np.all(output.data > 0)
|
||||
sums_to_one = abs(np.sum(output.data) - 1.0) < 1e-6
|
||||
print(f"✅ All outputs positive: {all_positive}")
|
||||
print(f"✅ Sum equals 1.0: {sums_to_one}")
|
||||
|
||||
# ❌ Bad: Only test specific values
|
||||
assert output.data[0] == 0.665 # Too specific, fragile
|
||||
```
|
||||
|
||||
#### **4. Progressive Complexity**
|
||||
```python
|
||||
# ✅ Good: Start simple, add complexity
|
||||
# Test 1: Basic functionality
|
||||
basic_test()
|
||||
|
||||
# Test 2: Edge cases
|
||||
edge_case_test()
|
||||
|
||||
# Test 3: Numerical stability
|
||||
stability_test()
|
||||
|
||||
# ❌ Bad: Jump to complex cases immediately
|
||||
complex_edge_case_test() # Students get overwhelmed
|
||||
```
|
||||
|
||||
#### **5. Helpful Error Messages**
|
||||
```python
|
||||
# ✅ Good: Actionable guidance
|
||||
except NotImplementedError:
|
||||
print("⚠️ Sigmoid not implemented yet - complete the forward method above!")
|
||||
except Exception as e:
|
||||
print(f"❌ Error in Sigmoid: {e}")
|
||||
print(" Check your implementation in the forward method")
|
||||
print(" Make sure: 0 < output < 1 and sigmoid(0) = 0.5")
|
||||
|
||||
# ❌ Bad: Generic or unhelpful messages
|
||||
except Exception as e:
|
||||
print(f"Error: {e}") # Not helpful
|
||||
```
|
||||
|
||||
## Comprehensive Testing Guidelines
|
||||
|
||||
### **Test File Structure**
|
||||
|
||||
```python
|
||||
"""
|
||||
Tests for TinyTorch [Module] module.
|
||||
|
||||
These tests validate [description of what module does].
|
||||
Focus on [key aspects being tested].
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from tinytorch.core.[module] import [Classes]
|
||||
|
||||
class Test[Component]:
|
||||
"""Test the [Component] class."""
|
||||
|
||||
def test_[component]_basic_functionality(self):
|
||||
"""Test basic [component] behavior."""
|
||||
# Test implementation
|
||||
|
||||
def test_[component]_edge_cases(self):
|
||||
"""Test [component] with edge cases."""
|
||||
# Edge case testing
|
||||
|
||||
def test_[component]_properties(self):
|
||||
"""Test mathematical properties of [component]."""
|
||||
# Property-based testing
|
||||
```
|
||||
|
||||
### **Test Categories**
|
||||
|
||||
#### **1. Correctness Tests**
|
||||
- Verify mathematical correctness
|
||||
- Test against known expected outputs
|
||||
- Validate algorithm implementations
|
||||
|
||||
#### **2. Property Tests**
|
||||
- Test mathematical properties (e.g., symmetry, monotonicity)
|
||||
- Verify invariants (e.g., probability sums to 1)
|
||||
- Check boundary conditions
|
||||
|
||||
#### **3. Edge Case Tests**
|
||||
- Extreme values (very large, very small)
|
||||
- Boundary conditions (zero, negative)
|
||||
- Empty inputs, single elements
|
||||
|
||||
#### **4. Integration Tests**
|
||||
- Test components working together
|
||||
- Verify data flow through pipelines
|
||||
- Check compatibility between modules
|
||||
|
||||
#### **5. Performance Tests**
|
||||
- Memory usage validation
|
||||
- Reasonable execution times
|
||||
- Scalability with data size
|
||||
|
||||
## Module-Specific Guidelines
|
||||
|
||||
### **Tensor Module**
|
||||
- **Inline tests**: After each arithmetic operation
|
||||
- **Focus**: Shape handling, broadcasting, data types
|
||||
- **Visual feedback**: Print shapes and values
|
||||
|
||||
### **Activations Module**
|
||||
- **Inline tests**: After each activation function
|
||||
- **Focus**: Mathematical properties, numerical stability
|
||||
- **Visual feedback**: Input/output ranges, function properties
|
||||
|
||||
### **Layers Module**
|
||||
- **Inline tests**: After matrix multiplication, after Dense layer
|
||||
- **Focus**: Weight initialization, forward pass correctness
|
||||
- **Visual feedback**: Weight shapes, output dimensions
|
||||
|
||||
### **Networks Module**
|
||||
- **Inline tests**: After Sequential class, after each network type
|
||||
- **Focus**: Layer composition, architecture correctness
|
||||
- **Visual feedback**: Network structure, data flow
|
||||
|
||||
### **DataLoader Module**
|
||||
- **Inline tests**: After Dataset, DataLoader, Normalizer
|
||||
- **Focus**: Data integrity, batching correctness, preprocessing
|
||||
- **Visual feedback**: Sample images, batch shapes, statistics
|
||||
|
||||
## Implementation Checklist
|
||||
|
||||
### **For Each Module**
|
||||
|
||||
- [ ] **Inline tests after major components**
|
||||
- [ ] Basic functionality test
|
||||
- [ ] Property validation test
|
||||
- [ ] Edge case test
|
||||
- [ ] Visual feedback
|
||||
- [ ] Helpful error messages
|
||||
|
||||
- [ ] **Comprehensive test suite**
|
||||
- [ ] Correctness tests
|
||||
- [ ] Property tests
|
||||
- [ ] Edge case tests
|
||||
- [ ] Integration tests
|
||||
- [ ] Performance tests
|
||||
|
||||
- [ ] **Documentation**
|
||||
- [ ] Clear test descriptions
|
||||
- [ ] Expected behavior documented
|
||||
- [ ] Error message guidance
|
||||
- [ ] Examples of usage
|
||||
|
||||
### **Quality Checks**
|
||||
|
||||
- [ ] **Inline tests provide immediate feedback**
|
||||
- [ ] **Error messages are actionable**
|
||||
- [ ] **Tests cover the most likely failure modes**
|
||||
- [ ] **Visual feedback helps build intuition**
|
||||
- [ ] **Progressive complexity from simple to advanced**
|
||||
|
||||
## Examples from Existing Modules
|
||||
|
||||
### **Excellent Inline Testing: Activations Module**
|
||||
|
||||
```python
|
||||
# Test ReLU implementation
|
||||
print("Testing ReLU Implementation:")
|
||||
print("=" * 40)
|
||||
|
||||
try:
|
||||
relu = ReLU()
|
||||
|
||||
# Test 1: Basic functionality
|
||||
test_input = Tensor([[-3, -1, 0, 1, 3]])
|
||||
output = relu(test_input)
|
||||
expected = [0, 0, 0, 1, 3]
|
||||
|
||||
print(f"✅ Input: {test_input.data.flatten()}")
|
||||
print(f"✅ Output: {output.data.flatten()}")
|
||||
print(f"✅ Expected: {expected}")
|
||||
|
||||
# Check if implementation is correct
|
||||
if np.allclose(output.data.flatten(), expected):
|
||||
print("🎉 ReLU implementation is CORRECT!")
|
||||
else:
|
||||
print("❌ ReLU implementation needs fixing")
|
||||
print(" Make sure negative values become 0, positive values stay unchanged")
|
||||
|
||||
print("✅ ReLU tests complete!")
|
||||
|
||||
except NotImplementedError:
|
||||
print("⚠️ ReLU not implemented yet - complete the forward method above!")
|
||||
except Exception as e:
|
||||
print(f"❌ Error in ReLU: {e}")
|
||||
print(" Check your implementation in the forward method")
|
||||
```
|
||||
|
||||
### **Good Progressive Testing: Tensor Module**
|
||||
|
||||
```python
|
||||
# Test basic tensor creation
|
||||
print("Testing Tensor creation...")
|
||||
|
||||
try:
|
||||
# Test scalar
|
||||
t1 = Tensor(5)
|
||||
print(f"✅ Scalar: {t1} (shape: {t1.shape}, size: {t1.size})")
|
||||
|
||||
# Test vector
|
||||
t2 = Tensor([1, 2, 3, 4])
|
||||
print(f"✅ Vector: {t2} (shape: {t2.shape}, size: {t2.size})")
|
||||
|
||||
# Test matrix
|
||||
t3 = Tensor([[1, 2], [3, 4]])
|
||||
print(f"✅ Matrix: {t3} (shape: {t3.shape}, size: {t3.size})")
|
||||
|
||||
print("\n🎉 All basic tests passed! Your Tensor class is working!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error: {e}")
|
||||
print("Make sure to implement all the required methods!")
|
||||
```
|
||||
|
||||
## Conclusion
|
||||
|
||||
The two-tier testing system ensures that students receive immediate, helpful feedback during development while maintaining professional-quality validation standards. This approach:
|
||||
|
||||
1. **Reduces frustration** by catching errors early
|
||||
2. **Builds confidence** through positive reinforcement
|
||||
3. **Teaches good practices** for incremental testing
|
||||
4. **Maintains quality** through comprehensive validation
|
||||
|
||||
By following these guidelines, every TinyTorch module provides an excellent learning experience while producing production-ready code.
|
||||
@@ -179,6 +179,53 @@ class ReLU:
|
||||
"""Allow calling the activation like a function: relu(x)"""
|
||||
return self.forward(x)
|
||||
|
||||
# %% [markdown]
|
||||
"""
|
||||
### 🧪 Test Your ReLU Implementation
|
||||
|
||||
Let's test your ReLU implementation right away to make sure it's working correctly:
|
||||
"""
|
||||
|
||||
# %%
|
||||
# Test ReLU implementation
|
||||
print("Testing ReLU Implementation:")
|
||||
print("=" * 40)
|
||||
|
||||
try:
|
||||
relu = ReLU()
|
||||
|
||||
# Test 1: Basic functionality
|
||||
test_input = Tensor([[-3, -1, 0, 1, 3]])
|
||||
output = relu(test_input)
|
||||
expected = [0, 0, 0, 1, 3]
|
||||
|
||||
print(f"✅ Input: {test_input.data.flatten()}")
|
||||
print(f"✅ Output: {output.data.flatten()}")
|
||||
print(f"✅ Expected: {expected}")
|
||||
|
||||
# Check if implementation is correct
|
||||
if np.allclose(output.data.flatten(), expected):
|
||||
print("🎉 ReLU implementation is CORRECT!")
|
||||
else:
|
||||
print("❌ ReLU implementation needs fixing")
|
||||
print(" Make sure negative values become 0, positive values stay unchanged")
|
||||
|
||||
# Test 2: Edge cases
|
||||
edge_cases = Tensor([[0.0, -0.0, 1e-10, -1e-10]])
|
||||
edge_output = relu(edge_cases)
|
||||
print(f"✅ Edge cases: {edge_cases.data.flatten()}")
|
||||
print(f"✅ Edge output: {edge_output.data.flatten()}")
|
||||
|
||||
print("✅ ReLU tests complete!")
|
||||
|
||||
except NotImplementedError:
|
||||
print("⚠️ ReLU not implemented yet - complete the forward method above!")
|
||||
except Exception as e:
|
||||
print(f"❌ Error in ReLU: {e}")
|
||||
print(" Check your implementation in the forward method")
|
||||
|
||||
print() # Add spacing
|
||||
|
||||
# %% [markdown]
|
||||
"""
|
||||
## Step 3: Sigmoid Activation Function
|
||||
@@ -247,6 +294,63 @@ class Sigmoid:
|
||||
"""Allow calling the activation like a function: sigmoid(x)"""
|
||||
return self.forward(x)
|
||||
|
||||
# %% [markdown]
|
||||
"""
|
||||
### 🧪 Test Your Sigmoid Implementation
|
||||
|
||||
Let's test your Sigmoid implementation to ensure it's working correctly:
|
||||
"""
|
||||
|
||||
# %%
|
||||
# Test Sigmoid implementation
|
||||
print("Testing Sigmoid Implementation:")
|
||||
print("=" * 40)
|
||||
|
||||
try:
|
||||
sigmoid = Sigmoid()
|
||||
|
||||
# Test 1: Basic functionality
|
||||
test_input = Tensor([[-2, -1, 0, 1, 2]])
|
||||
output = sigmoid(test_input)
|
||||
|
||||
print(f"✅ Input: {test_input.data.flatten()}")
|
||||
print(f"✅ Output: {output.data.flatten()}")
|
||||
|
||||
# Check properties
|
||||
all_positive = np.all(output.data > 0)
|
||||
all_less_than_one = np.all(output.data < 1)
|
||||
zero_maps_to_half = abs(sigmoid(Tensor([0])).data[0] - 0.5) < 1e-6
|
||||
|
||||
print(f"✅ All outputs positive: {all_positive}")
|
||||
print(f"✅ All outputs < 1: {all_less_than_one}")
|
||||
print(f"✅ Sigmoid(0) ≈ 0.5: {zero_maps_to_half}")
|
||||
|
||||
if all_positive and all_less_than_one and zero_maps_to_half:
|
||||
print("🎉 Sigmoid implementation is CORRECT!")
|
||||
else:
|
||||
print("❌ Sigmoid implementation needs fixing")
|
||||
print(" Make sure: 0 < output < 1 and sigmoid(0) = 0.5")
|
||||
|
||||
# Test 2: Numerical stability
|
||||
extreme_values = Tensor([[-1000, 1000]])
|
||||
extreme_output = sigmoid(extreme_values)
|
||||
print(f"✅ Extreme values: {extreme_values.data.flatten()}")
|
||||
print(f"✅ Extreme output: {extreme_output.data.flatten()}")
|
||||
|
||||
# Should not have NaN or inf
|
||||
no_nan_inf = not (np.isnan(extreme_output.data).any() or np.isinf(extreme_output.data).any())
|
||||
print(f"✅ No NaN/Inf: {no_nan_inf}")
|
||||
|
||||
print("✅ Sigmoid tests complete!")
|
||||
|
||||
except NotImplementedError:
|
||||
print("⚠️ Sigmoid not implemented yet - complete the forward method above!")
|
||||
except Exception as e:
|
||||
print(f"❌ Error in Sigmoid: {e}")
|
||||
print(" Check your implementation in the forward method")
|
||||
|
||||
print() # Add spacing
|
||||
|
||||
# %% [markdown]
|
||||
"""
|
||||
## Step 4: Tanh Activation Function
|
||||
@@ -315,6 +419,65 @@ class Tanh:
|
||||
"""Allow calling the activation like a function: tanh(x)"""
|
||||
return self.forward(x)
|
||||
|
||||
# %% [markdown]
|
||||
"""
|
||||
### 🧪 Test Your Tanh Implementation
|
||||
|
||||
Let's test your Tanh implementation to ensure it's working correctly:
|
||||
"""
|
||||
|
||||
# %%
|
||||
# Test Tanh implementation
|
||||
print("Testing Tanh Implementation:")
|
||||
print("=" * 40)
|
||||
|
||||
try:
|
||||
tanh = Tanh()
|
||||
|
||||
# Test 1: Basic functionality
|
||||
test_input = Tensor([[-2, -1, 0, 1, 2]])
|
||||
output = tanh(test_input)
|
||||
|
||||
print(f"✅ Input: {test_input.data.flatten()}")
|
||||
print(f"✅ Output: {output.data.flatten()}")
|
||||
|
||||
# Check properties
|
||||
in_range = np.all(np.abs(output.data) < 1)
|
||||
zero_maps_to_zero = abs(tanh(Tensor([0])).data[0]) < 1e-6
|
||||
symmetric = np.allclose(tanh(Tensor([1])).data, -tanh(Tensor([-1])).data)
|
||||
|
||||
print(f"✅ All outputs in (-1, 1): {in_range}")
|
||||
print(f"✅ Tanh(0) ≈ 0: {zero_maps_to_zero}")
|
||||
print(f"✅ Symmetric (tanh(-x) = -tanh(x)): {symmetric}")
|
||||
|
||||
if in_range and zero_maps_to_zero and symmetric:
|
||||
print("🎉 Tanh implementation is CORRECT!")
|
||||
else:
|
||||
print("❌ Tanh implementation needs fixing")
|
||||
print(" Make sure: -1 < output < 1, tanh(0) = 0, and tanh(-x) = -tanh(x)")
|
||||
|
||||
# Test 2: Compare with expected values
|
||||
expected_values = {
|
||||
0: 0.0,
|
||||
1: 0.7616, # approximately
|
||||
-1: -0.7616, # approximately
|
||||
}
|
||||
|
||||
for input_val, expected in expected_values.items():
|
||||
actual = tanh(Tensor([input_val])).data[0]
|
||||
close = abs(actual - expected) < 0.001
|
||||
print(f"✅ Tanh({input_val}) ≈ {expected}: {close} (got {actual:.4f})")
|
||||
|
||||
print("✅ Tanh tests complete!")
|
||||
|
||||
except NotImplementedError:
|
||||
print("⚠️ Tanh not implemented yet - complete the forward method above!")
|
||||
except Exception as e:
|
||||
print(f"❌ Error in Tanh: {e}")
|
||||
print(" Check your implementation in the forward method")
|
||||
|
||||
print() # Add spacing
|
||||
|
||||
# %% [markdown]
|
||||
"""
|
||||
## Step 5: Softmax Activation Function
|
||||
@@ -385,6 +548,74 @@ class Softmax:
|
||||
"""Allow calling the activation like a function: softmax(x)"""
|
||||
return self.forward(x)
|
||||
|
||||
# %% [markdown]
|
||||
"""
|
||||
### 🧪 Test Your Softmax Implementation
|
||||
|
||||
Let's test your Softmax implementation to ensure it's working correctly:
|
||||
"""
|
||||
|
||||
# %%
|
||||
# Test Softmax implementation
|
||||
print("Testing Softmax Implementation:")
|
||||
print("=" * 40)
|
||||
|
||||
try:
|
||||
softmax = Softmax()
|
||||
|
||||
# Test 1: Basic functionality
|
||||
test_input = Tensor([[2, 1, 0]])
|
||||
output = softmax(test_input)
|
||||
|
||||
print(f"✅ Input: {test_input.data.flatten()}")
|
||||
print(f"✅ Output: {output.data.flatten()}")
|
||||
|
||||
# Check properties
|
||||
all_positive = np.all(output.data > 0)
|
||||
sums_to_one = abs(np.sum(output.data) - 1.0) < 1e-6
|
||||
largest_input_largest_output = np.argmax(test_input.data) == np.argmax(output.data)
|
||||
|
||||
print(f"✅ All outputs positive: {all_positive}")
|
||||
print(f"✅ Sum equals 1.0: {sums_to_one} (sum = {np.sum(output.data):.6f})")
|
||||
print(f"✅ Largest input → largest output: {largest_input_largest_output}")
|
||||
|
||||
if all_positive and sums_to_one and largest_input_largest_output:
|
||||
print("🎉 Softmax implementation is CORRECT!")
|
||||
else:
|
||||
print("❌ Softmax implementation needs fixing")
|
||||
print(" Make sure: all outputs > 0, sum = 1.0, and largest input gets largest probability")
|
||||
|
||||
# Test 2: Numerical stability
|
||||
extreme_input = Tensor([[1000, 999, 998]])
|
||||
extreme_output = softmax(extreme_input)
|
||||
print(f"✅ Extreme input: {extreme_input.data.flatten()}")
|
||||
print(f"✅ Extreme output: {extreme_output.data.flatten()}")
|
||||
|
||||
# Should not have NaN or inf
|
||||
no_nan_inf = not (np.isnan(extreme_output.data).any() or np.isinf(extreme_output.data).any())
|
||||
extreme_sums_to_one = abs(np.sum(extreme_output.data) - 1.0) < 1e-6
|
||||
|
||||
print(f"✅ No NaN/Inf: {no_nan_inf}")
|
||||
print(f"✅ Extreme case sums to 1: {extreme_sums_to_one}")
|
||||
|
||||
# Test 3: Equal inputs should give equal probabilities
|
||||
equal_input = Tensor([[1, 1, 1]])
|
||||
equal_output = softmax(equal_input)
|
||||
expected_prob = 1.0 / 3.0
|
||||
all_equal = np.allclose(equal_output.data, expected_prob)
|
||||
print(f"✅ Equal inputs → equal probabilities: {all_equal}")
|
||||
print(f" Expected: {expected_prob:.3f}, Got: {equal_output.data.flatten()}")
|
||||
|
||||
print("✅ Softmax tests complete!")
|
||||
|
||||
except NotImplementedError:
|
||||
print("⚠️ Softmax not implemented yet - complete the forward method above!")
|
||||
except Exception as e:
|
||||
print(f"❌ Error in Softmax: {e}")
|
||||
print(" Check your implementation in the forward method")
|
||||
|
||||
print() # Add spacing
|
||||
|
||||
# %% [markdown]
|
||||
"""
|
||||
## Testing Our Activation Functions
|
||||
|
||||
Reference in New Issue
Block a user