mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-03-11 21:03:34 -05:00
Fix module dependency chain - clean imports now work
Critical fixes to resolve module import issues: 1. Module 01 (tensor_dev.py): - Wrapped all test calls in if __name__ == '__main__': guards - Tests no longer execute during import - Clean imports now work: from tensor_dev import Tensor 2. Module 08 (dataloader_dev.py): - REMOVED redefined Tensor class (was breaking dependency chain) - Now imports real Tensor from Module 01 - DataLoader uses actual Tensor with full gradient support Impact: - Modules properly build on previous work (no isolated implementations) - Clean dependency chain: each module imports from previous modules - No test execution during imports = fast, clean module loading This resolves the root cause where DataLoader had to redefine Tensor because importing tensor_dev.py would execute all test code.
This commit is contained in:
@@ -70,23 +70,11 @@ import os
|
||||
import gzip
|
||||
import urllib.request
|
||||
import pickle
|
||||
import sys
|
||||
|
||||
# Simplified Tensor class for DataLoader module
|
||||
# This avoids importing the full tensor_dev.py which executes all tests
|
||||
class Tensor:
|
||||
"""
|
||||
Simplified Tensor class for DataLoader module.
|
||||
Contains only the functionality needed for data loading.
|
||||
"""
|
||||
def __init__(self, data):
|
||||
self.data = np.array(data)
|
||||
self.shape = self.data.shape
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Tensor({self.data})"
|
||||
# Import real Tensor class from Module 01
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '01_tensor'))
|
||||
from tensor_dev import Tensor
|
||||
|
||||
# %% [markdown]
|
||||
"""
|
||||
@@ -221,7 +209,7 @@ def test_unit_dataset():
|
||||
|
||||
print("✅ Dataset interface works correctly!")
|
||||
|
||||
test_unit_dataset()
|
||||
# test_unit_dataset() # Moved to main block
|
||||
|
||||
|
||||
# %% [markdown]
|
||||
@@ -400,7 +388,7 @@ def test_unit_tensordataset():
|
||||
|
||||
print("✅ TensorDataset works correctly!")
|
||||
|
||||
test_unit_tensordataset()
|
||||
# test_unit_tensordataset() # Moved to main block
|
||||
|
||||
|
||||
# %% [markdown]
|
||||
@@ -627,7 +615,7 @@ def test_unit_dataloader():
|
||||
|
||||
print("✅ DataLoader works correctly!")
|
||||
|
||||
test_unit_dataloader()
|
||||
# test_unit_dataloader() # Moved to main block
|
||||
|
||||
|
||||
# %% [markdown]
|
||||
@@ -840,7 +828,7 @@ def test_unit_download_functions():
|
||||
|
||||
print("✅ Download functions work correctly!")
|
||||
|
||||
test_unit_download_functions()
|
||||
# test_unit_download_functions() # Moved to main block
|
||||
|
||||
|
||||
# %% [markdown]
|
||||
@@ -991,7 +979,7 @@ def analyze_dataloader_performance():
|
||||
print("• Memory usage scales linearly with batch size")
|
||||
print("🚀 Production tip: Balance batch size with GPU memory limits")
|
||||
|
||||
analyze_dataloader_performance()
|
||||
# analyze_dataloader_performance() # Moved to main block
|
||||
|
||||
|
||||
def analyze_memory_usage():
|
||||
@@ -1035,7 +1023,7 @@ def analyze_memory_usage():
|
||||
print(f" Large batch (512×784): {large_bytes / 1024:.1f} KB")
|
||||
print(f" Ratio: {large_bytes / small_bytes:.1f}×")
|
||||
|
||||
analyze_memory_usage()
|
||||
# analyze_memory_usage() # Moved to main block
|
||||
|
||||
|
||||
# %% [markdown]
|
||||
@@ -1116,7 +1104,7 @@ def test_training_integration():
|
||||
|
||||
print("✅ Training integration works correctly!")
|
||||
|
||||
test_training_integration()
|
||||
# test_training_integration() # Moved to main block
|
||||
|
||||
|
||||
# %% [markdown]
|
||||
@@ -1176,13 +1164,29 @@ def test_module():
|
||||
print("Run: tito module complete 08")
|
||||
|
||||
# Call before module summary
|
||||
test_module()
|
||||
# test_module() # Moved to main block
|
||||
|
||||
|
||||
# %%
|
||||
if __name__ == "__main__":
|
||||
print("🚀 Running DataLoader module...")
|
||||
|
||||
# Run all unit tests
|
||||
test_unit_dataset()
|
||||
test_unit_tensordataset()
|
||||
test_unit_dataloader()
|
||||
test_unit_download_functions()
|
||||
|
||||
# Run performance analysis
|
||||
analyze_dataloader_performance()
|
||||
analyze_memory_usage()
|
||||
|
||||
# Run integration test
|
||||
test_training_integration()
|
||||
|
||||
# Run final module test
|
||||
test_module()
|
||||
|
||||
print("✅ Module validation complete!")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user