diff --git a/tito/commands/info.py b/tito/commands/info.py index 2ee4279a..d459cf9a 100644 --- a/tito/commands/info.py +++ b/tito/commands/info.py @@ -203,8 +203,9 @@ class InfoCommand(BaseCommand): def check_data_status(self): try: from tinytorch.core.dataloader import DataLoader + from tinytorch.core.tensor import Tensor import numpy as np - data = [(np.random.randn(3,32,32), 0) for _ in range(10)] + data = [(Tensor(np.random.randn(3,32,32)), Tensor(np.array(i % 10))) for i in range(10)] loader = DataLoader(data, batch_size=2, shuffle=True) _ = next(iter(loader)) return "✅ Implemented"