mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2025-12-05 19:17:52 -06:00
Clean up TITO CLI: remove dead commands and consolidate duplicates
Removed 14 dead/unused command files that were not registered: - book.py, check.py, checkpoint.py, clean_workspace.py - demo.py, help.py, leaderboard.py, milestones.py (duplicate) - module_reset.py, module_workflow.py (duplicates) - protect.py, report.py, version.py, view.py Simplified olympics.py to "Coming Soon" feature with ASCII branding: - Reduced from 885 lines to 107 lines - Added inspiring Olympics logo and messaging for future competitions - Registered in main.py as student-facing command The module/ package directory structure is the source of truth: - module/workflow.py (active, has auth/submission handling) - module/reset.py (active) - module/test.py (active) All deleted commands either: 1. Had functionality superseded by other commands 2. Were duplicate implementations 3. Were never registered in main.py 4. Were incomplete/abandoned features 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -1,396 +0,0 @@
|
||||
"""
|
||||
Book command for TinyTorch CLI: builds and manages the Jupyter Book.
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from pathlib import Path
|
||||
from rich.panel import Panel
|
||||
|
||||
from .base import BaseCommand
|
||||
|
||||
NOTEBOOKS_DIR = "modules"
|
||||
|
||||
class BookCommand(BaseCommand):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "book"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Build and manage the TinyTorch Jupyter Book"
|
||||
|
||||
def add_arguments(self, parser: ArgumentParser) -> None:
|
||||
subparsers = parser.add_subparsers(
|
||||
dest='book_command',
|
||||
help='Book management commands',
|
||||
metavar='COMMAND'
|
||||
)
|
||||
|
||||
# Build command
|
||||
build_parser = subparsers.add_parser(
|
||||
'build',
|
||||
help='Build the Jupyter Book locally'
|
||||
)
|
||||
|
||||
# Publish command
|
||||
publish_parser = subparsers.add_parser(
|
||||
'publish',
|
||||
help='Generate content, commit, and publish to GitHub'
|
||||
)
|
||||
publish_parser.add_argument(
|
||||
'--message',
|
||||
type=str,
|
||||
default='📚 Update book content',
|
||||
help='Commit message (default: "📚 Update book content")'
|
||||
)
|
||||
publish_parser.add_argument(
|
||||
'--branch',
|
||||
type=str,
|
||||
default='main',
|
||||
help='Branch to push to (default: main)'
|
||||
)
|
||||
|
||||
# Clean command
|
||||
clean_parser = subparsers.add_parser(
|
||||
'clean',
|
||||
help='Clean built book files'
|
||||
)
|
||||
|
||||
# Serve command
|
||||
serve_parser = subparsers.add_parser(
|
||||
'serve',
|
||||
help='Build and serve the Jupyter Book locally'
|
||||
)
|
||||
serve_parser.add_argument(
|
||||
'--port',
|
||||
type=int,
|
||||
default=8001,
|
||||
help='Port to serve on (default: 8001)'
|
||||
)
|
||||
serve_parser.add_argument(
|
||||
'--no-build',
|
||||
action='store_true',
|
||||
help='Skip building and serve existing files'
|
||||
)
|
||||
|
||||
def run(self, args: Namespace) -> int:
|
||||
console = self.console
|
||||
|
||||
# Check if we're in the right directory
|
||||
if not Path("site").exists():
|
||||
console.print(Panel(
|
||||
"[red]❌ site/ directory not found. Run this command from the TinyTorch root directory.[/red]",
|
||||
title="Error",
|
||||
border_style="red"
|
||||
))
|
||||
return 1
|
||||
|
||||
# Handle subcommands
|
||||
if not hasattr(args, 'book_command') or not args.book_command:
|
||||
console.print(Panel(
|
||||
"[bold cyan]📚 TinyTorch Book Management[/bold cyan]\n\n"
|
||||
"[bold]Available Commands:[/bold]\n"
|
||||
" [bold green]build[/bold green] - Build the complete Jupyter Book\n"
|
||||
" [bold green]serve[/bold green] - Build and serve the Jupyter Book locally\n"
|
||||
" [bold green]publish[/bold green] - Generate content, commit, and publish to GitHub\n"
|
||||
" [bold green]clean[/bold green] - Clean built book files\n\n"
|
||||
"[bold]Quick Start:[/bold]\n"
|
||||
" [dim]tito book publish[/dim] - Generate, commit, and publish to GitHub\n"
|
||||
" [dim]tito book clean[/dim] - Clean built book files",
|
||||
title="Book Commands",
|
||||
border_style="bright_blue"
|
||||
))
|
||||
return 0
|
||||
|
||||
if args.book_command == 'build':
|
||||
return self._build_book(args)
|
||||
elif args.book_command == 'serve':
|
||||
return self._serve_book(args)
|
||||
elif args.book_command == 'publish':
|
||||
return self._publish_book(args)
|
||||
elif args.book_command == 'clean':
|
||||
return self._clean_book()
|
||||
else:
|
||||
console.print(f"[red]Unknown book command: {args.book_command}[/red]")
|
||||
return 1
|
||||
|
||||
def _generate_overview(self) -> int:
|
||||
"""Generate overview pages from modules."""
|
||||
console = self.console
|
||||
console.print("🔄 Generating overview pages from modules...")
|
||||
|
||||
try:
|
||||
os.chdir("site")
|
||||
result = subprocess.run(
|
||||
["python3", "convert_readmes.py"],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
console.print("✅ Overview pages generated successfully")
|
||||
# Show summary from the output
|
||||
for line in result.stdout.split('\n'):
|
||||
if "✅ Created" in line or "🎉 Converted" in line:
|
||||
console.print(f" {line.strip()}")
|
||||
return 0
|
||||
else:
|
||||
console.print(f"[red]❌ Failed to generate overview pages: {result.stderr}[/red]")
|
||||
return 1
|
||||
|
||||
except FileNotFoundError:
|
||||
console.print("[red]❌ Python3 not found or convert_readmes.py missing[/red]")
|
||||
return 1
|
||||
except Exception as e:
|
||||
console.print(f"[red]❌ Error generating overview pages: {e}[/red]")
|
||||
return 1
|
||||
finally:
|
||||
os.chdir("..")
|
||||
|
||||
def _generate_all(self) -> int:
|
||||
"""Verify that all book chapters exist."""
|
||||
console = self.console
|
||||
console.print("📝 Verifying book chapters...")
|
||||
|
||||
# Check that the chapters directory exists
|
||||
chapters_dir = Path("site/chapters")
|
||||
if not chapters_dir.exists():
|
||||
console.print("[red]❌ site/chapters directory not found[/red]")
|
||||
return 1
|
||||
|
||||
# Count markdown files in chapters directory
|
||||
chapter_files = list(chapters_dir.glob("*.md"))
|
||||
if chapter_files:
|
||||
console.print(f"✅ Found {len(chapter_files)} chapter files")
|
||||
else:
|
||||
console.print("[yellow]⚠️ No chapter files found in site/chapters/[/yellow]")
|
||||
|
||||
return 0
|
||||
|
||||
def _build_book(self, args: Namespace) -> int:
|
||||
"""Build the Jupyter Book locally."""
|
||||
console = self.console
|
||||
|
||||
# First generate all content (notebooks + overview pages)
|
||||
console.print("📄 Step 1: Generating all content...")
|
||||
if self._generate_all() != 0:
|
||||
return 1
|
||||
|
||||
# Then build the book
|
||||
console.print("📚 Step 2: Building Jupyter Book...")
|
||||
|
||||
try:
|
||||
os.chdir("site")
|
||||
result = subprocess.run(
|
||||
["jupyter-book", "build", "."],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
console.print("✅ Book built successfully!")
|
||||
|
||||
# Extract and show the file path
|
||||
if "file://" in result.stdout:
|
||||
for line in result.stdout.split('\n'):
|
||||
if "file://" in line:
|
||||
console.print(f"🌐 View at: {line.strip()}")
|
||||
break
|
||||
|
||||
console.print("📁 HTML files available in: site/_build/html/")
|
||||
return 0
|
||||
else:
|
||||
console.print(f"[red]❌ Failed to build book[/red]")
|
||||
if result.stderr:
|
||||
console.print(f"Error details: {result.stderr}")
|
||||
return 1
|
||||
|
||||
except FileNotFoundError:
|
||||
console.print("[red]❌ jupyter-book not found. Install with: pip install jupyter-book[/red]")
|
||||
return 1
|
||||
except Exception as e:
|
||||
console.print(f"[red]❌ Error building book: {e}[/red]")
|
||||
return 1
|
||||
finally:
|
||||
os.chdir("..")
|
||||
|
||||
def _serve_book(self, args: Namespace) -> int:
|
||||
"""Build and serve the Jupyter Book locally."""
|
||||
console = self.console
|
||||
|
||||
# Build the book first unless --no-build is specified
|
||||
if not args.no_build:
|
||||
console.print("📚 Step 1: Building the book...")
|
||||
if self._build_book(args) != 0:
|
||||
return 1
|
||||
console.print()
|
||||
|
||||
# Start the HTTP server
|
||||
console.print("🌐 Step 2: Starting development server...")
|
||||
console.print(f"📖 Open your browser to: [bold blue]http://localhost:{args.port}[/bold blue]")
|
||||
console.print("🛑 Press [bold]Ctrl+C[/bold] to stop the server")
|
||||
console.print()
|
||||
|
||||
book_dir = Path("site/_build/html")
|
||||
if not book_dir.exists():
|
||||
console.print("[red]❌ Built book not found. Run with --no-build=False to build first.[/red]")
|
||||
return 1
|
||||
|
||||
try:
|
||||
# Use Python's built-in HTTP server
|
||||
subprocess.run([
|
||||
"python3", "-m", "http.server", str(args.port),
|
||||
"--directory", str(book_dir)
|
||||
])
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n🛑 Development server stopped")
|
||||
except FileNotFoundError:
|
||||
console.print("[red]❌ Python3 not found in PATH[/red]")
|
||||
return 1
|
||||
except Exception as e:
|
||||
console.print(f"[red]❌ Error starting server: {e}[/red]")
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
def _clean_book(self) -> int:
|
||||
"""Clean built book files."""
|
||||
console = self.console
|
||||
console.print("🧹 Cleaning book build files...")
|
||||
|
||||
try:
|
||||
os.chdir("site")
|
||||
result = subprocess.run(
|
||||
["jupyter-book", "clean", "."],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
console.print("✅ Book files cleaned successfully")
|
||||
return 0
|
||||
else:
|
||||
console.print(f"[red]❌ Failed to clean book files: {result.stderr}[/red]")
|
||||
return 1
|
||||
|
||||
except FileNotFoundError:
|
||||
console.print("[red]❌ jupyter-book not found[/red]")
|
||||
return 1
|
||||
except Exception as e:
|
||||
console.print(f"[red]❌ Error cleaning book: {e}[/red]")
|
||||
return 1
|
||||
finally:
|
||||
os.chdir("..")
|
||||
|
||||
def _publish_book(self, args: Namespace) -> int:
|
||||
"""Generate content, commit, and publish to GitHub."""
|
||||
console = self.console
|
||||
|
||||
console.print("🚀 Starting book publishing workflow...")
|
||||
|
||||
# Step 1: Generate all content
|
||||
console.print("📝 Step 1: Generating all content...")
|
||||
if self._generate_all() != 0:
|
||||
console.print("[red]❌ Failed to generate content. Aborting publish.[/red]")
|
||||
return 1
|
||||
|
||||
# Step 2: Check git status
|
||||
console.print("🔍 Step 2: Checking git status...")
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "status", "--porcelain"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd="."
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
console.print("[red]❌ Git not available or not a git repository[/red]")
|
||||
return 1
|
||||
|
||||
changes = result.stdout.strip()
|
||||
if not changes:
|
||||
console.print("✅ No changes to publish")
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]❌ Error checking git status: {e}[/red]")
|
||||
return 1
|
||||
|
||||
# Step 3: Add and commit changes
|
||||
console.print("📦 Step 3: Committing changes...")
|
||||
try:
|
||||
# Add all changes
|
||||
subprocess.run(["git", "add", "."], check=True, cwd=".")
|
||||
|
||||
# Commit with message
|
||||
subprocess.run([
|
||||
"git", "commit", "-m", args.message
|
||||
], check=True, cwd=".")
|
||||
|
||||
console.print(f"✅ Committed with message: {args.message}")
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
console.print(f"[red]❌ Failed to commit changes: {e}[/red]")
|
||||
return 1
|
||||
except Exception as e:
|
||||
console.print(f"[red]❌ Error during commit: {e}[/red]")
|
||||
return 1
|
||||
|
||||
# Step 4: Push to GitHub
|
||||
console.print(f"⬆️ Step 4: Pushing to {args.branch} branch...")
|
||||
try:
|
||||
result = subprocess.run([
|
||||
"git", "push", "origin", args.branch
|
||||
], capture_output=True, text=True, cwd=".")
|
||||
|
||||
if result.returncode == 0:
|
||||
console.print(f"✅ Successfully pushed to {args.branch}")
|
||||
else:
|
||||
console.print(f"[red]❌ Failed to push: {result.stderr}[/red]")
|
||||
return 1
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"[red]❌ Error during push: {e}[/red]")
|
||||
return 1
|
||||
|
||||
# Step 5: Show deployment info
|
||||
console.print("🌐 Step 5: Deployment initiated...")
|
||||
console.print("✅ GitHub Actions will now:")
|
||||
console.print(" 📚 Build the Jupyter Book")
|
||||
console.print(" 🚀 Deploy to GitHub Pages")
|
||||
console.print(" 🔗 Update live website")
|
||||
|
||||
# Try to get repository info for deployment URL
|
||||
try:
|
||||
result = subprocess.run([
|
||||
"git", "remote", "get-url", "origin"
|
||||
], capture_output=True, text=True, cwd=".")
|
||||
|
||||
if result.returncode == 0:
|
||||
remote_url = result.stdout.strip()
|
||||
if "github.com" in remote_url:
|
||||
# Extract owner/repo from git URL
|
||||
if remote_url.endswith(".git"):
|
||||
remote_url = remote_url[:-4]
|
||||
if remote_url.startswith("git@github.com:"):
|
||||
repo_path = remote_url.replace("git@github.com:", "")
|
||||
elif remote_url.startswith("https://github.com/"):
|
||||
repo_path = remote_url.replace("https://github.com/", "")
|
||||
else:
|
||||
repo_path = None
|
||||
|
||||
if repo_path:
|
||||
console.print(f"\n🔗 Monitor deployment: https://github.com/{repo_path}/actions")
|
||||
console.print(f"📖 Live website: https://{repo_path.split('/')[0]}.github.io/{repo_path.split('/')[1]}/")
|
||||
|
||||
except Exception:
|
||||
# Don't fail the whole command if we can't get repo info
|
||||
pass
|
||||
|
||||
console.print("\n🎉 Publishing workflow complete!")
|
||||
console.print("💡 Check GitHub Actions for deployment status")
|
||||
|
||||
return 0
|
||||
@@ -1,181 +0,0 @@
|
||||
"""
|
||||
Check command for TinyTorch CLI: comprehensive environment validation.
|
||||
|
||||
Runs 60+ automated tests to validate the entire TinyTorch environment.
|
||||
Perfect for students to share with TAs when something isn't working.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import subprocess
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from pathlib import Path
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
|
||||
from .base import BaseCommand
|
||||
|
||||
class CheckCommand(BaseCommand):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "check"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Run comprehensive environment validation (60+ tests)"
|
||||
|
||||
def add_arguments(self, parser: ArgumentParser) -> None:
|
||||
parser.add_argument(
|
||||
'--verbose',
|
||||
'-v',
|
||||
action='store_true',
|
||||
help='Show detailed test output'
|
||||
)
|
||||
|
||||
def run(self, args: Namespace) -> int:
|
||||
"""Run comprehensive validation tests with rich output."""
|
||||
console = self.console
|
||||
|
||||
console.print()
|
||||
console.print(Panel(
|
||||
"🧪 Running Comprehensive Environment Validation\n\n"
|
||||
"This will test 60+ aspects of your TinyTorch environment.\n"
|
||||
"Perfect for sharing with TAs if something isn't working!",
|
||||
title="TinyTorch Environment Check",
|
||||
border_style="bright_cyan"
|
||||
))
|
||||
console.print()
|
||||
|
||||
# Check if tests directory exists
|
||||
tests_dir = Path("tests/environment")
|
||||
if not tests_dir.exists():
|
||||
console.print(Panel(
|
||||
"[red]❌ Validation tests not found![/red]\n\n"
|
||||
f"Expected location: {tests_dir.absolute()}\n\n"
|
||||
"Please ensure you're running this from the TinyTorch root directory.",
|
||||
title="Error",
|
||||
border_style="red"
|
||||
))
|
||||
return 1
|
||||
|
||||
# Run the validation tests with pytest
|
||||
test_files = [
|
||||
"tests/environment/test_setup_validation.py",
|
||||
"tests/environment/test_all_requirements.py"
|
||||
]
|
||||
|
||||
console.print("[bold cyan]Running validation tests...[/bold cyan]")
|
||||
console.print()
|
||||
|
||||
# Build pytest command
|
||||
pytest_args = [
|
||||
sys.executable, "-m", "pytest"
|
||||
] + test_files + [
|
||||
"-v" if args.verbose else "-q",
|
||||
"--tb=short",
|
||||
"--color=yes",
|
||||
"-p", "no:warnings" # Suppress warnings for cleaner output
|
||||
]
|
||||
|
||||
# Run pytest and capture output
|
||||
result = subprocess.run(
|
||||
pytest_args,
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
# Parse test results from output
|
||||
output_lines = result.stdout.split('\n')
|
||||
|
||||
# Count results
|
||||
passed = failed = skipped = 0
|
||||
|
||||
for line in output_lines:
|
||||
if 'passed' in line.lower():
|
||||
# Extract numbers from pytest summary
|
||||
import re
|
||||
match = re.search(r'(\d+) passed', line)
|
||||
if match:
|
||||
passed = int(match.group(1))
|
||||
match = re.search(r'(\d+) failed', line)
|
||||
if match:
|
||||
failed = int(match.group(1))
|
||||
match = re.search(r'(\d+) skipped', line)
|
||||
if match:
|
||||
skipped = int(match.group(1))
|
||||
|
||||
# Display results with rich formatting
|
||||
console.print()
|
||||
|
||||
# Summary table
|
||||
results_table = Table(title="Test Results Summary", show_header=True, header_style="bold magenta")
|
||||
results_table.add_column("Category", style="cyan", width=30)
|
||||
results_table.add_column("Count", justify="right", width=10)
|
||||
results_table.add_column("Status", width=20)
|
||||
|
||||
if passed > 0:
|
||||
results_table.add_row("Tests Passed", str(passed), "[green]✅ OK[/green]")
|
||||
if failed > 0:
|
||||
results_table.add_row("Tests Failed", str(failed), "[red]❌ Issues Found[/red]")
|
||||
if skipped > 0:
|
||||
results_table.add_row("Tests Skipped", str(skipped), "[yellow]⏭️ Optional[/yellow]")
|
||||
|
||||
console.print(results_table)
|
||||
console.print()
|
||||
|
||||
# Overall health status
|
||||
if failed == 0:
|
||||
status_panel = Panel(
|
||||
"[bold green]✅ Environment is HEALTHY![/bold green]\n\n"
|
||||
f"All {passed} required checks passed.\n"
|
||||
f"{skipped} optional checks skipped.\n\n"
|
||||
"Your TinyTorch environment is ready to use! 🎉\n\n"
|
||||
"[dim]Next: [/dim][cyan]tito module 01[/cyan]",
|
||||
title="Environment Status",
|
||||
border_style="green"
|
||||
)
|
||||
else:
|
||||
status_panel = Panel(
|
||||
f"[bold red]❌ Found {failed} issue(s)[/bold red]\n\n"
|
||||
f"{passed} checks passed, but some components need attention.\n\n"
|
||||
"[bold]What to share with your TA:[/bold]\n"
|
||||
"1. Copy the output above\n"
|
||||
"2. Include the error messages below\n"
|
||||
"3. Mention what you were trying to do\n\n"
|
||||
"[dim]Or try:[/dim] [cyan]tito setup[/cyan] [dim]to reinstall[/dim]",
|
||||
title="Environment Status",
|
||||
border_style="red"
|
||||
)
|
||||
|
||||
console.print(status_panel)
|
||||
|
||||
# Show detailed output if verbose or if there are failures
|
||||
if args.verbose or failed > 0:
|
||||
console.print()
|
||||
console.print(Panel("📋 Detailed Test Output", border_style="blue"))
|
||||
console.print()
|
||||
console.print(result.stdout)
|
||||
|
||||
if result.stderr:
|
||||
console.print()
|
||||
console.print(Panel("⚠️ Error Messages", border_style="yellow"))
|
||||
console.print()
|
||||
console.print(result.stderr)
|
||||
|
||||
# Add helpful hints for common failures
|
||||
if failed > 0:
|
||||
console.print()
|
||||
console.print(Panel(
|
||||
"[bold]Common Solutions:[/bold]\n\n"
|
||||
"• Missing packages: [cyan]pip install -r requirements.txt[/cyan]\n"
|
||||
"• Jupyter issues: [cyan]pip install --upgrade jupyterlab[/cyan]\n"
|
||||
"• Import errors: [cyan]pip install -e .[/cyan] [dim](reinstall TinyTorch)[/dim]\n"
|
||||
"• Still stuck: Run [cyan]tito system check --verbose[/cyan]\n\n"
|
||||
"[dim]Then share the full output with your TA[/dim]",
|
||||
title="💡 Quick Fixes",
|
||||
border_style="yellow"
|
||||
))
|
||||
|
||||
console.print()
|
||||
|
||||
# Return appropriate exit code
|
||||
return 0 if failed == 0 else 1
|
||||
@@ -1,690 +0,0 @@
|
||||
"""
|
||||
Checkpoint tracking and visualization command for TinyTorch CLI.
|
||||
|
||||
Provides capability-based progress tracking through the ML systems engineering journey:
|
||||
Foundation → Architecture → Training → Inference → Serving
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.progress import Progress, BarColumn, TextColumn, SpinnerColumn
|
||||
from rich.table import Table
|
||||
from rich.tree import Tree
|
||||
from rich.text import Text
|
||||
from rich.layout import Layout
|
||||
from rich.columns import Columns
|
||||
from rich.status import Status
|
||||
|
||||
from .base import BaseCommand
|
||||
from ..core.config import CLIConfig
|
||||
from ..core.console import get_console, print_error, print_success
|
||||
|
||||
|
||||
class CheckpointSystem:
|
||||
"""Core checkpoint tracking system."""
|
||||
|
||||
# Define the 20-checkpoint structure for complete ML systems engineering journey
|
||||
CHECKPOINTS = {
|
||||
"00": {
|
||||
"name": "Environment",
|
||||
"description": "Development environment setup and configuration",
|
||||
"test_file": "checkpoint_00_environment.py",
|
||||
"capability": "Can I configure my TinyTorch development environment?"
|
||||
},
|
||||
"01": {
|
||||
"name": "Foundation",
|
||||
"description": "Basic tensor operations and ML building blocks",
|
||||
"test_file": "checkpoint_01_foundation.py",
|
||||
"capability": "Can I create and manipulate the building blocks of ML?"
|
||||
},
|
||||
"02": {
|
||||
"name": "Intelligence",
|
||||
"description": "Nonlinear activation functions",
|
||||
"test_file": "checkpoint_02_intelligence.py",
|
||||
"capability": "Can I add nonlinearity - the key to neural network intelligence?"
|
||||
},
|
||||
"03": {
|
||||
"name": "Components",
|
||||
"description": "Fundamental neural network building blocks",
|
||||
"test_file": "checkpoint_03_components.py",
|
||||
"capability": "Can I build the fundamental building blocks of neural networks?"
|
||||
},
|
||||
"04": {
|
||||
"name": "Networks",
|
||||
"description": "Complete multi-layer neural networks",
|
||||
"test_file": "checkpoint_04_networks.py",
|
||||
"capability": "Can I build complete multi-layer neural networks?"
|
||||
},
|
||||
"05": {
|
||||
"name": "Learning",
|
||||
"description": "Spatial data processing with convolutional operations",
|
||||
"test_file": "checkpoint_05_learning.py",
|
||||
"capability": "Can I process spatial data like images with convolutional operations?"
|
||||
},
|
||||
"06": {
|
||||
"name": "Attention",
|
||||
"description": "Attention mechanisms for sequence understanding",
|
||||
"test_file": "checkpoint_06_attention.py",
|
||||
"capability": "Can I build attention mechanisms for sequence understanding?"
|
||||
},
|
||||
"07": {
|
||||
"name": "Stability",
|
||||
"description": "Training stabilization with normalization",
|
||||
"test_file": "checkpoint_07_stability.py",
|
||||
"capability": "Can I stabilize training with normalization techniques?"
|
||||
},
|
||||
"08": {
|
||||
"name": "Differentiation",
|
||||
"description": "Automatic gradient computation for learning",
|
||||
"test_file": "checkpoint_08_differentiation.py",
|
||||
"capability": "Can I automatically compute gradients for learning?"
|
||||
},
|
||||
"09": {
|
||||
"name": "Optimization",
|
||||
"description": "Sophisticated optimization algorithms",
|
||||
"test_file": "checkpoint_09_optimization.py",
|
||||
"capability": "Can I optimize neural networks with sophisticated algorithms?"
|
||||
},
|
||||
"10": {
|
||||
"name": "Training",
|
||||
"description": "Complete training loops for end-to-end learning",
|
||||
"test_file": "checkpoint_10_training.py",
|
||||
"capability": "Can I build complete training loops for end-to-end learning?"
|
||||
},
|
||||
"11": {
|
||||
"name": "Regularization",
|
||||
"description": "Overfitting prevention and robust model building",
|
||||
"test_file": "checkpoint_11_regularization.py",
|
||||
"capability": "Can I prevent overfitting and build robust models?"
|
||||
},
|
||||
"12": {
|
||||
"name": "Kernels",
|
||||
"description": "High-performance computational kernels",
|
||||
"test_file": "checkpoint_12_kernels.py",
|
||||
"capability": "Can I implement high-performance computational kernels?"
|
||||
},
|
||||
"13": {
|
||||
"name": "Benchmarking",
|
||||
"description": "Performance analysis and bottleneck identification",
|
||||
"test_file": "checkpoint_13_benchmarking.py",
|
||||
"capability": "Can I analyze performance and identify bottlenecks in ML systems?"
|
||||
},
|
||||
"14": {
|
||||
"name": "Deployment",
|
||||
"description": "Production deployment and monitoring",
|
||||
"test_file": "checkpoint_14_deployment.py",
|
||||
"capability": "Can I deploy and monitor ML systems in production?"
|
||||
},
|
||||
"15": {
|
||||
"name": "Acceleration",
|
||||
"description": "Algorithmic optimization and acceleration techniques",
|
||||
"test_file": "checkpoint_15_acceleration.py",
|
||||
"capability": "Can I accelerate computations through algorithmic optimization?"
|
||||
},
|
||||
"16": {
|
||||
"name": "Quantization",
|
||||
"description": "Trading precision for speed with INT8 quantization",
|
||||
"test_file": "checkpoint_16_quantization.py",
|
||||
"capability": "Can I trade precision for speed with INT8 quantization?"
|
||||
},
|
||||
"17": {
|
||||
"name": "Compression",
|
||||
"description": "Neural network pruning for edge deployment",
|
||||
"test_file": "checkpoint_17_compression.py",
|
||||
"capability": "Can I remove 70% of parameters while maintaining accuracy?"
|
||||
},
|
||||
"18": {
|
||||
"name": "Caching",
|
||||
"description": "KV caching for transformer inference optimization",
|
||||
"test_file": "checkpoint_18_caching.py",
|
||||
"capability": "Can I transform O(N²) to O(N) complexity with intelligent caching?"
|
||||
},
|
||||
"19": {
|
||||
"name": "Competition",
|
||||
"description": "TinyMLPerf competition system for optimization mastery",
|
||||
"test_file": "checkpoint_19_competition.py",
|
||||
"capability": "Can I build competition-grade benchmarking infrastructure?"
|
||||
},
|
||||
"20": {
|
||||
"name": "TinyGPT Capstone",
|
||||
"description": "Complete language model demonstrating ML systems mastery",
|
||||
"test_file": "checkpoint_20_capstone.py",
|
||||
"capability": "Can I build a complete language model that generates coherent text from scratch?"
|
||||
}
|
||||
}
|
||||
|
||||
def __init__(self, config: CLIConfig):
|
||||
"""Initialize checkpoint system."""
|
||||
self.config = config
|
||||
self.console = get_console()
|
||||
self.modules_dir = config.project_root / "modules" / "source"
|
||||
self.checkpoints_dir = config.project_root / "tests" / "checkpoints"
|
||||
|
||||
def get_checkpoint_test_status(self, checkpoint_id: str) -> Dict[str, bool]:
|
||||
"""Get the status of a checkpoint test file."""
|
||||
if checkpoint_id not in self.CHECKPOINTS:
|
||||
return {"exists": False, "tested": False, "passed": False}
|
||||
|
||||
test_file = self.CHECKPOINTS[checkpoint_id]["test_file"]
|
||||
test_path = self.checkpoints_dir / test_file
|
||||
|
||||
return {
|
||||
"exists": test_path.exists(),
|
||||
"tested": False, # Will be set when we run tests
|
||||
"passed": False # Will be set based on test results
|
||||
}
|
||||
|
||||
def get_checkpoint_status(self, checkpoint_id: str) -> Dict:
|
||||
"""Get status information for a checkpoint."""
|
||||
checkpoint = self.CHECKPOINTS[checkpoint_id]
|
||||
test_status = self.get_checkpoint_test_status(checkpoint_id)
|
||||
|
||||
return {
|
||||
"checkpoint": checkpoint,
|
||||
"test_status": test_status,
|
||||
"is_available": test_status["exists"],
|
||||
"is_complete": test_status.get("passed", False),
|
||||
"checkpoint_id": checkpoint_id
|
||||
}
|
||||
|
||||
def get_overall_progress(self) -> Dict:
|
||||
"""Get overall progress across all checkpoints."""
|
||||
checkpoints_status = {}
|
||||
current_checkpoint = None
|
||||
total_complete = 0
|
||||
total_checkpoints = len(self.CHECKPOINTS)
|
||||
|
||||
for checkpoint_id in self.CHECKPOINTS.keys():
|
||||
status = self.get_checkpoint_status(checkpoint_id)
|
||||
checkpoints_status[checkpoint_id] = status
|
||||
|
||||
if status["is_complete"]:
|
||||
total_complete += 1
|
||||
elif current_checkpoint is None and status["is_available"]:
|
||||
# First available but incomplete checkpoint is current
|
||||
current_checkpoint = checkpoint_id
|
||||
|
||||
# If all are complete, set current to last checkpoint
|
||||
if current_checkpoint is None and total_complete == total_checkpoints:
|
||||
current_checkpoint = list(self.CHECKPOINTS.keys())[-1]
|
||||
# If none are complete, start with first
|
||||
elif current_checkpoint is None:
|
||||
current_checkpoint = "00"
|
||||
|
||||
# Calculate overall percentage
|
||||
overall_percent = (total_complete / total_checkpoints * 100) if total_checkpoints > 0 else 0
|
||||
|
||||
return {
|
||||
"checkpoints": checkpoints_status,
|
||||
"current": current_checkpoint,
|
||||
"overall_progress": overall_percent,
|
||||
"total_complete": total_complete,
|
||||
"total_checkpoints": total_checkpoints
|
||||
}
|
||||
|
||||
def run_checkpoint_test(self, checkpoint_id: str) -> Dict:
|
||||
"""Run a specific checkpoint test and return results."""
|
||||
if checkpoint_id not in self.CHECKPOINTS:
|
||||
return {"success": False, "error": f"Unknown checkpoint: {checkpoint_id}"}
|
||||
|
||||
checkpoint = self.CHECKPOINTS[checkpoint_id]
|
||||
test_file = checkpoint["test_file"]
|
||||
test_path = self.checkpoints_dir / test_file
|
||||
|
||||
if not test_path.exists():
|
||||
return {"success": False, "error": f"Test file not found: {test_file}"}
|
||||
|
||||
try:
|
||||
# Run the test using subprocess to capture output
|
||||
result = subprocess.run(
|
||||
[sys.executable, str(test_path)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=self.config.project_root,
|
||||
timeout=30 # 30 second timeout
|
||||
)
|
||||
|
||||
return {
|
||||
"success": result.returncode == 0,
|
||||
"returncode": result.returncode,
|
||||
"stdout": result.stdout,
|
||||
"stderr": result.stderr,
|
||||
"checkpoint_name": checkpoint["name"],
|
||||
"capability": checkpoint["capability"]
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return {"success": False, "error": "Test timed out after 30 seconds"}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": f"Test execution failed: {str(e)}"}
|
||||
|
||||
|
||||
class CheckpointCommand(BaseCommand):
|
||||
"""Checkpoint tracking and visualization command."""
|
||||
|
||||
name = "checkpoint"
|
||||
description = "Track and visualize ML systems engineering progress through checkpoints"
|
||||
|
||||
def add_arguments(self, parser: argparse.ArgumentParser) -> None:
|
||||
"""Add checkpoint-specific arguments."""
|
||||
subparsers = parser.add_subparsers(
|
||||
dest='checkpoint_command',
|
||||
help='Checkpoint operations',
|
||||
metavar='COMMAND'
|
||||
)
|
||||
|
||||
# Status command
|
||||
status_parser = subparsers.add_parser(
|
||||
'status',
|
||||
help='Show current checkpoint progress'
|
||||
)
|
||||
status_parser.add_argument(
|
||||
'--detailed', '-d',
|
||||
action='store_true',
|
||||
help='Show detailed module-level progress'
|
||||
)
|
||||
|
||||
# Timeline command
|
||||
timeline_parser = subparsers.add_parser(
|
||||
'timeline',
|
||||
help='Show visual progress timeline'
|
||||
)
|
||||
timeline_parser.add_argument(
|
||||
'--horizontal',
|
||||
action='store_true',
|
||||
help='Show horizontal timeline (default: vertical)'
|
||||
)
|
||||
|
||||
# Test command
|
||||
test_parser = subparsers.add_parser(
|
||||
'test',
|
||||
help='Test checkpoint capabilities'
|
||||
)
|
||||
test_parser.add_argument(
|
||||
'checkpoint_id',
|
||||
nargs='?',
|
||||
help='Checkpoint ID to test (00-20, current checkpoint if not specified)'
|
||||
)
|
||||
|
||||
# Run command (new)
|
||||
run_parser = subparsers.add_parser(
|
||||
'run',
|
||||
help='Run specific checkpoint tests with progress tracking'
|
||||
)
|
||||
run_parser.add_argument(
|
||||
'checkpoint_id',
|
||||
help='Checkpoint ID to run (00-20)'
|
||||
)
|
||||
run_parser.add_argument(
|
||||
'--verbose', '-v',
|
||||
action='store_true',
|
||||
help='Show detailed test output'
|
||||
)
|
||||
|
||||
# Unlock command
|
||||
unlock_parser = subparsers.add_parser(
|
||||
'unlock',
|
||||
help='Attempt to unlock next checkpoint'
|
||||
)
|
||||
|
||||
def run(self, args: argparse.Namespace) -> int:
|
||||
"""Execute checkpoint command."""
|
||||
checkpoint_system = CheckpointSystem(self.config)
|
||||
|
||||
if not args.checkpoint_command:
|
||||
return self._show_help(args)
|
||||
|
||||
if args.checkpoint_command == 'status':
|
||||
return self._show_status(checkpoint_system, args)
|
||||
elif args.checkpoint_command == 'timeline':
|
||||
return self._show_timeline(checkpoint_system, args)
|
||||
elif args.checkpoint_command == 'test':
|
||||
return self._test_checkpoint(checkpoint_system, args)
|
||||
elif args.checkpoint_command == 'run':
|
||||
return self._run_checkpoint(checkpoint_system, args)
|
||||
elif args.checkpoint_command == 'unlock':
|
||||
return self._unlock_checkpoint(checkpoint_system, args)
|
||||
else:
|
||||
print_error(f"Unknown checkpoint command: {args.checkpoint_command}")
|
||||
return 1
|
||||
|
||||
def _show_help(self, args: argparse.Namespace) -> int:
|
||||
"""Show checkpoint command help."""
|
||||
console = get_console()
|
||||
console.print(Panel(
|
||||
"[bold cyan]TinyTorch Checkpoint System[/bold cyan]\n\n"
|
||||
"[bold]Track your progress through 20 capability checkpoints:[/bold]\n"
|
||||
" 00-04: Foundation → Environment, tensors, networks\n"
|
||||
" 05-09: Architecture → Spatial, attention, autograd, optimization\n"
|
||||
" 10-14: Systems → Training, kernels, benchmarking, deployment\n"
|
||||
" 15-19: Optimization → Acceleration, quantization, compression, caching, competition\n"
|
||||
" 20: Capstone → Complete TinyGPT language model\n\n"
|
||||
"[bold]Available Commands:[/bold]\n"
|
||||
" [green]status[/green] - Show current progress and capabilities\n"
|
||||
" [green]timeline[/green] - Visual progress timeline\n"
|
||||
" [green]test[/green] - Test checkpoint capabilities\n"
|
||||
" [green]run[/green] - Run specific checkpoint with progress\n"
|
||||
" [green]unlock[/green] - Attempt to unlock next checkpoint\n\n"
|
||||
"[bold]Examples:[/bold]\n"
|
||||
" [dim]tito checkpoint status --detailed[/dim]\n"
|
||||
" [dim]tito checkpoint timeline --horizontal[/dim]\n"
|
||||
" [dim]tito checkpoint test 16[/dim]\n"
|
||||
" [dim]tito checkpoint run 20 --verbose[/dim]",
|
||||
title="Checkpoint System (20 Checkpoints)",
|
||||
border_style="bright_blue"
|
||||
))
|
||||
return 0
|
||||
|
||||
def _show_status(self, checkpoint_system: CheckpointSystem, args: argparse.Namespace) -> int:
|
||||
"""Show checkpoint status."""
|
||||
console = get_console()
|
||||
progress_data = checkpoint_system.get_overall_progress()
|
||||
|
||||
# Header
|
||||
console.print(Panel(
|
||||
"[bold cyan]🚀 TinyTorch Framework Capabilities[/bold cyan]",
|
||||
border_style="bright_blue"
|
||||
))
|
||||
|
||||
# Overall progress
|
||||
overall_percent = progress_data["overall_progress"]
|
||||
console.print(f"\n[bold]Overall Progress:[/bold] {overall_percent:.0f}% ({progress_data['total_complete']}/{progress_data['total_checkpoints']} checkpoints)")
|
||||
|
||||
# Current status summary
|
||||
current = progress_data["current"]
|
||||
if current:
|
||||
current_status = progress_data["checkpoints"][current]
|
||||
current_name = current_status["checkpoint"]["name"]
|
||||
|
||||
console.print(f"[bold]Current Checkpoint:[/bold] {current:0>2} - {current_name}")
|
||||
|
||||
if current_status["is_complete"]:
|
||||
console.print(f"[bold green]✅ {current_name} checkpoint achieved![/bold green]")
|
||||
console.print(f"[dim]Capability unlocked: {current_status['checkpoint']['capability']}[/dim]")
|
||||
else:
|
||||
console.print(f"[bold yellow]🎯 Ready to test {current_name} capabilities[/bold yellow]")
|
||||
console.print(f"[dim]Goal: {current_status['checkpoint']['capability']}[/dim]")
|
||||
|
||||
console.print()
|
||||
|
||||
# Checkpoint progress
|
||||
for checkpoint_id, checkpoint_data in progress_data["checkpoints"].items():
|
||||
checkpoint = checkpoint_data["checkpoint"]
|
||||
|
||||
# Checkpoint header
|
||||
if checkpoint_data["is_complete"]:
|
||||
status_icon = "✅"
|
||||
status_color = "green"
|
||||
elif checkpoint_id == current:
|
||||
status_icon = "🎯"
|
||||
status_color = "yellow"
|
||||
else:
|
||||
status_icon = "⏳"
|
||||
status_color = "dim"
|
||||
|
||||
console.print(f"[bold]{status_icon} {checkpoint_id:0>2}: {checkpoint['name']}[/bold] [{status_color}]{'COMPLETE' if checkpoint_data['is_complete'] else 'PENDING'}[/{status_color}]")
|
||||
|
||||
if args.detailed:
|
||||
# Show test file and availability
|
||||
test_status = checkpoint_data["test_status"]
|
||||
test_available = "✅" if test_status["exists"] else "❌"
|
||||
console.print(f" {test_available} Test: {checkpoint['test_file']}")
|
||||
|
||||
console.print(f" [dim]{checkpoint['capability']}[/dim]\n")
|
||||
|
||||
return 0
|
||||
|
||||
def _show_timeline(self, checkpoint_system: CheckpointSystem, args: argparse.Namespace) -> int:
|
||||
"""Show visual timeline with Rich progress bar."""
|
||||
console = get_console()
|
||||
progress_data = checkpoint_system.get_overall_progress()
|
||||
|
||||
console.print("\n[bold cyan]🚀 TinyTorch Framework Progress Timeline[/bold cyan]\n")
|
||||
|
||||
if args.horizontal:
|
||||
# Enhanced horizontal timeline with progress line
|
||||
overall_percent = progress_data["overall_progress"]
|
||||
total_checkpoints = progress_data["total_checkpoints"]
|
||||
complete_checkpoints = progress_data["total_complete"]
|
||||
|
||||
# Create a visual progress bar
|
||||
filled = int(overall_percent / 2) # 50 characters total width
|
||||
bar = "█" * filled + "░" * (50 - filled)
|
||||
console.print(f"[bold]Overall:[/bold] [{bar}] {overall_percent:.0f}%")
|
||||
console.print(f"[dim]{complete_checkpoints}/{total_checkpoints} checkpoints complete[/dim]\n")
|
||||
|
||||
# Show checkpoint progression - group in rows of 8
|
||||
checkpoints_list = list(progress_data["checkpoints"].items())
|
||||
|
||||
for row_start in range(0, len(checkpoints_list), 8):
|
||||
row_checkpoints = checkpoints_list[row_start:row_start + 8]
|
||||
|
||||
# Build the checkpoint line for this row
|
||||
checkpoint_line = ""
|
||||
names_line = ""
|
||||
|
||||
for i, (checkpoint_id, checkpoint_data) in enumerate(row_checkpoints):
|
||||
checkpoint = checkpoint_data["checkpoint"]
|
||||
|
||||
# Checkpoint status
|
||||
if checkpoint_data["is_complete"]:
|
||||
checkpoint_marker = f"[green]●[/green]"
|
||||
name_color = "green"
|
||||
elif checkpoint_id == progress_data["current"]:
|
||||
checkpoint_marker = f"[yellow]◉[/yellow]"
|
||||
name_color = "yellow"
|
||||
else:
|
||||
checkpoint_marker = f"[dim]○[/dim]"
|
||||
name_color = "dim"
|
||||
|
||||
# Add checkpoint with ID
|
||||
checkpoint_line += f"{checkpoint_marker}{checkpoint_id}"
|
||||
names_line += f"[{name_color}]{checkpoint['name'][:9]:^9}[/{name_color}]"
|
||||
|
||||
# Add spacing (except for last in row)
|
||||
if i < len(row_checkpoints) - 1:
|
||||
if checkpoint_data["is_complete"]:
|
||||
checkpoint_line += "[green]━━[/green]"
|
||||
else:
|
||||
checkpoint_line += "[dim]━━[/dim]"
|
||||
names_line += " "
|
||||
|
||||
console.print(checkpoint_line)
|
||||
console.print(names_line)
|
||||
console.print() # Empty line between rows
|
||||
|
||||
else:
|
||||
# Vertical timeline (tree structure)
|
||||
tree = Tree("ML Systems Engineering Journey (20 Checkpoints)")
|
||||
|
||||
for checkpoint_id, checkpoint_data in progress_data["checkpoints"].items():
|
||||
checkpoint = checkpoint_data["checkpoint"]
|
||||
|
||||
if checkpoint_data["is_complete"]:
|
||||
checkpoint_text = f"[green]✅ {checkpoint_id}: {checkpoint['name']}[/green]"
|
||||
elif checkpoint_id == progress_data["current"]:
|
||||
checkpoint_text = f"[yellow]🎯 {checkpoint_id}: {checkpoint['name']} (CURRENT)[/yellow]"
|
||||
else:
|
||||
checkpoint_text = f"[dim]⏳ {checkpoint_id}: {checkpoint['name']}[/dim]"
|
||||
|
||||
checkpoint_node = tree.add(checkpoint_text)
|
||||
checkpoint_node.add(f"[dim]{checkpoint['capability']}[/dim]")
|
||||
|
||||
console.print(tree)
|
||||
|
||||
console.print()
|
||||
return 0
|
||||
|
||||
def _test_checkpoint(self, checkpoint_system: CheckpointSystem, args: argparse.Namespace) -> int:
|
||||
"""Test checkpoint capabilities."""
|
||||
console = get_console()
|
||||
|
||||
# Determine which checkpoint to test
|
||||
checkpoint_id = args.checkpoint_id
|
||||
if not checkpoint_id:
|
||||
progress_data = checkpoint_system.get_overall_progress()
|
||||
checkpoint_id = progress_data["current"]
|
||||
|
||||
# Validate checkpoint ID
|
||||
if checkpoint_id not in checkpoint_system.CHECKPOINTS:
|
||||
print_error(f"Unknown checkpoint: {checkpoint_id}")
|
||||
console.print(f"[dim]Available checkpoints: {', '.join(checkpoint_system.CHECKPOINTS.keys())}[/dim]")
|
||||
return 1
|
||||
|
||||
checkpoint = checkpoint_system.CHECKPOINTS[checkpoint_id]
|
||||
|
||||
# Show what we're testing
|
||||
console.print(f"\n[bold cyan]Testing Checkpoint {checkpoint_id}: {checkpoint['name']}[/bold cyan]")
|
||||
console.print(f"[bold]Capability Question:[/bold] {checkpoint['capability']}\n")
|
||||
|
||||
# Run the test
|
||||
with console.status(f"[bold green]Running checkpoint {checkpoint_id} test...", spinner="dots") as status:
|
||||
result = checkpoint_system.run_checkpoint_test(checkpoint_id)
|
||||
|
||||
# Display results
|
||||
if result["success"]:
|
||||
console.print(f"[bold green]✅ Checkpoint {checkpoint_id} PASSED![/bold green]")
|
||||
console.print(f"[green]Capability achieved: {checkpoint['capability']}[/green]\n")
|
||||
|
||||
# Show brief output
|
||||
if result.get("stdout") and "🎉" in result["stdout"]:
|
||||
# Extract the completion message
|
||||
lines = result["stdout"].split('\n')
|
||||
for line in lines:
|
||||
if "🎉" in line or "📝" in line or "🎯" in line:
|
||||
console.print(f"[dim]{line}[/dim]")
|
||||
|
||||
print_success(f"Checkpoint {checkpoint_id} test completed successfully!")
|
||||
return 0
|
||||
else:
|
||||
console.print(f"[bold red]❌ Checkpoint {checkpoint_id} FAILED[/bold red]\n")
|
||||
|
||||
# Show error details
|
||||
if "error" in result:
|
||||
console.print(f"[red]Error: {result['error']}[/red]")
|
||||
elif result.get("stderr"):
|
||||
console.print(f"[red]Error output:[/red]")
|
||||
console.print(f"[dim]{result['stderr']}[/dim]")
|
||||
elif result.get("stdout"):
|
||||
console.print(f"[yellow]Test output:[/yellow]")
|
||||
console.print(f"[dim]{result['stdout']}[/dim]")
|
||||
|
||||
print_error(f"Checkpoint {checkpoint_id} test failed")
|
||||
return 1
|
||||
|
||||
def _run_checkpoint(self, checkpoint_system: CheckpointSystem, args: argparse.Namespace) -> int:
|
||||
"""Run specific checkpoint test with detailed progress tracking."""
|
||||
console = get_console()
|
||||
checkpoint_id = args.checkpoint_id
|
||||
|
||||
# Validate checkpoint ID
|
||||
if checkpoint_id not in checkpoint_system.CHECKPOINTS:
|
||||
print_error(f"Unknown checkpoint: {checkpoint_id}")
|
||||
console.print(f"[dim]Available checkpoints: {', '.join(checkpoint_system.CHECKPOINTS.keys())}[/dim]")
|
||||
return 1
|
||||
|
||||
checkpoint = checkpoint_system.CHECKPOINTS[checkpoint_id]
|
||||
|
||||
# Show detailed information
|
||||
console.print(Panel(
|
||||
f"[bold cyan]Checkpoint {checkpoint_id}: {checkpoint['name']}[/bold cyan]\n\n"
|
||||
f"[bold]Capability Question:[/bold]\n{checkpoint['capability']}\n\n"
|
||||
f"[bold]Test File:[/bold] {checkpoint['test_file']}\n"
|
||||
f"[bold]Description:[/bold] {checkpoint['description']}",
|
||||
title=f"Running Checkpoint {checkpoint_id}",
|
||||
border_style="bright_blue"
|
||||
))
|
||||
|
||||
# Check if test file exists
|
||||
test_path = checkpoint_system.checkpoints_dir / checkpoint["test_file"]
|
||||
if not test_path.exists():
|
||||
print_error(f"Test file not found: {checkpoint['test_file']}")
|
||||
return 1
|
||||
|
||||
console.print(f"\n[bold]Executing test...[/bold]")
|
||||
|
||||
# Run the test with status feedback
|
||||
with console.status(f"[bold green]Running checkpoint {checkpoint_id} test...", spinner="dots"):
|
||||
result = checkpoint_system.run_checkpoint_test(checkpoint_id)
|
||||
|
||||
console.print()
|
||||
|
||||
# Display detailed results
|
||||
if result["success"]:
|
||||
console.print(Panel(
|
||||
f"[bold green]✅ SUCCESS![/bold green]\n\n"
|
||||
f"[green]Checkpoint {checkpoint_id} completed successfully![/green]\n"
|
||||
f"[green]Capability achieved: {checkpoint['capability']}[/green]",
|
||||
title="Test Results",
|
||||
border_style="green"
|
||||
))
|
||||
|
||||
# Show test output if verbose or if it contains key markers
|
||||
if args.verbose or (result.get("stdout") and any(marker in result["stdout"] for marker in ["🎉", "✅", "📝", "🎯"])):
|
||||
console.print(f"\n[bold]Test Output:[/bold]")
|
||||
if result.get("stdout"):
|
||||
console.print(result["stdout"])
|
||||
|
||||
return 0
|
||||
else:
|
||||
console.print(Panel(
|
||||
f"[bold red]❌ FAILED[/bold red]\n\n"
|
||||
f"[red]Checkpoint {checkpoint_id} test failed[/red]\n"
|
||||
f"[yellow]This indicates the required capabilities are not yet implemented.[/yellow]",
|
||||
title="Test Results",
|
||||
border_style="red"
|
||||
))
|
||||
|
||||
# Show error details
|
||||
if "error" in result:
|
||||
console.print(f"\n[bold red]Error:[/bold red] {result['error']}")
|
||||
|
||||
if args.verbose or "error" in result:
|
||||
if result.get("stdout"):
|
||||
console.print(f"\n[bold]Standard Output:[/bold]")
|
||||
console.print(result["stdout"])
|
||||
if result.get("stderr"):
|
||||
console.print(f"\n[bold]Error Output:[/bold]")
|
||||
console.print(result["stderr"])
|
||||
|
||||
return 1
|
||||
|
||||
def _unlock_checkpoint(self, checkpoint_system: CheckpointSystem, args: argparse.Namespace) -> int:
|
||||
"""Attempt to unlock next checkpoint."""
|
||||
console = get_console()
|
||||
progress_data = checkpoint_system.get_overall_progress()
|
||||
current = progress_data["current"]
|
||||
|
||||
if not current:
|
||||
console.print("[green]All checkpoints completed! 🎉[/green]")
|
||||
return 0
|
||||
|
||||
current_status = progress_data["checkpoints"][current]
|
||||
|
||||
if current_status["is_complete"]:
|
||||
console.print(f"[green]✅ Checkpoint {current} ({current_status['checkpoint']['name']}) already complete![/green]")
|
||||
|
||||
# Find next checkpoint
|
||||
checkpoint_ids = list(checkpoint_system.CHECKPOINTS.keys())
|
||||
try:
|
||||
current_index = checkpoint_ids.index(current)
|
||||
if current_index < len(checkpoint_ids) - 1:
|
||||
next_id = checkpoint_ids[current_index + 1]
|
||||
next_checkpoint = checkpoint_system.CHECKPOINTS[next_id]
|
||||
console.print(f"[bold]Next checkpoint:[/bold] {next_id} - {next_checkpoint['name']}")
|
||||
console.print(f"[dim]Goal: {next_checkpoint['capability']}[/dim]")
|
||||
else:
|
||||
console.print("[bold]🎉 All checkpoints completed![/bold]")
|
||||
except ValueError:
|
||||
console.print("[yellow]Cannot determine next checkpoint[/yellow]")
|
||||
else:
|
||||
console.print(f"[yellow]Test checkpoint {current} to unlock your next capability:[/yellow]")
|
||||
console.print(f"[bold]Goal:[/bold] {current_status['checkpoint']['capability']}")
|
||||
console.print(f"[dim]Run: tito checkpoint run {current}[/dim]")
|
||||
|
||||
return 0
|
||||
@@ -1,232 +0,0 @@
|
||||
"""
|
||||
Clean command for TinyTorch CLI: clean up generated files and caches.
|
||||
"""
|
||||
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.prompt import Confirm
|
||||
|
||||
from .base import BaseCommand
|
||||
|
||||
class CleanWorkspaceCommand(BaseCommand):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "clean"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Clean up generated files, caches, and temporary files"
|
||||
|
||||
def add_arguments(self, parser: ArgumentParser) -> None:
|
||||
parser.add_argument(
|
||||
'--all',
|
||||
action='store_true',
|
||||
help='Clean everything including build artifacts'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dry-run',
|
||||
action='store_true',
|
||||
help='Show what would be deleted without actually deleting'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-y', '--yes',
|
||||
action='store_true',
|
||||
help='Skip confirmation prompt'
|
||||
)
|
||||
|
||||
def run(self, args: Namespace) -> int:
|
||||
console = self.console
|
||||
|
||||
console.print()
|
||||
console.print(Panel(
|
||||
"🧹 Cleaning TinyTorch Workspace",
|
||||
title="Workspace Cleanup",
|
||||
border_style="bright_yellow"
|
||||
))
|
||||
console.print()
|
||||
|
||||
# Define patterns to clean
|
||||
patterns = {
|
||||
'__pycache__': ('__pycache__/', 'Python bytecode cache'),
|
||||
'.pytest_cache': ('.pytest_cache/', 'Pytest cache'),
|
||||
'.ipynb_checkpoints': ('.ipynb_checkpoints/', 'Jupyter checkpoints'),
|
||||
'*.pyc': ('*.pyc', 'Compiled Python files'),
|
||||
'*.pyo': ('*.pyo', 'Optimized Python files'),
|
||||
'*.pyd': ('*.pyd', 'Python extension modules'),
|
||||
}
|
||||
|
||||
if args.all:
|
||||
# Additional patterns for --all
|
||||
patterns.update({
|
||||
'.coverage': ('.coverage', 'Coverage data'),
|
||||
'htmlcov': ('htmlcov/', 'Coverage HTML report'),
|
||||
'.tox': ('.tox/', 'Tox environments'),
|
||||
'dist': ('dist/', 'Distribution files'),
|
||||
'build': ('build/', 'Build files'),
|
||||
'*.egg-info': ('*.egg-info/', 'Egg info directories'),
|
||||
})
|
||||
|
||||
# Scan for files to delete
|
||||
console.print("[bold cyan]🔍 Scanning for files to clean...[/bold cyan]")
|
||||
console.print()
|
||||
|
||||
files_to_delete = []
|
||||
total_size = 0
|
||||
|
||||
# Find __pycache__ directories
|
||||
for pycache_dir in Path.cwd().rglob('__pycache__'):
|
||||
for file in pycache_dir.iterdir():
|
||||
if file.is_file():
|
||||
files_to_delete.append(file)
|
||||
total_size += file.stat().st_size
|
||||
files_to_delete.append(pycache_dir)
|
||||
|
||||
# Find .pytest_cache directories
|
||||
for cache_dir in Path.cwd().rglob('.pytest_cache'):
|
||||
for file in cache_dir.rglob('*'):
|
||||
if file.is_file():
|
||||
files_to_delete.append(file)
|
||||
total_size += file.stat().st_size
|
||||
files_to_delete.append(cache_dir)
|
||||
|
||||
# Find .ipynb_checkpoints directories
|
||||
for checkpoint_dir in Path.cwd().rglob('.ipynb_checkpoints'):
|
||||
for file in checkpoint_dir.rglob('*'):
|
||||
if file.is_file():
|
||||
files_to_delete.append(file)
|
||||
total_size += file.stat().st_size
|
||||
files_to_delete.append(checkpoint_dir)
|
||||
|
||||
# Find .pyc, .pyo, .pyd files
|
||||
for ext in ['*.pyc', '*.pyo', '*.pyd']:
|
||||
for file in Path.cwd().rglob(ext):
|
||||
if file.is_file():
|
||||
files_to_delete.append(file)
|
||||
total_size += file.stat().st_size
|
||||
|
||||
if args.all:
|
||||
# Additional cleanups for --all flag
|
||||
for pattern in ['.coverage', 'htmlcov', '.tox', 'dist', 'build']:
|
||||
target = Path.cwd() / pattern
|
||||
if target.exists():
|
||||
if target.is_file():
|
||||
files_to_delete.append(target)
|
||||
total_size += target.stat().st_size
|
||||
elif target.is_dir():
|
||||
for file in target.rglob('*'):
|
||||
if file.is_file():
|
||||
files_to_delete.append(file)
|
||||
total_size += file.stat().st_size
|
||||
files_to_delete.append(target)
|
||||
|
||||
# Find .egg-info directories
|
||||
for egg_info in Path.cwd().rglob('*.egg-info'):
|
||||
if egg_info.is_dir():
|
||||
for file in egg_info.rglob('*'):
|
||||
if file.is_file():
|
||||
files_to_delete.append(file)
|
||||
total_size += file.stat().st_size
|
||||
files_to_delete.append(egg_info)
|
||||
|
||||
if not files_to_delete:
|
||||
console.print(Panel(
|
||||
"[green]✅ Workspace is already clean![/green]\n\n"
|
||||
"No temporary files or caches found.",
|
||||
title="Clean Workspace",
|
||||
border_style="green"
|
||||
))
|
||||
return 0
|
||||
|
||||
# Count file types
|
||||
file_count = len([f for f in files_to_delete if f.is_file()])
|
||||
dir_count = len([f for f in files_to_delete if f.is_dir()])
|
||||
|
||||
# Summary table
|
||||
summary_table = Table(title="Files Found", show_header=True, header_style="bold yellow")
|
||||
summary_table.add_column("Type", style="cyan", width=30)
|
||||
summary_table.add_column("Count", justify="right", width=15)
|
||||
summary_table.add_column("Size", width=20)
|
||||
|
||||
# Count by type
|
||||
pycache_count = len([f for f in files_to_delete if '__pycache__' in str(f)])
|
||||
pytest_count = len([f for f in files_to_delete if '.pytest_cache' in str(f)])
|
||||
checkpoint_count = len([f for f in files_to_delete if '.ipynb_checkpoints' in str(f)])
|
||||
pyc_count = len([f for f in files_to_delete if str(f).endswith(('.pyc', '.pyo', '.pyd'))])
|
||||
|
||||
if pycache_count > 0:
|
||||
summary_table.add_row("__pycache__/", str(pycache_count), "—")
|
||||
if pytest_count > 0:
|
||||
summary_table.add_row(".pytest_cache/", str(pytest_count), "—")
|
||||
if checkpoint_count > 0:
|
||||
summary_table.add_row(".ipynb_checkpoints/", str(checkpoint_count), "—")
|
||||
if pyc_count > 0:
|
||||
summary_table.add_row("*.pyc/*.pyo/*.pyd", str(pyc_count), f"{total_size / 1024 / 1024:.2f} MB")
|
||||
|
||||
if args.all:
|
||||
other_count = file_count - pycache_count - pytest_count - checkpoint_count - pyc_count
|
||||
if other_count > 0:
|
||||
summary_table.add_row("Build artifacts", str(other_count), "—")
|
||||
|
||||
summary_table.add_row("[bold]Total[/bold]", f"[bold]{file_count} files, {dir_count} dirs[/bold]", f"[bold]{total_size / 1024 / 1024:.2f} MB[/bold]")
|
||||
|
||||
console.print(summary_table)
|
||||
console.print()
|
||||
|
||||
# Dry run mode
|
||||
if args.dry_run:
|
||||
console.print(Panel(
|
||||
"[yellow]🔍 DRY RUN MODE[/yellow]\n\n"
|
||||
f"Would delete {file_count} files and {dir_count} directories ({total_size / 1024 / 1024:.2f} MB)\n\n"
|
||||
"[dim]Remove --dry-run flag to actually delete these files[/dim]",
|
||||
title="Dry Run",
|
||||
border_style="yellow"
|
||||
))
|
||||
return 0
|
||||
|
||||
# Confirm deletion
|
||||
if not args.yes:
|
||||
confirmed = Confirm.ask(
|
||||
f"\n[yellow]⚠️ Delete {file_count} files and {dir_count} directories ({total_size / 1024 / 1024:.2f} MB)?[/yellow]",
|
||||
default=False
|
||||
)
|
||||
if not confirmed:
|
||||
console.print("\n[dim]Cleanup cancelled.[/dim]")
|
||||
return 0
|
||||
|
||||
# Perform cleanup
|
||||
console.print()
|
||||
console.print("[bold cyan]🗑️ Cleaning workspace...[/bold cyan]")
|
||||
|
||||
deleted_files = 0
|
||||
deleted_dirs = 0
|
||||
freed_space = 0
|
||||
|
||||
# Delete files first, then directories
|
||||
for item in files_to_delete:
|
||||
try:
|
||||
if item.is_file():
|
||||
size = item.stat().st_size
|
||||
item.unlink()
|
||||
deleted_files += 1
|
||||
freed_space += size
|
||||
elif item.is_dir() and not any(item.iterdir()): # Only delete if empty
|
||||
shutil.rmtree(item, ignore_errors=True)
|
||||
deleted_dirs += 1
|
||||
except Exception as e:
|
||||
console.print(f"[dim red] ✗ Failed to delete {item}: {e}[/dim red]")
|
||||
|
||||
# Success message
|
||||
console.print()
|
||||
console.print(Panel(
|
||||
f"[bold green]✅ Workspace Cleaned![/bold green]\n\n"
|
||||
f"🗑️ Deleted: {deleted_files} files, {deleted_dirs} directories\n"
|
||||
f"💾 Freed: {freed_space / 1024 / 1024:.2f} MB\n"
|
||||
f"⏱️ Time: < 1 second",
|
||||
title="Cleanup Complete",
|
||||
border_style="green"
|
||||
))
|
||||
|
||||
return 0
|
||||
@@ -1,263 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tito Demo Command - Show off your AI capabilities!
|
||||
Runs progressive demos showing what TinyTorch can do at each stage.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
from .base import BaseCommand
|
||||
|
||||
console = Console()
|
||||
|
||||
class TinyTorchDemoMatrix:
|
||||
"""Tracks and displays TinyTorch AI demo capabilities"""
|
||||
|
||||
def __init__(self):
|
||||
self.demos = {
|
||||
'math': {
|
||||
'name': 'Mathematical Operations',
|
||||
'file': 'demo_tensor_math.py',
|
||||
'requires': ['02_tensor'],
|
||||
'description': 'Linear algebra, matrix operations, transformations'
|
||||
},
|
||||
'logic': {
|
||||
'name': 'Logical Reasoning',
|
||||
'file': 'demo_activations.py',
|
||||
'requires': ['02_tensor', '03_activations'],
|
||||
'description': 'Boolean functions, XOR problem, decision boundaries'
|
||||
},
|
||||
'neuron': {
|
||||
'name': 'Single Neuron Learning',
|
||||
'file': 'demo_single_neuron.py',
|
||||
'requires': ['02_tensor', '03_activations', '04_layers'],
|
||||
'description': 'Watch a neuron learn the AND gate'
|
||||
},
|
||||
'network': {
|
||||
'name': 'Multi-Layer Networks',
|
||||
'file': 'demo_xor_network.py',
|
||||
'requires': ['02_tensor', '03_activations', '04_layers', '05_dense'],
|
||||
'description': 'Solve the famous XOR problem'
|
||||
},
|
||||
'vision': {
|
||||
'name': 'Computer Vision',
|
||||
'file': 'demo_vision.py',
|
||||
'requires': ['02_tensor', '03_activations', '04_layers', '05_dense', '06_spatial'],
|
||||
'description': 'Image processing and pattern recognition'
|
||||
},
|
||||
'attention': {
|
||||
'name': 'Attention Mechanisms',
|
||||
'file': 'demo_attention.py',
|
||||
'requires': ['02_tensor', '03_activations', '04_layers', '05_dense', '07_attention'],
|
||||
'description': 'Sequence processing and attention'
|
||||
},
|
||||
'training': {
|
||||
'name': 'End-to-End Training',
|
||||
'file': 'demo_training.py',
|
||||
'requires': ['02_tensor', '03_activations', '04_layers', '05_dense', '11_training'],
|
||||
'description': 'Complete training pipelines'
|
||||
},
|
||||
'language': {
|
||||
'name': 'Language Generation',
|
||||
'file': 'demo_language.py',
|
||||
'requires': ['02_tensor', '03_activations', '04_layers', '05_dense', '07_attention', '16_tinygpt'],
|
||||
'description': 'AI text generation and language models'
|
||||
}
|
||||
}
|
||||
|
||||
def check_module_exported(self, module_name):
|
||||
"""Check if a module has been exported to the package"""
|
||||
try:
|
||||
if module_name == '02_tensor':
|
||||
import tinytorch.core.tensor
|
||||
return True
|
||||
elif module_name == '03_activations':
|
||||
import tinytorch.core.activations
|
||||
return True
|
||||
elif module_name == '04_layers':
|
||||
import tinytorch.core.layers
|
||||
return True
|
||||
elif module_name == '05_dense':
|
||||
import tinytorch.core.dense
|
||||
return True
|
||||
elif module_name == '06_spatial':
|
||||
import tinytorch.core.spatial
|
||||
return True
|
||||
elif module_name == '07_attention':
|
||||
import tinytorch.core.attention
|
||||
return True
|
||||
elif module_name == '11_training':
|
||||
import tinytorch.core.training
|
||||
return True
|
||||
elif module_name == '16_tinygpt':
|
||||
import tinytorch.tinygpt
|
||||
return True
|
||||
return False
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
def get_demo_status(self, demo_name):
|
||||
"""Get status of a demo: available, partial, or unavailable"""
|
||||
demo = self.demos[demo_name]
|
||||
required_modules = demo['requires']
|
||||
|
||||
available_count = sum(1 for module in required_modules if self.check_module_exported(module))
|
||||
total_count = len(required_modules)
|
||||
|
||||
if available_count == total_count:
|
||||
return '✅' # Fully available
|
||||
elif available_count > 0:
|
||||
return '⚡' # Partially available
|
||||
else:
|
||||
return '❌' # Not available
|
||||
|
||||
def show_matrix(self):
|
||||
"""Display the demo capability matrix"""
|
||||
console.print("\n🤖 TinyTorch Demo Matrix", style="bold cyan")
|
||||
console.print("=" * 50)
|
||||
|
||||
table = Table(show_header=True, header_style="bold magenta")
|
||||
table.add_column("Demo", style="cyan", width=20)
|
||||
table.add_column("Status", justify="center", width=8)
|
||||
table.add_column("Description", style="dim")
|
||||
|
||||
available_demos = []
|
||||
|
||||
for demo_name, demo_info in self.demos.items():
|
||||
status = self.get_demo_status(demo_name)
|
||||
table.add_row(demo_info['name'], status, demo_info['description'])
|
||||
|
||||
if status == '✅':
|
||||
available_demos.append(demo_name)
|
||||
|
||||
console.print(table)
|
||||
console.print()
|
||||
|
||||
if available_demos:
|
||||
console.print("🎯 Available Demos:", style="bold green")
|
||||
for demo in available_demos:
|
||||
console.print(f" • tito demo {demo}")
|
||||
console.print()
|
||||
|
||||
console.print("Legend: ✅ Ready ⚡ Partial ❌ Not Available")
|
||||
console.print()
|
||||
|
||||
def run_demo(self, demo_name):
|
||||
"""Run a specific demo"""
|
||||
if demo_name not in self.demos:
|
||||
console.print(f"❌ Unknown demo: {demo_name}", style="red")
|
||||
console.print("Available demos:", ', '.join(self.demos.keys()))
|
||||
return False
|
||||
|
||||
demo = self.demos[demo_name]
|
||||
status = self.get_demo_status(demo_name)
|
||||
|
||||
if status == '❌':
|
||||
console.print(f"❌ Demo '{demo_name}' not available", style="red")
|
||||
missing_modules = [m for m in demo['requires'] if not self.check_module_exported(m)]
|
||||
console.print(f"Missing modules: {', '.join(missing_modules)}")
|
||||
console.print(f"Run: tito export {' '.join(missing_modules)}")
|
||||
return False
|
||||
|
||||
if status == '⚡':
|
||||
console.print(f"⚠️ Demo '{demo_name}' partially available", style="yellow")
|
||||
console.print("Some features may not work correctly.")
|
||||
|
||||
# Find the demo file
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
demo_file = project_root / "demos" / demo['file']
|
||||
|
||||
if not demo_file.exists():
|
||||
console.print(f"❌ Demo file not found: {demo_file}", style="red")
|
||||
return False
|
||||
|
||||
console.print(f"🚀 Running {demo['name']} Demo...", style="bold green")
|
||||
console.print()
|
||||
|
||||
# Run the demo
|
||||
try:
|
||||
result = subprocess.run([sys.executable, str(demo_file)],
|
||||
capture_output=False,
|
||||
text=True)
|
||||
return result.returncode == 0
|
||||
except Exception as e:
|
||||
console.print(f"❌ Demo failed: {e}", style="red")
|
||||
return False
|
||||
|
||||
class DemoCommand(BaseCommand):
|
||||
"""Command for running TinyTorch AI capability demos"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.matrix = TinyTorchDemoMatrix()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "demo"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Run AI capability demos"
|
||||
|
||||
def add_arguments(self, parser):
|
||||
"""Add demo command arguments"""
|
||||
parser.add_argument('demo_name', nargs='?',
|
||||
help='Name of demo to run (math, logic, neuron, network, etc.)')
|
||||
parser.add_argument('--all', action='store_true',
|
||||
help='Run all available demos')
|
||||
parser.add_argument('--matrix', action='store_true',
|
||||
help='Show capability matrix only')
|
||||
|
||||
def run(self, args):
|
||||
"""Execute the demo command"""
|
||||
# Just show matrix if no args or --matrix flag
|
||||
if not args.demo_name and not args.all or args.matrix:
|
||||
self.matrix.show_matrix()
|
||||
return
|
||||
|
||||
# Run all available demos
|
||||
if args.all:
|
||||
self.matrix.show_matrix()
|
||||
available_demos = [name for name in self.matrix.demos.keys()
|
||||
if self.matrix.get_demo_status(name) == '✅']
|
||||
|
||||
if not available_demos:
|
||||
console.print("❌ No demos available. Export some modules first!", style="red")
|
||||
return
|
||||
|
||||
console.print(f"🚀 Running {len(available_demos)} available demos...", style="bold green")
|
||||
console.print()
|
||||
|
||||
for demo_name in available_demos:
|
||||
console.print(f"\n{'='*60}")
|
||||
success = self.matrix.run_demo(demo_name)
|
||||
if not success:
|
||||
console.print(f"❌ Demo {demo_name} failed", style="red")
|
||||
|
||||
console.print(f"\n{'='*60}")
|
||||
console.print("🏆 All available demos completed!", style="bold green")
|
||||
return
|
||||
|
||||
# Run specific demo
|
||||
if args.demo_name:
|
||||
self.matrix.run_demo(args.demo_name)
|
||||
|
||||
def main():
|
||||
"""Standalone entry point for development"""
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
DemoCommand.add_parser(parser._subparsers_action.add_parser if hasattr(parser, '_subparsers_action') else parser.add_subparser)
|
||||
args = parser.parse_args()
|
||||
|
||||
cmd = DemoCommand()
|
||||
cmd.execute(args)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,469 +0,0 @@
|
||||
"""
|
||||
Tiny🔥Torch Interactive Help System
|
||||
|
||||
Provides contextual, progressive guidance for new and experienced users.
|
||||
"""
|
||||
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from typing import Optional, List, Dict, Any
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from .base import BaseCommand
|
||||
from ..core.config import CLIConfig
|
||||
from ..core.console import get_console
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.columns import Columns
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
from rich.prompt import Prompt, Confirm
|
||||
|
||||
|
||||
class HelpCommand(BaseCommand):
|
||||
"""Interactive help and onboarding system."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "help"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Interactive help system with guided onboarding"
|
||||
|
||||
def add_arguments(self, parser: ArgumentParser) -> None:
|
||||
"""Add help command arguments."""
|
||||
parser.add_argument(
|
||||
'topic',
|
||||
nargs='?',
|
||||
help='Specific help topic (getting-started, commands, workflow, etc.)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--interactive', '-i',
|
||||
action='store_true',
|
||||
help='Launch interactive onboarding wizard'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--quick', '-q',
|
||||
action='store_true',
|
||||
help='Show quick reference card'
|
||||
)
|
||||
|
||||
def run(self, args: Namespace) -> int:
|
||||
"""Execute help command."""
|
||||
console = get_console()
|
||||
|
||||
# Interactive onboarding wizard
|
||||
if args.interactive:
|
||||
return self._interactive_onboarding()
|
||||
|
||||
# Quick reference
|
||||
if args.quick:
|
||||
return self._show_quick_reference()
|
||||
|
||||
# Topic-specific help
|
||||
if args.topic:
|
||||
return self._show_topic_help(args.topic)
|
||||
|
||||
# Default: Show main help with user context
|
||||
return self._show_contextual_help()
|
||||
|
||||
def _interactive_onboarding(self) -> int:
|
||||
"""Launch interactive onboarding wizard."""
|
||||
console = get_console()
|
||||
|
||||
# Welcome screen
|
||||
console.print(Panel.fit(
|
||||
"[bold blue]🚀 Welcome to Tiny🔥Torch![/bold blue]\n\n"
|
||||
"Let's get you started on your ML systems engineering journey.\n"
|
||||
"This quick wizard will help you understand what Tiny🔥Torch is\n"
|
||||
"and guide you to the right starting point.",
|
||||
title="Tiny🔥Torch Onboarding Wizard",
|
||||
border_style="blue"
|
||||
))
|
||||
|
||||
# User experience assessment
|
||||
experience = self._assess_user_experience()
|
||||
|
||||
# Learning goal identification
|
||||
goals = self._identify_learning_goals()
|
||||
|
||||
# Time commitment assessment
|
||||
time_commitment = self._assess_time_commitment()
|
||||
|
||||
# Generate personalized recommendations
|
||||
recommendations = self._generate_recommendations(experience, goals, time_commitment)
|
||||
|
||||
# Show personalized path
|
||||
self._show_personalized_path(recommendations)
|
||||
|
||||
# Offer to start immediately
|
||||
if Confirm.ask("\n[bold green]Ready to start your first steps?[/bold green]"):
|
||||
self._launch_first_steps(recommendations)
|
||||
|
||||
return 0
|
||||
|
||||
def _assess_user_experience(self) -> str:
|
||||
"""Assess user's ML and programming experience."""
|
||||
console = get_console()
|
||||
|
||||
console.print("\n[bold cyan]📋 Quick Experience Assessment[/bold cyan]")
|
||||
|
||||
choices = [
|
||||
"New to ML and Python - need fundamentals",
|
||||
"Know Python, new to ML - want to learn systems",
|
||||
"Use PyTorch/TensorFlow - want to understand internals",
|
||||
"ML Engineer - need to debug/optimize production systems",
|
||||
"Instructor - want to teach this course"
|
||||
]
|
||||
|
||||
console.print("\nWhat best describes your background?")
|
||||
for i, choice in enumerate(choices, 1):
|
||||
console.print(f" {i}. {choice}")
|
||||
|
||||
while True:
|
||||
try:
|
||||
selection = int(Prompt.ask("\nEnter your choice (1-5)"))
|
||||
if 1 <= selection <= 5:
|
||||
return ['beginner', 'python_user', 'framework_user', 'ml_engineer', 'instructor'][selection-1]
|
||||
else:
|
||||
console.print("[red]Please enter a number between 1-5[/red]")
|
||||
except ValueError:
|
||||
console.print("[red]Please enter a valid number[/red]")
|
||||
|
||||
def _identify_learning_goals(self) -> List[str]:
|
||||
"""Identify user's learning goals."""
|
||||
console = get_console()
|
||||
|
||||
console.print("\n[bold cyan]🎯 Learning Goals[/bold cyan]")
|
||||
console.print("What do you want to achieve? (Select all that apply)")
|
||||
|
||||
goals = [
|
||||
("understand_internals", "Understand how PyTorch/TensorFlow work internally"),
|
||||
("build_networks", "Build neural networks from scratch"),
|
||||
("optimize_performance", "Learn to optimize ML system performance"),
|
||||
("debug_production", "Debug production ML systems"),
|
||||
("teach_course", "Teach ML systems to others"),
|
||||
("career_transition", "Transition from software engineering to ML"),
|
||||
("research_custom", "Implement custom operations for research")
|
||||
]
|
||||
|
||||
selected_goals = []
|
||||
for key, description in goals:
|
||||
if Confirm.ask(f" • {description}?"):
|
||||
selected_goals.append(key)
|
||||
|
||||
return selected_goals
|
||||
|
||||
def _assess_time_commitment(self) -> str:
|
||||
"""Assess available time commitment."""
|
||||
console = get_console()
|
||||
|
||||
console.print("\n[bold cyan]⏰ Time Commitment[/bold cyan]")
|
||||
|
||||
choices = [
|
||||
("15_minutes", "15 minutes - just want a quick taste"),
|
||||
("2_hours", "2 hours - explore a few modules"),
|
||||
("weekend", "Weekend project - build something substantial"),
|
||||
("semester", "8-12 weeks - complete learning journey"),
|
||||
("teaching", "Teaching timeline - need instructor resources")
|
||||
]
|
||||
|
||||
console.print("How much time can you dedicate?")
|
||||
for i, (key, description) in enumerate(choices, 1):
|
||||
console.print(f" {i}. {description}")
|
||||
|
||||
while True:
|
||||
try:
|
||||
selection = int(Prompt.ask("\nEnter your choice (1-5)"))
|
||||
if 1 <= selection <= 5:
|
||||
return choices[selection-1][0]
|
||||
else:
|
||||
console.print("[red]Please enter a number between 1-5[/red]")
|
||||
except ValueError:
|
||||
console.print("[red]Please enter a valid number[/red]")
|
||||
|
||||
def _generate_recommendations(self, experience: str, goals: List[str], time: str) -> Dict[str, Any]:
|
||||
"""Generate personalized recommendations."""
|
||||
|
||||
# Learning path mapping
|
||||
path_mapping = {
|
||||
'beginner': 'foundation_first',
|
||||
'python_user': 'guided_learning',
|
||||
'framework_user': 'systems_focus',
|
||||
'ml_engineer': 'optimization_focus',
|
||||
'instructor': 'teaching_resources'
|
||||
}
|
||||
|
||||
# Starting point mapping
|
||||
start_mapping = {
|
||||
'15_minutes': 'quick_demo',
|
||||
'2_hours': 'first_module',
|
||||
'weekend': 'milestone_project',
|
||||
'semester': 'full_curriculum',
|
||||
'teaching': 'instructor_setup'
|
||||
}
|
||||
|
||||
return {
|
||||
'learning_path': path_mapping.get(experience, 'guided_learning'),
|
||||
'starting_point': start_mapping.get(time, 'first_module'),
|
||||
'experience_level': experience,
|
||||
'goals': goals,
|
||||
'time_commitment': time
|
||||
}
|
||||
|
||||
def _show_personalized_path(self, recommendations: Dict[str, Any]) -> None:
|
||||
"""Show personalized learning path."""
|
||||
console = get_console()
|
||||
|
||||
# Path descriptions
|
||||
paths = {
|
||||
'foundation_first': {
|
||||
'title': '🌱 Foundation First Path',
|
||||
'description': 'Build fundamentals step-by-step with extra explanations',
|
||||
'next_steps': ['Module 1: Setup & Environment', 'Python fundamentals review', 'Linear algebra primer']
|
||||
},
|
||||
'guided_learning': {
|
||||
'title': '🎯 Guided Learning Path',
|
||||
'description': 'Structured progression through all major concepts',
|
||||
'next_steps': ['Module 1: Setup', 'Module 2: Tensors', 'Track progress with checkpoints']
|
||||
},
|
||||
'systems_focus': {
|
||||
'title': '⚡ Systems Focus Path',
|
||||
'description': 'Understand internals of frameworks you already use',
|
||||
'next_steps': ['Compare PyTorch vs your code', 'Profile memory usage', 'Optimization modules']
|
||||
},
|
||||
'optimization_focus': {
|
||||
'title': '🚀 Optimization Focus Path',
|
||||
'description': 'Performance debugging and production optimization',
|
||||
'next_steps': ['Profiling module', 'Benchmarking module', 'TinyMLPerf competition']
|
||||
},
|
||||
'teaching_resources': {
|
||||
'title': '🎓 Teaching Resources Path',
|
||||
'description': 'Instructor guides and classroom setup',
|
||||
'next_steps': ['Instructor guide', 'NBGrader setup', 'Student progress tracking']
|
||||
}
|
||||
}
|
||||
|
||||
path_info = paths[recommendations['learning_path']]
|
||||
|
||||
console.print(f"\n[bold green]✨ Your Personalized Learning Path[/bold green]")
|
||||
console.print(Panel(
|
||||
f"[bold]{path_info['title']}[/bold]\n\n"
|
||||
f"{path_info['description']}\n\n"
|
||||
f"[bold cyan]Your Next Steps:[/bold cyan]\n" +
|
||||
"\n".join(f" • {step}" for step in path_info['next_steps']),
|
||||
border_style="green"
|
||||
))
|
||||
|
||||
def _launch_first_steps(self, recommendations: Dict[str, Any]) -> None:
|
||||
"""Launch appropriate first steps based on recommendations."""
|
||||
console = get_console()
|
||||
|
||||
starting_point = recommendations['starting_point']
|
||||
|
||||
if starting_point == 'quick_demo':
|
||||
console.print("\n[bold blue]🚀 Launching Quick Demo...[/bold blue]")
|
||||
console.print("Running: [code]tito demo quick[/code]")
|
||||
os.system("tito demo quick")
|
||||
|
||||
elif starting_point == 'first_module':
|
||||
console.print("\n[bold blue]🛠️ Setting up Module 1...[/bold blue]")
|
||||
console.print("Next commands:")
|
||||
console.print(" [code]cd modules/01_setup[/code]")
|
||||
console.print(" [code]jupyter lab setup.py[/code]")
|
||||
|
||||
elif starting_point == 'milestone_project':
|
||||
console.print("\n[bold blue]🎯 Weekend Project Recommendations...[/bold blue]")
|
||||
console.print("Suggested goal: Build XOR solver (Modules 1-6)")
|
||||
console.print("Time estimate: 6-8 hours")
|
||||
|
||||
elif starting_point == 'full_curriculum':
|
||||
console.print("\n[bold blue]📚 Full Curriculum Setup...[/bold blue]")
|
||||
console.print("Running checkpoint system initialization...")
|
||||
os.system("tito checkpoint status")
|
||||
|
||||
elif starting_point == 'instructor_setup':
|
||||
console.print("\n[bold blue]🎓 Instructor Resources...[/bold blue]")
|
||||
console.print("Opening instructor guide...")
|
||||
console.print("Check: [code]book/usage-paths/classroom-use.html[/code]")
|
||||
|
||||
def _show_quick_reference(self) -> int:
|
||||
"""Show quick reference card."""
|
||||
console = get_console()
|
||||
|
||||
# Essential commands table
|
||||
table = Table(title="🚀 TinyTorch Quick Reference", show_header=True, header_style="bold cyan")
|
||||
table.add_column("Command", style="bold", width=25)
|
||||
table.add_column("Description", width=40)
|
||||
table.add_column("Example", style="dim", width=30)
|
||||
|
||||
essential_commands = [
|
||||
("tito help --interactive", "Launch onboarding wizard", "First time users"),
|
||||
("tito checkpoint status", "See your progress", "Track learning journey"),
|
||||
("tito module complete 02", "Finish a module", "Export & test your code"),
|
||||
("tito demo quick", "See framework in action", "5-minute demonstration"),
|
||||
("tito leaderboard join", "Join community", "Connect with learners"),
|
||||
("tito system doctor", "Check environment", "Troubleshoot issues")
|
||||
]
|
||||
|
||||
for cmd, desc, example in essential_commands:
|
||||
table.add_row(cmd, desc, example)
|
||||
|
||||
console.print(table)
|
||||
|
||||
# Common workflows
|
||||
console.print("\n[bold cyan]📋 Common Workflows:[/bold cyan]")
|
||||
workflows = [
|
||||
("New User", "tito help -i → tito checkpoint status → cd modules/01_setup"),
|
||||
("Continue Learning", "tito checkpoint status → work on next module → tito module complete XX"),
|
||||
("Join Community", "tito leaderboard join → submit progress → see global rankings"),
|
||||
("Get Help", "tito system doctor → check docs/FAQ → ask community")
|
||||
]
|
||||
|
||||
for workflow, commands in workflows:
|
||||
console.print(f" [bold]{workflow}:[/bold] {commands}")
|
||||
|
||||
return 0
|
||||
|
||||
def _show_topic_help(self, topic: str) -> int:
|
||||
"""Show help for specific topic."""
|
||||
console = get_console()
|
||||
|
||||
topics = {
|
||||
'getting-started': self._help_getting_started,
|
||||
'commands': self._help_commands,
|
||||
'workflow': self._help_workflow,
|
||||
'modules': self._help_modules,
|
||||
'checkpoints': self._help_checkpoints,
|
||||
'community': self._help_community,
|
||||
'troubleshooting': self._help_troubleshooting
|
||||
}
|
||||
|
||||
if topic in topics:
|
||||
topics[topic]()
|
||||
return 0
|
||||
else:
|
||||
console.print(f"[red]Unknown help topic: {topic}[/red]")
|
||||
console.print("Available topics: " + ", ".join(topics.keys()))
|
||||
return 1
|
||||
|
||||
def _show_contextual_help(self) -> int:
|
||||
"""Show contextual help based on user progress."""
|
||||
console = get_console()
|
||||
|
||||
# Check user progress to provide contextual guidance
|
||||
progress = self._assess_user_progress()
|
||||
|
||||
if progress['is_new_user']:
|
||||
self._show_new_user_help()
|
||||
elif progress['current_module']:
|
||||
self._show_in_progress_help(progress['current_module'])
|
||||
else:
|
||||
self._show_experienced_user_help()
|
||||
|
||||
return 0
|
||||
|
||||
def _assess_user_progress(self) -> Dict[str, Any]:
|
||||
"""Assess user's current progress."""
|
||||
# Check for checkpoint files, completed modules, etc.
|
||||
# This would integrate with the checkpoint system
|
||||
|
||||
# Simplified implementation for now
|
||||
checkpoints_dir = Path("tests/checkpoints")
|
||||
modules_dir = Path("modules")
|
||||
|
||||
return {
|
||||
'is_new_user': not checkpoints_dir.exists(),
|
||||
'current_module': None, # Would be determined by checkpoint status
|
||||
'completed_modules': [], # Would be populated from checkpoint results
|
||||
'has_joined_community': False # Would check leaderboard status
|
||||
}
|
||||
|
||||
def _show_new_user_help(self) -> None:
|
||||
"""Show help optimized for new users."""
|
||||
console = get_console()
|
||||
|
||||
console.print(Panel.fit(
|
||||
"[bold blue]👋 Welcome to Tiny🔥Torch![/bold blue]\n\n"
|
||||
"You're about to build a complete ML framework from scratch.\n"
|
||||
"Here's how to get started:\n\n"
|
||||
"[bold cyan]Next Steps:[/bold cyan]\n"
|
||||
"1. [code]tito help --interactive[/code] - Personalized onboarding\n"
|
||||
"2. [code]tito system doctor[/code] - Check your environment\n"
|
||||
"3. [code]tito checkpoint status[/code] - See the learning journey\n\n"
|
||||
"[bold yellow]New to ML systems?[/bold yellow] Run the interactive wizard!",
|
||||
title="Getting Started",
|
||||
border_style="blue"
|
||||
))
|
||||
|
||||
def _help_getting_started(self) -> None:
|
||||
"""Detailed getting started help."""
|
||||
console = get_console()
|
||||
|
||||
console.print("[bold blue]🚀 Getting Started with Tiny🔥Torch[/bold blue]\n")
|
||||
|
||||
# Installation steps
|
||||
install_panel = Panel(
|
||||
"[bold]1. Environment Setup[/bold]\n"
|
||||
"```bash\n"
|
||||
"git clone https://github.com/mlsysbook/Tiny🔥Torch.git\n"
|
||||
"cd Tiny🔥Torch\n"
|
||||
f"python -m venv {self.venv_path}\n"
|
||||
f"source {self.venv_path}/bin/activate # Windows: .venv\\Scripts\\activate\n"
|
||||
"pip install -r requirements.txt\n"
|
||||
"pip install -e .\n"
|
||||
"```",
|
||||
title="Installation",
|
||||
border_style="green"
|
||||
)
|
||||
|
||||
# First steps
|
||||
first_steps_panel = Panel(
|
||||
"[bold]2. First Steps[/bold]\n"
|
||||
"• [code]tito system doctor[/code] - Verify installation\n"
|
||||
"• [code]tito help --interactive[/code] - Personalized guidance\n"
|
||||
"• [code]tito checkpoint status[/code] - See learning path\n"
|
||||
"• [code]cd modules/01_setup[/code] - Start first module",
|
||||
title="First Steps",
|
||||
border_style="blue"
|
||||
)
|
||||
|
||||
# Learning path
|
||||
learning_panel = Panel(
|
||||
"[bold]3. Learning Journey[/bold]\n"
|
||||
"📚 [bold]Modules 1-8:[/bold] Neural Network Foundations\n"
|
||||
"🔬 [bold]Modules 9-10:[/bold] Computer Vision (CNNs)\n"
|
||||
"🤖 [bold]Modules 11-14:[/bold] Language Models (Transformers)\n"
|
||||
"⚡ [bold]Modules 15-20:[/bold] System Optimization\n\n"
|
||||
"[dim]Each module: Build → Test → Export → Checkpoint[/dim]",
|
||||
title="Learning Path",
|
||||
border_style="yellow"
|
||||
)
|
||||
|
||||
console.print(Columns([install_panel, first_steps_panel, learning_panel]))
|
||||
|
||||
# Additional help methods would be implemented here...
|
||||
def _help_commands(self) -> None:
|
||||
"""Show comprehensive command reference."""
|
||||
pass
|
||||
|
||||
def _help_workflow(self) -> None:
|
||||
"""Show common workflow patterns."""
|
||||
pass
|
||||
|
||||
def _help_modules(self) -> None:
|
||||
"""Show module system explanation."""
|
||||
pass
|
||||
|
||||
def _help_checkpoints(self) -> None:
|
||||
"""Show checkpoint system explanation."""
|
||||
pass
|
||||
|
||||
def _help_community(self) -> None:
|
||||
"""Show community features and leaderboard."""
|
||||
pass
|
||||
|
||||
def _help_troubleshooting(self) -> None:
|
||||
"""Show troubleshooting guide."""
|
||||
pass
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,159 +0,0 @@
|
||||
"""
|
||||
Milestones command for TinyTorch CLI: track progress through ML history.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from pathlib import Path
|
||||
|
||||
from .base import BaseCommand
|
||||
|
||||
|
||||
class MilestonesCommand(BaseCommand):
|
||||
"""Track and run milestone verification tests."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "milestones"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Track progress through ML history milestones"
|
||||
|
||||
def add_arguments(self, parser: ArgumentParser) -> None:
|
||||
subparsers = parser.add_subparsers(dest="milestone_action", help="Milestone actions")
|
||||
|
||||
# Progress command
|
||||
progress_parser = subparsers.add_parser("progress", help="Show milestone progress")
|
||||
|
||||
# List command
|
||||
list_parser = subparsers.add_parser("list", help="List unlocked milestone tests")
|
||||
|
||||
# Run command
|
||||
run_parser = subparsers.add_parser("run", help="Run a milestone verification test")
|
||||
run_parser.add_argument("milestone", help="Milestone ID (perceptron, xor, mlp_digits, cnn, transformer)")
|
||||
|
||||
def run(self, args: Namespace) -> int:
|
||||
console = self.console
|
||||
|
||||
# Import milestone tracker
|
||||
milestone_tracker_path = Path(__file__).parent.parent.parent / "tests" / "milestones"
|
||||
if str(milestone_tracker_path) not in sys.path:
|
||||
sys.path.insert(0, str(milestone_tracker_path))
|
||||
|
||||
try:
|
||||
from milestone_tracker import MilestoneTracker, MILESTONES
|
||||
except ImportError:
|
||||
console.print("[red]❌ Milestone tracker not available[/red]")
|
||||
return 1
|
||||
|
||||
tracker = MilestoneTracker()
|
||||
|
||||
if not hasattr(args, 'milestone_action') or args.milestone_action is None:
|
||||
# Default: show progress
|
||||
tracker.show_progress()
|
||||
return 0
|
||||
|
||||
if args.milestone_action == "progress":
|
||||
tracker.show_progress()
|
||||
return 0
|
||||
|
||||
elif args.milestone_action == "list":
|
||||
tracker.list_unlocked_tests()
|
||||
return 0
|
||||
|
||||
elif args.milestone_action == "run":
|
||||
milestone_id = args.milestone
|
||||
|
||||
if milestone_id not in MILESTONES:
|
||||
console.print(f"[red]❌ Unknown milestone: {milestone_id}[/red]")
|
||||
console.print(f"[yellow]Available: {', '.join(MILESTONES.keys())}[/yellow]")
|
||||
return 1
|
||||
|
||||
if not tracker.can_run_milestone(milestone_id):
|
||||
milestone = MILESTONES[milestone_id]
|
||||
console.print(f"[yellow]🔒 Milestone locked: {milestone['name']}[/yellow]")
|
||||
console.print(f"\n[bold]Complete these modules first:[/bold]")
|
||||
for req in milestone["requires"]:
|
||||
status = "✅" if req in tracker.progress["completed_modules"] else "❌"
|
||||
console.print(f" {status} {req}")
|
||||
return 1
|
||||
|
||||
# Run the test
|
||||
import subprocess
|
||||
milestone = MILESTONES[milestone_id]
|
||||
test_name = milestone["test"]
|
||||
|
||||
console.print(f"[bold cyan]🧪 Running {milestone['name']}[/bold cyan]")
|
||||
console.print(f"[dim]{milestone['description']}[/dim]\n")
|
||||
|
||||
# Run pytest
|
||||
test_file = Path(__file__).parent.parent.parent / "tests" / "milestones" / "test_learning_verification.py"
|
||||
|
||||
result = subprocess.run([
|
||||
"pytest",
|
||||
f"{test_file}::{test_name}",
|
||||
"-v",
|
||||
"--tb=short"
|
||||
], cwd=Path.cwd())
|
||||
|
||||
if result.returncode == 0:
|
||||
tracker.mark_milestone_complete(milestone_id)
|
||||
|
||||
# Show what's next
|
||||
console.print()
|
||||
self._show_next_milestone(tracker, milestone_id)
|
||||
else:
|
||||
console.print()
|
||||
console.print("[yellow]💡 The test didn't pass. Check your implementation and try again.[/yellow]")
|
||||
|
||||
return result.returncode
|
||||
|
||||
else:
|
||||
console.print(f"[red]❌ Unknown action: {args.milestone_action}[/red]")
|
||||
return 1
|
||||
|
||||
def _show_next_milestone(self, tracker, completed_id):
|
||||
"""Show the next milestone after completing one."""
|
||||
from rich.panel import Panel
|
||||
|
||||
milestone_order = ["perceptron", "xor", "mlp_digits", "cnn", "transformer"]
|
||||
|
||||
try:
|
||||
current_index = milestone_order.index(completed_id)
|
||||
if current_index < len(milestone_order) - 1:
|
||||
next_id = milestone_order[current_index + 1]
|
||||
|
||||
if next_id in tracker.progress.get("unlocked_milestones", []):
|
||||
from milestone_tracker import MILESTONES
|
||||
next_milestone = MILESTONES[next_id]
|
||||
|
||||
self.console.print(Panel(
|
||||
f"[bold cyan]🎯 Next Milestone Available![/bold cyan]\n\n"
|
||||
f"[bold]{next_milestone['name']}[/bold]\n"
|
||||
f"{next_milestone['description']}\n\n"
|
||||
f"[bold]Run it now:[/bold]\n"
|
||||
f"[yellow]tito milestones run {next_id}[/yellow]",
|
||||
title="Continue Your Journey",
|
||||
border_style="cyan"
|
||||
))
|
||||
else:
|
||||
self.console.print(Panel(
|
||||
f"[bold yellow]🔒 Next milestone locked[/bold yellow]\n\n"
|
||||
f"Complete more modules to unlock the next milestone.\n\n"
|
||||
f"[dim]Check progress:[/dim]\n"
|
||||
f"[dim]tito milestones progress[/dim]",
|
||||
title="Keep Building",
|
||||
border_style="yellow"
|
||||
))
|
||||
else:
|
||||
self.console.print(Panel(
|
||||
f"[bold green]🏆 ALL MILESTONES COMPLETED![/bold green]\n\n"
|
||||
f"You've verified 60+ years of neural network history!\n"
|
||||
f"Your TinyTorch implementation is complete and working. 🎓",
|
||||
title="Congratulations!",
|
||||
border_style="gold1"
|
||||
))
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
@@ -1,666 +0,0 @@
|
||||
"""
|
||||
Module Reset Command for TinyTorch CLI.
|
||||
|
||||
Provides comprehensive module reset functionality:
|
||||
- Backup current work before reset
|
||||
- Unexport from package
|
||||
- Restore pristine source from git or backup
|
||||
- Update progress tracking
|
||||
|
||||
This enables students to restart a module cleanly while preserving their work.
|
||||
"""
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import stat
|
||||
import subprocess
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
from .base import BaseCommand
|
||||
|
||||
|
||||
class ModuleResetCommand(BaseCommand):
|
||||
"""Command to reset a module to clean state with backup functionality."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "reset"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Reset module to clean state (backup + unexport + restore)"
|
||||
|
||||
def add_arguments(self, parser: ArgumentParser) -> None:
|
||||
"""Add reset command arguments."""
|
||||
parser.add_argument(
|
||||
"module_number", nargs="?", help="Module number to reset (01, 02, etc.)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--soft",
|
||||
action="store_true",
|
||||
help="Soft reset: backup + restore source (keep package export)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hard",
|
||||
action="store_true",
|
||||
help="Hard reset: backup + unexport + restore (full reset) [DEFAULT]",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--from-git",
|
||||
action="store_true",
|
||||
help="Restore from git HEAD [DEFAULT]",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--restore-backup",
|
||||
metavar="TIMESTAMP",
|
||||
help="Restore from specific backup timestamp",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--list-backups", action="store_true", help="List available backups"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-backup", action="store_true", help="Skip backup creation (dangerous)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force", action="store_true", help="Skip confirmation prompts"
|
||||
)
|
||||
|
||||
def get_module_mapping(self) -> Dict[str, str]:
|
||||
"""Get mapping from numbers to module names."""
|
||||
return {
|
||||
"01": "01_tensor",
|
||||
"02": "02_activations",
|
||||
"03": "03_layers",
|
||||
"04": "04_losses",
|
||||
"05": "05_autograd",
|
||||
"06": "06_optimizers",
|
||||
"07": "07_training",
|
||||
"08": "08_dataloader",
|
||||
"09": "09_spatial",
|
||||
"10": "10_tokenization",
|
||||
"11": "11_embeddings",
|
||||
"12": "12_attention",
|
||||
"13": "13_transformers",
|
||||
"14": "14_profiling",
|
||||
"15": "15_quantization",
|
||||
"16": "16_acceleration",
|
||||
"17": "17_compression",
|
||||
"18": "18_memoization",
|
||||
"19": "19_benchmarking",
|
||||
"20": "20_capstone",
|
||||
}
|
||||
|
||||
def normalize_module_number(self, module_input: str) -> str:
|
||||
"""Normalize module input to 2-digit format."""
|
||||
if module_input.isdigit():
|
||||
return f"{int(module_input):02d}"
|
||||
return module_input
|
||||
|
||||
def get_backup_dir(self) -> Path:
|
||||
"""Get the backup directory, creating it if needed."""
|
||||
backup_dir = self.config.project_root / ".tito" / "backups"
|
||||
backup_dir.mkdir(parents=True, exist_ok=True)
|
||||
return backup_dir
|
||||
|
||||
def list_backups(self, module_name: str) -> List[Dict]:
|
||||
"""List available backups for a module."""
|
||||
backup_dir = self.get_backup_dir()
|
||||
backups = []
|
||||
|
||||
# Find all backup directories for this module
|
||||
pattern = f"{module_name}_*"
|
||||
for backup_path in backup_dir.glob(pattern):
|
||||
if backup_path.is_dir():
|
||||
# Read metadata if it exists
|
||||
metadata_file = backup_path / "backup_metadata.json"
|
||||
if metadata_file.exists():
|
||||
try:
|
||||
with open(metadata_file, "r") as f:
|
||||
metadata = json.load(f)
|
||||
backups.append(
|
||||
{
|
||||
"path": backup_path,
|
||||
"timestamp": metadata.get("timestamp"),
|
||||
"git_hash": metadata.get("git_hash"),
|
||||
"files": metadata.get("files", []),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
# If metadata is corrupt, just use directory name
|
||||
timestamp = backup_path.name.split("_", 1)[1]
|
||||
backups.append(
|
||||
{"path": backup_path, "timestamp": timestamp, "files": []}
|
||||
)
|
||||
else:
|
||||
# No metadata, use directory name
|
||||
timestamp = backup_path.name.split("_", 1)[1]
|
||||
backups.append(
|
||||
{"path": backup_path, "timestamp": timestamp, "files": []}
|
||||
)
|
||||
|
||||
return sorted(backups, key=lambda x: x["timestamp"], reverse=True)
|
||||
|
||||
def show_backups_list(self, module_name: str) -> int:
|
||||
"""Display list of available backups for a module."""
|
||||
console = self.console
|
||||
backups = self.list_backups(module_name)
|
||||
|
||||
if not backups:
|
||||
console.print(
|
||||
Panel(
|
||||
f"[yellow]No backups found for module: {module_name}[/yellow]",
|
||||
title="No Backups",
|
||||
border_style="yellow",
|
||||
)
|
||||
)
|
||||
return 0
|
||||
|
||||
# Create table
|
||||
table = Table(title=f"Available Backups for {module_name}", show_header=True)
|
||||
table.add_column("Timestamp", style="cyan")
|
||||
table.add_column("Git Hash", style="dim")
|
||||
table.add_column("Files", style="green")
|
||||
|
||||
for backup in backups:
|
||||
table.add_row(
|
||||
backup["timestamp"],
|
||||
backup.get("git_hash", "unknown")[:8],
|
||||
str(len(backup.get("files", []))),
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
console.print(
|
||||
f"\n[dim]Restore a backup with:[/dim] [cyan]tito module reset {module_name} --restore-backup TIMESTAMP[/cyan]"
|
||||
)
|
||||
return 0
|
||||
|
||||
def create_backup(self, module_name: str) -> Optional[Path]:
|
||||
"""Create a backup of the current module state."""
|
||||
console = self.console
|
||||
|
||||
# Get module directory
|
||||
module_dir = self.config.modules_dir / module_name
|
||||
if not module_dir.exists():
|
||||
console.print(
|
||||
f"[red]Module directory not found: {module_dir}[/red]"
|
||||
)
|
||||
return None
|
||||
|
||||
# Create backup directory with timestamp
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_dir = self.get_backup_dir() / f"{module_name}_{timestamp}"
|
||||
backup_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
console.print(f"[cyan]Creating backup: {backup_dir.name}[/cyan]")
|
||||
|
||||
# Get current git hash if in git repo
|
||||
git_hash = "unknown"
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "rev-parse", "HEAD"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=self.config.project_root,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
git_hash = result.stdout.strip()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Copy all Python files from module directory
|
||||
backed_up_files = []
|
||||
for py_file in module_dir.glob("*.py"):
|
||||
dest_file = backup_dir / py_file.name
|
||||
shutil.copy2(py_file, dest_file)
|
||||
backed_up_files.append(py_file.name)
|
||||
console.print(f" [dim]✓ Backed up: {py_file.name}[/dim]")
|
||||
|
||||
# Save metadata
|
||||
metadata = {
|
||||
"module_name": module_name,
|
||||
"timestamp": timestamp,
|
||||
"git_hash": git_hash,
|
||||
"files": backed_up_files,
|
||||
"backup_dir": str(backup_dir),
|
||||
}
|
||||
|
||||
metadata_file = backup_dir / "backup_metadata.json"
|
||||
with open(metadata_file, "w") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
console.print(f"[green]✓ Backup created successfully[/green]")
|
||||
return backup_dir
|
||||
|
||||
def unexport_module(self, module_name: str) -> bool:
|
||||
"""Remove module exports from the package."""
|
||||
console = self.console
|
||||
|
||||
# Get export target from module's #| default_exp directive
|
||||
module_dir = self.config.modules_dir / module_name
|
||||
short_name = module_name.split("_", 1)[1] if "_" in module_name else module_name
|
||||
dev_file = module_dir / f"{short_name}.py"
|
||||
|
||||
if not dev_file.exists():
|
||||
console.print(f"[yellow]Dev file not found: {dev_file}[/yellow]")
|
||||
return True # Nothing to unexport
|
||||
|
||||
# Read export target
|
||||
export_target = None
|
||||
try:
|
||||
with open(dev_file, "r") as f:
|
||||
content = f.read()
|
||||
import re
|
||||
|
||||
match = re.search(r"#\|\s*default_exp\s+([^\n\r]+)", content)
|
||||
if match:
|
||||
export_target = match.group(1).strip()
|
||||
except Exception as e:
|
||||
console.print(f"[yellow]Could not read export target: {e}[/yellow]")
|
||||
return True
|
||||
|
||||
if not export_target:
|
||||
console.print("[dim]No export target found (no #| default_exp)[/dim]")
|
||||
return True
|
||||
|
||||
# Convert export target to file path
|
||||
target_file = (
|
||||
self.config.project_root
|
||||
/ "tinytorch"
|
||||
/ export_target.replace(".", "/")
|
||||
).with_suffix(".py")
|
||||
|
||||
if not target_file.exists():
|
||||
console.print(f"[dim]Export file not found (already removed?): {target_file}[/dim]")
|
||||
return True
|
||||
|
||||
# Remove protection if file is read-only
|
||||
try:
|
||||
target_file.chmod(
|
||||
target_file.stat().st_mode | stat.S_IWUSR | stat.S_IWGRP | stat.S_IWOTH
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Remove the exported file
|
||||
try:
|
||||
target_file.unlink()
|
||||
console.print(f" [dim]✓ Removed export: {target_file.relative_to(self.config.project_root)}[/dim]")
|
||||
return True
|
||||
except Exception as e:
|
||||
console.print(f"[red]Failed to remove export: {e}[/red]")
|
||||
return False
|
||||
|
||||
def restore_from_git(self, module_name: str) -> bool:
|
||||
"""Restore module from git HEAD."""
|
||||
console = self.console
|
||||
|
||||
# Get module directory and dev file
|
||||
module_dir = self.config.modules_dir / module_name
|
||||
short_name = module_name.split("_", 1)[1] if "_" in module_name else module_name
|
||||
dev_file = module_dir / f"{short_name}.py"
|
||||
|
||||
console.print(f"[cyan]Restoring from git: {dev_file.relative_to(self.config.project_root)}[/cyan]")
|
||||
|
||||
# Check if file exists in git
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "ls-files", str(dev_file.relative_to(self.config.project_root))],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=self.config.project_root,
|
||||
)
|
||||
if result.returncode != 0 or not result.stdout.strip():
|
||||
console.print(
|
||||
f"[red]File not tracked in git: {dev_file}[/red]"
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
console.print(f"[red]Git check failed: {e}[/red]")
|
||||
return False
|
||||
|
||||
# Restore from git
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "checkout", "HEAD", "--", str(dev_file.relative_to(self.config.project_root))],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=self.config.project_root,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
console.print(f"[green]✓ Restored from git HEAD[/green]")
|
||||
return True
|
||||
else:
|
||||
console.print(
|
||||
f"[red]Git checkout failed: {result.stderr}[/red]"
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
console.print(f"[red]Failed to restore from git: {e}[/red]")
|
||||
return False
|
||||
|
||||
def restore_from_backup(self, module_name: str, timestamp: str) -> bool:
|
||||
"""Restore module from a specific backup."""
|
||||
console = self.console
|
||||
|
||||
# Find backup directory
|
||||
backup_dir = self.get_backup_dir() / f"{module_name}_{timestamp}"
|
||||
|
||||
if not backup_dir.exists():
|
||||
console.print(
|
||||
f"[red]Backup not found: {backup_dir.name}[/red]"
|
||||
)
|
||||
return False
|
||||
|
||||
# Get module directory
|
||||
module_dir = self.config.modules_dir / module_name
|
||||
|
||||
console.print(f"[cyan]Restoring from backup: {backup_dir.name}[/cyan]")
|
||||
|
||||
# Read metadata to get backed up files
|
||||
metadata_file = backup_dir / "backup_metadata.json"
|
||||
backed_up_files = []
|
||||
|
||||
if metadata_file.exists():
|
||||
try:
|
||||
with open(metadata_file, "r") as f:
|
||||
metadata = json.load(f)
|
||||
backed_up_files = metadata.get("files", [])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# If no metadata, find all .py files in backup
|
||||
if not backed_up_files:
|
||||
backed_up_files = [f.name for f in backup_dir.glob("*.py")]
|
||||
|
||||
# Restore each file
|
||||
restored_count = 0
|
||||
for filename in backed_up_files:
|
||||
backup_file = backup_dir / filename
|
||||
dest_file = module_dir / filename
|
||||
|
||||
if backup_file.exists():
|
||||
try:
|
||||
shutil.copy2(backup_file, dest_file)
|
||||
console.print(f" [dim]✓ Restored: {filename}[/dim]")
|
||||
restored_count += 1
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f" [red]Failed to restore {filename}: {e}[/red]"
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
f" [yellow]Backup file missing: {filename}[/yellow]"
|
||||
)
|
||||
|
||||
if restored_count > 0:
|
||||
console.print(
|
||||
f"[green]✓ Restored {restored_count} file(s) from backup[/green]"
|
||||
)
|
||||
return True
|
||||
else:
|
||||
console.print("[red]Failed to restore any files from backup[/red]")
|
||||
return False
|
||||
|
||||
def update_progress_tracking(self, module_name: str, module_number: str) -> None:
|
||||
"""Update progress tracking to mark module as not completed."""
|
||||
console = self.console
|
||||
|
||||
# Update progress.json (module_workflow.py format)
|
||||
progress_file = self.config.project_root / "progress.json"
|
||||
if progress_file.exists():
|
||||
try:
|
||||
with open(progress_file, "r") as f:
|
||||
progress = json.load(f)
|
||||
|
||||
# Remove from completed modules
|
||||
if "completed_modules" in progress:
|
||||
if module_number in progress["completed_modules"]:
|
||||
progress["completed_modules"].remove(module_number)
|
||||
console.print(
|
||||
f" [dim]✓ Removed from completed modules[/dim]"
|
||||
)
|
||||
|
||||
# Update last_updated timestamp
|
||||
progress["last_updated"] = datetime.now().isoformat()
|
||||
|
||||
with open(progress_file, "w") as f:
|
||||
json.dump(progress, f, indent=2)
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"[yellow]Could not update progress.json: {e}[/yellow]"
|
||||
)
|
||||
|
||||
# Update .tito/progress.json (comprehensive format)
|
||||
tito_progress_dir = self.config.project_root / ".tito"
|
||||
tito_progress_file = tito_progress_dir / "progress.json"
|
||||
|
||||
if tito_progress_file.exists():
|
||||
try:
|
||||
with open(tito_progress_file, "r") as f:
|
||||
progress = json.load(f)
|
||||
|
||||
# Remove from completed modules
|
||||
if "completed_modules" in progress:
|
||||
if module_name in progress["completed_modules"]:
|
||||
progress["completed_modules"].remove(module_name)
|
||||
|
||||
# Remove completion date
|
||||
if "completion_dates" in progress:
|
||||
if module_name in progress["completion_dates"]:
|
||||
del progress["completion_dates"][module_name]
|
||||
|
||||
with open(tito_progress_file, "w") as f:
|
||||
json.dump(progress, f, indent=2)
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"[yellow]Could not update .tito/progress.json: {e}[/yellow]"
|
||||
)
|
||||
|
||||
def check_git_status(self) -> bool:
|
||||
"""Check if there are uncommitted changes and warn user."""
|
||||
console = self.console
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "status", "--porcelain"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=self.config.project_root,
|
||||
)
|
||||
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
console.print(
|
||||
Panel(
|
||||
"[yellow]⚠️ You have uncommitted changes in your repository![/yellow]\n\n"
|
||||
"[dim]Consider committing your work before resetting.[/dim]",
|
||||
title="Uncommitted Changes",
|
||||
border_style="yellow",
|
||||
)
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
# If git check fails, continue anyway
|
||||
return True
|
||||
|
||||
def run(self, args: Namespace) -> int:
|
||||
"""Execute the reset command."""
|
||||
console = self.console
|
||||
|
||||
# Handle --list-backups
|
||||
if args.list_backups:
|
||||
if not args.module_number:
|
||||
console.print(
|
||||
"[red]Error: --list-backups requires a module number[/red]"
|
||||
)
|
||||
return 1
|
||||
|
||||
module_mapping = self.get_module_mapping()
|
||||
normalized = self.normalize_module_number(args.module_number)
|
||||
|
||||
if normalized not in module_mapping:
|
||||
console.print(f"[red]Invalid module number: {args.module_number}[/red]")
|
||||
return 1
|
||||
|
||||
module_name = module_mapping[normalized]
|
||||
return self.show_backups_list(module_name)
|
||||
|
||||
# Require module number
|
||||
if not args.module_number:
|
||||
console.print(
|
||||
Panel(
|
||||
"[red]Error: Module number required[/red]\n\n"
|
||||
"[dim]Examples:[/dim]\n"
|
||||
"[dim] tito module reset 01 # Reset module 01[/dim]\n"
|
||||
"[dim] tito module reset 01 --list-backups # Show backups[/dim]\n"
|
||||
"[dim] tito module reset 01 --soft # Keep package export[/dim]\n"
|
||||
"[dim] tito module reset 01 --restore-backup # Restore from backup[/dim]",
|
||||
title="Module Number Required",
|
||||
border_style="red",
|
||||
)
|
||||
)
|
||||
return 1
|
||||
|
||||
# Normalize and validate module number
|
||||
module_mapping = self.get_module_mapping()
|
||||
normalized = self.normalize_module_number(args.module_number)
|
||||
|
||||
if normalized not in module_mapping:
|
||||
console.print(f"[red]Invalid module number: {args.module_number}[/red]")
|
||||
console.print("Available modules: 01-20")
|
||||
return 1
|
||||
|
||||
module_name = module_mapping[normalized]
|
||||
|
||||
# Determine reset type
|
||||
is_hard_reset = args.hard or not args.soft # Default to hard reset
|
||||
|
||||
# Show reset plan
|
||||
console.print(
|
||||
Panel(
|
||||
f"[bold cyan]Module Reset: {module_name}[/bold cyan]\n\n"
|
||||
f"[bold]Reset Type:[/bold] {'Hard' if is_hard_reset else 'Soft'}\n"
|
||||
f"[bold]Actions:[/bold]\n"
|
||||
f" {'✓' if not args.no_backup else '✗'} Backup current work\n"
|
||||
f" {'✓' if is_hard_reset else '✗'} Unexport from package\n"
|
||||
f" ✓ Restore pristine source\n"
|
||||
f" ✓ Update progress tracking\n\n"
|
||||
f"[dim]{'Soft reset keeps package exports intact' if not is_hard_reset else 'Hard reset removes package exports'}[/dim]",
|
||||
title="Reset Plan",
|
||||
border_style="bright_yellow",
|
||||
)
|
||||
)
|
||||
|
||||
# Check git status (warn but don't block)
|
||||
self.check_git_status()
|
||||
|
||||
# Confirmation prompt (unless --force)
|
||||
if not args.force:
|
||||
console.print(
|
||||
"\n[yellow]This will reset the module to a clean state.[/yellow]"
|
||||
)
|
||||
if not args.no_backup:
|
||||
console.print("[green]Your current work will be backed up.[/green]")
|
||||
else:
|
||||
console.print(
|
||||
"[red]Your current work will NOT be backed up![/red]"
|
||||
)
|
||||
|
||||
try:
|
||||
response = input("\nContinue with reset? (y/N): ").strip().lower()
|
||||
if response not in ["y", "yes"]:
|
||||
console.print(
|
||||
Panel(
|
||||
"[cyan]Reset cancelled.[/cyan]",
|
||||
title="Cancelled",
|
||||
border_style="cyan",
|
||||
)
|
||||
)
|
||||
return 0
|
||||
except KeyboardInterrupt:
|
||||
console.print(
|
||||
Panel(
|
||||
"[cyan]Reset cancelled.[/cyan]",
|
||||
title="Cancelled",
|
||||
border_style="cyan",
|
||||
)
|
||||
)
|
||||
return 0
|
||||
|
||||
# Step 1: Create backup (unless --no-backup)
|
||||
if not args.no_backup:
|
||||
console.print("\n[bold]Step 1: Creating backup...[/bold]")
|
||||
backup_dir = self.create_backup(module_name)
|
||||
if not backup_dir:
|
||||
console.print("[red]Backup failed. Reset aborted.[/red]")
|
||||
return 1
|
||||
else:
|
||||
console.print(
|
||||
"\n[bold yellow]Step 1: Skipping backup (--no-backup)[/bold yellow]"
|
||||
)
|
||||
|
||||
# Step 2: Unexport from package (unless --soft)
|
||||
if is_hard_reset:
|
||||
console.print("\n[bold]Step 2: Removing package exports...[/bold]")
|
||||
if not self.unexport_module(module_name):
|
||||
console.print(
|
||||
"[yellow]Warning: Unexport may have failed (continuing)[/yellow]"
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
"\n[bold]Step 2: Keeping package exports (soft reset)[/bold]"
|
||||
)
|
||||
|
||||
# Step 3: Restore source
|
||||
console.print("\n[bold]Step 3: Restoring pristine source...[/bold]")
|
||||
|
||||
if args.restore_backup:
|
||||
# Restore from specific backup
|
||||
success = self.restore_from_backup(module_name, args.restore_backup)
|
||||
else:
|
||||
# Restore from git (default)
|
||||
success = self.restore_from_git(module_name)
|
||||
|
||||
if not success:
|
||||
console.print("[red]Restore failed. Module may be in inconsistent state.[/red]")
|
||||
if not args.no_backup and 'backup_dir' in locals():
|
||||
console.print(
|
||||
f"[yellow]Your work was backed up to: {backup_dir}[/yellow]"
|
||||
)
|
||||
return 1
|
||||
|
||||
# Step 4: Update progress tracking
|
||||
console.print("\n[bold]Step 4: Updating progress tracking...[/bold]")
|
||||
self.update_progress_tracking(module_name, normalized)
|
||||
|
||||
# Success summary
|
||||
console.print(
|
||||
Panel(
|
||||
f"[bold green]✓ Module {module_name} reset successfully![/bold green]\n\n"
|
||||
f"[green]Actions completed:[/green]\n"
|
||||
f" {'✓ Work backed up' if not args.no_backup else '✗ No backup created'}\n"
|
||||
f" {'✓ Package exports removed' if is_hard_reset else '✗ Package exports preserved'}\n"
|
||||
f" ✓ Source restored to pristine state\n"
|
||||
f" ✓ Progress tracking updated\n\n"
|
||||
f"[bold cyan]Next steps:[/bold cyan]\n"
|
||||
f" • [dim]tito module start {normalized}[/dim] - Begin working again\n"
|
||||
f" • [dim]tito module resume {normalized}[/dim] - Continue from where you left off\n"
|
||||
+ (
|
||||
f" • [dim]tito module reset {normalized} --list-backups[/dim] - View backups\n"
|
||||
if not args.no_backup
|
||||
else ""
|
||||
),
|
||||
title="Reset Complete",
|
||||
border_style="green",
|
||||
)
|
||||
)
|
||||
|
||||
return 0
|
||||
@@ -1,575 +0,0 @@
|
||||
"""
|
||||
Enhanced Module Workflow for TinyTorch CLI.
|
||||
|
||||
Implements the natural workflow:
|
||||
1. tito module 01 → Opens module 01 in Jupyter
|
||||
2. Student works and saves
|
||||
3. tito module complete 01 → Tests, exports, updates progress
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||
|
||||
from .base import BaseCommand
|
||||
from .view import ViewCommand
|
||||
from .test import TestCommand
|
||||
from .export import ExportCommand
|
||||
from .module_reset import ModuleResetCommand
|
||||
from ..core.exceptions import ModuleNotFoundError
|
||||
|
||||
class ModuleWorkflowCommand(BaseCommand):
|
||||
"""Enhanced module command with natural workflow."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "module"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Module development workflow - open, work, complete"
|
||||
|
||||
def add_arguments(self, parser: ArgumentParser) -> None:
|
||||
"""Add module workflow arguments."""
|
||||
# Add subcommands - clean lifecycle workflow
|
||||
subparsers = parser.add_subparsers(
|
||||
dest='module_command',
|
||||
help='Module lifecycle operations'
|
||||
)
|
||||
|
||||
# START command - begin working on a module
|
||||
start_parser = subparsers.add_parser(
|
||||
'start',
|
||||
help='Start working on a module (first time)'
|
||||
)
|
||||
start_parser.add_argument(
|
||||
'module_number',
|
||||
help='Module number to start (01, 02, 03, etc.)'
|
||||
)
|
||||
|
||||
# RESUME command - continue working on a module
|
||||
resume_parser = subparsers.add_parser(
|
||||
'resume',
|
||||
help='Resume working on a module (continue previous work)'
|
||||
)
|
||||
resume_parser.add_argument(
|
||||
'module_number',
|
||||
nargs='?',
|
||||
help='Module number to resume (01, 02, 03, etc.) - defaults to last worked'
|
||||
)
|
||||
|
||||
# COMPLETE command - finish and validate a module
|
||||
complete_parser = subparsers.add_parser(
|
||||
'complete',
|
||||
help='Complete module: run tests, export if passing, update progress'
|
||||
)
|
||||
complete_parser.add_argument(
|
||||
'module_number',
|
||||
nargs='?',
|
||||
help='Module number to complete (01, 02, 03, etc.) - defaults to current'
|
||||
)
|
||||
complete_parser.add_argument(
|
||||
'--skip-tests',
|
||||
action='store_true',
|
||||
help='Skip integration tests'
|
||||
)
|
||||
complete_parser.add_argument(
|
||||
'--skip-export',
|
||||
action='store_true',
|
||||
help='Skip automatic export'
|
||||
)
|
||||
|
||||
# RESET command - reset module to clean state
|
||||
reset_parser = subparsers.add_parser(
|
||||
'reset',
|
||||
help='Reset module to clean state (backup + unexport + restore)'
|
||||
)
|
||||
reset_parser.add_argument(
|
||||
'module_number',
|
||||
help='Module number to reset (01, 02, 03, etc.)'
|
||||
)
|
||||
reset_parser.add_argument(
|
||||
'--soft',
|
||||
action='store_true',
|
||||
help='Soft reset: backup + restore (keep exports)'
|
||||
)
|
||||
reset_parser.add_argument(
|
||||
'--hard',
|
||||
action='store_true',
|
||||
help='Hard reset: backup + unexport + restore [DEFAULT]'
|
||||
)
|
||||
reset_parser.add_argument(
|
||||
'--from-git',
|
||||
action='store_true',
|
||||
help='Restore from git HEAD [DEFAULT]'
|
||||
)
|
||||
reset_parser.add_argument(
|
||||
'--restore-backup',
|
||||
metavar='TIMESTAMP',
|
||||
help='Restore from specific backup'
|
||||
)
|
||||
reset_parser.add_argument(
|
||||
'--list-backups',
|
||||
action='store_true',
|
||||
help='List available backups'
|
||||
)
|
||||
reset_parser.add_argument(
|
||||
'--no-backup',
|
||||
action='store_true',
|
||||
help='Skip backup (dangerous)'
|
||||
)
|
||||
reset_parser.add_argument(
|
||||
'--force',
|
||||
action='store_true',
|
||||
help='Skip confirmation'
|
||||
)
|
||||
|
||||
# STATUS command - show progress
|
||||
status_parser = subparsers.add_parser(
|
||||
'status',
|
||||
help='Show module completion status and progress'
|
||||
)
|
||||
|
||||
def get_module_mapping(self) -> Dict[str, str]:
|
||||
"""Get mapping from numbers to module names."""
|
||||
return {
|
||||
"01": "01_tensor",
|
||||
"02": "02_activations",
|
||||
"03": "03_layers",
|
||||
"04": "04_losses",
|
||||
"05": "05_autograd",
|
||||
"06": "06_optimizers",
|
||||
"07": "07_training",
|
||||
"08": "08_spatial",
|
||||
"09": "09_dataloader",
|
||||
"10": "10_tokenization",
|
||||
"11": "11_embeddings",
|
||||
"12": "12_attention",
|
||||
"13": "13_transformers",
|
||||
"14": "14_profiling",
|
||||
"15": "15_acceleration",
|
||||
"16": "16_quantization",
|
||||
"17": "17_compression",
|
||||
"18": "18_caching",
|
||||
"19": "19_benchmarking",
|
||||
"20": "20_capstone",
|
||||
"21": "21_mlops"
|
||||
}
|
||||
|
||||
def normalize_module_number(self, module_input: str) -> str:
|
||||
"""Normalize module input to 2-digit format."""
|
||||
if module_input.isdigit():
|
||||
return f"{int(module_input):02d}"
|
||||
return module_input
|
||||
|
||||
def start_module(self, module_number: str) -> int:
|
||||
"""Start working on a module (first time)."""
|
||||
module_mapping = self.get_module_mapping()
|
||||
normalized = self.normalize_module_number(module_number)
|
||||
|
||||
if normalized not in module_mapping:
|
||||
self.console.print(f"[red]❌ Module {normalized} not found[/red]")
|
||||
self.console.print("💡 Available modules: 01-21")
|
||||
return 1
|
||||
|
||||
module_name = module_mapping[normalized]
|
||||
|
||||
# Check if already started
|
||||
if self.is_module_started(normalized):
|
||||
self.console.print(f"[yellow]⚠️ Module {normalized} already started[/yellow]")
|
||||
self.console.print(f"💡 Did you mean: [bold cyan]tito module resume {normalized}[/bold cyan]")
|
||||
return 1
|
||||
|
||||
# Mark as started
|
||||
self.mark_module_started(normalized)
|
||||
|
||||
self.console.print(f"🚀 Starting Module {normalized}: {module_name}")
|
||||
self.console.print("💡 Work in Jupyter, save your changes, then run:")
|
||||
self.console.print(f" [bold cyan]tito module complete {normalized}[/bold cyan]")
|
||||
|
||||
return self._open_jupyter(module_name)
|
||||
|
||||
def resume_module(self, module_number: Optional[str] = None) -> int:
|
||||
"""Resume working on a module (continue previous work)."""
|
||||
module_mapping = self.get_module_mapping()
|
||||
|
||||
# If no module specified, resume last worked
|
||||
if not module_number:
|
||||
last_worked = self.get_last_worked_module()
|
||||
if not last_worked:
|
||||
self.console.print("[yellow]⚠️ No module to resume[/yellow]")
|
||||
self.console.print("💡 Start with: [bold cyan]tito module start 01[/bold cyan]")
|
||||
return 1
|
||||
module_number = last_worked
|
||||
|
||||
normalized = self.normalize_module_number(module_number)
|
||||
|
||||
if normalized not in module_mapping:
|
||||
self.console.print(f"[red]❌ Module {normalized} not found[/red]")
|
||||
self.console.print("💡 Available modules: 01-21")
|
||||
return 1
|
||||
|
||||
module_name = module_mapping[normalized]
|
||||
|
||||
# Check if module was started
|
||||
if not self.is_module_started(normalized):
|
||||
self.console.print(f"[yellow]⚠️ Module {normalized} not started yet[/yellow]")
|
||||
self.console.print(f"💡 Start with: [bold cyan]tito module start {normalized}[/bold cyan]")
|
||||
return 1
|
||||
|
||||
# Update last worked
|
||||
self.update_last_worked(normalized)
|
||||
|
||||
self.console.print(f"🔄 Resuming Module {normalized}: {module_name}")
|
||||
self.console.print("💡 Continue your work, then run:")
|
||||
self.console.print(f" [bold cyan]tito module complete {normalized}[/bold cyan]")
|
||||
|
||||
return self._open_jupyter(module_name)
|
||||
|
||||
def _open_jupyter(self, module_name: str) -> int:
|
||||
"""Open Jupyter Lab for a module."""
|
||||
# Use the existing view command
|
||||
fake_args = Namespace()
|
||||
fake_args.module = module_name
|
||||
fake_args.force = False
|
||||
|
||||
view_command = ViewCommand(self.config)
|
||||
return view_command.run(fake_args)
|
||||
|
||||
def complete_module(self, module_number: Optional[str] = None, skip_tests: bool = False, skip_export: bool = False) -> int:
|
||||
"""Complete a module with testing and export."""
|
||||
module_mapping = self.get_module_mapping()
|
||||
|
||||
# If no module specified, complete current/last worked
|
||||
if not module_number:
|
||||
last_worked = self.get_last_worked_module()
|
||||
if not last_worked:
|
||||
self.console.print("[yellow]⚠️ No module to complete[/yellow]")
|
||||
self.console.print("💡 Start with: [bold cyan]tito module start 01[/bold cyan]")
|
||||
return 1
|
||||
module_number = last_worked
|
||||
|
||||
normalized = self.normalize_module_number(module_number)
|
||||
|
||||
if normalized not in module_mapping:
|
||||
self.console.print(f"[red]❌ Module {normalized} not found[/red]")
|
||||
return 1
|
||||
|
||||
module_name = module_mapping[normalized]
|
||||
|
||||
self.console.print(Panel(
|
||||
f"🎯 Completing Module {normalized}: {module_name}",
|
||||
title="Module Completion Workflow",
|
||||
border_style="bright_green"
|
||||
))
|
||||
|
||||
success = True
|
||||
|
||||
# Step 1: Run integration tests
|
||||
if not skip_tests:
|
||||
self.console.print("🧪 Running integration tests...")
|
||||
test_result = self.run_module_tests(module_name)
|
||||
if test_result != 0:
|
||||
self.console.print(f"[red]❌ Tests failed for {module_name}[/red]")
|
||||
self.console.print("💡 Fix the issues and try again")
|
||||
return 1
|
||||
self.console.print("✅ All tests passed!")
|
||||
|
||||
# Step 2: Export to package
|
||||
if not skip_export:
|
||||
self.console.print("📦 Exporting to TinyTorch package...")
|
||||
export_result = self.export_module(module_name)
|
||||
if export_result != 0:
|
||||
self.console.print(f"[red]❌ Export failed for {module_name}[/red]")
|
||||
success = False
|
||||
else:
|
||||
self.console.print("✅ Module exported successfully!")
|
||||
|
||||
# Step 3: Update progress tracking
|
||||
self.update_progress(normalized, module_name)
|
||||
|
||||
# Step 4: Check for milestone unlocks
|
||||
if success:
|
||||
self._check_milestone_unlocks(module_name)
|
||||
|
||||
# Step 5: Show next steps
|
||||
self.show_next_steps(normalized)
|
||||
|
||||
return 0 if success else 1
|
||||
|
||||
def run_module_tests(self, module_name: str) -> int:
|
||||
"""Run tests for a specific module."""
|
||||
try:
|
||||
# Run the module's inline tests
|
||||
module_dir = self.config.modules_dir / module_name
|
||||
dev_file = module_dir / f"{module_name.split('_')[1]}.py"
|
||||
|
||||
if not dev_file.exists():
|
||||
self.console.print(f"[yellow]⚠️ No dev file found: {dev_file}[/yellow]")
|
||||
return 0
|
||||
|
||||
# Execute the Python file to run inline tests
|
||||
result = subprocess.run([
|
||||
sys.executable, str(dev_file)
|
||||
], capture_output=True, text=True, cwd=module_dir)
|
||||
|
||||
if result.returncode == 0:
|
||||
return 0
|
||||
else:
|
||||
self.console.print(f"[red]Test output:[/red]\n{result.stdout}")
|
||||
if result.stderr:
|
||||
self.console.print(f"[red]Errors:[/red]\n{result.stderr}")
|
||||
return 1
|
||||
|
||||
except Exception as e:
|
||||
self.console.print(f"[red]Error running tests: {e}[/red]")
|
||||
return 1
|
||||
|
||||
def export_module(self, module_name: str) -> int:
|
||||
"""Export module to the TinyTorch package."""
|
||||
try:
|
||||
# Use the existing export command
|
||||
fake_args = Namespace()
|
||||
fake_args.module = module_name
|
||||
fake_args.force = False
|
||||
|
||||
export_command = ExportCommand(self.config)
|
||||
return export_command.run(fake_args)
|
||||
|
||||
except Exception as e:
|
||||
self.console.print(f"[red]Error exporting module: {e}[/red]")
|
||||
return 1
|
||||
|
||||
def get_progress_data(self) -> dict:
|
||||
"""Get current progress data."""
|
||||
progress_file = self.config.project_root / "progress.json"
|
||||
|
||||
try:
|
||||
import json
|
||||
if progress_file.exists():
|
||||
with open(progress_file, 'r') as f:
|
||||
return json.load(f)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
'started_modules': [],
|
||||
'completed_modules': [],
|
||||
'last_worked': None,
|
||||
'last_completed': None,
|
||||
'last_updated': None
|
||||
}
|
||||
|
||||
def save_progress_data(self, progress: dict) -> None:
|
||||
"""Save progress data."""
|
||||
progress_file = self.config.project_root / "progress.json"
|
||||
|
||||
try:
|
||||
import json
|
||||
from datetime import datetime
|
||||
progress['last_updated'] = datetime.now().isoformat()
|
||||
|
||||
with open(progress_file, 'w') as f:
|
||||
json.dump(progress, f, indent=2)
|
||||
except Exception as e:
|
||||
self.console.print(f"[yellow]⚠️ Could not save progress: {e}[/yellow]")
|
||||
|
||||
def is_module_started(self, module_number: str) -> bool:
|
||||
"""Check if a module has been started."""
|
||||
progress = self.get_progress_data()
|
||||
return module_number in progress.get('started_modules', [])
|
||||
|
||||
def is_module_completed(self, module_number: str) -> bool:
|
||||
"""Check if a module has been completed."""
|
||||
progress = self.get_progress_data()
|
||||
return module_number in progress.get('completed_modules', [])
|
||||
|
||||
def mark_module_started(self, module_number: str) -> None:
|
||||
"""Mark a module as started."""
|
||||
progress = self.get_progress_data()
|
||||
|
||||
if 'started_modules' not in progress:
|
||||
progress['started_modules'] = []
|
||||
|
||||
if module_number not in progress['started_modules']:
|
||||
progress['started_modules'].append(module_number)
|
||||
|
||||
progress['last_worked'] = module_number
|
||||
self.save_progress_data(progress)
|
||||
|
||||
def update_last_worked(self, module_number: str) -> None:
|
||||
"""Update the last worked module."""
|
||||
progress = self.get_progress_data()
|
||||
progress['last_worked'] = module_number
|
||||
self.save_progress_data(progress)
|
||||
|
||||
def get_last_worked_module(self) -> Optional[str]:
|
||||
"""Get the last worked module."""
|
||||
progress = self.get_progress_data()
|
||||
return progress.get('last_worked')
|
||||
|
||||
def update_progress(self, module_number: str, module_name: str) -> None:
|
||||
"""Update user progress tracking."""
|
||||
progress = self.get_progress_data()
|
||||
|
||||
# Update completed modules
|
||||
if 'completed_modules' not in progress:
|
||||
progress['completed_modules'] = []
|
||||
|
||||
if module_number not in progress['completed_modules']:
|
||||
progress['completed_modules'].append(module_number)
|
||||
|
||||
progress['last_completed'] = module_number
|
||||
self.save_progress_data(progress)
|
||||
|
||||
self.console.print(f"📈 Progress updated: {len(progress['completed_modules'])} modules completed")
|
||||
|
||||
def show_next_steps(self, completed_module: str) -> None:
|
||||
"""Show next steps after completing a module."""
|
||||
module_mapping = self.get_module_mapping()
|
||||
completed_num = int(completed_module)
|
||||
next_num = f"{completed_num + 1:02d}"
|
||||
|
||||
if next_num in module_mapping:
|
||||
next_module = module_mapping[next_num]
|
||||
self.console.print(Panel(
|
||||
f"🎉 Module {completed_module} completed!\n\n"
|
||||
f"Next steps:\n"
|
||||
f" [bold cyan]tito module {next_num}[/bold cyan] - Start {next_module}\n"
|
||||
f" [dim]tito module status[/dim] - View overall progress",
|
||||
title="What's Next?",
|
||||
border_style="green"
|
||||
))
|
||||
else:
|
||||
self.console.print(Panel(
|
||||
f"🎉 Module {completed_module} completed!\n\n"
|
||||
"🏆 Congratulations! You've completed all available modules!\n"
|
||||
"🚀 You're now ready to build production ML systems!",
|
||||
title="All Modules Complete!",
|
||||
border_style="gold1"
|
||||
))
|
||||
|
||||
def show_status(self) -> int:
|
||||
"""Show module completion status."""
|
||||
module_mapping = self.get_module_mapping()
|
||||
progress = self.get_progress_data()
|
||||
|
||||
started = progress.get('started_modules', [])
|
||||
completed = progress.get('completed_modules', [])
|
||||
last_worked = progress.get('last_worked')
|
||||
|
||||
self.console.print(Panel(
|
||||
"📊 Module Status & Progress",
|
||||
title="Your Learning Journey",
|
||||
border_style="bright_blue"
|
||||
))
|
||||
|
||||
for num, name in module_mapping.items():
|
||||
if num in completed:
|
||||
status = "✅"
|
||||
state = "completed"
|
||||
elif num in started:
|
||||
status = "🚀" if num == last_worked else "💻"
|
||||
state = "in progress" if num == last_worked else "started"
|
||||
else:
|
||||
status = "⏳"
|
||||
state = "not started"
|
||||
|
||||
marker = " ← current" if num == last_worked else ""
|
||||
self.console.print(f" {status} Module {num}: {name} ({state}){marker}")
|
||||
|
||||
# Summary
|
||||
self.console.print(f"\n📈 Progress: {len(completed)}/{len(module_mapping)} completed, {len(started)} started")
|
||||
|
||||
# Next steps
|
||||
if last_worked:
|
||||
if last_worked not in completed:
|
||||
self.console.print(f"💡 Continue: [bold cyan]tito module resume {last_worked}[/bold cyan]")
|
||||
self.console.print(f"💡 Or complete: [bold cyan]tito module complete {last_worked}[/bold cyan]")
|
||||
else:
|
||||
next_num = f"{int(last_worked) + 1:02d}"
|
||||
if next_num in module_mapping:
|
||||
self.console.print(f"💡 Next: [bold cyan]tito module start {next_num}[/bold cyan]")
|
||||
else:
|
||||
self.console.print("💡 Start with: [bold cyan]tito module start 01[/bold cyan]")
|
||||
|
||||
return 0
|
||||
|
||||
def run(self, args: Namespace) -> int:
|
||||
"""Execute the module workflow command."""
|
||||
# Handle subcommands
|
||||
if hasattr(args, 'module_command') and args.module_command:
|
||||
if args.module_command == 'start':
|
||||
return self.start_module(args.module_number)
|
||||
elif args.module_command == 'resume':
|
||||
return self.resume_module(getattr(args, 'module_number', None))
|
||||
elif args.module_command == 'complete':
|
||||
return self.complete_module(
|
||||
getattr(args, 'module_number', None),
|
||||
getattr(args, 'skip_tests', False),
|
||||
getattr(args, 'skip_export', False)
|
||||
)
|
||||
elif args.module_command == 'reset':
|
||||
# Delegate to ModuleResetCommand
|
||||
reset_command = ModuleResetCommand(self.config)
|
||||
return reset_command.run(args)
|
||||
elif args.module_command == 'status':
|
||||
return self.show_status()
|
||||
|
||||
# Show help if no valid command
|
||||
self.console.print(Panel(
|
||||
"[bold cyan]Module Lifecycle Commands[/bold cyan]\n\n"
|
||||
"[bold]Core Workflow:[/bold]\n"
|
||||
" [bold green]tito module start 01[/bold green] - Start working on Module 01 (first time)\n"
|
||||
" [bold green]tito module resume 01[/bold green] - Resume working on Module 01 (continue)\n"
|
||||
" [bold green]tito module complete 01[/bold green] - Complete Module 01 (test + export)\n"
|
||||
" [bold yellow]tito module reset 01[/bold yellow] - Reset Module 01 to clean state (with backup)\n\n"
|
||||
"[bold]Smart Defaults:[/bold]\n"
|
||||
" [bold]tito module resume[/bold] - Resume last worked module\n"
|
||||
" [bold]tito module complete[/bold] - Complete current module\n"
|
||||
" [bold]tito module status[/bold] - Show progress with states\n\n"
|
||||
"[bold]Natural Learning Flow:[/bold]\n"
|
||||
" 1. [dim]tito module start 01[/dim] → Begin tensors (first time)\n"
|
||||
" 2. [dim]Work in Jupyter, save[/dim] → Ctrl+S to save progress\n"
|
||||
" 3. [dim]tito module complete 01[/dim] → Test, export, track progress\n"
|
||||
" 4. [dim]tito module start 02[/dim] → Begin activations\n"
|
||||
" 5. [dim]tito module resume 02[/dim] → Continue activations later\n\n"
|
||||
"[bold]Module States:[/bold]\n"
|
||||
" ⏳ Not started 🚀 In progress ✅ Completed\n\n"
|
||||
"[bold]Reset Options:[/bold]\n"
|
||||
" [dim]tito module reset 01 --list-backups[/dim] - View available backups\n"
|
||||
" [dim]tito module reset 01 --soft[/dim] - Keep package exports\n"
|
||||
" [dim]tito module reset 01 --restore-backup[/dim] - Restore from backup",
|
||||
title="Module Development Workflow",
|
||||
border_style="bright_cyan"
|
||||
))
|
||||
|
||||
return 0
|
||||
|
||||
def _check_milestone_unlocks(self, module_name: str) -> None:
|
||||
"""Check if completing this module unlocks any milestones."""
|
||||
try:
|
||||
# Import milestone tracker
|
||||
import sys
|
||||
from pathlib import Path as PathLib
|
||||
milestone_tracker_path = PathLib(__file__).parent.parent.parent / "tests" / "milestones"
|
||||
if str(milestone_tracker_path) not in sys.path:
|
||||
sys.path.insert(0, str(milestone_tracker_path))
|
||||
|
||||
from milestone_tracker import check_module_export
|
||||
|
||||
# Let milestone tracker handle everything
|
||||
check_module_export(module_name, console=self.console)
|
||||
|
||||
except ImportError:
|
||||
# Milestone tracker not available, skip silently
|
||||
pass
|
||||
except Exception as e:
|
||||
# Don't fail the workflow if milestone checking fails
|
||||
self.console.print(f"[dim]Note: Could not check milestone unlocks: {e}[/dim]")
|
||||
@@ -1,897 +1,91 @@
|
||||
"""
|
||||
TinyTorch Olympics Command
|
||||
TinyTorch Olympics - Coming Soon!
|
||||
|
||||
Special competition events with focused challenges, time-limited competitions,
|
||||
and unique recognition opportunities beyond the regular community leaderboard.
|
||||
Special competition events where students learn and compete together.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
import uuid
|
||||
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.progress import track
|
||||
from rich.prompt import Prompt, Confirm
|
||||
from rich.console import Group
|
||||
from rich.align import Align
|
||||
from rich.text import Text
|
||||
|
||||
from .base import BaseCommand
|
||||
from ..core.exceptions import TinyTorchCLIError
|
||||
|
||||
|
||||
class OlympicsCommand(BaseCommand):
|
||||
"""Special competition events - Focused challenges and recognition"""
|
||||
|
||||
"""🏅 TinyTorch Olympics - Future competition events"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "olympics"
|
||||
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Special competition events with unique challenges and recognition"
|
||||
|
||||
return "🏅 Competition events - Coming Soon!"
|
||||
|
||||
def add_arguments(self, parser: ArgumentParser) -> None:
|
||||
"""Add olympics subcommands."""
|
||||
subparsers = parser.add_subparsers(
|
||||
dest='olympics_command',
|
||||
help='Olympics operations',
|
||||
metavar='COMMAND'
|
||||
)
|
||||
|
||||
# Events command
|
||||
events_parser = subparsers.add_parser(
|
||||
'events',
|
||||
help='View current and upcoming competition events'
|
||||
)
|
||||
events_parser.add_argument(
|
||||
'--upcoming',
|
||||
action='store_true',
|
||||
help='Show only upcoming events'
|
||||
)
|
||||
events_parser.add_argument(
|
||||
'--past',
|
||||
action='store_true',
|
||||
help='Show past competition results'
|
||||
)
|
||||
|
||||
# Compete command
|
||||
compete_parser = subparsers.add_parser(
|
||||
'compete',
|
||||
help='Enter a specific competition event'
|
||||
)
|
||||
compete_parser.add_argument(
|
||||
'--event',
|
||||
required=True,
|
||||
help='Event ID to compete in'
|
||||
)
|
||||
compete_parser.add_argument(
|
||||
'--accuracy',
|
||||
type=float,
|
||||
help='Accuracy achieved for this competition'
|
||||
)
|
||||
compete_parser.add_argument(
|
||||
'--model',
|
||||
help='Model description and approach used'
|
||||
)
|
||||
compete_parser.add_argument(
|
||||
'--code-url',
|
||||
help='Optional: Link to your competition code/approach'
|
||||
)
|
||||
compete_parser.add_argument(
|
||||
'--notes',
|
||||
help='Competition-specific notes, innovations, learnings'
|
||||
)
|
||||
|
||||
# Awards command
|
||||
awards_parser = subparsers.add_parser(
|
||||
'awards',
|
||||
help='View special recognition and achievement badges'
|
||||
)
|
||||
awards_parser.add_argument(
|
||||
'--personal',
|
||||
action='store_true',
|
||||
help='Show only your personal awards'
|
||||
)
|
||||
|
||||
# History command
|
||||
history_parser = subparsers.add_parser(
|
||||
'history',
|
||||
help='View past competition events and memorable moments'
|
||||
)
|
||||
history_parser.add_argument(
|
||||
'--year',
|
||||
type=int,
|
||||
help='Filter by specific year'
|
||||
)
|
||||
history_parser.add_argument(
|
||||
'--event-type',
|
||||
choices=['speed', 'accuracy', 'innovation', 'efficiency', 'community'],
|
||||
help='Filter by event type'
|
||||
)
|
||||
|
||||
"""Add olympics subcommands (coming soon)."""
|
||||
pass
|
||||
|
||||
def run(self, args: Namespace) -> int:
|
||||
"""Execute olympics command."""
|
||||
command = getattr(args, 'olympics_command', None)
|
||||
|
||||
if not command:
|
||||
self._show_olympics_overview()
|
||||
return 0
|
||||
|
||||
if command == 'events':
|
||||
return self._show_events(args)
|
||||
elif command == 'compete':
|
||||
return self._compete_in_event(args)
|
||||
elif command == 'awards':
|
||||
return self._show_awards(args)
|
||||
elif command == 'history':
|
||||
return self._show_history(args)
|
||||
else:
|
||||
raise TinyTorchCLIError(f"Unknown olympics command: {command}")
|
||||
|
||||
def _show_olympics_overview(self) -> None:
|
||||
"""Show olympics overview and current special events."""
|
||||
self.console.print(Panel(
|
||||
Group(
|
||||
Align.center("[bold bright_gold]🏅 TinyTorch Olympics 🏅[/bold bright_gold]"),
|
||||
"",
|
||||
"[bold]Special Competition Events![/bold] Beyond the regular community leaderboard:",
|
||||
"",
|
||||
"🎯 [bold bright_blue]Focused Challenges[/bold bright_blue]",
|
||||
" • Time-limited competitions (24hr, 1week, 1month challenges)",
|
||||
" • Specific constraints (memory-efficient, fastest training, novel architectures)",
|
||||
" • Theme-based events (interpretability, fairness, efficiency)",
|
||||
"",
|
||||
"🏆 [bold bright_yellow]Special Recognition[/bold bright_yellow]",
|
||||
" • Olympic medals and achievement badges",
|
||||
" • Innovation awards for creative approaches",
|
||||
" • Community impact recognition",
|
||||
"",
|
||||
"🌟 [bold bright_green]Current Active Events[/bold bright_green]",
|
||||
" • Winter 2024 Speed Challenge (Training under 5 minutes)",
|
||||
" • Memory Efficiency Olympics (Models under 1MB)",
|
||||
" • Architecture Innovation Contest (Novel designs welcome)",
|
||||
"",
|
||||
"[bold]Available Commands:[/bold]",
|
||||
" [green]events[/green] - See current and upcoming competitions",
|
||||
" [green]compete[/green] - Enter a specific event",
|
||||
" [green]awards[/green] - View special recognition and badges",
|
||||
" [green]history[/green] - Past competitions and memorable moments",
|
||||
"",
|
||||
"[dim]💡 Note: Olympics are special events separate from daily community leaderboard[/dim]",
|
||||
),
|
||||
title="🥇 Competition Central",
|
||||
"""Show coming soon message with Olympics branding."""
|
||||
console = self.console
|
||||
|
||||
# ASCII Olympics Logo
|
||||
logo = """
|
||||
╔════════════════════════════════════════════════════════════╗
|
||||
║ ║
|
||||
║ 🏅 TINYTORCH OLYMPICS 🏅 ║
|
||||
║ ║
|
||||
║ ⚡ Learn • Build • Compete ⚡ ║
|
||||
║ ║
|
||||
║ 🔥🔥🔥 COMING SOON 🔥🔥🔥 ║
|
||||
║ ║
|
||||
╚════════════════════════════════════════════════════════════╝
|
||||
"""
|
||||
|
||||
message = Text()
|
||||
message.append(logo, style="bold yellow")
|
||||
message.append("\n\n")
|
||||
message.append("🎯 What's Coming:\n\n", style="bold cyan")
|
||||
message.append(" • ", style="cyan")
|
||||
message.append("Speed Challenges", style="bold white")
|
||||
message.append(" - Optimize inference latency\n", style="dim")
|
||||
message.append(" • ", style="cyan")
|
||||
message.append("Compression Competitions", style="bold white")
|
||||
message.append(" - Smallest model, best accuracy\n", style="dim")
|
||||
message.append(" • ", style="cyan")
|
||||
message.append("Accuracy Leaderboards", style="bold white")
|
||||
message.append(" - Push the limits on TinyML datasets\n", style="dim")
|
||||
message.append(" • ", style="cyan")
|
||||
message.append("Innovation Awards", style="bold white")
|
||||
message.append(" - Novel architectures and techniques\n", style="dim")
|
||||
message.append(" • ", style="cyan")
|
||||
message.append("Team Events", style="bold white")
|
||||
message.append(" - Collaborate and compete together\n\n", style="dim")
|
||||
|
||||
message.append("🏆 Why Olympics?\n\n", style="bold yellow")
|
||||
message.append("The TinyTorch Olympics will be a global competition where students\n", style="white")
|
||||
message.append("can showcase their ML engineering skills, learn from each other,\n", style="white")
|
||||
message.append("and earn recognition in the TinyML community.\n\n", style="white")
|
||||
|
||||
message.append("📅 Stay Tuned!\n\n", style="bold green")
|
||||
message.append("Follow TinyTorch updates for the competition launch announcement.\n", style="dim")
|
||||
message.append("In the meantime, keep building and perfecting your TinyTorch skills!\n\n", style="dim")
|
||||
|
||||
message.append("💡 Continue Your Journey:\n", style="bold cyan")
|
||||
message.append(" • Complete modules: ", style="white")
|
||||
message.append("tito module status\n", style="cyan")
|
||||
message.append(" • Track milestones: ", style="white")
|
||||
message.append("tito milestone status\n", style="cyan")
|
||||
message.append(" • Join community: ", style="white")
|
||||
message.append("tito community login\n", style="cyan")
|
||||
|
||||
console.print(Panel(
|
||||
Align.center(message),
|
||||
title="🔥 TinyTorch Olympics 🔥",
|
||||
border_style="bright_yellow",
|
||||
padding=(1, 2)
|
||||
))
|
||||
|
||||
def _show_events(self, args: Namespace) -> int:
|
||||
"""Show current and upcoming competition events."""
|
||||
# Load events data (mock for now)
|
||||
events = self._load_olympics_events()
|
||||
|
||||
if args.upcoming:
|
||||
events = [e for e in events if e["status"] == "upcoming"]
|
||||
title = "📅 Upcoming Competition Events"
|
||||
elif args.past:
|
||||
events = [e for e in events if e["status"] == "completed"]
|
||||
title = "🏛️ Past Competition Results"
|
||||
else:
|
||||
title = "🏅 All Competition Events"
|
||||
|
||||
if not events:
|
||||
status_text = "upcoming" if args.upcoming else "past" if args.past else "available"
|
||||
self.console.print(Panel(
|
||||
f"[yellow]No {status_text} events at this time![/yellow]\n\n"
|
||||
"Check back soon for new competition opportunities!",
|
||||
title="📅 No Events",
|
||||
border_style="yellow"
|
||||
))
|
||||
return 0
|
||||
|
||||
# Create events table
|
||||
table = Table(title=title)
|
||||
table.add_column("Event", style="bold")
|
||||
table.add_column("Type", style="blue")
|
||||
table.add_column("Duration", style="green")
|
||||
table.add_column("Status", style="yellow")
|
||||
table.add_column("Prize/Recognition", style="bright_magenta")
|
||||
table.add_column("Participants", style="cyan", justify="right")
|
||||
|
||||
for event in events:
|
||||
status_display = self._get_status_display(event["status"], event.get("end_date"))
|
||||
|
||||
table.add_row(
|
||||
event["name"],
|
||||
event["type"],
|
||||
event["duration"],
|
||||
status_display,
|
||||
event["prize"],
|
||||
str(event.get("participants", 0))
|
||||
)
|
||||
|
||||
self.console.print(table)
|
||||
|
||||
# Show active event details
|
||||
active_events = [e for e in events if e["status"] == "active"]
|
||||
if active_events:
|
||||
self.console.print(Panel(
|
||||
Group(
|
||||
"[bold bright_green]🔥 Active Competitions You Can Join Now![/bold bright_green]",
|
||||
"",
|
||||
*[f"• [bold]{event['name']}[/bold]: {event['description']}" for event in active_events[:3]],
|
||||
"",
|
||||
"[bold]Join a competition:[/bold]",
|
||||
"[dim]tito olympics compete --event <event_id>[/dim]",
|
||||
),
|
||||
title="⚡ Join Now",
|
||||
border_style="bright_green",
|
||||
padding=(0, 1)
|
||||
))
|
||||
|
||||
|
||||
return 0
|
||||
|
||||
def _compete_in_event(self, args: Namespace) -> int:
|
||||
"""Enter a competition event."""
|
||||
# Check if user is registered for leaderboard
|
||||
if not self._is_user_registered():
|
||||
self.console.print(Panel(
|
||||
"[yellow]Please register for the community leaderboard first![/yellow]\n\n"
|
||||
"Olympics competitions require community membership:\n"
|
||||
"[bold]tito leaderboard register[/bold]",
|
||||
title="📝 Registration Required",
|
||||
border_style="yellow"
|
||||
))
|
||||
return 1
|
||||
|
||||
# Load event details
|
||||
event = self._get_event_details(args.event)
|
||||
if not event:
|
||||
self.console.print(Panel(
|
||||
f"[red]Event '{args.event}' not found![/red]\n\n"
|
||||
"See available events: [bold]tito olympics events[/bold]",
|
||||
title="❌ Event Not Found",
|
||||
border_style="red"
|
||||
))
|
||||
return 1
|
||||
|
||||
# Check if event is active
|
||||
if event["status"] != "active":
|
||||
self.console.print(Panel(
|
||||
f"[yellow]Event '{event['name']}' is not currently active![/yellow]\n\n"
|
||||
f"Status: {event['status']}\n"
|
||||
"See active events: [bold]tito olympics events[/bold]",
|
||||
title="⏰ Event Not Active",
|
||||
border_style="yellow"
|
||||
))
|
||||
return 1
|
||||
|
||||
# Show event details and confirm participation
|
||||
self._show_event_details(event)
|
||||
|
||||
if not Confirm.ask("\n[bold]Compete in this event?[/bold]"):
|
||||
self.console.print("[dim]Maybe next time! 👋[/dim]")
|
||||
return 0
|
||||
|
||||
# Gather competition submission
|
||||
submission = self._gather_competition_submission(event, args)
|
||||
|
||||
# Validate submission meets event criteria
|
||||
validation_result = self._validate_submission(event, submission)
|
||||
if not validation_result["valid"]:
|
||||
self.console.print(Panel(
|
||||
f"[red]Submission doesn't meet event criteria![/red]\n\n"
|
||||
f"Issue: {validation_result['reason']}\n\n"
|
||||
"Please check event requirements and try again.",
|
||||
title="❌ Validation Failed",
|
||||
border_style="red"
|
||||
))
|
||||
return 1
|
||||
|
||||
# Save competition entry
|
||||
self._save_competition_entry(event, submission)
|
||||
|
||||
# Show competition confirmation and standing
|
||||
self._show_competition_confirmation(event, submission)
|
||||
|
||||
return 0
|
||||
|
||||
def _show_awards(self, args: Namespace) -> int:
|
||||
"""Show special recognition and achievement badges."""
|
||||
if args.personal:
|
||||
return self._show_personal_awards()
|
||||
else:
|
||||
return self._show_all_awards()
|
||||
|
||||
def _show_personal_awards(self) -> int:
|
||||
"""Show user's personal awards and badges."""
|
||||
if not self._is_user_registered():
|
||||
self.console.print(Panel(
|
||||
"[yellow]Please register first to see your awards![/yellow]\n\n"
|
||||
"Run: [bold]tito leaderboard register[/bold]",
|
||||
title="📝 Registration Required",
|
||||
border_style="yellow"
|
||||
))
|
||||
return 1
|
||||
|
||||
# Load user's Olympic achievements
|
||||
olympic_profile = self._load_user_olympic_profile()
|
||||
awards = olympic_profile.get("awards", [])
|
||||
competitions = olympic_profile.get("competitions", [])
|
||||
|
||||
if not awards and not competitions:
|
||||
self.console.print(Panel(
|
||||
Group(
|
||||
"[bold bright_blue]🌟 Your Olympic Journey Awaits![/bold bright_blue]",
|
||||
"",
|
||||
"You haven't participated in Olympics competitions yet.",
|
||||
"",
|
||||
"[bold]Start your journey:[/bold]",
|
||||
"• Check active events: [green]tito olympics events[/green]",
|
||||
"• Join a competition: [green]tito olympics compete --event <id>[/green]",
|
||||
"• Earn your first Olympic badge! 🏅",
|
||||
"",
|
||||
"[dim]Every Olympic participant gets recognition for participation![/dim]",
|
||||
),
|
||||
title="🏅 Your Olympic Profile",
|
||||
border_style="bright_blue",
|
||||
padding=(1, 2)
|
||||
))
|
||||
return 0
|
||||
|
||||
# Show awards and achievements
|
||||
self._display_personal_olympic_achievements(olympic_profile)
|
||||
return 0
|
||||
|
||||
def _show_all_awards(self) -> int:
|
||||
"""Show community awards and notable achievements."""
|
||||
# Mock awards data
|
||||
notable_awards = self._load_notable_awards()
|
||||
|
||||
# Recent awards table
|
||||
table = Table(title="🏆 Recent Olympic Achievements")
|
||||
table.add_column("Award", style="bold")
|
||||
table.add_column("Recipient", style="green")
|
||||
table.add_column("Event", style="blue")
|
||||
table.add_column("Achievement", style="yellow")
|
||||
table.add_column("Date", style="dim")
|
||||
|
||||
for award in notable_awards[:10]:
|
||||
table.add_row(
|
||||
award["award_type"],
|
||||
award["recipient"],
|
||||
award["event"],
|
||||
award["description"],
|
||||
award["date"]
|
||||
)
|
||||
|
||||
self.console.print(table)
|
||||
|
||||
# Award categories explanation
|
||||
self.console.print(Panel(
|
||||
Group(
|
||||
"[bold bright_yellow]🏅 Olympic Award Categories[/bold bright_yellow]",
|
||||
"",
|
||||
"🥇 [bold]Performance Awards[/bold]",
|
||||
" • Gold/Silver/Bronze medals for top competition results",
|
||||
" • Speed records, accuracy achievements, efficiency milestones",
|
||||
"",
|
||||
"🌟 [bold]Innovation Awards[/bold]",
|
||||
" • Novel Architecture Award for creative model designs",
|
||||
" • Optimization Genius for breakthrough efficiency techniques",
|
||||
" • Interpretability Champion for explainable AI contributions",
|
||||
"",
|
||||
"🤝 [bold]Community Awards[/bold]",
|
||||
" • Mentor Badge for helping other competitors",
|
||||
" • Knowledge Sharer for valuable insights and tutorials",
|
||||
" • Sportsperson Award for exceptional community spirit",
|
||||
"",
|
||||
"🎯 [bold]Special Recognition[/bold]",
|
||||
" • First Participation Badge (everyone gets this!)",
|
||||
" • Consistency Award for regular competition participation",
|
||||
" • Breakthrough Achievement for major personal improvements",
|
||||
),
|
||||
title="🏆 Recognition System",
|
||||
border_style="bright_yellow",
|
||||
padding=(0, 1)
|
||||
))
|
||||
|
||||
return 0
|
||||
|
||||
def _show_history(self, args: Namespace) -> int:
|
||||
"""Show past competition events and memorable moments."""
|
||||
# Load historical data
|
||||
history = self._load_olympics_history()
|
||||
|
||||
# Filter by year if specified
|
||||
if args.year:
|
||||
history = [h for h in history if h["year"] == args.year]
|
||||
|
||||
# Filter by event type if specified
|
||||
if args.event_type:
|
||||
history = [h for h in history if h["type"] == args.event_type]
|
||||
|
||||
if not history:
|
||||
filter_text = f" for {args.year}" if args.year else ""
|
||||
filter_text += f" ({args.event_type} events)" if args.event_type else ""
|
||||
|
||||
self.console.print(Panel(
|
||||
f"[yellow]No competition history found{filter_text}![/yellow]\n\n"
|
||||
"The Olympics program is just getting started!",
|
||||
title="📚 No History",
|
||||
border_style="yellow"
|
||||
))
|
||||
return 0
|
||||
|
||||
# Create history table
|
||||
table = Table(title="📚 TinyTorch Olympics History")
|
||||
table.add_column("Event", style="bold")
|
||||
table.add_column("Date", style="dim")
|
||||
table.add_column("Type", style="blue")
|
||||
table.add_column("Winner", style="green")
|
||||
table.add_column("Achievement", style="yellow")
|
||||
table.add_column("Memorable Moment", style="cyan")
|
||||
|
||||
for event in sorted(history, key=lambda x: x["date"], reverse=True):
|
||||
table.add_row(
|
||||
event["name"],
|
||||
event["date"],
|
||||
event["type"],
|
||||
event["winner"],
|
||||
event["winning_achievement"],
|
||||
event["memorable_moment"]
|
||||
)
|
||||
|
||||
self.console.print(table)
|
||||
|
||||
# Show legendary moments
|
||||
if not args.year and not args.event_type:
|
||||
self.console.print(Panel(
|
||||
Group(
|
||||
"[bold bright_gold]🌟 Legendary Olympic Moments[/bold bright_gold]",
|
||||
"",
|
||||
"🏆 [bold]The Great Speed Challenge 2024[/bold]",
|
||||
" Winner achieved 75% CIFAR-10 accuracy in just 47 seconds!",
|
||||
"",
|
||||
"🧠 [bold]Architecture Innovation Contest[/bold]",
|
||||
" Revolutionary attention mechanism reduced parameters by 90%",
|
||||
"",
|
||||
"🤝 [bold]Community Spirit Award[/bold]",
|
||||
" Competitor shared winning code to help others improve",
|
||||
"",
|
||||
"[dim]Each Olympics creates new legends in the TinyTorch community! 💫[/dim]",
|
||||
),
|
||||
title="🏛️ Hall of Fame",
|
||||
border_style="bright_gold",
|
||||
padding=(0, 1)
|
||||
))
|
||||
|
||||
return 0
|
||||
|
||||
def _load_olympics_events(self) -> List[Dict[str, Any]]:
|
||||
"""Load olympics events data (mock implementation)."""
|
||||
return [
|
||||
{
|
||||
"id": "winter2024_speed",
|
||||
"name": "Winter 2024 Speed Challenge",
|
||||
"type": "Speed",
|
||||
"status": "active",
|
||||
"duration": "24 hours",
|
||||
"description": "Train CIFAR-10 model to 70%+ accuracy in under 5 minutes",
|
||||
"prize": "🏆 Speed Medal + Recognition",
|
||||
"participants": 23,
|
||||
"start_date": "2024-01-15",
|
||||
"end_date": "2024-01-16",
|
||||
"criteria": {"min_accuracy": 70.0, "max_time_minutes": 5}
|
||||
},
|
||||
{
|
||||
"id": "memory2024_efficiency",
|
||||
"name": "Memory Efficiency Olympics",
|
||||
"type": "Efficiency",
|
||||
"status": "active",
|
||||
"duration": "1 week",
|
||||
"description": "Best CIFAR-10 accuracy with model under 1MB",
|
||||
"prize": "🥇 Efficiency Champion",
|
||||
"participants": 15,
|
||||
"start_date": "2024-01-10",
|
||||
"end_date": "2024-01-17",
|
||||
"criteria": {"max_model_size_mb": 1.0}
|
||||
},
|
||||
{
|
||||
"id": "innovation2024_arch",
|
||||
"name": "Architecture Innovation Contest",
|
||||
"type": "Innovation",
|
||||
"status": "upcoming",
|
||||
"duration": "2 weeks",
|
||||
"description": "Novel architectures and creative approaches welcome",
|
||||
"prize": "🌟 Innovation Award",
|
||||
"participants": 0,
|
||||
"start_date": "2024-02-01",
|
||||
"end_date": "2024-02-14",
|
||||
"criteria": {"novelty_required": True}
|
||||
},
|
||||
{
|
||||
"id": "autumn2023_classic",
|
||||
"name": "Autumn 2023 Classic",
|
||||
"type": "Accuracy",
|
||||
"status": "completed",
|
||||
"duration": "1 month",
|
||||
"description": "Best overall CIFAR-10 accuracy challenge",
|
||||
"prize": "🥇 Gold Medal",
|
||||
"participants": 87,
|
||||
"start_date": "2023-10-01",
|
||||
"end_date": "2023-10-31",
|
||||
"winner": "neural_champion",
|
||||
"winning_score": 84.2
|
||||
}
|
||||
]
|
||||
|
||||
def _get_status_display(self, status: str, end_date: Optional[str] = None) -> str:
|
||||
"""Get display-friendly status with timing information."""
|
||||
if status == "active":
|
||||
if end_date:
|
||||
# Calculate time remaining
|
||||
end = datetime.fromisoformat(end_date)
|
||||
now = datetime.now()
|
||||
if end > now:
|
||||
remaining = end - now
|
||||
if remaining.days > 0:
|
||||
return f"🔥 Active ({remaining.days}d left)"
|
||||
else:
|
||||
hours = remaining.seconds // 3600
|
||||
return f"🔥 Active ({hours}h left)"
|
||||
return "🔥 Active"
|
||||
elif status == "upcoming":
|
||||
return "📅 Upcoming"
|
||||
elif status == "completed":
|
||||
return "✅ Completed"
|
||||
else:
|
||||
return status.title()
|
||||
|
||||
def _is_user_registered(self) -> bool:
|
||||
"""Check if user is registered for community leaderboard."""
|
||||
from .leaderboard import LeaderboardCommand
|
||||
leaderboard_cmd = LeaderboardCommand(self.config)
|
||||
return leaderboard_cmd._load_user_profile() is not None
|
||||
|
||||
def _get_event_details(self, event_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get details for a specific event."""
|
||||
events = self._load_olympics_events()
|
||||
return next((e for e in events if e["id"] == event_id), None)
|
||||
|
||||
def _show_event_details(self, event: Dict[str, Any]) -> None:
|
||||
"""Show detailed information about an event."""
|
||||
self.console.print(Panel(
|
||||
Group(
|
||||
f"[bold bright_blue]{event['name']}[/bold bright_blue]",
|
||||
"",
|
||||
f"[bold]Type:[/bold] {event['type']}",
|
||||
f"[bold]Duration:[/bold] {event['duration']}",
|
||||
f"[bold]Current Participants:[/bold] {event.get('participants', 0)}",
|
||||
"",
|
||||
f"[bold]Challenge:[/bold]",
|
||||
f" {event['description']}",
|
||||
"",
|
||||
f"[bold]Recognition:[/bold]",
|
||||
f" {event['prize']}",
|
||||
"",
|
||||
f"[bold]Requirements:[/bold]",
|
||||
*[f" • {k.replace('_', ' ').title()}: {v}" for k, v in event.get('criteria', {}).items()],
|
||||
),
|
||||
title=f"🏅 {event['type']} Competition",
|
||||
border_style="bright_blue",
|
||||
padding=(1, 2)
|
||||
))
|
||||
|
||||
def _gather_competition_submission(self, event: Dict[str, Any], args: Namespace) -> Dict[str, Any]:
|
||||
"""Gather submission details for competition."""
|
||||
submission = {
|
||||
"event_id": event["id"],
|
||||
"submitted_date": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Get accuracy
|
||||
if args.accuracy is not None:
|
||||
submission["accuracy"] = args.accuracy
|
||||
else:
|
||||
submission["accuracy"] = float(Prompt.ask(
|
||||
f"[bold]Accuracy achieved on {event.get('dataset', 'the task')}[/bold]",
|
||||
default="0.0"
|
||||
))
|
||||
|
||||
# Get model description
|
||||
if args.model:
|
||||
submission["model"] = args.model
|
||||
else:
|
||||
submission["model"] = Prompt.ask(
|
||||
"[bold]Model description[/bold] (architecture, approach, innovations)",
|
||||
default="Custom Model"
|
||||
)
|
||||
|
||||
# Optional fields
|
||||
submission["code_url"] = args.code_url or Prompt.ask(
|
||||
"[bold]Code/approach URL[/bold] (optional)",
|
||||
default=""
|
||||
) or None
|
||||
|
||||
submission["notes"] = args.notes or Prompt.ask(
|
||||
"[bold]Competition notes[/bold] (innovations, challenges, learnings)",
|
||||
default=""
|
||||
) or None
|
||||
|
||||
# Event-specific metrics
|
||||
if "max_time_minutes" in event.get("criteria", {}):
|
||||
training_time = float(Prompt.ask(
|
||||
"[bold]Training time in minutes[/bold]",
|
||||
default="0.0"
|
||||
))
|
||||
submission["training_time_minutes"] = training_time
|
||||
|
||||
if "max_model_size_mb" in event.get("criteria", {}):
|
||||
model_size = float(Prompt.ask(
|
||||
"[bold]Model size in MB[/bold]",
|
||||
default="0.0"
|
||||
))
|
||||
submission["model_size_mb"] = model_size
|
||||
|
||||
return submission
|
||||
|
||||
def _validate_submission(self, event: Dict[str, Any], submission: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate submission meets event criteria."""
|
||||
criteria = event.get("criteria", {})
|
||||
|
||||
# Check minimum accuracy
|
||||
if "min_accuracy" in criteria:
|
||||
if submission["accuracy"] < criteria["min_accuracy"]:
|
||||
return {
|
||||
"valid": False,
|
||||
"reason": f"Accuracy {submission['accuracy']:.1f}% below required {criteria['min_accuracy']:.1f}%"
|
||||
}
|
||||
|
||||
# Check maximum training time
|
||||
if "max_time_minutes" in criteria:
|
||||
if submission.get("training_time_minutes", 0) > criteria["max_time_minutes"]:
|
||||
return {
|
||||
"valid": False,
|
||||
"reason": f"Training time {submission['training_time_minutes']:.1f}min exceeds limit {criteria['max_time_minutes']:.1f}min"
|
||||
}
|
||||
|
||||
# Check maximum model size
|
||||
if "max_model_size_mb" in criteria:
|
||||
if submission.get("model_size_mb", 0) > criteria["max_model_size_mb"]:
|
||||
return {
|
||||
"valid": False,
|
||||
"reason": f"Model size {submission['model_size_mb']:.1f}MB exceeds limit {criteria['max_model_size_mb']:.1f}MB"
|
||||
}
|
||||
|
||||
return {"valid": True}
|
||||
|
||||
def _save_competition_entry(self, event: Dict[str, Any], submission: Dict[str, Any]) -> None:
|
||||
"""Save competition entry to user's Olympic profile."""
|
||||
olympic_profile = self._load_user_olympic_profile()
|
||||
|
||||
if "competitions" not in olympic_profile:
|
||||
olympic_profile["competitions"] = []
|
||||
|
||||
olympic_profile["competitions"].append(submission)
|
||||
|
||||
# Add participation award if first competition
|
||||
if len(olympic_profile["competitions"]) == 1:
|
||||
award = {
|
||||
"type": "participation",
|
||||
"name": "First Olympic Participation",
|
||||
"description": "Welcomed to the Olympics community!",
|
||||
"event": event["name"],
|
||||
"earned_date": datetime.now().isoformat()
|
||||
}
|
||||
if "awards" not in olympic_profile:
|
||||
olympic_profile["awards"] = []
|
||||
olympic_profile["awards"].append(award)
|
||||
|
||||
self._save_user_olympic_profile(olympic_profile)
|
||||
|
||||
def _show_competition_confirmation(self, event: Dict[str, Any], submission: Dict[str, Any]) -> None:
|
||||
"""Show confirmation and current standing."""
|
||||
# Determine performance level for this competition
|
||||
ranking_message = self._get_competition_ranking_message(event, submission)
|
||||
|
||||
self.console.print(Panel(
|
||||
Group(
|
||||
Align.center("[bold bright_green]🎉 Competition Entry Submitted! 🎉[/bold bright_green]"),
|
||||
"",
|
||||
f"[bold]Event:[/bold] {event['name']}",
|
||||
f"[bold]Your Result:[/bold] {submission['accuracy']:.1f}% accuracy",
|
||||
f"[bold]Model:[/bold] {submission['model']}",
|
||||
"",
|
||||
ranking_message,
|
||||
"",
|
||||
"[bold bright_blue]🏅 Recognition Earned:[/bold bright_blue]",
|
||||
"• Olympic Participant Badge",
|
||||
"• Competition Experience Points",
|
||||
"• Community Recognition",
|
||||
"",
|
||||
"[bold]Next Steps:[/bold]",
|
||||
"• View your awards: [green]tito olympics awards --personal[/green]",
|
||||
"• See current standings: [green]tito olympics events[/green]",
|
||||
"• Join another event: [green]tito olympics events[/green]",
|
||||
),
|
||||
title="🥇 Olympic Achievement",
|
||||
border_style="bright_green",
|
||||
padding=(1, 2)
|
||||
))
|
||||
|
||||
def _get_competition_ranking_message(self, event: Dict[str, Any], submission: Dict[str, Any]) -> str:
|
||||
"""Get appropriate ranking/performance message for competition."""
|
||||
accuracy = submission["accuracy"]
|
||||
|
||||
# Mock competition standings for encouragement
|
||||
if accuracy >= 80:
|
||||
return "[bright_green]🏆 Outstanding performance! You're in contention for top prizes![/bright_green]"
|
||||
elif accuracy >= 70:
|
||||
return "[bright_blue]🎯 Strong showing! You're competing well in this event![/bright_blue]"
|
||||
elif accuracy >= 60:
|
||||
return "[bright_yellow]🌟 Good effort! Every competition teaches valuable lessons![/bright_yellow]"
|
||||
else:
|
||||
return "[bright_magenta]💝 Thank you for participating! Competition experience is valuable![/bright_magenta]"
|
||||
|
||||
def _load_user_olympic_profile(self) -> Dict[str, Any]:
|
||||
"""Load user's Olympic competition profile."""
|
||||
data_dir = Path.home() / ".tinytorch" / "olympics"
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
profile_file = data_dir / "olympic_profile.json"
|
||||
|
||||
if profile_file.exists():
|
||||
with open(profile_file, 'r') as f:
|
||||
return json.load(f)
|
||||
|
||||
return {
|
||||
"competitions": [],
|
||||
"awards": [],
|
||||
"created_date": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
def _save_user_olympic_profile(self, profile: Dict[str, Any]) -> None:
|
||||
"""Save user's Olympic competition profile."""
|
||||
data_dir = Path.home() / ".tinytorch" / "olympics"
|
||||
profile_file = data_dir / "olympic_profile.json"
|
||||
|
||||
with open(profile_file, 'w') as f:
|
||||
json.dump(profile, f, indent=2)
|
||||
|
||||
def _display_personal_olympic_achievements(self, olympic_profile: Dict[str, Any]) -> None:
|
||||
"""Display user's personal Olympic achievements."""
|
||||
competitions = olympic_profile.get("competitions", [])
|
||||
awards = olympic_profile.get("awards", [])
|
||||
|
||||
# Summary stats
|
||||
total_competitions = len(competitions)
|
||||
best_accuracy = max([c["accuracy"] for c in competitions], default=0)
|
||||
events_participated = len(set(c["event_id"] for c in competitions))
|
||||
|
||||
self.console.print(Panel(
|
||||
Group(
|
||||
Align.center("[bold bright_gold]🏅 Your Olympic Journey 🏅[/bold bright_gold]"),
|
||||
"",
|
||||
f"🎯 Competitions Entered: {total_competitions}",
|
||||
f"🏆 Best Performance: {best_accuracy:.1f}% accuracy",
|
||||
f"🌟 Events Participated: {events_participated}",
|
||||
f"🥇 Awards Earned: {len(awards)}",
|
||||
),
|
||||
title="📊 Olympic Stats",
|
||||
border_style="bright_gold",
|
||||
padding=(1, 2)
|
||||
))
|
||||
|
||||
# Awards table
|
||||
if awards:
|
||||
awards_table = Table(title="🏆 Your Olympic Awards")
|
||||
awards_table.add_column("Award", style="bold")
|
||||
awards_table.add_column("Event", style="blue")
|
||||
awards_table.add_column("Description", style="green")
|
||||
awards_table.add_column("Date", style="dim")
|
||||
|
||||
for award in sorted(awards, key=lambda x: x["earned_date"], reverse=True):
|
||||
awards_table.add_row(
|
||||
award["name"],
|
||||
award["event"],
|
||||
award["description"],
|
||||
award["earned_date"][:10]
|
||||
)
|
||||
|
||||
self.console.print(awards_table)
|
||||
|
||||
# Recent competitions
|
||||
if competitions:
|
||||
recent_comps = sorted(competitions, key=lambda x: x["submitted_date"], reverse=True)[:5]
|
||||
|
||||
comps_table = Table(title="🎯 Recent Competition Entries")
|
||||
comps_table.add_column("Event", style="bold")
|
||||
comps_table.add_column("Accuracy", style="green", justify="right")
|
||||
comps_table.add_column("Model", style="blue")
|
||||
comps_table.add_column("Date", style="dim")
|
||||
|
||||
for comp in recent_comps:
|
||||
comps_table.add_row(
|
||||
comp["event_id"],
|
||||
f"{comp['accuracy']:.1f}%",
|
||||
comp["model"],
|
||||
comp["submitted_date"][:10]
|
||||
)
|
||||
|
||||
self.console.print(comps_table)
|
||||
|
||||
def _load_notable_awards(self) -> List[Dict[str, Any]]:
|
||||
"""Load notable community awards (mock implementation)."""
|
||||
return [
|
||||
{
|
||||
"award_type": "🥇 Gold Medal",
|
||||
"recipient": "speed_demon",
|
||||
"event": "Winter 2024 Speed Challenge",
|
||||
"description": "2.3 min training, 78.4% accuracy",
|
||||
"date": "2024-01-16"
|
||||
},
|
||||
{
|
||||
"award_type": "🌟 Innovation Award",
|
||||
"recipient": "arch_wizard",
|
||||
"event": "Memory Efficiency Olympics",
|
||||
"description": "Novel attention mechanism",
|
||||
"date": "2024-01-15"
|
||||
},
|
||||
{
|
||||
"award_type": "🤝 Community Spirit",
|
||||
"recipient": "helpful_mentor",
|
||||
"event": "Autumn 2023 Classic",
|
||||
"description": "Shared winning approach publicly",
|
||||
"date": "2023-11-01"
|
||||
},
|
||||
{
|
||||
"award_type": "🏆 Speed Record",
|
||||
"recipient": "lightning_fast",
|
||||
"event": "Winter 2024 Speed Challenge",
|
||||
"description": "47 second training record",
|
||||
"date": "2024-01-15"
|
||||
},
|
||||
{
|
||||
"award_type": "🎯 Accuracy Champion",
|
||||
"recipient": "precision_master",
|
||||
"event": "Architecture Innovation",
|
||||
"description": "86.7% CIFAR-10 accuracy",
|
||||
"date": "2024-01-10"
|
||||
}
|
||||
]
|
||||
|
||||
def _load_olympics_history(self) -> List[Dict[str, Any]]:
|
||||
"""Load historical Olympics data (mock implementation)."""
|
||||
return [
|
||||
{
|
||||
"name": "Autumn 2023 Classic",
|
||||
"date": "2023-10-31",
|
||||
"year": 2023,
|
||||
"type": "accuracy",
|
||||
"winner": "neural_champion",
|
||||
"winning_achievement": "84.2% CIFAR-10 accuracy",
|
||||
"memorable_moment": "First 80%+ achievement in community"
|
||||
},
|
||||
{
|
||||
"name": "Summer 2023 Speed Trial",
|
||||
"date": "2023-07-15",
|
||||
"year": 2023,
|
||||
"type": "speed",
|
||||
"winner": "velocity_victor",
|
||||
"winning_achievement": "3.2 minute training",
|
||||
"memorable_moment": "Breakthrough GPU optimization technique"
|
||||
},
|
||||
{
|
||||
"name": "Spring 2023 Innovation Fair",
|
||||
"date": "2023-04-20",
|
||||
"year": 2023,
|
||||
"type": "innovation",
|
||||
"winner": "creative_genius",
|
||||
"winning_achievement": "Self-organizing architecture",
|
||||
"memorable_moment": "Inspired 12 follow-up research papers"
|
||||
}
|
||||
]
|
||||
@@ -1,418 +0,0 @@
|
||||
"""
|
||||
🛡️ Protection command for TinyTorch CLI: Student protection system management.
|
||||
|
||||
Industry-standard approach to prevent students from accidentally breaking
|
||||
critical Variable/Tensor compatibility fixes that enable CIFAR-10 training.
|
||||
"""
|
||||
|
||||
import os
|
||||
import stat
|
||||
import warnings
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
from .base import BaseCommand
|
||||
|
||||
|
||||
class ProtectCommand(BaseCommand):
|
||||
"""🛡️ Student Protection System for TinyTorch core files."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "protect"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "🛡️ Student protection system to prevent accidental core file edits"
|
||||
|
||||
def add_arguments(self, parser: ArgumentParser) -> None:
|
||||
subparsers = parser.add_subparsers(
|
||||
dest='protect_command',
|
||||
help='Protection subcommands',
|
||||
metavar='SUBCOMMAND'
|
||||
)
|
||||
|
||||
# Enable protection
|
||||
enable_parser = subparsers.add_parser(
|
||||
'enable',
|
||||
help='🔒 Enable comprehensive student protection system'
|
||||
)
|
||||
enable_parser.add_argument(
|
||||
'--force',
|
||||
action='store_true',
|
||||
help='Force enable even if already protected'
|
||||
)
|
||||
|
||||
# Disable protection (for development)
|
||||
disable_parser = subparsers.add_parser(
|
||||
'disable',
|
||||
help='🔓 Disable protection system (for development only)'
|
||||
)
|
||||
disable_parser.add_argument(
|
||||
'--confirm',
|
||||
action='store_true',
|
||||
help='Confirm disabling protection'
|
||||
)
|
||||
|
||||
# Check protection status
|
||||
status_parser = subparsers.add_parser(
|
||||
'status',
|
||||
help='🔍 Check current protection status'
|
||||
)
|
||||
|
||||
# Validate core functionality
|
||||
validate_parser = subparsers.add_parser(
|
||||
'validate',
|
||||
help='✅ Validate core functionality works correctly'
|
||||
)
|
||||
validate_parser.add_argument(
|
||||
'--verbose',
|
||||
action='store_true',
|
||||
help='Show detailed validation output'
|
||||
)
|
||||
|
||||
# Quick health check
|
||||
check_parser = subparsers.add_parser(
|
||||
'check',
|
||||
help='⚡ Quick health check of critical functionality'
|
||||
)
|
||||
|
||||
def run(self, args: Namespace) -> int:
|
||||
"""Execute the protection command."""
|
||||
console = Console()
|
||||
|
||||
# Show header
|
||||
console.print(Panel.fit(
|
||||
"🛡️ [bold cyan]TinyTorch Student Protection System[/bold cyan]\n"
|
||||
"Prevents accidental edits to critical core functionality",
|
||||
border_style="blue"
|
||||
))
|
||||
|
||||
# Route to appropriate subcommand
|
||||
if args.protect_command == 'enable':
|
||||
return self._enable_protection(console, args)
|
||||
elif args.protect_command == 'disable':
|
||||
return self._disable_protection(console, args)
|
||||
elif args.protect_command == 'status':
|
||||
return self._show_protection_status(console)
|
||||
elif args.protect_command == 'validate':
|
||||
return self._validate_functionality(console, args)
|
||||
elif args.protect_command == 'check':
|
||||
return self._quick_health_check(console)
|
||||
else:
|
||||
console.print("[red]❌ No protection subcommand specified[/red]")
|
||||
console.print("Use: [yellow]tito system protect --help[/yellow]")
|
||||
return 1
|
||||
|
||||
def _enable_protection(self, console: Console, args: Namespace) -> int:
|
||||
"""🔒 Enable comprehensive protection system."""
|
||||
console.print("[cyan]🔒 Enabling TinyTorch Student Protection System...[/blue]")
|
||||
console.print()
|
||||
|
||||
protection_count = 0
|
||||
|
||||
# 1. Set file permissions
|
||||
tinytorch_core = Path("tinytorch/core")
|
||||
if tinytorch_core.exists():
|
||||
console.print("[yellow]🔒 Setting core files to read-only...[/yellow]")
|
||||
for py_file in tinytorch_core.glob("*.py"):
|
||||
try:
|
||||
# Make file read-only
|
||||
py_file.chmod(stat.S_IREAD | stat.S_IRGRP | stat.S_IROTH)
|
||||
protection_count += 1
|
||||
except OSError as e:
|
||||
console.print(f"[red]⚠️ Could not protect {py_file}: {e}[/red]")
|
||||
console.print(f"[green]✅ Protected {protection_count} core files[/green]")
|
||||
else:
|
||||
console.print("[yellow]⚠️ tinytorch/core/ not found - run export first[/yellow]")
|
||||
|
||||
# 2. Create .gitattributes
|
||||
console.print("[yellow]📝 Setting up Git attributes...[/yellow]")
|
||||
gitattributes_content = """# 🛡️ TinyTorch Protection: Mark auto-generated files
|
||||
# GitHub will show "Generated" label for these files
|
||||
tinytorch/core/*.py linguist-generated=true
|
||||
tinytorch/**/*.py linguist-generated=true
|
||||
|
||||
# Exclude from diff by default (reduces noise in pull requests)
|
||||
tinytorch/core/*.py -diff
|
||||
"""
|
||||
with open(".gitattributes", "w") as f:
|
||||
f.write(gitattributes_content)
|
||||
console.print("[green]✅ Git attributes configured[/green]")
|
||||
|
||||
# 3. Create pre-commit hook
|
||||
console.print("[yellow]🚫 Installing Git pre-commit hook...[/yellow]")
|
||||
git_hooks_dir = Path(".git/hooks")
|
||||
if git_hooks_dir.exists():
|
||||
precommit_hook = git_hooks_dir / "pre-commit"
|
||||
hook_content = """#!/bin/bash
|
||||
# 🛡️ TinyTorch Protection: Prevent committing auto-generated files
|
||||
|
||||
echo "🛡️ Checking for modifications to auto-generated files..."
|
||||
|
||||
# Check if any tinytorch/core files are staged
|
||||
CORE_FILES_MODIFIED=$(git diff --cached --name-only | grep "^tinytorch/core/")
|
||||
|
||||
if [ ! -z "$CORE_FILES_MODIFIED" ]; then
|
||||
echo ""
|
||||
echo "🚨 ERROR: Attempting to commit auto-generated files!"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
echo "The following auto-generated files are staged:"
|
||||
echo "$CORE_FILES_MODIFIED"
|
||||
echo ""
|
||||
echo "🛡️ PROTECTION TRIGGERED: These files are auto-generated from modules/"
|
||||
echo ""
|
||||
echo "TO FIX:"
|
||||
echo "1. Unstage these files: git reset HEAD tinytorch/core/"
|
||||
echo "2. Make changes in modules/ instead"
|
||||
echo "3. Run: tito module complete <module_name>"
|
||||
echo "4. Commit the source changes, not the generated files"
|
||||
echo ""
|
||||
echo "⚠️ This protection prevents breaking CIFAR-10 training!"
|
||||
echo ""
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "✅ No auto-generated files being committed"
|
||||
"""
|
||||
with open(precommit_hook, "w") as f:
|
||||
f.write(hook_content)
|
||||
precommit_hook.chmod(0o755) # Make executable
|
||||
console.print("[green]✅ Git pre-commit hook installed[/green]")
|
||||
else:
|
||||
console.print("[yellow]⚠️ .git directory not found - skipping Git hooks[/yellow]")
|
||||
|
||||
# 4. Create VSCode settings
|
||||
console.print("[yellow]⚙️ Setting up VSCode protection...[/yellow]")
|
||||
vscode_dir = Path(".vscode")
|
||||
vscode_dir.mkdir(exist_ok=True)
|
||||
|
||||
python_default_interpreter = str(self.venv_path) + "/bin/python"
|
||||
vscode_settings = {
|
||||
"_comment_protection": "🛡️ TinyTorch Student Protection",
|
||||
"files.readonlyInclude": {
|
||||
"**/tinytorch/core/**/*.py": True
|
||||
},
|
||||
"files.readonlyFromPermissions": True,
|
||||
"files.decorations.colors": True,
|
||||
"files.decorations.badges": True,
|
||||
"explorer.decorations.colors": True,
|
||||
"explorer.decorations.badges": True,
|
||||
"python.defaultInterpreterPath": python_default_interpreter,
|
||||
"python.terminal.activateEnvironment": True
|
||||
}
|
||||
|
||||
import json
|
||||
with open(vscode_dir / "settings.json", "w") as f:
|
||||
json.dump(vscode_settings, f, indent=4)
|
||||
console.print("[green]✅ VSCode protection configured[/green]")
|
||||
|
||||
console.print()
|
||||
console.print(Panel.fit(
|
||||
"[green]🎉 Protection System Activated![/green]\n\n"
|
||||
"🔒 Core files are read-only\n"
|
||||
"📝 GitHub will label files as 'Generated'\n"
|
||||
"🚫 Git prevents committing generated files\n"
|
||||
"⚙️ VSCode shows protection warnings\n\n"
|
||||
"[cyan]Students are now protected from breaking CIFAR-10 training![/blue]",
|
||||
border_style="green"
|
||||
))
|
||||
|
||||
return 0
|
||||
|
||||
def _disable_protection(self, console: Console, args: Namespace) -> int:
|
||||
"""🔓 Disable protection system (for development)."""
|
||||
if not args.confirm:
|
||||
console.print("[red]❌ Protection disable requires --confirm flag[/red]")
|
||||
console.print("[yellow]This is to prevent accidental disabling[/yellow]")
|
||||
return 1
|
||||
|
||||
console.print("[yellow]🔓 Disabling TinyTorch Protection System...[/yellow]")
|
||||
|
||||
# Reset file permissions
|
||||
tinytorch_core = Path("tinytorch/core")
|
||||
if tinytorch_core.exists():
|
||||
for py_file in tinytorch_core.glob("*.py"):
|
||||
try:
|
||||
py_file.chmod(0o644) # Reset to normal permissions
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# Remove protection files
|
||||
protection_files = [".gitattributes", ".git/hooks/pre-commit", ".vscode/settings.json"]
|
||||
for file_path in protection_files:
|
||||
path = Path(file_path)
|
||||
if path.exists():
|
||||
try:
|
||||
path.unlink()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
console.print("[green]✅ Protection system disabled[/green]")
|
||||
console.print("[red]⚠️ Remember to re-enable before students use the system![/red]")
|
||||
|
||||
return 0
|
||||
|
||||
def _show_protection_status(self, console: Console) -> int:
|
||||
"""🔍 Show current protection status."""
|
||||
console.print("[cyan]🔍 TinyTorch Protection Status[/blue]")
|
||||
console.print()
|
||||
|
||||
table = Table(show_header=True, header_style="bold cyan")
|
||||
table.add_column("Protection Feature", style="cyan")
|
||||
table.add_column("Status", justify="center")
|
||||
table.add_column("Details", style="dim")
|
||||
|
||||
# Check file permissions
|
||||
tinytorch_core = Path("tinytorch/core")
|
||||
if tinytorch_core.exists():
|
||||
readonly_count = 0
|
||||
total_files = 0
|
||||
for py_file in tinytorch_core.glob("*.py"):
|
||||
total_files += 1
|
||||
if not (py_file.stat().st_mode & stat.S_IWRITE):
|
||||
readonly_count += 1
|
||||
|
||||
if readonly_count == total_files and total_files > 0:
|
||||
table.add_row("🔒 File Permissions", "[green]✅ PROTECTED[/green]", f"{readonly_count}/{total_files} files read-only")
|
||||
elif readonly_count > 0:
|
||||
table.add_row("🔒 File Permissions", "[yellow]⚠️ PARTIAL[/yellow]", f"{readonly_count}/{total_files} files read-only")
|
||||
else:
|
||||
table.add_row("🔒 File Permissions", "[red]❌ UNPROTECTED[/red]", "Files are writable")
|
||||
else:
|
||||
table.add_row("🔒 File Permissions", "[yellow]⚠️ N/A[/yellow]", "tinytorch/core/ not found")
|
||||
|
||||
# Check Git attributes
|
||||
gitattributes = Path(".gitattributes")
|
||||
if gitattributes.exists():
|
||||
table.add_row("📝 Git Attributes", "[green]✅ CONFIGURED[/green]", "Generated files marked")
|
||||
else:
|
||||
table.add_row("📝 Git Attributes", "[red]❌ MISSING[/red]", "No .gitattributes file")
|
||||
|
||||
# Check pre-commit hook
|
||||
precommit_hook = Path(".git/hooks/pre-commit")
|
||||
if precommit_hook.exists():
|
||||
table.add_row("🚫 Git Pre-commit", "[green]✅ ACTIVE[/green]", "Prevents core file commits")
|
||||
else:
|
||||
table.add_row("🚫 Git Pre-commit", "[red]❌ MISSING[/red]", "No pre-commit protection")
|
||||
|
||||
# Check VSCode settings
|
||||
vscode_settings = Path(".vscode/settings.json")
|
||||
if vscode_settings.exists():
|
||||
table.add_row("⚙️ VSCode Protection", "[green]✅ CONFIGURED[/green]", "Editor warnings enabled")
|
||||
else:
|
||||
table.add_row("⚙️ VSCode Protection", "[yellow]⚠️ MISSING[/yellow]", "No VSCode settings")
|
||||
|
||||
console.print(table)
|
||||
console.print()
|
||||
|
||||
# Overall status
|
||||
protection_features = [
|
||||
tinytorch_core.exists() and all(not (f.stat().st_mode & stat.S_IWRITE) for f in tinytorch_core.glob("*.py")),
|
||||
gitattributes.exists(),
|
||||
precommit_hook.exists()
|
||||
]
|
||||
|
||||
if all(protection_features):
|
||||
console.print("[green]🛡️ Overall Status: FULLY PROTECTED[/green]")
|
||||
elif any(protection_features):
|
||||
console.print("[yellow]🛡️ Overall Status: PARTIALLY PROTECTED[/yellow]")
|
||||
console.print("[yellow]💡 Run 'tito system protect enable' to complete protection[/yellow]")
|
||||
else:
|
||||
console.print("[red]🛡️ Overall Status: UNPROTECTED[/red]")
|
||||
console.print("[red]⚠️ Run 'tito system protect enable' to protect against student errors[/red]")
|
||||
|
||||
return 0
|
||||
|
||||
def _validate_functionality(self, console: Console, args: Namespace) -> int:
|
||||
"""✅ Validate core functionality works correctly."""
|
||||
try:
|
||||
from tinytorch.core._validation import run_student_protection_checks
|
||||
console.print("[cyan]🔍 Running comprehensive validation...[/blue]")
|
||||
console.print()
|
||||
|
||||
try:
|
||||
run_student_protection_checks(verbose=args.verbose)
|
||||
console.print()
|
||||
console.print("[green]🎉 All validation checks passed![/green]")
|
||||
console.print("[green]✅ CIFAR-10 training should work correctly[/green]")
|
||||
return 0
|
||||
except Exception as e:
|
||||
console.print()
|
||||
console.print(f"[red]❌ Validation failed: {e}[/red]")
|
||||
console.print("[red]⚠️ CIFAR-10 training may not work properly[/red]")
|
||||
console.print("[yellow]💡 Check if core files have been accidentally modified[/yellow]")
|
||||
return 1
|
||||
|
||||
except ImportError:
|
||||
console.print("[red]❌ Validation system not available[/red]")
|
||||
console.print("[yellow]💡 Run module export to generate validation system[/yellow]")
|
||||
return 1
|
||||
|
||||
def _quick_health_check(self, console: Console) -> int:
|
||||
"""⚡ Quick health check of critical functionality."""
|
||||
console.print("[cyan]⚡ Quick Health Check[/blue]")
|
||||
console.print()
|
||||
|
||||
checks = []
|
||||
|
||||
# Check if core modules can be imported
|
||||
try:
|
||||
from tinytorch.core.tensor import Tensor
|
||||
checks.append(("Core Tensor", True, "Import successful"))
|
||||
except Exception as e:
|
||||
checks.append(("Core Tensor", False, str(e)))
|
||||
|
||||
try:
|
||||
from tinytorch.core.autograd import Variable
|
||||
checks.append(("Core Autograd", True, "Import successful"))
|
||||
except Exception as e:
|
||||
checks.append(("Core Autograd", False, str(e)))
|
||||
|
||||
try:
|
||||
from tinytorch.core.layers import matmul
|
||||
checks.append(("Core Layers", True, "Import successful"))
|
||||
except Exception as e:
|
||||
checks.append(("Core Layers", False, str(e)))
|
||||
|
||||
# Quick Variable/Tensor compatibility test
|
||||
try:
|
||||
from tinytorch.core.tensor import Tensor
|
||||
from tinytorch.core.autograd import Variable
|
||||
from tinytorch.core.layers import matmul
|
||||
|
||||
a = Variable(Tensor([[1, 2]]), requires_grad=True)
|
||||
b = Variable(Tensor([[3], [4]]), requires_grad=True)
|
||||
result = matmul(a, b)
|
||||
|
||||
if hasattr(result, 'requires_grad'):
|
||||
checks.append(("Variable Compatibility", True, "matmul works with Variables"))
|
||||
else:
|
||||
checks.append(("Variable Compatibility", False, "matmul doesn't return Variables"))
|
||||
|
||||
except Exception as e:
|
||||
checks.append(("Variable Compatibility", False, str(e)))
|
||||
|
||||
# Display results
|
||||
for check_name, passed, details in checks:
|
||||
status = "[green]✅ PASS[/green]" if passed else "[red]❌ FAIL[/red]"
|
||||
console.print(f"{status} {check_name}: {details}")
|
||||
|
||||
console.print()
|
||||
|
||||
# Overall status
|
||||
all_passed = all(passed for _, passed, _ in checks)
|
||||
if all_passed:
|
||||
console.print("[green]🎉 All health checks passed![/green]")
|
||||
return 0
|
||||
else:
|
||||
console.print("[red]❌ Some health checks failed[/red]")
|
||||
console.print("[yellow]💡 Run 'tito system protect validate --verbose' for details[/yellow]")
|
||||
return 1
|
||||
@@ -1,363 +0,0 @@
|
||||
"""
|
||||
Report command for TinyTorch CLI: generate comprehensive diagnostic report.
|
||||
|
||||
This command generates a JSON report containing all environment information,
|
||||
perfect for sharing with TAs, instructors, or when filing bug reports.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
|
||||
from .base import BaseCommand
|
||||
|
||||
|
||||
class ReportCommand(BaseCommand):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "report"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Generate comprehensive diagnostic report (JSON)"
|
||||
|
||||
def add_arguments(self, parser: ArgumentParser) -> None:
|
||||
parser.add_argument(
|
||||
'-o', '--output',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Output file path (default: tinytorch-report-TIMESTAMP.json)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--stdout',
|
||||
action='store_true',
|
||||
help='Print JSON to stdout instead of file'
|
||||
)
|
||||
|
||||
def run(self, args: Namespace) -> int:
|
||||
console = self.console
|
||||
|
||||
console.print()
|
||||
console.print(Panel(
|
||||
"📋 Generating TinyTorch Diagnostic Report\n\n"
|
||||
"[dim]This report contains all environment information\n"
|
||||
"needed for debugging and support.[/dim]",
|
||||
title="System Report",
|
||||
border_style="bright_yellow"
|
||||
))
|
||||
console.print()
|
||||
|
||||
# Collect all diagnostic information
|
||||
report = self._collect_report_data()
|
||||
|
||||
# Determine output path
|
||||
if args.stdout:
|
||||
# Print to stdout
|
||||
print(json.dumps(report, indent=2))
|
||||
return 0
|
||||
else:
|
||||
# Write to file
|
||||
if args.output:
|
||||
output_path = Path(args.output)
|
||||
else:
|
||||
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
output_path = Path.cwd() / f"tinytorch-report-{timestamp}.json"
|
||||
|
||||
try:
|
||||
with open(output_path, 'w') as f:
|
||||
json.dump(report, f, indent=2)
|
||||
|
||||
console.print(Panel(
|
||||
f"[bold green]✅ Report Generated Successfully![/bold green]\n\n"
|
||||
f"📄 File: [cyan]{output_path}[/cyan]\n"
|
||||
f"📦 Size: {output_path.stat().st_size} bytes\n\n"
|
||||
f"[dim]Share this file with your TA or instructor for support.[/dim]",
|
||||
title="Report Complete",
|
||||
border_style="green"
|
||||
))
|
||||
|
||||
return 0
|
||||
except Exception as e:
|
||||
console.print(Panel(
|
||||
f"[red]❌ Failed to write report:[/red]\n\n{str(e)}",
|
||||
title="Error",
|
||||
border_style="red"
|
||||
))
|
||||
return 1
|
||||
|
||||
def _collect_report_data(self) -> dict:
|
||||
"""Collect comprehensive diagnostic information."""
|
||||
report = {
|
||||
"report_metadata": {
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
"report_version": "1.0",
|
||||
"tinytorch_cli_version": self._get_tinytorch_version()
|
||||
},
|
||||
"system": self._collect_system_info(),
|
||||
"python": self._collect_python_info(),
|
||||
"environment": self._collect_environment_info(),
|
||||
"dependencies": self._collect_dependencies_info(),
|
||||
"tinytorch": self._collect_tinytorch_info(),
|
||||
"modules": self._collect_modules_info(),
|
||||
"git": self._collect_git_info(),
|
||||
"disk_memory": self._collect_disk_memory_info()
|
||||
}
|
||||
return report
|
||||
|
||||
def _get_tinytorch_version(self) -> str:
|
||||
"""Get TinyTorch version."""
|
||||
try:
|
||||
import tinytorch
|
||||
return getattr(tinytorch, '__version__', 'unknown')
|
||||
except ImportError:
|
||||
return "not_installed"
|
||||
|
||||
def _collect_system_info(self) -> dict:
|
||||
"""Collect system information."""
|
||||
return {
|
||||
"os": platform.system(),
|
||||
"os_release": platform.release(),
|
||||
"os_version": platform.version(),
|
||||
"machine": platform.machine(),
|
||||
"processor": platform.processor(),
|
||||
"platform": platform.platform(),
|
||||
"node": platform.node()
|
||||
}
|
||||
|
||||
def _collect_python_info(self) -> dict:
|
||||
"""Collect Python interpreter information."""
|
||||
return {
|
||||
"version": sys.version,
|
||||
"version_info": {
|
||||
"major": sys.version_info.major,
|
||||
"minor": sys.version_info.minor,
|
||||
"micro": sys.version_info.micro,
|
||||
"releaselevel": sys.version_info.releaselevel,
|
||||
"serial": sys.version_info.serial
|
||||
},
|
||||
"implementation": platform.python_implementation(),
|
||||
"compiler": platform.python_compiler(),
|
||||
"executable": sys.executable,
|
||||
"prefix": sys.prefix,
|
||||
"base_prefix": getattr(sys, 'base_prefix', sys.prefix),
|
||||
"path": sys.path
|
||||
}
|
||||
|
||||
def _collect_environment_info(self) -> dict:
|
||||
"""Collect environment variables and paths."""
|
||||
venv_exists = self.venv_path.exists()
|
||||
in_venv = (
|
||||
os.environ.get('VIRTUAL_ENV') is not None or
|
||||
(hasattr(sys, 'base_prefix') and sys.base_prefix != sys.prefix) or
|
||||
hasattr(sys, 'real_prefix')
|
||||
)
|
||||
|
||||
return {
|
||||
"working_directory": str(Path.cwd()),
|
||||
"virtual_environment": {
|
||||
"exists": venv_exists,
|
||||
"active": in_venv,
|
||||
"path": str(self.venv_path) if venv_exists else None,
|
||||
"VIRTUAL_ENV": os.environ.get('VIRTUAL_ENV'),
|
||||
"CONDA_DEFAULT_ENV": os.environ.get('CONDA_DEFAULT_ENV'),
|
||||
"CONDA_PREFIX": os.environ.get('CONDA_PREFIX')
|
||||
},
|
||||
"PATH": os.environ.get('PATH', '').split(os.pathsep),
|
||||
"PYTHONPATH": os.environ.get('PYTHONPATH', '').split(os.pathsep) if os.environ.get('PYTHONPATH') else []
|
||||
}
|
||||
|
||||
def _collect_dependencies_info(self) -> dict:
|
||||
"""Collect installed package information."""
|
||||
dependencies = {}
|
||||
|
||||
# Core dependencies
|
||||
packages = [
|
||||
('numpy', 'numpy'),
|
||||
('pytest', 'pytest'),
|
||||
('PyYAML', 'yaml'),
|
||||
('rich', 'rich'),
|
||||
('jupyterlab', 'jupyterlab'),
|
||||
('jupytext', 'jupytext'),
|
||||
('nbformat', 'nbformat'),
|
||||
('nbgrader', 'nbgrader'),
|
||||
('nbconvert', 'nbconvert'),
|
||||
('jupyter', 'jupyter'),
|
||||
('matplotlib', 'matplotlib'),
|
||||
('psutil', 'psutil'),
|
||||
('black', 'black'),
|
||||
('isort', 'isort'),
|
||||
('flake8', 'flake8')
|
||||
]
|
||||
|
||||
for display_name, import_name in packages:
|
||||
try:
|
||||
module = __import__(import_name)
|
||||
version = getattr(module, '__version__', 'unknown')
|
||||
location = getattr(module, '__file__', 'unknown')
|
||||
dependencies[display_name] = {
|
||||
"installed": True,
|
||||
"version": version,
|
||||
"location": location
|
||||
}
|
||||
except ImportError:
|
||||
dependencies[display_name] = {
|
||||
"installed": False,
|
||||
"version": None,
|
||||
"location": None
|
||||
}
|
||||
|
||||
return dependencies
|
||||
|
||||
def _collect_tinytorch_info(self) -> dict:
|
||||
"""Collect TinyTorch package information."""
|
||||
try:
|
||||
import tinytorch
|
||||
version = getattr(tinytorch, '__version__', 'unknown')
|
||||
location = Path(tinytorch.__file__).parent
|
||||
|
||||
# Check if in development mode
|
||||
is_dev = (location / '../setup.py').exists() or (location / '../pyproject.toml').exists()
|
||||
|
||||
return {
|
||||
"installed": True,
|
||||
"version": version,
|
||||
"location": str(location),
|
||||
"development_mode": is_dev,
|
||||
"package_structure": {
|
||||
"has_init": (location / '__init__.py').exists(),
|
||||
"has_core": (location / 'core').exists(),
|
||||
"has_ops": (location / 'ops').exists()
|
||||
}
|
||||
}
|
||||
except ImportError:
|
||||
return {
|
||||
"installed": False,
|
||||
"version": None,
|
||||
"location": None,
|
||||
"development_mode": False,
|
||||
"package_structure": {}
|
||||
}
|
||||
|
||||
def _collect_modules_info(self) -> dict:
|
||||
"""Collect TinyTorch modules information."""
|
||||
modules_dir = Path.cwd() / "modules"
|
||||
|
||||
if not modules_dir.exists():
|
||||
return {"exists": False, "modules": []}
|
||||
|
||||
modules = []
|
||||
for module_path in sorted(modules_dir.iterdir()):
|
||||
if module_path.is_dir() and module_path.name.startswith(('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')):
|
||||
module_info = {
|
||||
"name": module_path.name,
|
||||
"path": str(module_path),
|
||||
"has_notebook": any(module_path.glob("*.ipynb")),
|
||||
"has_dev_py": any(module_path.glob("*_dev.py")),
|
||||
"has_tests": (module_path / "tests").exists()
|
||||
}
|
||||
modules.append(module_info)
|
||||
|
||||
return {
|
||||
"exists": True,
|
||||
"count": len(modules),
|
||||
"modules": modules
|
||||
}
|
||||
|
||||
def _collect_git_info(self) -> dict:
|
||||
"""Collect git repository information."""
|
||||
git_dir = Path.cwd() / ".git"
|
||||
|
||||
if not git_dir.exists():
|
||||
return {"is_repo": False}
|
||||
|
||||
try:
|
||||
import subprocess
|
||||
|
||||
# Get current branch
|
||||
branch = subprocess.check_output(
|
||||
['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
|
||||
stderr=subprocess.DEVNULL,
|
||||
text=True
|
||||
).strip()
|
||||
|
||||
# Get current commit
|
||||
commit = subprocess.check_output(
|
||||
['git', 'rev-parse', '--short', 'HEAD'],
|
||||
stderr=subprocess.DEVNULL,
|
||||
text=True
|
||||
).strip()
|
||||
|
||||
# Get remote URL
|
||||
try:
|
||||
remote = subprocess.check_output(
|
||||
['git', 'remote', 'get-url', 'origin'],
|
||||
stderr=subprocess.DEVNULL,
|
||||
text=True
|
||||
).strip()
|
||||
except:
|
||||
remote = None
|
||||
|
||||
# Check for uncommitted changes
|
||||
status = subprocess.check_output(
|
||||
['git', 'status', '--porcelain'],
|
||||
stderr=subprocess.DEVNULL,
|
||||
text=True
|
||||
).strip()
|
||||
has_changes = len(status) > 0
|
||||
|
||||
return {
|
||||
"is_repo": True,
|
||||
"branch": branch,
|
||||
"commit": commit,
|
||||
"remote": remote,
|
||||
"has_uncommitted_changes": has_changes,
|
||||
"status": status if has_changes else None
|
||||
}
|
||||
except:
|
||||
return {"is_repo": True, "error": "Failed to get git info"}
|
||||
|
||||
def _collect_disk_memory_info(self) -> dict:
|
||||
"""Collect disk space and memory information."""
|
||||
info = {}
|
||||
|
||||
# Disk space
|
||||
try:
|
||||
disk_usage = shutil.disk_usage(Path.cwd())
|
||||
info["disk"] = {
|
||||
"total_bytes": disk_usage.total,
|
||||
"used_bytes": disk_usage.used,
|
||||
"free_bytes": disk_usage.free,
|
||||
"total_gb": round(disk_usage.total / (1024**3), 2),
|
||||
"used_gb": round(disk_usage.used / (1024**3), 2),
|
||||
"free_gb": round(disk_usage.free / (1024**3), 2),
|
||||
"percent_used": round((disk_usage.used / disk_usage.total) * 100, 1)
|
||||
}
|
||||
except Exception as e:
|
||||
info["disk"] = {"error": str(e)}
|
||||
|
||||
# Memory
|
||||
try:
|
||||
import psutil
|
||||
mem = psutil.virtual_memory()
|
||||
info["memory"] = {
|
||||
"total_bytes": mem.total,
|
||||
"available_bytes": mem.available,
|
||||
"used_bytes": mem.used,
|
||||
"total_gb": round(mem.total / (1024**3), 2),
|
||||
"available_gb": round(mem.available / (1024**3), 2),
|
||||
"used_gb": round(mem.used / (1024**3), 2),
|
||||
"percent_used": mem.percent
|
||||
}
|
||||
except ImportError:
|
||||
info["memory"] = {"error": "psutil not installed"}
|
||||
except Exception as e:
|
||||
info["memory"] = {"error": str(e)}
|
||||
|
||||
return info
|
||||
@@ -1,127 +0,0 @@
|
||||
"""
|
||||
Version command for TinyTorch CLI: show version information for TinyTorch and dependencies.
|
||||
"""
|
||||
|
||||
from argparse import ArgumentParser, Namespace
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
|
||||
from .base import BaseCommand
|
||||
|
||||
class VersionCommand(BaseCommand):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "version"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Show version information for TinyTorch and dependencies"
|
||||
|
||||
def add_arguments(self, parser: ArgumentParser) -> None:
|
||||
# No arguments needed
|
||||
pass
|
||||
|
||||
def run(self, args: Namespace) -> int:
|
||||
console = self.console
|
||||
|
||||
console.print()
|
||||
console.print(Panel(
|
||||
"📦 TinyTorch Version Information",
|
||||
title="Version Info",
|
||||
border_style="bright_magenta"
|
||||
))
|
||||
console.print()
|
||||
|
||||
# Main Version Table
|
||||
version_table = Table(title="TinyTorch", show_header=True, header_style="bold cyan")
|
||||
version_table.add_column("Component", style="yellow", width=25)
|
||||
version_table.add_column("Version", style="white", width=50)
|
||||
|
||||
# TinyTorch Version
|
||||
try:
|
||||
import tinytorch
|
||||
tinytorch_version = getattr(tinytorch, '__version__', '0.1.0-dev')
|
||||
tinytorch_path = Path(tinytorch.__file__).parent
|
||||
|
||||
version_table.add_row("TinyTorch", f"v{tinytorch_version}")
|
||||
version_table.add_row(" └─ Installation", "Development Mode")
|
||||
version_table.add_row(" └─ Location", str(tinytorch_path))
|
||||
|
||||
# Check if it's a git repo
|
||||
git_dir = Path.cwd() / ".git"
|
||||
if git_dir.exists():
|
||||
try:
|
||||
import subprocess
|
||||
result = subprocess.run(
|
||||
["git", "rev-parse", "--short", "HEAD"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True
|
||||
)
|
||||
commit_hash = result.stdout.strip()
|
||||
version_table.add_row(" └─ Git Commit", commit_hash)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except ImportError:
|
||||
version_table.add_row("TinyTorch", "[red]Not Installed[/red]")
|
||||
|
||||
# Python Version
|
||||
python_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
|
||||
version_table.add_row("Python", python_version)
|
||||
|
||||
console.print(version_table)
|
||||
console.print()
|
||||
|
||||
# Dependencies Version Table
|
||||
deps_table = Table(title="Core Dependencies", show_header=True, header_style="bold magenta")
|
||||
deps_table.add_column("Package", style="cyan", width=20)
|
||||
deps_table.add_column("Version", style="white", width=20)
|
||||
deps_table.add_column("Status", width=30)
|
||||
|
||||
dependencies = [
|
||||
('numpy', 'NumPy'),
|
||||
('pytest', 'pytest'),
|
||||
('yaml', 'PyYAML'),
|
||||
('rich', 'Rich'),
|
||||
('jupyterlab', 'JupyterLab'),
|
||||
('jupytext', 'Jupytext'),
|
||||
('nbformat', 'nbformat'),
|
||||
]
|
||||
|
||||
for import_name, display_name in dependencies:
|
||||
try:
|
||||
module = __import__(import_name)
|
||||
version = getattr(module, '__version__', 'unknown')
|
||||
deps_table.add_row(display_name, version, "[green]✅ Installed[/green]")
|
||||
except ImportError:
|
||||
deps_table.add_row(display_name, "—", "[red]❌ Not Installed[/red]")
|
||||
|
||||
console.print(deps_table)
|
||||
console.print()
|
||||
|
||||
# System Info (brief)
|
||||
system_table = Table(title="System", show_header=True, header_style="bold blue")
|
||||
system_table.add_column("Component", style="yellow", width=20)
|
||||
system_table.add_column("Value", style="white", width=50)
|
||||
|
||||
import platform
|
||||
system_table.add_row("OS", f"{platform.system()} {platform.release()}")
|
||||
system_table.add_row("Architecture", platform.machine())
|
||||
system_table.add_row("Python Implementation", platform.python_implementation())
|
||||
|
||||
console.print(system_table)
|
||||
console.print()
|
||||
|
||||
# Helpful info panel
|
||||
console.print(Panel(
|
||||
"[dim]💡 For complete system information, run:[/dim] [cyan]tito system info[/cyan]\n"
|
||||
"[dim]💡 To check environment health, run:[/dim] [cyan]tito system health[/cyan]",
|
||||
border_style="blue"
|
||||
))
|
||||
|
||||
return 0
|
||||
@@ -1,221 +0,0 @@
|
||||
"""
|
||||
View command for TinyTorch CLI: generates notebooks and opens Jupyter Lab.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
from .base import BaseCommand
|
||||
from ..core.exceptions import ExecutionError, ModuleNotFoundError
|
||||
|
||||
class ViewCommand(BaseCommand):
|
||||
"""Command to generate notebooks and open Jupyter Lab."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "view"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Generate notebooks and open Jupyter Lab"
|
||||
|
||||
def add_arguments(self, parser: ArgumentParser) -> None:
|
||||
"""Add view command arguments."""
|
||||
parser.add_argument(
|
||||
'module',
|
||||
nargs='?',
|
||||
help='View specific module (optional)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--force',
|
||||
action='store_true',
|
||||
help='Force rebuild even if notebook exists'
|
||||
)
|
||||
|
||||
def validate_args(self, args: Namespace) -> None:
|
||||
"""Validate view command arguments."""
|
||||
if args.module:
|
||||
module_dir = self.config.modules_dir / args.module
|
||||
if not module_dir.exists():
|
||||
raise ModuleNotFoundError(f"Module directory '{args.module}' not found")
|
||||
|
||||
# Look for the specific dev file for this module
|
||||
# Extract module name (e.g., "tensor" from "01_tensor")
|
||||
module_name = args.module.split('_', 1)[1] if '_' in args.module else args.module
|
||||
dev_file = module_dir / f"{module_name}.py"
|
||||
|
||||
if not dev_file.exists():
|
||||
# Fallback: look for any *.py file
|
||||
dev_files = list(module_dir.glob("*.py"))
|
||||
if not dev_files:
|
||||
raise ModuleNotFoundError(
|
||||
f"No dev file found in module '{args.module}'. Expected: {dev_file}"
|
||||
)
|
||||
|
||||
def _find_dev_files(self) -> List[Path]:
|
||||
"""Find all *.py files in modules directory."""
|
||||
dev_files = []
|
||||
for module_dir in self.config.modules_dir.iterdir():
|
||||
if module_dir.is_dir():
|
||||
# Look for any *.py file in the directory
|
||||
for dev_py in module_dir.glob("*.py"):
|
||||
dev_files.append(dev_py)
|
||||
return dev_files
|
||||
|
||||
def _convert_file(self, dev_file: Path, force: bool = False) -> Tuple[bool, str]:
|
||||
"""Convert a single Python file to notebook using Jupytext."""
|
||||
try:
|
||||
notebook_file = dev_file.with_suffix('.ipynb')
|
||||
|
||||
# Check if notebook exists and we're not forcing
|
||||
if notebook_file.exists() and not force:
|
||||
return True, f"{dev_file.name} → {notebook_file.name} (already exists)"
|
||||
|
||||
# Use Jupytext to convert Python file to notebook
|
||||
result = subprocess.run([
|
||||
"jupytext", "--to", "notebook", str(dev_file)
|
||||
], capture_output=True, text=True, timeout=30, cwd=dev_file.parent)
|
||||
|
||||
if result.returncode == 0:
|
||||
return True, f"{dev_file.name} → {notebook_file.name}"
|
||||
else:
|
||||
error_msg = result.stderr.strip() if result.stderr.strip() else "Conversion failed"
|
||||
return False, error_msg
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return False, "Conversion timed out"
|
||||
except FileNotFoundError:
|
||||
return False, "Jupytext not found. Install with: pip install jupytext"
|
||||
except Exception as e:
|
||||
return False, f"Error: {str(e)}"
|
||||
|
||||
def _launch_jupyter_lab(self, target_dir: Path) -> bool:
|
||||
"""Launch Jupyter Lab in the specified directory."""
|
||||
try:
|
||||
# Change to target directory and launch Jupyter Lab
|
||||
subprocess.Popen([
|
||||
"jupyter", "lab", "--no-browser"
|
||||
], cwd=target_dir)
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
self.console.print(Panel(
|
||||
"[red]❌ Jupyter Lab not found. Install with: pip install jupyterlab[/red]",
|
||||
title="Error",
|
||||
border_style="red"
|
||||
))
|
||||
return False
|
||||
except Exception as e:
|
||||
self.console.print(Panel(
|
||||
f"[red]❌ Failed to launch Jupyter Lab: {e}[/red]",
|
||||
title="Error",
|
||||
border_style="red"
|
||||
))
|
||||
return False
|
||||
|
||||
def run(self, args: Namespace) -> int:
|
||||
"""Execute the view command."""
|
||||
self.console.print(Panel(
|
||||
"📓 View: Generate Notebooks and Open Jupyter Lab",
|
||||
title="Interactive Development",
|
||||
border_style="bright_cyan"
|
||||
))
|
||||
|
||||
# Determine target directory for Jupyter Lab
|
||||
if args.module:
|
||||
target_dir = self.config.modules_dir / args.module
|
||||
# Find the specific dev file for this module
|
||||
module_name = args.module.split('_', 1)[1] if '_' in args.module else args.module
|
||||
dev_file = target_dir / f"{module_name}.py"
|
||||
|
||||
if dev_file.exists():
|
||||
dev_files = [dev_file]
|
||||
else:
|
||||
# Fallback: find any dev files
|
||||
dev_files = list(target_dir.glob("*.py"))
|
||||
|
||||
self.console.print(f"🔄 Generating notebook for module: {args.module}")
|
||||
else:
|
||||
target_dir = self.config.modules_dir
|
||||
dev_files = self._find_dev_files()
|
||||
if not dev_files:
|
||||
self.console.print(Panel(
|
||||
"[yellow]⚠️ No *.py files found in modules/[/yellow]",
|
||||
title="Nothing to Convert",
|
||||
border_style="yellow"
|
||||
))
|
||||
# Still launch Jupyter Lab even if no notebooks to generate
|
||||
self.console.print("🚀 Opening Jupyter Lab anyway...")
|
||||
if self._launch_jupyter_lab(target_dir):
|
||||
self._print_launch_info(target_dir)
|
||||
return 0
|
||||
self.console.print(f"🔄 Generating notebooks for {len(dev_files)} modules...")
|
||||
|
||||
# Generate notebooks
|
||||
success_count = 0
|
||||
error_count = 0
|
||||
|
||||
for dev_file in dev_files:
|
||||
success, message = self._convert_file(dev_file, args.force)
|
||||
module_name = dev_file.parent.name
|
||||
|
||||
if success:
|
||||
success_count += 1
|
||||
self.console.print(f" ✅ {module_name}: {message}")
|
||||
else:
|
||||
error_count += 1
|
||||
self.console.print(f" ❌ {module_name}: {message}")
|
||||
|
||||
# Launch Jupyter Lab
|
||||
self.console.print("\n🚀 Opening Jupyter Lab...")
|
||||
if not self._launch_jupyter_lab(target_dir):
|
||||
return 1
|
||||
|
||||
# Print summary and instructions
|
||||
self._print_summary(success_count, error_count, target_dir)
|
||||
|
||||
return 0 if error_count == 0 else 1
|
||||
|
||||
def _print_launch_info(self, target_dir: Path) -> None:
|
||||
"""Print Jupyter Lab launch information."""
|
||||
info_text = Text()
|
||||
info_text.append("🌟 Jupyter Lab launched successfully!\n\n", style="bold green")
|
||||
info_text.append("📍 Working directory: ", style="white")
|
||||
info_text.append(f"{target_dir}\n", style="cyan")
|
||||
info_text.append("🌐 Open your browser and navigate to the URL shown in the terminal\n", style="white")
|
||||
info_text.append("📁 Your notebooks will be available in the file browser\n", style="white")
|
||||
info_text.append("🔄 Press Ctrl+C in the terminal to stop Jupyter Lab", style="white")
|
||||
|
||||
self.console.print(Panel(
|
||||
info_text,
|
||||
title="Jupyter Lab Ready",
|
||||
border_style="green"
|
||||
))
|
||||
|
||||
def _print_summary(self, success_count: int, error_count: int, target_dir: Path) -> None:
|
||||
"""Print command execution summary."""
|
||||
summary_text = Text()
|
||||
|
||||
if success_count > 0:
|
||||
summary_text.append(f"✅ Successfully generated {success_count} notebook(s)\n", style="bold green")
|
||||
if error_count > 0:
|
||||
summary_text.append(f"❌ Failed to generate {error_count} notebook(s)\n", style="bold red")
|
||||
|
||||
summary_text.append("\n🌟 Jupyter Lab launched successfully!\n\n", style="bold green")
|
||||
summary_text.append("📍 Working directory: ", style="white")
|
||||
summary_text.append(f"{target_dir}\n", style="cyan")
|
||||
summary_text.append("🌐 Open your browser and navigate to the URL shown above\n", style="white")
|
||||
summary_text.append("📁 Your notebooks are ready for interactive development\n", style="white")
|
||||
summary_text.append("🔄 Press Ctrl+C in the terminal to stop Jupyter Lab", style="white")
|
||||
|
||||
border_style = "green" if error_count == 0 else "yellow"
|
||||
self.console.print(Panel(
|
||||
summary_text,
|
||||
title="View Command Complete",
|
||||
border_style=border_style
|
||||
))
|
||||
@@ -39,6 +39,7 @@ from .commands.setup import SetupCommand
|
||||
from .commands.benchmark import BenchmarkCommand
|
||||
from .commands.community import CommunityCommand
|
||||
from .commands.dev import DevCommand
|
||||
from .commands.olympics import OlympicsCommand
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@@ -76,6 +77,7 @@ class TinyTorchCLI:
|
||||
# Community
|
||||
'community': CommunityCommand,
|
||||
'benchmark': BenchmarkCommand,
|
||||
'olympics': OlympicsCommand,
|
||||
# Shortcuts
|
||||
'export': ExportCommand,
|
||||
'test': TestCommand,
|
||||
@@ -84,7 +86,7 @@ class TinyTorchCLI:
|
||||
}
|
||||
|
||||
# Command categorization for help display
|
||||
self.student_commands = ['module', 'milestones', 'community', 'benchmark']
|
||||
self.student_commands = ['module', 'milestones', 'community', 'benchmark', 'olympics']
|
||||
self.developer_commands = ['dev', 'system', 'src', 'package', 'nbgrader']
|
||||
|
||||
# Welcome screen sections (used for both tito and tito --help)
|
||||
|
||||
Reference in New Issue
Block a user