mirror of
https://github.com/harvard-edge/cs249r_book.git
synced 2026-05-05 00:58:56 -05:00
fix(tinytorch): add missing exports and fix benchmark tests
- Module 19: Add #| export to import block so dataclass is exported - Fix benchmark tests to use correct Benchmark API (requires models/datasets)
This commit is contained in:
@@ -134,11 +134,14 @@ class TestBenchmarkMetrics:
|
||||
|
||||
model = SimpleModel()
|
||||
x = Tensor(np.random.randn(10))
|
||||
datasets = [[(x, None)]]
|
||||
|
||||
bench = Benchmark()
|
||||
latency = bench.measure_latency(model, x)
|
||||
bench = Benchmark([model], datasets)
|
||||
results = bench.run_latency_benchmark(input_shape=(10,))
|
||||
|
||||
assert latency > 0, "Latency must be positive"
|
||||
assert len(results) > 0, "Should produce results"
|
||||
for name, result in results.items():
|
||||
assert result.mean > 0, "Latency must be positive"
|
||||
|
||||
def test_multiple_runs_are_consistent(self):
|
||||
"""
|
||||
@@ -161,24 +164,21 @@ class TestBenchmarkMetrics:
|
||||
|
||||
model = SimpleModel()
|
||||
x = Tensor(np.random.randn(1, 10))
|
||||
datasets = [[(x, None)]]
|
||||
|
||||
bench = Benchmark()
|
||||
bench = Benchmark([model], datasets, measurement_runs=10)
|
||||
results = bench.run_latency_benchmark(input_shape=(1, 10))
|
||||
|
||||
# Run 3 times
|
||||
latencies = [
|
||||
bench.measure_latency(model, x, iterations=10)
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
# Check variance is reasonable (within 2x of each other)
|
||||
max_latency = max(latencies)
|
||||
min_latency = min(latencies)
|
||||
|
||||
assert max_latency < min_latency * 3, (
|
||||
f"Benchmark results too variable!\n"
|
||||
f" Latencies: {latencies}\n"
|
||||
"Results should be within 3x of each other."
|
||||
)
|
||||
# Check that we get results with reasonable variance
|
||||
for name, result in results.items():
|
||||
# Coefficient of variation should be reasonable (std/mean < 100%)
|
||||
if result.mean > 0:
|
||||
cv = result.std / result.mean
|
||||
assert cv < 1.0, (
|
||||
f"Benchmark results too variable!\n"
|
||||
f" Mean: {result.mean}, Std: {result.std}, CV: {cv}\n"
|
||||
"Coefficient of variation should be < 100%."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user