mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-05-06 17:27:32 -05:00
fix: resolve 05_cnn external test failures completely
🎯 Issues Fixed: 1. Conv2D Layer: Made polymorphic to preserve input tensor types (MockTensor compatibility) 2. Flatten Function: Made polymorphic to return same type as input tensor 3. Type Signatures: Updated method signatures to be flexible (remove Tensor type annotations) ✅ Impact: 05_cnn external tests now pass 35/35 (was 31/35) 🔧 Technical Changes: - Conv2D.forward(): return type(x)(result) instead of Tensor(result) - flatten(): return type(x)(result) instead of Tensor(result) - Updated method signatures: forward(self, x) instead of forward(self, x: Tensor) -> Tensor - Consistent polymorphic pattern across all CNN components This resolves the MockTensor vs Tensor compatibility issues, making CNN components work with external testing frameworks.
This commit is contained in:
@@ -345,7 +345,7 @@ class Conv2D:
|
||||
self.kernel = np.random.randn(kH, kW).astype(np.float32) * 0.1
|
||||
### END SOLUTION
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass: apply convolution to input tensor.
|
||||
|
||||
@@ -375,10 +375,10 @@ class Conv2D:
|
||||
### BEGIN SOLUTION
|
||||
# Apply convolution using naive implementation
|
||||
result = conv2d_naive(x.data, self.kernel)
|
||||
return Tensor(result)
|
||||
return type(x)(result)
|
||||
### END SOLUTION
|
||||
|
||||
def __call__(self, x: Tensor) -> Tensor:
|
||||
def __call__(self, x):
|
||||
"""Make layer callable: layer(x) same as layer.forward(x)"""
|
||||
return self.forward(x)
|
||||
|
||||
@@ -469,7 +469,7 @@ Conv2D → ReLU → Conv2D → ReLU → Flatten → Dense → Output
|
||||
|
||||
# %% nbgrader={"grade": false, "grade_id": "flatten-function", "locked": false, "schema_version": 3, "solution": true, "task": false}
|
||||
#| export
|
||||
def flatten(x: Tensor) -> Tensor:
|
||||
def flatten(x):
|
||||
"""
|
||||
Flatten a 2D tensor to 1D (for connecting to Dense layers).
|
||||
|
||||
@@ -500,7 +500,7 @@ def flatten(x: Tensor) -> Tensor:
|
||||
# Flatten the tensor and add batch dimension
|
||||
flattened = x.data.flatten()
|
||||
result = flattened[None, :] # Add batch dimension
|
||||
return Tensor(result)
|
||||
return type(x)(result)
|
||||
### END SOLUTION
|
||||
|
||||
# %% [markdown]
|
||||
|
||||
Reference in New Issue
Block a user