mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-04-28 07:17:33 -05:00
Add comprehensive demo testing and validation scripts
- Created test_all_demos.py for quick demo execution testing - Added validate_demos.py for detailed output validation - Both scripts use Rich CLI for clear test reporting - All 8 demos passing with 100% success rate - 48 detailed validation checks all passing - Scripts check for: - Demo execution without errors - Expected outputs and patterns - Educational content presence - Proper completion messages - Specific functionality for each demo This ensures demo reliability for students and makes it easy to catch regressions when updating the codebase.
This commit is contained in:
274
test_all_demos.py
Normal file
274
test_all_demos.py
Normal file
@@ -0,0 +1,274 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive test script for all TinyTorch demos.
|
||||
Validates that each demo runs successfully and produces expected outputs.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from rich.panel import Panel
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeElapsedColumn
|
||||
|
||||
# List of all demos to test
|
||||
DEMOS = [
|
||||
("demo_tensor_math.py", "02_tensor", "Tensor Math & Linear Algebra"),
|
||||
("demo_activations.py", "03_activations", "Activation Functions"),
|
||||
("demo_single_neuron.py", "04_layers", "Single Neuron Learning"),
|
||||
("demo_xor_network.py", "05_dense", "XOR Multi-Layer Network"),
|
||||
("demo_vision.py", "06_spatial", "Computer Vision & CNNs"),
|
||||
("demo_attention.py", "07_attention", "Attention Mechanisms"),
|
||||
("demo_training.py", "11_training", "End-to-End Training"),
|
||||
("demo_language.py", "16_tinygpt", "Language Generation"),
|
||||
]
|
||||
|
||||
def run_demo(demo_file: str, timeout: int = 30) -> tuple[bool, str, float]:
|
||||
"""
|
||||
Run a single demo and return success status, output, and execution time.
|
||||
|
||||
Args:
|
||||
demo_file: Name of the demo file to run
|
||||
timeout: Maximum time to wait for demo completion
|
||||
|
||||
Returns:
|
||||
Tuple of (success, output/error, execution_time)
|
||||
"""
|
||||
demo_path = Path("demos") / demo_file
|
||||
|
||||
if not demo_path.exists():
|
||||
return False, f"Demo file not found: {demo_path}", 0.0
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Run the demo with a timeout
|
||||
result = subprocess.run(
|
||||
[sys.executable, str(demo_path)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
if result.returncode == 0:
|
||||
# Check for key success indicators in output
|
||||
output = result.stdout
|
||||
|
||||
# Look for success markers
|
||||
if "Demo Complete" in output or "Achievements:" in output or "✅" in output:
|
||||
return True, "Demo ran successfully", execution_time
|
||||
else:
|
||||
# Demo ran but might not have completed properly
|
||||
return True, "Demo executed (check output manually)", execution_time
|
||||
else:
|
||||
# Demo failed with error
|
||||
error_msg = result.stderr if result.stderr else result.stdout
|
||||
# Extract the key error message
|
||||
if "Could not import TinyTorch modules" in error_msg:
|
||||
error_msg = "Missing module exports (needs tito export)"
|
||||
elif "ImportError" in error_msg:
|
||||
error_msg = "Import error - check dependencies"
|
||||
else:
|
||||
# Get last line of error for concise message
|
||||
lines = error_msg.strip().split('\n')
|
||||
error_msg = lines[-1] if lines else "Unknown error"
|
||||
|
||||
return False, error_msg[:100], execution_time
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return False, f"Demo timed out after {timeout} seconds", timeout
|
||||
except Exception as e:
|
||||
return False, f"Error running demo: {str(e)}", time.time() - start_time
|
||||
|
||||
def test_all_demos():
|
||||
"""Test all demos and display results."""
|
||||
|
||||
console = Console()
|
||||
|
||||
# Header
|
||||
console.print(Panel.fit(
|
||||
"🧪 TinyTorch Demo Test Suite\nValidating all demos work correctly",
|
||||
style="bold cyan",
|
||||
border_style="bright_blue"
|
||||
))
|
||||
console.print()
|
||||
|
||||
# Check virtual environment
|
||||
console.print("🔍 Checking environment...")
|
||||
|
||||
in_venv = hasattr(sys, 'real_prefix') or (hasattr(sys, 'base_prefix') and sys.base_prefix != sys.prefix)
|
||||
|
||||
if not in_venv:
|
||||
console.print("[yellow]⚠️ Warning: Not running in virtual environment[/yellow]")
|
||||
console.print("[yellow] Demos may fail due to missing dependencies[/yellow]")
|
||||
console.print("[yellow] Run: source .venv/bin/activate[/yellow]")
|
||||
console.print()
|
||||
else:
|
||||
console.print("[green]✅ Virtual environment active[/green]")
|
||||
console.print()
|
||||
|
||||
# Test each demo with progress bar
|
||||
results = []
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||
TimeElapsedColumn(),
|
||||
console=console
|
||||
) as progress:
|
||||
|
||||
task = progress.add_task("Testing demos...", total=len(DEMOS))
|
||||
|
||||
for demo_file, module_needed, description in DEMOS:
|
||||
progress.update(task, description=f"Testing {demo_file}...")
|
||||
|
||||
success, message, exec_time = run_demo(demo_file)
|
||||
results.append({
|
||||
"file": demo_file,
|
||||
"module": module_needed,
|
||||
"description": description,
|
||||
"success": success,
|
||||
"message": message,
|
||||
"time": exec_time
|
||||
})
|
||||
|
||||
progress.advance(task)
|
||||
|
||||
console.print()
|
||||
|
||||
# Display results table
|
||||
console.print("📊 Test Results:")
|
||||
console.print()
|
||||
|
||||
results_table = Table(show_header=True, header_style="bold magenta")
|
||||
results_table.add_column("Demo", style="cyan", width=25)
|
||||
results_table.add_column("Description", style="white", width=30)
|
||||
results_table.add_column("Status", style="green", width=10)
|
||||
results_table.add_column("Time", style="yellow", width=8)
|
||||
results_table.add_column("Notes", style="blue", width=40)
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
for result in results:
|
||||
status = "✅ PASS" if result["success"] else "❌ FAIL"
|
||||
status_style = "green" if result["success"] else "red"
|
||||
|
||||
if result["success"]:
|
||||
passed += 1
|
||||
else:
|
||||
failed += 1
|
||||
|
||||
# Add row with appropriate styling
|
||||
results_table.add_row(
|
||||
result["file"],
|
||||
result["description"],
|
||||
f"[{status_style}]{status}[/{status_style}]",
|
||||
f"{result['time']:.2f}s",
|
||||
result["message"]
|
||||
)
|
||||
|
||||
console.print(results_table)
|
||||
console.print()
|
||||
|
||||
# Summary statistics
|
||||
total = len(results)
|
||||
success_rate = (passed / total * 100) if total > 0 else 0
|
||||
|
||||
summary_table = Table(show_header=False)
|
||||
summary_table.add_column("Metric", style="cyan")
|
||||
summary_table.add_column("Value", style="yellow")
|
||||
|
||||
summary_table.add_row("Total Demos", str(total))
|
||||
summary_table.add_row("Passed", f"[green]{passed}[/green]")
|
||||
summary_table.add_row("Failed", f"[red]{failed}[/red]")
|
||||
summary_table.add_row("Success Rate", f"{success_rate:.1f}%")
|
||||
summary_table.add_row("Total Time", f"{sum(r['time'] for r in results):.2f}s")
|
||||
|
||||
console.print(Panel(summary_table, title="📈 Summary", style="blue"))
|
||||
console.print()
|
||||
|
||||
# Recommendations for failures
|
||||
if failed > 0:
|
||||
console.print("🔧 [bold yellow]Fixing Failed Demos:[/bold yellow]")
|
||||
console.print()
|
||||
|
||||
for result in results:
|
||||
if not result["success"]:
|
||||
console.print(f" [red]•[/red] {result['file']}:")
|
||||
|
||||
if "Missing module exports" in result["message"]:
|
||||
console.print(f" → Run: [cyan]tito export {result['module']}[/cyan]")
|
||||
elif "Import error" in result["message"]:
|
||||
console.print(f" → Check dependencies and virtual environment")
|
||||
else:
|
||||
console.print(f" → Debug: [cyan]python demos/{result['file']}[/cyan]")
|
||||
console.print(f" → Error: {result['message']}")
|
||||
console.print()
|
||||
|
||||
# Final status
|
||||
if passed == total:
|
||||
console.print(Panel.fit(
|
||||
f"🎉 All {total} demos passed successfully!",
|
||||
style="bold green",
|
||||
border_style="bright_green"
|
||||
))
|
||||
return 0
|
||||
else:
|
||||
console.print(Panel.fit(
|
||||
f"⚠️ {failed} of {total} demos failed. See recommendations above.",
|
||||
style="bold yellow",
|
||||
border_style="yellow"
|
||||
))
|
||||
return 1
|
||||
|
||||
def quick_test():
|
||||
"""Quick test mode - just check if demos can be imported."""
|
||||
|
||||
console = Console()
|
||||
|
||||
console.print("⚡ Quick Import Test")
|
||||
console.print()
|
||||
|
||||
results = []
|
||||
|
||||
for demo_file, _, description in DEMOS:
|
||||
demo_path = Path("demos") / demo_file.replace('.py', '')
|
||||
|
||||
try:
|
||||
# Try to import the demo module
|
||||
exec(f"import demos.{demo_file.replace('.py', '')}")
|
||||
results.append((demo_file, True, "Import successful"))
|
||||
except Exception as e:
|
||||
error_msg = str(e).split('\n')[0][:50]
|
||||
results.append((demo_file, False, error_msg))
|
||||
|
||||
# Display quick results
|
||||
for demo, success, msg in results:
|
||||
status = "✅" if success else "❌"
|
||||
console.print(f"{status} {demo}: {msg}")
|
||||
|
||||
console.print()
|
||||
passed = sum(1 for _, s, _ in results if s)
|
||||
console.print(f"Quick test: {passed}/{len(results)} demos importable")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Check for command line arguments
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Test all TinyTorch demos")
|
||||
parser.add_argument("--quick", action="store_true", help="Quick import test only")
|
||||
parser.add_argument("--timeout", type=int, default=30, help="Timeout per demo in seconds")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.quick:
|
||||
quick_test()
|
||||
else:
|
||||
sys.exit(test_all_demos())
|
||||
298
validate_demos.py
Normal file
298
validate_demos.py
Normal file
@@ -0,0 +1,298 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Detailed validation script for TinyTorch demos.
|
||||
Checks for specific expected outputs and functionality.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import subprocess
|
||||
import re
|
||||
from pathlib import Path
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from rich.panel import Panel
|
||||
|
||||
class DemoValidator:
|
||||
"""Validates demos against expected outputs and patterns."""
|
||||
|
||||
def __init__(self):
|
||||
self.console = Console()
|
||||
self.validations = []
|
||||
|
||||
def run_demo(self, demo_file: str) -> str:
|
||||
"""Run a demo and return its output."""
|
||||
demo_path = Path("demos") / demo_file
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[sys.executable, str(demo_path)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30
|
||||
)
|
||||
return result.stdout + result.stderr
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
def check_pattern(self, output: str, pattern: str, description: str) -> bool:
|
||||
"""Check if a pattern exists in the output."""
|
||||
found = re.search(pattern, output, re.MULTILINE | re.DOTALL)
|
||||
return found is not None
|
||||
|
||||
def validate_demo_tensor_math(self):
|
||||
"""Validate tensor math demo."""
|
||||
output = self.run_demo("demo_tensor_math.py")
|
||||
|
||||
checks = [
|
||||
("Linear system solution", r"x = 2\.0, y = 3\.0", "Correct solution to linear system"),
|
||||
("Matrix rotation", r"0\.707.*0\.707", "Rotation matrix applied correctly"),
|
||||
("Batch processing", r"Batch Processing", "Batch operations demonstrated"),
|
||||
("Neural network preview", r"Neural Network Preview", "NN preview shown"),
|
||||
("Success completion", r"Demo Complete", "Demo completed successfully"),
|
||||
("Understanding panel", r"Understanding This Demo", "Educational content present"),
|
||||
]
|
||||
|
||||
results = []
|
||||
for name, pattern, desc in checks:
|
||||
passed = self.check_pattern(output, pattern, desc)
|
||||
results.append((name, passed, desc))
|
||||
|
||||
return "demo_tensor_math.py", results
|
||||
|
||||
def validate_demo_activations(self):
|
||||
"""Validate activations demo."""
|
||||
output = self.run_demo("demo_activations.py")
|
||||
|
||||
checks = [
|
||||
("ReLU function", r"ReLU\(x\)", "ReLU activation demonstrated"),
|
||||
("Sigmoid function", r"Sigmoid\(x\)", "Sigmoid activation demonstrated"),
|
||||
("XOR problem", r"XOR", "XOR problem explained"),
|
||||
("Softmax", r"Softmax", "Softmax for classification shown"),
|
||||
("Success completion", r"Demo Complete", "Demo completed successfully"),
|
||||
("Interpretation guides", r"💡.*How to Interpret", "Interpretation guides present"),
|
||||
]
|
||||
|
||||
results = []
|
||||
for name, pattern, desc in checks:
|
||||
passed = self.check_pattern(output, pattern, desc)
|
||||
results.append((name, passed, desc))
|
||||
|
||||
return "demo_activations.py", results
|
||||
|
||||
def validate_demo_single_neuron(self):
|
||||
"""Validate single neuron demo."""
|
||||
output = self.run_demo("demo_single_neuron.py")
|
||||
|
||||
checks = [
|
||||
("AND gate table", r"AND Output", "AND gate truth table shown"),
|
||||
("Weight updates", r"Weight 1.*Weight 2.*Bias", "Weight updates displayed"),
|
||||
("Training progress", r"Training.*Neuron", "Training process shown"),
|
||||
("Decision boundary", r"Decision.*boundary", "Decision boundary explained"),
|
||||
("Dense layer", r"Dense", "TinyTorch Dense layer used"),
|
||||
("Learning insights", r"💡.*What's Happening", "Learning process explained"),
|
||||
]
|
||||
|
||||
results = []
|
||||
for name, pattern, desc in checks:
|
||||
passed = self.check_pattern(output, pattern, desc)
|
||||
results.append((name, passed, desc))
|
||||
|
||||
return "demo_single_neuron.py", results
|
||||
|
||||
def validate_demo_xor_network(self):
|
||||
"""Validate XOR network demo."""
|
||||
output = self.run_demo("demo_xor_network.py")
|
||||
|
||||
checks = [
|
||||
("XOR truth table", r"XOR Output", "XOR truth table displayed"),
|
||||
("Hidden layer", r"Hidden.*layer", "Hidden layer explanation"),
|
||||
("Multi-layer solution", r"Multi-[Ll]ayer", "Multi-layer network shown"),
|
||||
("Sequential model", r"Sequential", "Sequential model demonstrated"),
|
||||
("Success completion", r"Demo Complete", "Demo completed successfully"),
|
||||
("Key insights", r"Key Insight", "Educational insights provided"),
|
||||
]
|
||||
|
||||
results = []
|
||||
for name, pattern, desc in checks:
|
||||
passed = self.check_pattern(output, pattern, desc)
|
||||
results.append((name, passed, desc))
|
||||
|
||||
return "demo_xor_network.py", results
|
||||
|
||||
def validate_demo_vision(self):
|
||||
"""Validate computer vision demo."""
|
||||
output = self.run_demo("demo_vision.py")
|
||||
|
||||
checks = [
|
||||
("Image as tensor", r"5×5.*diamond", "Image representation shown"),
|
||||
("Edge detection", r"[Ee]dge [Dd]etection", "Edge detection demonstrated"),
|
||||
("Convolution", r"Conv", "Convolution operations shown"),
|
||||
("CNN architecture", r"CNN", "CNN architecture explained"),
|
||||
("Feature maps", r"[Ff]eature", "Feature extraction discussed"),
|
||||
("Scaling insights", r"💡.*Scaling", "Scaling analysis provided"),
|
||||
]
|
||||
|
||||
results = []
|
||||
for name, pattern, desc in checks:
|
||||
passed = self.check_pattern(output, pattern, desc)
|
||||
results.append((name, passed, desc))
|
||||
|
||||
return "demo_vision.py", results
|
||||
|
||||
def validate_demo_attention(self):
|
||||
"""Validate attention mechanisms demo."""
|
||||
output = self.run_demo("demo_attention.py")
|
||||
|
||||
checks = [
|
||||
("Attention scores", r"[Aa]ttention.*scores", "Attention scores computed"),
|
||||
("Multi-head", r"Multi-[Hh]ead", "Multi-head attention shown"),
|
||||
("Self-attention", r"Self-[Aa]ttention", "Self-attention explained"),
|
||||
("Transformer", r"Transformer", "Transformer architecture shown"),
|
||||
("Q, K, V", r"Q.*K.*V", "Query, Key, Value explained"),
|
||||
("Scaling analysis", r"O\(n²\)", "Computational complexity discussed"),
|
||||
]
|
||||
|
||||
results = []
|
||||
for name, pattern, desc in checks:
|
||||
passed = self.check_pattern(output, pattern, desc)
|
||||
results.append((name, passed, desc))
|
||||
|
||||
return "demo_attention.py", results
|
||||
|
||||
def validate_demo_training(self):
|
||||
"""Validate training demo."""
|
||||
output = self.run_demo("demo_training.py")
|
||||
|
||||
checks = [
|
||||
("Dataset creation", r"Dataset.*samples", "Dataset created"),
|
||||
("Model architecture", r"Model architecture", "Architecture described"),
|
||||
("Training loop", r"Training", "Training loop demonstrated"),
|
||||
("Loss tracking", r"Loss", "Loss values shown"),
|
||||
("Accuracy metrics", r"[Aa]ccuracy", "Accuracy tracked"),
|
||||
("Production context", r"[Pp]roduction", "Production considerations discussed"),
|
||||
]
|
||||
|
||||
results = []
|
||||
for name, pattern, desc in checks:
|
||||
passed = self.check_pattern(output, pattern, desc)
|
||||
results.append((name, passed, desc))
|
||||
|
||||
return "demo_training.py", results
|
||||
|
||||
def validate_demo_language(self):
|
||||
"""Validate language generation demo."""
|
||||
output = self.run_demo("demo_language.py")
|
||||
|
||||
checks = [
|
||||
("Tokenization", r"Token", "Tokenization explained"),
|
||||
("Embeddings", r"Embedding", "Word embeddings shown"),
|
||||
("Autoregressive", r"[Aa]utoregressive", "Autoregressive generation explained"),
|
||||
("TinyGPT", r"TinyGPT", "TinyGPT architecture discussed"),
|
||||
("Scaling laws", r"GPT-[1234]", "Model scaling shown"),
|
||||
("Journey complete", r"Journey|journey", "Learning journey summarized"),
|
||||
]
|
||||
|
||||
results = []
|
||||
for name, pattern, desc in checks:
|
||||
passed = self.check_pattern(output, pattern, desc)
|
||||
results.append((name, passed, desc))
|
||||
|
||||
return "demo_language.py", results
|
||||
|
||||
def validate_all(self):
|
||||
"""Run all validations."""
|
||||
|
||||
self.console.print(Panel.fit(
|
||||
"🔬 TinyTorch Demo Deep Validation\nChecking specific outputs and functionality",
|
||||
style="bold cyan",
|
||||
border_style="bright_blue"
|
||||
))
|
||||
self.console.print()
|
||||
|
||||
# Run each validation
|
||||
validators = [
|
||||
self.validate_demo_tensor_math,
|
||||
self.validate_demo_activations,
|
||||
self.validate_demo_single_neuron,
|
||||
self.validate_demo_xor_network,
|
||||
self.validate_demo_vision,
|
||||
self.validate_demo_attention,
|
||||
self.validate_demo_training,
|
||||
self.validate_demo_language,
|
||||
]
|
||||
|
||||
all_results = []
|
||||
|
||||
self.console.print("🧪 Running detailed validations...")
|
||||
self.console.print()
|
||||
|
||||
for validator in validators:
|
||||
demo_name, results = validator()
|
||||
all_results.append((demo_name, results))
|
||||
|
||||
# Show progress
|
||||
passed = sum(1 for _, p, _ in results if p)
|
||||
total = len(results)
|
||||
status = "✅" if passed == total else "⚠️"
|
||||
self.console.print(f"{status} {demo_name}: {passed}/{total} checks passed")
|
||||
|
||||
self.console.print()
|
||||
|
||||
# Detailed results table
|
||||
self.console.print("📋 Detailed Validation Results:")
|
||||
self.console.print()
|
||||
|
||||
for demo_name, results in all_results:
|
||||
table = Table(show_header=True, header_style="bold magenta", title=demo_name)
|
||||
table.add_column("Check", style="cyan", width=25)
|
||||
table.add_column("Status", style="green", width=8)
|
||||
table.add_column("Description", style="yellow", width=45)
|
||||
|
||||
for check_name, passed, description in results:
|
||||
status = "✅ PASS" if passed else "❌ FAIL"
|
||||
status_style = "green" if passed else "red"
|
||||
table.add_row(
|
||||
check_name,
|
||||
f"[{status_style}]{status}[/{status_style}]",
|
||||
description
|
||||
)
|
||||
|
||||
self.console.print(table)
|
||||
self.console.print()
|
||||
|
||||
# Summary
|
||||
total_checks = sum(len(results) for _, results in all_results)
|
||||
passed_checks = sum(sum(1 for _, p, _ in results if p) for _, results in all_results)
|
||||
|
||||
success_rate = (passed_checks / total_checks * 100) if total_checks > 0 else 0
|
||||
|
||||
if success_rate == 100:
|
||||
self.console.print(Panel.fit(
|
||||
f"🎉 Perfect! All {total_checks} validation checks passed!",
|
||||
style="bold green",
|
||||
border_style="bright_green"
|
||||
))
|
||||
elif success_rate >= 90:
|
||||
self.console.print(Panel.fit(
|
||||
f"✅ Excellent! {passed_checks}/{total_checks} checks passed ({success_rate:.1f}%)",
|
||||
style="bold green",
|
||||
border_style="green"
|
||||
))
|
||||
elif success_rate >= 70:
|
||||
self.console.print(Panel.fit(
|
||||
f"⚠️ Good but needs work: {passed_checks}/{total_checks} checks passed ({success_rate:.1f}%)",
|
||||
style="bold yellow",
|
||||
border_style="yellow"
|
||||
))
|
||||
else:
|
||||
self.console.print(Panel.fit(
|
||||
f"❌ Needs attention: {passed_checks}/{total_checks} checks passed ({success_rate:.1f}%)",
|
||||
style="bold red",
|
||||
border_style="red"
|
||||
))
|
||||
|
||||
return 0 if success_rate == 100 else 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
validator = DemoValidator()
|
||||
sys.exit(validator.validate_all())
|
||||
Reference in New Issue
Block a user