mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-04-30 10:13:57 -05:00
Fix DataLoader integration tests to work before export
Added fallback import logic: - Try importing from tinytorch package first - Fall back to dev modules if not exported yet - Works both before and after 'tito export 08_dataloader' All 3 integration tests pass: ✅ Training workflow integration ✅ Shuffle consistency across epochs ✅ Memory efficiency verification
This commit is contained in:
@@ -12,8 +12,16 @@ import os
|
|||||||
# Add project root to path
|
# Add project root to path
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
|
||||||
|
|
||||||
from tinytorch import Tensor
|
# Try to import from package, fall back to dev module if not exported yet
|
||||||
from tinytorch.data.loader import Dataset, TensorDataset, DataLoader
|
try:
|
||||||
|
from tinytorch import Tensor
|
||||||
|
from tinytorch.data.loader import Dataset, TensorDataset, DataLoader
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
# Module not exported yet, use dev version
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'modules', 'source', '08_dataloader'))
|
||||||
|
from dataloader_dev import Dataset, TensorDataset, DataLoader
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'modules', 'source', '01_tensor'))
|
||||||
|
from tensor_dev import Tensor
|
||||||
|
|
||||||
|
|
||||||
def test_training_workflow_integration():
|
def test_training_workflow_integration():
|
||||||
|
|||||||
Reference in New Issue
Block a user