2 Commits

Author SHA1 Message Date
Kai Kleinbard
277033d2f9 Merge pull request #16 from MLSysBook/feat/submission-login-flow-v2
updates to submission to server database and login flow
2025-11-30 21:05:32 -05:00
kai
c682ec2ee3 updates to submission to server database and login flow 2025-11-30 21:01:57 -05:00
5 changed files with 670 additions and 1 deletions

92
tito/commands/login.py Normal file
View File

@@ -0,0 +1,92 @@
# tito/commands/login.py
import webbrowser
import time
from argparse import ArgumentParser, Namespace
from rich.prompt import Confirm
from tito.commands.base import BaseCommand
from tito.core.auth import AuthReceiver, save_credentials, delete_credentials, ENDPOINTS, is_logged_in
class LoginCommand(BaseCommand):
@property
def name(self) -> str:
return "login"
@property
def description(self) -> str:
return "Log in to TinyTorch via web browser"
def add_arguments(self, parser: ArgumentParser) -> None:
parser.add_argument("--force", action="store_true", help="Force re-login")
def run(self, args: Namespace) -> int:
# Adapted logic from api.py
if args.force:
delete_credentials()
self.console.print("Cleared existing credentials.")
# Check if already logged in (unless force was used)
if is_logged_in():
self.console.print("[green]You are already logged in.[/green]")
if Confirm.ask("[bold yellow]Do you want to force re-login?[/bold yellow]", default=False):
delete_credentials()
self.console.print("Cleared existing credentials. Proceeding with new login...")
else:
self.console.print("Login cancelled.")
return 0
receiver = AuthReceiver()
try:
port = receiver.start()
target_url = f"{ENDPOINTS['cli_login']}?redirect_port={port}"
self.console.print(f"Opening browser to: [blue]{target_url}[/blue]")
self.console.print("Waiting for authentication...")
webbrowser.open(target_url)
tokens = receiver.wait_for_tokens()
if tokens:
save_credentials(tokens)
self.console.print(f"[green]Success! Logged in as {tokens['user_email']}[/green]")
return 0
else:
self.console.print("[red]Login timed out.[/red]")
return 1
except Exception as e:
self.console.print(f"[red]Error: {e}[/red]")
return 1
class LogoutCommand(BaseCommand):
@property
def name(self) -> str:
return "logout"
@property
def description(self) -> str:
return "Log out of TinyTorch by clearing stored credentials"
def add_arguments(self, parser: ArgumentParser) -> None:
pass # No arguments needed
def run(self, args: Namespace) -> int:
try:
# Start local server for logout redirect
receiver = AuthReceiver()
port = receiver.start()
# Open browser to local logout endpoint
logout_url = f"http://127.0.0.1:{port}/logout"
self.console.print(f"Opening browser to complete logout...")
webbrowser.open(logout_url)
# Give browser time to redirect and close
time.sleep(2.0)
# Clean up server
receiver.stop()
# Delete local credentials
delete_credentials()
self.console.print("[green]Successfully logged out of TinyTorch![/green]")
return 0
except Exception as e:
self.console.print(f"[red]Error during logout: {e}[/red]")
return 1

View File

@@ -16,6 +16,7 @@ from typing import Dict, Optional
from rich.panel import Panel
from rich.text import Text
from rich.progress import Progress, SpinnerColumn, TextColumn
from rich.prompt import Confirm
from ..base import BaseCommand
from ..view import ViewCommand
@@ -24,6 +25,8 @@ from ..export import ExportCommand
from .reset import ModuleResetCommand
from .test import ModuleTestCommand
from ...core.exceptions import ModuleNotFoundError
from ...core import auth
from ...core.submission import SubmissionHandler
class ModuleWorkflowCommand(BaseCommand):
"""Enhanced module command with natural workflow."""
@@ -521,9 +524,28 @@ class ModuleWorkflowCommand(BaseCommand):
# Step 5: Check for milestone unlocks
if success:
self._check_milestone_unlocks(module_name)
self._trigger_submission()
return 0 if success else 1
def _trigger_submission(self):
"""Asks the user to submit their progress if they are logged in."""
self.console.print() # Add a blank line for spacing
if auth.is_logged_in():
should_submit = Confirm.ask(
"[bold yellow]Would you like to sync your progress with the TinyTorch website?[/bold yellow]",
default=True
)
if should_submit:
handler = SubmissionHandler(self.config, self.console)
total_modules = len(self.get_module_mapping())
handler.sync_progress(total_modules=total_modules)
else:
self.console.print("[dim]💡 Run 'tito login' to enable automatic progress syncing![/dim]")
def run_module_tests(self, module_name: str, verbose: bool = True) -> int:
"""
Run comprehensive tests for a module:
@@ -1157,4 +1179,4 @@ class ModuleWorkflowCommand(BaseCommand):
pass
except Exception as e:
# Don't fail the workflow if milestone checking fails
self.console.print(f"[dim]Note: Could not check milestone unlocks: {e}[/dim]")
self.console.print(f"[dim]Note: Could not check milestone unlocks: {e}[/dim]")

320
tito/core/auth.py Normal file
View File

@@ -0,0 +1,320 @@
"""Simple secure JSON credentials storage system for TinyTorch CLI."""
from __future__ import annotations
import http.server
import threading
import json
import os
import time
import socket
import webbrowser
from pathlib import Path
from typing import Optional, Dict
from urllib.parse import urlparse, parse_qs
# --- Configuration Constants ---
API_BASE_URL = "https://tinytorch.netlify.app"
# API Endpoints
ENDPOINTS = {
"login": f"{API_BASE_URL}/api/auth/login",
"leaderboard": f"{API_BASE_URL}/api/leaderboard",
"submissions": f"{API_BASE_URL}/api/submissions",
"cli_login": f"{API_BASE_URL}/cli-login",
}
# Defaults
LOCAL_SERVER_HOST = "127.0.0.1"
AUTH_START_PORT = 54321
AUTH_PORT_HUNT_RANGE = 100
AUTH_CALLBACK_PATH = "/callback"
CREDENTIALS_FILE_NAME = "credentials.json"
# Determine credentials directory (Standard Python way)
CREDENTIALS_DIR = os.getenv("TINOTORCH_CREDENTIALS_DIR", str(Path.home() / ".tinytorch"))
# --- Storage Logic ---
def _credentials_dir() -> Path:
return Path(os.path.expanduser(CREDENTIALS_DIR))
def _credentials_path() -> Path:
return _credentials_dir() / CREDENTIALS_FILE_NAME
def _ensure_dir() -> None:
d = _credentials_dir()
d.mkdir(parents=True, exist_ok=True)
try:
os.chmod(d, 0o700)
except OSError:
pass
def save_credentials(data: Dict[str, str]) -> None:
"""Persist credentials to disk safely and atomically."""
_ensure_dir()
p = _credentials_path()
tmp = p.with_suffix(".tmp")
with tmp.open("w", encoding="utf-8") as f:
json.dump(data, f, indent=2)
f.flush()
os.fsync(f.fileno())
os.replace(str(tmp), str(p))
try:
os.chmod(p, 0o600)
except OSError:
pass
def load_credentials() -> Optional[Dict[str, str]]:
p = _credentials_path()
if not p.exists():
return None
try:
with p.open("r", encoding="utf-8") as f:
return json.load(f)
except (OSError, json.JSONDecodeError):
return None
def delete_credentials() -> None:
p = _credentials_path()
try:
p.unlink()
except OSError:
pass
# --- Public Auth Helpers ---
def get_token() -> Optional[str]:
"""Retrieve the access token if it exists."""
creds = load_credentials()
if creds:
return creds.get("access_token")
return None
def is_logged_in() -> bool:
"""Check if the user has valid credentials stored."""
return get_token() is not None
def get_user_email() -> Optional[str]:
"""Retrieve the user's email if it exists."""
creds = load_credentials()
if creds:
return creds.get("user_email")
return None
def get_refresh_token() -> Optional[str]:
"""Retrieve the refresh token if it exists."""
creds = load_credentials()
if creds:
return creds.get("refresh_token")
return None
def refresh_token(console: "Console") -> Optional[str]:
"""Refresh the access token. If refresh fails, clear credentials to force re-login."""
refresh_token_val = get_refresh_token()
if not refresh_token_val:
return None
import urllib.request
import urllib.error
import json
url = f"{API_BASE_URL}/api/auth/refresh"
data = {"refreshToken": refresh_token_val}
headers = {"Content-Type": "application/json"}
req = urllib.request.Request(
url,
data=json.dumps(data).encode('utf-8'),
headers=headers,
method="POST"
)
try:
with urllib.request.urlopen(req) as response:
if response.status == 200:
new_session = json.loads(response.read().decode('utf-8'))
# Handle nested session structures (adjust based on your actual API response)
# Some APIs return { session: { ... } }, others return { access_token: ... } direct
session_data = new_session.get('session', new_session)
if 'access_token' in session_data:
new_access_token = session_data['access_token']
# IMPORTANT: Always grab the new refresh token if the server rotates it
new_refresh_token = session_data.get('refresh_token', refresh_token_val)
creds = load_credentials() or {}
creds.update({
"access_token": new_access_token,
"refresh_token": new_refresh_token,
})
save_credentials(creds)
return new_access_token
else:
console.print("[red]Token refresh response is missing session data.[/red]")
return None
else:
console.print(f"[red]Token refresh failed with status: {response.status}[/red]")
return None
except urllib.error.HTTPError as e:
# --- CRITICAL FIX HERE ---
# If we get a 400 (Bad Request) or 401 (Unauthorized), the refresh token is dead.
# We must delete the credentials so the user is forced to log in again.
if e.code in [400, 401, 403]:
console.print("[yellow]Session expired. Please log in again.[/yellow]")
delete_credentials() # This deletes the JSON file
return None
console.print(f"[red]Token refresh failed (HTTP {e.code}): {e.reason}[/red]")
try:
error_body = e.read().decode('utf-8')
error_json = json.loads(error_body)
console.print(f" [dim red]Error details: {error_json.get('error', 'No description provided.')}[/dim red]")
except (json.JSONDecodeError, Exception):
pass
return None
except urllib.error.URLError as e:
console.print(f"[red]Token refresh failed (Network error): {e.reason}[/red]")
return None
# --- Auth Server Logic ---
class CallbackHandler(http.server.BaseHTTPRequestHandler):
def do_GET(self):
parsed_path = urlparse(self.path)
if parsed_path.path == "/logout":
self.send_response(302)
self.send_header('Location', f"{API_BASE_URL}/logout")
self.end_headers()
return
if parsed_path.path != AUTH_CALLBACK_PATH:
self.send_error(404, "Not Found")
return
query_params = parse_qs(parsed_path.query)
if 'access_token' in query_params and 'refresh_token' in query_params:
self.server.auth_data = {
'access_token': query_params['access_token'][0],
'refresh_token': query_params['refresh_token'][0],
'user_email': query_params.get('email', [''])[0]
}
self.send_response(200)
self.send_header('Content-type', 'text/html')
self.end_headers()
html_content = "<html><body><h1>🔥 Tinytorch <h1> <h2>Login Successful</h2><p>You can close this window and return to the CLI.</p><p><a href='https://tinytorch.netlify.app/dashboard'>Go to TinyTorch Dashboard</a></p><script>window.close()</script></body></html>"
self.wfile.write(html_content.encode('utf-8'))
self.wfile.flush()
# Persist immediately
try:
save_credentials(self.server.auth_data)
except Exception:
pass
else:
self.send_error(400, "Missing tokens in callback URL")
def log_message(self, format, *args):
pass
class LocalAuthServer(http.server.HTTPServer):
def __init__(self, server_address, RequestHandlerClass):
super().__init__(server_address, RequestHandlerClass)
self.auth_data: Optional[Dict[str, str]] = None
class AuthReceiver:
def __init__(self, start_port: int = None):
self.start_port = start_port if start_port is not None else AUTH_START_PORT
self.server: Optional[LocalAuthServer] = None
self.thread: Optional[threading.Thread] = None
self.port: int = 0
def start(self) -> int:
port = self.start_port
max_port = self.start_port + AUTH_PORT_HUNT_RANGE
while True:
try:
self.server = LocalAuthServer((LOCAL_SERVER_HOST, port), CallbackHandler)
self.port = self.server.server_address[1]
break
except OSError:
port += 1
if port > max_port:
raise Exception("Could not find an open port for authentication.")
def serve_with_error_handling():
try:
self.server.serve_forever()
except Exception:
pass
self.thread = threading.Thread(target=serve_with_error_handling, daemon=True)
self.thread.start()
time.sleep(0.2)
# Check if server is ready
max_wait = 2.0
waited = 0.0
server_ready = False
while waited < max_wait:
try:
if not self.thread.is_alive():
break
test_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
test_socket.settimeout(0.2)
result = test_socket.connect_ex((LOCAL_SERVER_HOST, self.port))
test_socket.close()
if result == 0:
server_ready = True
break
except Exception:
pass
time.sleep(0.1)
waited += 0.1
if not server_ready:
self.stop()
raise Exception(f"Server failed to start on port {self.port}")
return self.port
def wait_for_tokens(self, timeout: int = 120) -> Optional[Dict[str, str]]:
start_time = time.time()
try:
while getattr(self.server, "auth_data", None) is None:
if time.time() - start_time > timeout:
return None
time.sleep(0.25)
try:
save_credentials(self.server.auth_data)
except Exception:
pass
time.sleep(1.0)
return self.server.auth_data
finally:
self.stop()
def stop(self):
if self.server:
try:
self.server.shutdown()
self.server.server_close()
except Exception:
pass
if self.thread and self.thread.is_alive():
self.thread.join(timeout=1)

231
tito/core/submission.py Normal file
View File

@@ -0,0 +1,231 @@
"""
Handles data aggregation and submission to the Supabase Edge Function.
This version is refactored into a class-based handler that integrates
with the TinyTorch CLI's config and console objects, using only standard libraries.
"""
import json
import os
import urllib.request
import urllib.error
from pathlib import Path
from typing import Dict, Any, Optional
from rich.console import Console
from rich.table import Table
from rich import box
# Local import for auth handler
from . import auth
from .config import CLIConfig
class SubmissionError(Exception):
"""Custom exception for submission-related errors."""
pass
class SubmissionHandler:
"""
Handles assembling progress data and submitting it to a remote server.
"""
def __init__(self, config: CLIConfig, console: Console):
"""
Initialize the handler with CLI config and console.
Args:
config: The CLI configuration object.
console: The rich console for output.
"""
self.config = config
self.console = console
self.auth_handler = auth # Using the auth module directly for now
# TODO: In the future, the API endpoint could be made configurable via CLIConfig
self.edge_function_url = "https://zrvmjrxhokwwmjacyhpq.supabase.co/functions/v1/upload-progress"
# Derive paths from the project root in config
self.tito_dir = self.config.project_root / ".tito"
self.progress_file = self.config.project_root / "progress.json"
self.milestones_file = self.tito_dir / "milestones.json"
self.config_file = self.tito_dir / "config.json" # Though config is passed via CLIConfig
def _read_json_safe(self, path: Path) -> Dict[str, Any]:
"""Helper to read JSON files safely."""
if not path.exists():
return {}
try:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
except (json.JSONDecodeError, IOError) as e:
self.console.print(f"[yellow]Warning: Could not read {path}: {e}[/yellow]")
return {}
def _format_milestones(self, local_data: Dict) -> list:
"""Transforms local milestone storage format to the API array format."""
unlocked = local_data.get("unlocked_milestones", [])
completed = local_data.get("completed_milestones", [])
unlock_dates = local_data.get("unlock_dates", {})
completion_dates = local_data.get("completion_dates", {})
# You might want a lookup map for real names
milestone_names = {
"01": "1957: Perceptron",
"02": "1969: XOR Problem",
"03": "1986: MLP Revival",
"04": "1998: CNN Revolution",
"05": "2017: Transformer Era",
"06": "2018: MLPerf Benchmarking",
}
formatted = []
for m_id in unlocked:
formatted.append({
"id": m_id,
"name": milestone_names.get(m_id, f"Milestone {m_id}"),
"unlocked_at": unlock_dates.get(m_id),
"completed": m_id in completed,
"completed_at": completion_dates.get(m_id)
})
return formatted
def assemble_payload(self, total_modules: int = 20) -> Dict[str, Any]:
"""
Reads distinct local files and assembles the Unified Payload.
"""
progress_data = self._read_json_safe(self.progress_file)
milestone_data = self._read_json_safe(self.milestones_file)
completed_modules = progress_data.get("completed_modules", [])
payload = {
# user_id will be derived from the auth token on the backend,
# but we can send a placeholder if needed for schema validation.
"user_id": self.auth_handler.get_user_email() or "anonymous", # Using get_user_email from auth module
"timestamp": progress_data.get("last_updated", ""),
"version": "1.0",
"module_progress": {
"total_modules": total_modules,
"completed_count": len(completed_modules),
"completed_modules": completed_modules,
"completion_dates": progress_data.get("completion_dates", {}),
"completion_percentage": (len(completed_modules) / total_modules) * 100 if total_modules > 0 else 0,
},
"milestone_progress": {
"total_milestones": 6,
"unlocked_count": milestone_data.get("total_unlocked", 0),
"unlocked_milestones": self._format_milestones(milestone_data)
},
"statistics": {
"current_streak_days": progress_data.get("streak", 0)
}
}
return payload
def sync_progress(self, total_modules: int = 20, is_retry: bool = False) -> bool:
"""
Main public function to assemble data and upload it.
"""
token = self.auth_handler.get_token()
if not token:
self.console.print("❌ [bold red]You are not logged in.[/bold red] Please run 'tito login' first.")
return False
if not is_retry:
self.console.print("📦 Assembling local progress...")
try:
payload = self.assemble_payload(total_modules=total_modules)
if not is_retry:
self.console.print("Submitting payload:")
table = Table(show_header=False, box=box.MINIMAL, padding=(0, 1))
table.add_column("Field", style="dim")
table.add_column("Value")
table.add_row("User ID", payload['user_id'])
table.add_row("Timestamp", payload['timestamp'])
table.add_row("Version", payload['version'])
table.add_row("")
table.add_row("[bold]Module Progress[/bold]")
table.add_row(" Total Modules", str(payload['module_progress']['total_modules']))
table.add_row(" Completed", str(payload['module_progress']['completed_count']))
table.add_row(" Completed Modules", ", ".join(payload['module_progress']['completed_modules']))
table.add_row(" Completion %", f"{payload['module_progress']['completion_percentage']:.2f}%")
table.add_row("")
table.add_row("[bold]Milestone Progress[/bold]")
table.add_row(" Total Milestones", str(payload['milestone_progress']['total_milestones']))
table.add_row(" Unlocked", str(payload['milestone_progress']['unlocked_count']))
table.add_row("")
table.add_row("[bold]Statistics[/bold]")
table.add_row(" Current Streak", str(payload['statistics']['current_streak_days']))
self.console.print(table)
except Exception as e:
self.console.print(f"❌ [red]Error assembling payload: {e}[/red]")
return False
if not is_retry:
self.console.print("🚀 Syncing with TinyTorch Cloud...")
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json"
}
req = urllib.request.Request(
self.edge_function_url,
data=json.dumps(payload).encode('utf-8'),
headers=headers,
method="POST"
)
try:
with urllib.request.urlopen(req, timeout=15) as response: # Added timeout
if 200 <= response.status < 300:
resp_body = json.loads(response.read().decode('utf-8'))
self.console.print("✅ [bold green]Sync Successful![/bold green]")
self.console.print(f" Modules Synced: {resp_body.get('synced_modules', 'N/A')}")
return True
else:
self.console.print(f"⚠️ Server returned status: {response.status}")
# Try to read error message from response body
try:
error_resp = json.loads(response.read().decode('utf-8'))
self.console.print(f" [dim red]Error details: {error_resp.get('error', 'No message provided.')}[/dim red]")
except json.JSONDecodeError:
self.console.print(f" [dim red]Error details: {response.read().decode('utf-8')[:200]}...[/dim red]") # Truncate long body
return False
except urllib.error.HTTPError as e:
if e.code == 401 and not is_retry:
self.console.print("🔑 Token expired. Attempting to refresh...")
new_token = self.auth_handler.refresh_token(self.console)
if new_token:
self.console.print("✅ Token refreshed successfully. Retrying submission...")
return self.sync_progress(total_modules=total_modules, is_retry=True)
else:
self.console.print("❌ [bold red]Token refresh failed.[/bold red]")
self.console.print(" Run 'tito login --force' to refresh.")
return False
elif e.code == 401 and is_retry:
self.console.print("❌ [bold red]Unauthorized.[/bold red] Your session may have expired.")
self.console.print(" Run 'tito login --force' to refresh.")
return False
else:
self.console.print(f"❌ [red]Upload failed (HTTP {e.code}): {e.reason}[/red]")
try: # Attempt to read error body if available
error_body = e.read().decode('utf-8')
error_json = json.loads(error_body)
self.console.print(f" [dim red]Error details: {error_json.get('error', 'No message provided.')}[/dim red]")
except (json.JSONDecodeError, Exception):
self.console.print(f" [dim red]Error details: {error_body[:200]}...[/dim red]")
return False
except urllib.error.URLError as e:
self.console.print(f"❌ [red]Network error:[/red] Could not connect to the server.")
self.console.print(f" [dim]{e.reason}[/dim]")
return False
except TimeoutError:
self.console.print("❌ [red]Network error:[/red] Connection timed out.")
return False

View File

@@ -38,6 +38,7 @@ from .commands.milestone import MilestoneCommand
from .commands.setup import SetupCommand
from .commands.benchmark import BenchmarkCommand
from .commands.community import CommunityCommand
from .commands.login import LoginCommand, LogoutCommand
# Configure logging
logging.basicConfig(
@@ -79,6 +80,9 @@ class TinyTorchCLI:
'test': TestCommand,
'grade': GradeCommand,
'logo': LogoCommand,
# Authentication commands
'login': LoginCommand,
'logout': LogoutCommand,
}
def create_parser(self) -> argparse.ArgumentParser: