mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-04-29 16:22:36 -05:00
Replace plain argparse help with custom Rich-formatted output: - Intercept --help flag before argparse processes it - Show ASCII logo at top (consistent with tito alone) - Display commands in beautiful Rich table with colors - Reuse _generate_welcome_text() for consistency - Add formatted global options section Benefits: - Consistent Rich formatting across tito and tito --help - Professional table layout for commands (green names, dim descriptions) - Color-coded sections (cyan headers, yellow options, green commands) - Same visual experience whether you run tito or tito --help - All command info dynamically pulled from self.commands dict Before: Plain black-and-white argparse output After: Colorful, formatted Rich output with ASCII logo and tables
337 lines
12 KiB
Python
337 lines
12 KiB
Python
"""
|
|
TinyTorch CLI Main Entry Point
|
|
|
|
A professional command-line interface with proper architecture:
|
|
- Clean separation of concerns
|
|
- Proper error handling
|
|
- Logging support
|
|
- Configuration management
|
|
- Extensible command system
|
|
"""
|
|
|
|
import argparse
|
|
import logging
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Dict, Type, Optional, List
|
|
|
|
# Set TINYTORCH_QUIET before any tinytorch imports to suppress autograd messages
|
|
os.environ['TINYTORCH_QUIET'] = '1'
|
|
|
|
from .core.config import CLIConfig
|
|
from .core.virtual_env_manager import get_venv_path
|
|
from .core.console import get_console, print_banner, print_error, print_ascii_logo
|
|
from .core.exceptions import TinyTorchCLIError
|
|
from rich.panel import Panel
|
|
from .commands.base import BaseCommand
|
|
from .commands.test import TestCommand
|
|
from .commands.export import ExportCommand
|
|
from .commands.src import SrcCommand
|
|
from .commands.system import SystemCommand
|
|
from .commands.module import ModuleWorkflowCommand
|
|
from .commands.package import PackageCommand
|
|
from .commands.nbgrader import NBGraderCommand
|
|
from .commands.grade import GradeCommand
|
|
from .commands.logo import LogoCommand
|
|
from .commands.milestone import MilestoneCommand
|
|
from .commands.setup import SetupCommand
|
|
from .commands.benchmark import BenchmarkCommand
|
|
from .commands.community import CommunityCommand
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler('tito-cli.log'),
|
|
logging.StreamHandler(sys.stderr)
|
|
]
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class TinyTorchCLI:
|
|
"""Main CLI application class."""
|
|
|
|
def __init__(self):
|
|
"""Initialize the CLI application."""
|
|
self.config = CLIConfig.from_project_root()
|
|
self.console = get_console()
|
|
# SINGLE SOURCE OF TRUTH: All valid commands registered here
|
|
self.commands: Dict[str, Type[BaseCommand]] = {
|
|
# Essential
|
|
'setup': SetupCommand,
|
|
# Workflow (student-facing)
|
|
'system': SystemCommand,
|
|
'module': ModuleWorkflowCommand,
|
|
# Developer tools
|
|
'src': SrcCommand,
|
|
'package': PackageCommand,
|
|
'nbgrader': NBGraderCommand,
|
|
# Progress tracking
|
|
'milestones': MilestoneCommand,
|
|
# Community
|
|
'community': CommunityCommand,
|
|
'benchmark': BenchmarkCommand,
|
|
# Shortcuts
|
|
'export': ExportCommand,
|
|
'test': TestCommand,
|
|
'grade': GradeCommand,
|
|
'logo': LogoCommand,
|
|
}
|
|
|
|
# Command categorization for help display
|
|
self.student_commands = ['module', 'milestones', 'community', 'benchmark']
|
|
self.developer_commands = ['system', 'src', 'package', 'nbgrader']
|
|
|
|
# Welcome screen sections (used for both tito and tito --help)
|
|
self.welcome_sections = {
|
|
'quick_start': [
|
|
('[green]tito setup[/green]', 'First-time setup'),
|
|
('[green]tito module start 01[/green]', 'Start Module 01 (tensors)'),
|
|
('[green]tito module complete 01[/green]', 'Test, export, and track progress'),
|
|
],
|
|
'track_progress': [
|
|
('[yellow]tito module status[/yellow]', 'View module progress'),
|
|
('[yellow]tito milestones status[/yellow]', 'View unlocked capabilities'),
|
|
],
|
|
'community': [
|
|
('[cyan]tito community login[/cyan]', 'Log in to TinyTorch'),
|
|
('[cyan]tito community leaderboard[/cyan]', 'View global leaderboard'),
|
|
],
|
|
'help_docs': [
|
|
('[magenta]tito system doctor[/magenta]', 'Check environment health'),
|
|
('[magenta]tito --help[/magenta]', 'See all commands'),
|
|
]
|
|
}
|
|
|
|
def _generate_welcome_text(self) -> str:
|
|
"""Generate dynamic welcome text for interactive mode."""
|
|
lines = []
|
|
|
|
# Quick Start
|
|
lines.append("[bold cyan]Quick Start:[/bold cyan]")
|
|
for cmd, desc in self.welcome_sections['quick_start']:
|
|
lines.append(f" {cmd:<38} {desc}")
|
|
|
|
# Track Progress
|
|
lines.append("\n[bold cyan]Track Progress:[/bold cyan]")
|
|
for cmd, desc in self.welcome_sections['track_progress']:
|
|
lines.append(f" {cmd:<38} {desc}")
|
|
|
|
# Community
|
|
lines.append("\n[bold cyan]Community:[/bold cyan]")
|
|
for cmd, desc in self.welcome_sections['community']:
|
|
lines.append(f" {cmd:<38} {desc}")
|
|
|
|
# Help & Docs
|
|
lines.append("\n[bold cyan]Help & Docs:[/bold cyan]")
|
|
for cmd, desc in self.welcome_sections['help_docs']:
|
|
lines.append(f" {cmd:<38} {desc}")
|
|
|
|
return "\n".join(lines)
|
|
|
|
def _generate_epilog(self) -> str:
|
|
"""Generate dynamic epilog from registered commands."""
|
|
lines = []
|
|
|
|
# Student Commands section
|
|
lines.append("Student Commands:")
|
|
for cmd_name in self.student_commands:
|
|
if cmd_name in self.commands:
|
|
cmd = self.commands[cmd_name](self.config)
|
|
# Simplify description for epilog (first sentence or shorter version)
|
|
desc = cmd.description.split('.')[0].split('-')[0].strip()
|
|
lines.append(f" {cmd_name:<12} {desc}")
|
|
lines.append("")
|
|
|
|
# Developer Commands section
|
|
lines.append("Developer Commands:")
|
|
for cmd_name in self.developer_commands:
|
|
if cmd_name in self.commands:
|
|
cmd = self.commands[cmd_name](self.config)
|
|
desc = cmd.description.split('.')[0].split('-')[0].strip()
|
|
lines.append(f" {cmd_name:<12} {desc}")
|
|
lines.append("")
|
|
|
|
# Quick Start section (strip Rich formatting for plain text)
|
|
lines.append("Quick Start:")
|
|
for cmd, desc in self.welcome_sections['quick_start']:
|
|
# Remove Rich color tags for plain epilog
|
|
plain_cmd = cmd.replace('[green]', '').replace('[/green]', '')
|
|
lines.append(f" {plain_cmd:<28} {desc}")
|
|
|
|
return "\n".join(lines)
|
|
|
|
def create_parser(self) -> argparse.ArgumentParser:
|
|
"""Create the main argument parser."""
|
|
parser = argparse.ArgumentParser(
|
|
prog="tito",
|
|
description="Tiny🔥Torch CLI - Build ML systems from scratch",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog=self._generate_epilog()
|
|
)
|
|
|
|
# Global options
|
|
parser.add_argument(
|
|
'--version',
|
|
action='version',
|
|
version='Tiny🔥Torch CLI 0.1.0'
|
|
)
|
|
parser.add_argument(
|
|
'--verbose', '-v',
|
|
action='store_true',
|
|
help='Enable verbose output'
|
|
)
|
|
parser.add_argument(
|
|
'--no-color',
|
|
action='store_true',
|
|
help='Disable colored output'
|
|
)
|
|
|
|
# Subcommands
|
|
subparsers = parser.add_subparsers(
|
|
dest='command',
|
|
help='Available commands',
|
|
metavar='COMMAND'
|
|
)
|
|
|
|
# Add command parsers
|
|
for command_name, command_class in self.commands.items():
|
|
# Create temporary instance to get metadata
|
|
temp_command = command_class(self.config)
|
|
cmd_parser = subparsers.add_parser(
|
|
command_name,
|
|
help=temp_command.description
|
|
)
|
|
temp_command.add_arguments(cmd_parser)
|
|
|
|
return parser
|
|
|
|
def validate_environment(self) -> bool:
|
|
"""Validate the environment and show issues if any."""
|
|
issues = self.config.validate(get_venv_path())
|
|
|
|
if issues:
|
|
print_error(
|
|
"Environment validation failed:\n" + "\n".join(f" • {issue}" for issue in issues),
|
|
"Environment Issues"
|
|
)
|
|
self.console.print("\n[dim]Run 'tito doctor' for detailed diagnosis[/dim]")
|
|
# Return True to allow command execution despite validation issues
|
|
# This is temporary for development
|
|
return True
|
|
|
|
return True
|
|
|
|
def _show_help(self) -> int:
|
|
"""Show custom Rich-formatted help."""
|
|
from rich.table import Table
|
|
|
|
# Show ASCII logo
|
|
print_ascii_logo()
|
|
|
|
# Create commands table
|
|
table = Table(show_header=True, header_style="bold cyan", box=None, padding=(0, 2))
|
|
table.add_column("Command", style="green", width=15)
|
|
table.add_column("Description", style="dim")
|
|
|
|
# Add all commands dynamically
|
|
for cmd_name, cmd_class in self.commands.items():
|
|
cmd = cmd_class(self.config)
|
|
table.add_row(cmd_name, cmd.description)
|
|
|
|
self.console.print()
|
|
self.console.print("[bold cyan]Tiny🔥Torch CLI[/bold cyan] - Build ML systems from scratch")
|
|
self.console.print()
|
|
self.console.print("[bold]Usage:[/bold] [cyan]tito[/cyan] [yellow]COMMAND[/yellow] [dim][OPTIONS][/dim]")
|
|
self.console.print()
|
|
self.console.print("[bold cyan]Available Commands:[/bold cyan]")
|
|
self.console.print(table)
|
|
self.console.print()
|
|
self.console.print(self._generate_welcome_text())
|
|
self.console.print()
|
|
self.console.print("[bold cyan]Global Options:[/bold cyan]")
|
|
self.console.print(" [yellow]--help, -h[/yellow] Show this help message")
|
|
self.console.print(" [yellow]--version[/yellow] Show version number")
|
|
self.console.print(" [yellow]--verbose, -v[/yellow] Enable verbose output")
|
|
self.console.print(" [yellow]--no-color[/yellow] Disable colored output")
|
|
self.console.print()
|
|
|
|
return 0
|
|
|
|
def run(self, args: Optional[List[str]] = None) -> int:
|
|
"""Run the CLI application."""
|
|
try:
|
|
# Check for help flag before argparse to use Rich formatting
|
|
if args and ('-h' in args or '--help' in args) and len(args) == 1:
|
|
return self._show_help()
|
|
|
|
parser = self.create_parser()
|
|
parsed_args = parser.parse_args(args)
|
|
|
|
# Update config with global options
|
|
if hasattr(parsed_args, 'verbose') and parsed_args.verbose:
|
|
self.config.verbose = True
|
|
logging.getLogger().setLevel(logging.DEBUG)
|
|
|
|
if hasattr(parsed_args, 'no_color') and parsed_args.no_color:
|
|
self.config.no_color = True
|
|
|
|
# Show banner for interactive commands (except logo which has its own display)
|
|
if parsed_args.command and not self.config.no_color and parsed_args.command != 'logo':
|
|
print_banner()
|
|
|
|
# Validate environment for most commands (skip for doctor)
|
|
skip_validation = (
|
|
parsed_args.command in [None, 'version', 'help'] or
|
|
(parsed_args.command == 'system' and
|
|
hasattr(parsed_args, 'system_command') and
|
|
parsed_args.system_command == 'doctor')
|
|
)
|
|
if not skip_validation:
|
|
if not self.validate_environment():
|
|
return 1
|
|
|
|
# Handle no command
|
|
if not parsed_args.command:
|
|
# Show ASCII logo first
|
|
print_ascii_logo()
|
|
|
|
# Generate dynamic welcome message
|
|
self.console.print(Panel(
|
|
self._generate_welcome_text(),
|
|
title="Welcome to Tiny🔥Torch!",
|
|
border_style="bright_green"
|
|
))
|
|
return 0
|
|
|
|
# Execute command
|
|
if parsed_args.command in self.commands:
|
|
command_class = self.commands[parsed_args.command]
|
|
command = command_class(self.config)
|
|
return command.execute(parsed_args)
|
|
else:
|
|
print_error(f"Unknown command: {parsed_args.command}")
|
|
return 1
|
|
|
|
except KeyboardInterrupt:
|
|
self.console.print("\n[yellow]Operation cancelled by user[/yellow]")
|
|
return 130
|
|
except TinyTorchCLIError as e:
|
|
logger.error(f"CLI error: {e}")
|
|
print_error(str(e))
|
|
return 1
|
|
except Exception as e:
|
|
logger.exception("Unexpected error in CLI")
|
|
print_error(f"Unexpected error: {e}")
|
|
return 1
|
|
|
|
def main() -> int:
|
|
"""Main entry point for the CLI."""
|
|
cli = TinyTorchCLI()
|
|
return cli.run(sys.argv[1:])
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main()) |