Simplify top bar to match MLSysBook style

- Clean navigation bar: brand left, links right
- Icon+text buttons: MLSysBook, Subscribe, Star, Community
- Responsive: icons only on mobile
- Removed old WIP banner complexity
This commit is contained in:
Vijay Janapa Reddi
2025-12-02 21:24:22 -05:00
parent 2c5ef2e3cf
commit af0beb9408
6 changed files with 900 additions and 284 deletions

78
tito/commands/dev/dev.py Normal file
View File

@@ -0,0 +1,78 @@
"""
Developer command group for TinyTorch CLI.
These commands are for TinyTorch developers and instructors, not students.
They help with:
- Pre-commit/pre-release verification (preflight)
- CI/CD integration
- Development workflows
"""
from argparse import ArgumentParser, Namespace
from rich.panel import Panel
from ..base import BaseCommand
from .preflight import PreflightCommand
class DevCommand(BaseCommand):
"""Developer tools command group."""
@property
def name(self) -> str:
return "dev"
@property
def description(self) -> str:
return "Developer tools: preflight checks, CI/CD, workflows"
def add_arguments(self, parser: ArgumentParser) -> None:
subparsers = parser.add_subparsers(
dest='dev_command',
help='Developer subcommands',
metavar='SUBCOMMAND'
)
# Preflight subcommand
preflight_parser = subparsers.add_parser(
'preflight',
help='Run preflight verification checks before commit/release'
)
preflight_cmd = PreflightCommand(self.config)
preflight_cmd.add_arguments(preflight_parser)
def run(self, args: Namespace) -> int:
console = self.console
if not hasattr(args, 'dev_command') or not args.dev_command:
console.print(Panel(
"[bold cyan]Developer Commands[/bold cyan]\n\n"
"[bold]For developers and instructors - not for students.[/bold]\n\n"
"Available subcommands:\n"
" • [bold]preflight[/bold] - Run verification checks before commit/release\n\n"
"[bold cyan]Preflight Levels:[/bold cyan]\n"
" [dim]tito dev preflight[/dim] Standard checks (~30s)\n"
" [dim]tito dev preflight --quick[/dim] Quick checks only (~10s)\n"
" [dim]tito dev preflight --full[/dim] Full validation (~2-5min)\n"
" [dim]tito dev preflight --release[/dim] Release validation (~10-30min)\n\n"
"[bold cyan]CI/CD Integration:[/bold cyan]\n"
" [dim]tito dev preflight --ci[/dim] Non-interactive, exit codes\n"
" [dim]tito dev preflight --json[/dim] JSON output for automation\n\n"
"[dim]Example: tito dev preflight --full[/dim]",
title="🛠️ Developer Tools",
border_style="bright_cyan"
))
return 0
# Execute the appropriate subcommand
if args.dev_command == 'preflight':
cmd = PreflightCommand(self.config)
return cmd.execute(args)
else:
console.print(Panel(
f"[red]Unknown dev subcommand: {args.dev_command}[/red]",
title="Error",
border_style="red"
))
return 1

View File

@@ -0,0 +1,692 @@
"""
Preflight checks for TinyTorch development and releases.
This command runs comprehensive verification before commits, PRs, or releases.
The same checks can be used in CI/CD pipelines.
Usage:
tito dev preflight # Standard preflight (quick + structure)
tito dev preflight --full # Full validation (includes module tests)
tito dev preflight --release # Release validation (comprehensive)
tito dev preflight --ci # CI mode (non-interactive, exit codes)
"""
import subprocess
import sys
import time
import json
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass, field
from enum import Enum
from rich.panel import Panel
from rich.table import Table
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn
from rich.console import Group
from rich.text import Text
from rich import box
from ..base import BaseCommand
class CheckStatus(Enum):
"""Status of a preflight check."""
PASS = "pass"
FAIL = "fail"
WARN = "warn"
SKIP = "skip"
@dataclass
class CheckResult:
"""Result of a single preflight check."""
name: str
status: CheckStatus
message: str = ""
duration_ms: int = 0
details: List[str] = field(default_factory=list)
@dataclass
class CheckCategory:
"""A category of preflight checks."""
name: str
emoji: str
checks: List[CheckResult] = field(default_factory=list)
@property
def passed(self) -> int:
return sum(1 for c in self.checks if c.status == CheckStatus.PASS)
@property
def failed(self) -> int:
return sum(1 for c in self.checks if c.status == CheckStatus.FAIL)
@property
def warned(self) -> int:
return sum(1 for c in self.checks if c.status == CheckStatus.WARN)
@property
def all_passed(self) -> bool:
return self.failed == 0
class PreflightCommand(BaseCommand):
"""Run preflight checks before commits, PRs, or releases."""
@property
def name(self) -> str:
return "preflight"
@property
def description(self) -> str:
return "Run preflight verification checks"
def add_arguments(self, parser: ArgumentParser) -> None:
parser.add_argument(
'--quick',
action='store_true',
help='Quick checks only (~10 seconds)'
)
parser.add_argument(
'--full',
action='store_true',
help='Full validation including module tests (~2-5 minutes)'
)
parser.add_argument(
'--release',
action='store_true',
help='Release validation - comprehensive (~10-30 minutes)'
)
parser.add_argument(
'--ci',
action='store_true',
help='CI mode: non-interactive, structured output, strict exit codes'
)
parser.add_argument(
'--json',
action='store_true',
help='Output results as JSON (implies --ci)'
)
parser.add_argument(
'--fix',
action='store_true',
help='Attempt to auto-fix common issues'
)
def run(self, args: Namespace) -> int:
console = self.console
project_root = Path.cwd()
start_time = time.time()
# Determine check level
if args.release:
level = "release"
level_emoji = "🚀"
level_desc = "Release Validation"
elif args.full:
level = "full"
level_emoji = "🔍"
level_desc = "Full Validation"
elif args.quick:
level = "quick"
level_emoji = ""
level_desc = "Quick Checks"
else:
level = "standard"
level_emoji = "✈️"
level_desc = "Standard Preflight"
is_ci = args.ci or args.json
# Show header (unless JSON output)
if not args.json:
console.print(Panel(
f"[bold cyan]{level_emoji} {level_desc}[/bold cyan]\n\n"
f"Running verification checks before {'CI/CD' if is_ci else 'your next step'}...\n"
f"[dim]Level: {level} | CI Mode: {is_ci}[/dim]",
title="TinyTorch Preflight",
border_style="bright_cyan"
))
console.print()
# Run checks based on level
categories = []
# Level 1: Quick checks (always run)
categories.append(self._check_structure(project_root))
categories.append(self._check_cli(project_root))
categories.append(self._check_imports(project_root))
# Level 2: Standard checks
if level in ["standard", "full", "release"]:
categories.append(self._check_git_state(project_root))
# Level 3: Full checks
if level in ["full", "release"]:
categories.append(self._check_module_tests(project_root, quick=(level != "release")))
# Level 4: Release checks
if level == "release":
categories.append(self._check_milestones(project_root))
categories.append(self._check_e2e(project_root))
categories.append(self._check_docs(project_root))
# Calculate totals
total_passed = sum(c.passed for c in categories)
total_failed = sum(c.failed for c in categories)
total_warned = sum(c.warned for c in categories)
total_checks = sum(len(c.checks) for c in categories)
all_passed = total_failed == 0
duration = time.time() - start_time
# Output results
if args.json:
self._output_json(categories, all_passed, duration)
else:
self._output_rich(categories, all_passed, duration, total_passed, total_failed, total_warned, total_checks, level, is_ci)
return 0 if all_passed else 1
def _run_command(self, cmd: List[str], cwd: Path, timeout: int = 60) -> Tuple[int, str, str]:
"""Run a command and return (exit_code, stdout, stderr)."""
try:
result = subprocess.run(
cmd,
cwd=cwd,
capture_output=True,
text=True,
timeout=timeout
)
return result.returncode, result.stdout, result.stderr
except subprocess.TimeoutExpired:
return -1, "", "Command timed out"
except Exception as e:
return -1, "", str(e)
def _check_structure(self, project_root: Path) -> CheckCategory:
"""Check project structure and required files."""
category = CheckCategory(name="Project Structure", emoji="📁")
# Required directories
required_dirs = [
("modules/", "Module notebooks directory"),
("src/", "Source files directory"),
("tinytorch/", "Package directory"),
("milestones/", "Milestone scripts"),
("tests/", "Test directory"),
("tito/", "CLI directory"),
]
for dir_path, desc in required_dirs:
start = time.time()
path = project_root / dir_path
if path.exists() and path.is_dir():
category.checks.append(CheckResult(
name=f"{dir_path} exists",
status=CheckStatus.PASS,
duration_ms=int((time.time() - start) * 1000)
))
else:
category.checks.append(CheckResult(
name=f"{dir_path} exists",
status=CheckStatus.FAIL,
message=f"Missing: {desc}",
duration_ms=int((time.time() - start) * 1000)
))
# Required files
required_files = [
"pyproject.toml",
"requirements.txt",
"README.md",
]
for file_path in required_files:
start = time.time()
path = project_root / file_path
if path.exists():
category.checks.append(CheckResult(
name=f"{file_path} exists",
status=CheckStatus.PASS,
duration_ms=int((time.time() - start) * 1000)
))
else:
category.checks.append(CheckResult(
name=f"{file_path} exists",
status=CheckStatus.FAIL,
message=f"Missing required file",
duration_ms=int((time.time() - start) * 1000)
))
# Check module count
start = time.time()
modules_dir = project_root / "modules"
if modules_dir.exists():
module_count = len([d for d in modules_dir.iterdir() if d.is_dir() and d.name[0].isdigit()])
if module_count >= 15:
category.checks.append(CheckResult(
name=f"Module count ({module_count})",
status=CheckStatus.PASS,
duration_ms=int((time.time() - start) * 1000)
))
else:
category.checks.append(CheckResult(
name=f"Module count ({module_count})",
status=CheckStatus.WARN,
message=f"Expected 20+ modules, found {module_count}",
duration_ms=int((time.time() - start) * 1000)
))
return category
def _check_cli(self, project_root: Path) -> CheckCategory:
"""Check CLI commands work."""
category = CheckCategory(name="CLI Commands", emoji="🖥️")
cli_checks = [
(["--version"], "tito --version"),
(["--help"], "tito --help"),
(["module", "status"], "tito module status"),
(["system", "info"], "tito system info"),
(["milestones", "list", "--simple"], "tito milestones list"),
]
for args, name in cli_checks:
start = time.time()
cmd = [sys.executable, "-m", "tito.main"] + args
code, stdout, stderr = self._run_command(cmd, project_root, timeout=30)
duration = int((time.time() - start) * 1000)
if code == 0:
category.checks.append(CheckResult(
name=name,
status=CheckStatus.PASS,
duration_ms=duration
))
else:
category.checks.append(CheckResult(
name=name,
status=CheckStatus.FAIL,
message=stderr[:100] if stderr else "Command failed",
duration_ms=duration
))
return category
def _check_imports(self, project_root: Path) -> CheckCategory:
"""Check that key imports work."""
category = CheckCategory(name="Package Imports", emoji="📦")
imports = [
("import tinytorch", "tinytorch package"),
("from tinytorch import Tensor", "Tensor class"),
("from tito.main import TinyTorchCLI", "CLI class"),
]
for import_stmt, name in imports:
start = time.time()
cmd = [sys.executable, "-c", import_stmt]
code, stdout, stderr = self._run_command(cmd, project_root, timeout=10)
duration = int((time.time() - start) * 1000)
if code == 0:
category.checks.append(CheckResult(
name=name,
status=CheckStatus.PASS,
duration_ms=duration
))
else:
# Import failures might be expected if module not exported yet
category.checks.append(CheckResult(
name=name,
status=CheckStatus.WARN,
message="Import failed (may need export)",
duration_ms=duration
))
return category
def _check_git_state(self, project_root: Path) -> CheckCategory:
"""Check git repository state."""
category = CheckCategory(name="Git State", emoji="🔀")
# Check if git repo
start = time.time()
code, stdout, stderr = self._run_command(["git", "status", "--porcelain"], project_root)
duration = int((time.time() - start) * 1000)
if code != 0:
category.checks.append(CheckResult(
name="Git repository",
status=CheckStatus.WARN,
message="Not a git repository",
duration_ms=duration
))
return category
category.checks.append(CheckResult(
name="Git repository",
status=CheckStatus.PASS,
duration_ms=duration
))
# Check for uncommitted changes
start = time.time()
if stdout.strip():
lines = stdout.strip().split('\n')
category.checks.append(CheckResult(
name="Clean working tree",
status=CheckStatus.WARN,
message=f"{len(lines)} uncommitted changes",
duration_ms=int((time.time() - start) * 1000)
))
else:
category.checks.append(CheckResult(
name="Clean working tree",
status=CheckStatus.PASS,
duration_ms=int((time.time() - start) * 1000)
))
# Check current branch
start = time.time()
code, stdout, stderr = self._run_command(["git", "rev-parse", "--abbrev-ref", "HEAD"], project_root)
branch = stdout.strip() if code == 0 else "unknown"
category.checks.append(CheckResult(
name=f"Branch: {branch}",
status=CheckStatus.PASS,
duration_ms=int((time.time() - start) * 1000)
))
return category
def _check_module_tests(self, project_root: Path, quick: bool = True) -> CheckCategory:
"""Run module tests."""
category = CheckCategory(name="Module Tests", emoji="🧪")
# Determine which tests to run
if quick:
# Just run a few key tests
test_targets = [
("tests/01_tensor/", "Module 01 tests"),
("tests/02_activations/", "Module 02 tests"),
]
else:
# Run all module tests
test_targets = [
("tests/", "All tests"),
]
for test_path, name in test_targets:
start = time.time()
full_path = project_root / test_path
if not full_path.exists():
category.checks.append(CheckResult(
name=name,
status=CheckStatus.SKIP,
message="Test directory not found",
duration_ms=0
))
continue
cmd = [sys.executable, "-m", "pytest", str(full_path), "-v", "--tb=short", "-q"]
timeout = 300 if not quick else 60
code, stdout, stderr = self._run_command(cmd, project_root, timeout=timeout)
duration = int((time.time() - start) * 1000)
if code == 0:
# Parse test count from output
passed_count = "all"
for line in stdout.split('\n'):
if 'passed' in line:
passed_count = line.strip()
break
category.checks.append(CheckResult(
name=name,
status=CheckStatus.PASS,
message=passed_count,
duration_ms=duration
))
else:
# Extract failure info
failed_info = "Tests failed"
for line in stdout.split('\n'):
if 'failed' in line.lower() or 'error' in line.lower():
failed_info = line.strip()[:80]
break
category.checks.append(CheckResult(
name=name,
status=CheckStatus.FAIL,
message=failed_info,
duration_ms=duration
))
return category
def _check_milestones(self, project_root: Path) -> CheckCategory:
"""Check milestone scripts exist and are runnable."""
category = CheckCategory(name="Milestones", emoji="🏆")
milestones_dir = project_root / "milestones"
if not milestones_dir.exists():
category.checks.append(CheckResult(
name="Milestones directory",
status=CheckStatus.FAIL,
message="milestones/ not found"
))
return category
# Check key milestone scripts exist
milestone_scripts = [
("01_1957_perceptron/02_rosenblatt_trained.py", "Perceptron script"),
("02_1969_xor/02_xor_solved.py", "XOR script"),
("03_1986_mlp/01_rumelhart_tinydigits.py", "MLP script"),
]
for script_path, name in milestone_scripts:
start = time.time()
full_path = milestones_dir / script_path
if full_path.exists():
category.checks.append(CheckResult(
name=name,
status=CheckStatus.PASS,
duration_ms=int((time.time() - start) * 1000)
))
else:
category.checks.append(CheckResult(
name=name,
status=CheckStatus.WARN,
message="Script not found",
duration_ms=int((time.time() - start) * 1000)
))
return category
def _check_e2e(self, project_root: Path) -> CheckCategory:
"""Run E2E tests."""
category = CheckCategory(name="E2E Tests", emoji="🔄")
e2e_dir = project_root / "tests" / "e2e"
if not e2e_dir.exists():
category.checks.append(CheckResult(
name="E2E test directory",
status=CheckStatus.WARN,
message="tests/e2e/ not found"
))
return category
# Run quick E2E tests
start = time.time()
cmd = [sys.executable, "-m", "pytest", str(e2e_dir), "-v", "-k", "quick", "--tb=short"]
code, stdout, stderr = self._run_command(cmd, project_root, timeout=120)
duration = int((time.time() - start) * 1000)
if code == 0:
category.checks.append(CheckResult(
name="E2E quick tests",
status=CheckStatus.PASS,
duration_ms=duration
))
else:
category.checks.append(CheckResult(
name="E2E quick tests",
status=CheckStatus.FAIL,
message="E2E tests failed",
duration_ms=duration
))
return category
def _check_docs(self, project_root: Path) -> CheckCategory:
"""Check documentation exists."""
category = CheckCategory(name="Documentation", emoji="📚")
doc_files = [
("README.md", "Main README"),
("docs/getting-started.md", "Getting Started"),
("CONTRIBUTING.md", "Contributing Guide"),
]
for file_path, name in doc_files:
start = time.time()
full_path = project_root / file_path
if full_path.exists():
# Check it's not empty
size = full_path.stat().st_size
if size > 100:
category.checks.append(CheckResult(
name=name,
status=CheckStatus.PASS,
duration_ms=int((time.time() - start) * 1000)
))
else:
category.checks.append(CheckResult(
name=name,
status=CheckStatus.WARN,
message="File seems empty",
duration_ms=int((time.time() - start) * 1000)
))
else:
category.checks.append(CheckResult(
name=name,
status=CheckStatus.WARN,
message="File not found",
duration_ms=int((time.time() - start) * 1000)
))
return category
def _output_json(self, categories: List[CheckCategory], all_passed: bool, duration: float):
"""Output results as JSON for CI/CD."""
output = {
"success": all_passed,
"duration_seconds": round(duration, 2),
"categories": []
}
for category in categories:
cat_data = {
"name": category.name,
"passed": category.passed,
"failed": category.failed,
"warned": category.warned,
"checks": []
}
for check in category.checks:
cat_data["checks"].append({
"name": check.name,
"status": check.status.value,
"message": check.message,
"duration_ms": check.duration_ms
})
output["categories"].append(cat_data)
print(json.dumps(output, indent=2))
def _output_rich(self, categories: List[CheckCategory], all_passed: bool, duration: float,
total_passed: int, total_failed: int, total_warned: int, total_checks: int,
level: str, is_ci: bool):
"""Output results with rich formatting."""
console = self.console
for category in categories:
# Create table for category
table = Table(
show_header=False,
box=None,
padding=(0, 1),
expand=True
)
table.add_column("Status", width=3)
table.add_column("Check", style="bold")
table.add_column("Message", style="dim")
table.add_column("Time", style="dim", justify="right", width=8)
for check in category.checks:
if check.status == CheckStatus.PASS:
status_icon = "[green]✓[/green]"
elif check.status == CheckStatus.FAIL:
status_icon = "[red]✗[/red]"
elif check.status == CheckStatus.WARN:
status_icon = "[yellow]⚠[/yellow]"
else:
status_icon = "[dim]○[/dim]"
time_str = f"{check.duration_ms}ms" if check.duration_ms > 0 else ""
table.add_row(status_icon, check.name, check.message, time_str)
# Category header
status_summary = f"[green]{category.passed}✓[/green]"
if category.failed > 0:
status_summary += f" [red]{category.failed}✗[/red]"
if category.warned > 0:
status_summary += f" [yellow]{category.warned}⚠[/yellow]"
console.print(f"\n[bold]{category.emoji} {category.name}[/bold] {status_summary}")
console.print(table)
# Summary
console.print()
if all_passed:
console.print(Panel(
f"[bold green]✅ All preflight checks passed![/bold green]\n\n"
f"[green]{total_passed}[/green] passed "
f"[yellow]{total_warned}[/yellow] warnings "
f"[dim]{duration:.1f}s[/dim]\n\n"
f"[dim]Ready for: commit, PR, or {level} deployment[/dim]",
title="Preflight Complete",
border_style="green"
))
else:
console.print(Panel(
f"[bold red]❌ Preflight checks failed[/bold red]\n\n"
f"[green]{total_passed}[/green] passed "
f"[red]{total_failed}[/red] failed "
f"[yellow]{total_warned}[/yellow] warnings "
f"[dim]{duration:.1f}s[/dim]\n\n"
f"[dim]Fix the issues above before proceeding[/dim]",
title="Preflight Failed",
border_style="red"
))
# Show next steps
if not is_ci:
if all_passed:
if level == "quick":
console.print("\n[dim]💡 For thorough validation: tito dev preflight --full[/dim]")
elif level == "standard":
console.print("\n[dim]💡 For release validation: tito dev preflight --release[/dim]")
else:
console.print("\n[dim]💡 Fix issues and re-run: tito dev preflight[/dim]")

View File

@@ -38,6 +38,7 @@ from .commands.milestone import MilestoneCommand
from .commands.setup import SetupCommand
from .commands.benchmark import BenchmarkCommand
from .commands.community import CommunityCommand
from .commands.dev import DevCommand
# Configure logging
logging.basicConfig(
@@ -66,6 +67,7 @@ class TinyTorchCLI:
'system': SystemCommand,
'module': ModuleWorkflowCommand,
# Developer tools
'dev': DevCommand,
'src': SrcCommand,
'package': PackageCommand,
'nbgrader': NBGraderCommand,
@@ -83,7 +85,7 @@ class TinyTorchCLI:
# Command categorization for help display
self.student_commands = ['module', 'milestones', 'community', 'benchmark']
self.developer_commands = ['system', 'src', 'package', 'nbgrader']
self.developer_commands = ['dev', 'system', 'src', 'package', 'nbgrader']
# Welcome screen sections (used for both tito and tito --help)
self.welcome_sections = {
@@ -280,7 +282,12 @@ class TinyTorchCLI:
self.config.no_color = True
# Show banner for interactive commands (except logo which has its own display)
if parsed_args.command and not self.config.no_color and parsed_args.command != 'logo':
# Skip banner for dev command with --json flag (CI/CD output)
skip_banner = (
parsed_args.command == 'logo' or
(parsed_args.command == 'dev' and hasattr(parsed_args, 'json') and parsed_args.json)
)
if parsed_args.command and not self.config.no_color and not skip_banner:
print_banner()
# Validate environment for most commands (skip for doctor)