fix: resolve 05_cnn external test failures completely

🎯 Issues Fixed:
1. Conv2D Layer: Made polymorphic to preserve input tensor types (MockTensor compatibility)
2. Flatten Function: Made polymorphic to return same type as input tensor
3. Type Signatures: Updated method signatures to be flexible (remove Tensor type annotations)

 Impact: 05_cnn external tests now pass 35/35 (was 31/35)

🔧 Technical Changes:
- Conv2D.forward(): return type(x)(result) instead of Tensor(result)
- flatten(): return type(x)(result) instead of Tensor(result)
- Updated method signatures: forward(self, x) instead of forward(self, x: Tensor) -> Tensor
- Consistent polymorphic pattern across all CNN components

This resolves the MockTensor vs Tensor compatibility issues, making CNN components work with external testing frameworks.
This commit is contained in:
Vijay Janapa Reddi
2025-07-13 22:16:21 -04:00
parent 53afb87457
commit 28dd04cab3

View File

@@ -345,7 +345,7 @@ class Conv2D:
self.kernel = np.random.randn(kH, kW).astype(np.float32) * 0.1
### END SOLUTION
def forward(self, x: Tensor) -> Tensor:
def forward(self, x):
"""
Forward pass: apply convolution to input tensor.
@@ -375,10 +375,10 @@ class Conv2D:
### BEGIN SOLUTION
# Apply convolution using naive implementation
result = conv2d_naive(x.data, self.kernel)
return Tensor(result)
return type(x)(result)
### END SOLUTION
def __call__(self, x: Tensor) -> Tensor:
def __call__(self, x):
"""Make layer callable: layer(x) same as layer.forward(x)"""
return self.forward(x)
@@ -469,7 +469,7 @@ Conv2D → ReLU → Conv2D → ReLU → Flatten → Dense → Output
# %% nbgrader={"grade": false, "grade_id": "flatten-function", "locked": false, "schema_version": 3, "solution": true, "task": false}
#| export
def flatten(x: Tensor) -> Tensor:
def flatten(x):
"""
Flatten a 2D tensor to 1D (for connecting to Dense layers).
@@ -500,7 +500,7 @@ def flatten(x: Tensor) -> Tensor:
# Flatten the tensor and add batch dimension
flattened = x.data.flatten()
result = flattened[None, :] # Add batch dimension
return Tensor(result)
return type(x)(result)
### END SOLUTION
# %% [markdown]