mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2025-12-05 19:17:52 -06:00
Compare commits
2 Commits
d82e34e51a
...
277033d2f9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
277033d2f9 | ||
|
|
c682ec2ee3 |
92
tito/commands/login.py
Normal file
92
tito/commands/login.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# tito/commands/login.py
|
||||
import webbrowser
|
||||
import time
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from rich.prompt import Confirm
|
||||
from tito.commands.base import BaseCommand
|
||||
from tito.core.auth import AuthReceiver, save_credentials, delete_credentials, ENDPOINTS, is_logged_in
|
||||
|
||||
class LoginCommand(BaseCommand):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "login"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Log in to TinyTorch via web browser"
|
||||
|
||||
def add_arguments(self, parser: ArgumentParser) -> None:
|
||||
parser.add_argument("--force", action="store_true", help="Force re-login")
|
||||
|
||||
def run(self, args: Namespace) -> int:
|
||||
# Adapted logic from api.py
|
||||
if args.force:
|
||||
delete_credentials()
|
||||
self.console.print("Cleared existing credentials.")
|
||||
|
||||
# Check if already logged in (unless force was used)
|
||||
if is_logged_in():
|
||||
self.console.print("[green]You are already logged in.[/green]")
|
||||
if Confirm.ask("[bold yellow]Do you want to force re-login?[/bold yellow]", default=False):
|
||||
delete_credentials()
|
||||
self.console.print("Cleared existing credentials. Proceeding with new login...")
|
||||
else:
|
||||
self.console.print("Login cancelled.")
|
||||
return 0
|
||||
|
||||
receiver = AuthReceiver()
|
||||
try:
|
||||
port = receiver.start()
|
||||
target_url = f"{ENDPOINTS['cli_login']}?redirect_port={port}"
|
||||
self.console.print(f"Opening browser to: [blue]{target_url}[/blue]")
|
||||
self.console.print("Waiting for authentication...")
|
||||
webbrowser.open(target_url)
|
||||
tokens = receiver.wait_for_tokens()
|
||||
if tokens:
|
||||
save_credentials(tokens)
|
||||
self.console.print(f"[green]Success! Logged in as {tokens['user_email']}[/green]")
|
||||
return 0
|
||||
else:
|
||||
self.console.print("[red]Login timed out.[/red]")
|
||||
return 1
|
||||
except Exception as e:
|
||||
self.console.print(f"[red]Error: {e}[/red]")
|
||||
return 1
|
||||
|
||||
|
||||
class LogoutCommand(BaseCommand):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "logout"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Log out of TinyTorch by clearing stored credentials"
|
||||
|
||||
def add_arguments(self, parser: ArgumentParser) -> None:
|
||||
pass # No arguments needed
|
||||
|
||||
def run(self, args: Namespace) -> int:
|
||||
try:
|
||||
# Start local server for logout redirect
|
||||
receiver = AuthReceiver()
|
||||
port = receiver.start()
|
||||
|
||||
# Open browser to local logout endpoint
|
||||
logout_url = f"http://127.0.0.1:{port}/logout"
|
||||
self.console.print(f"Opening browser to complete logout...")
|
||||
webbrowser.open(logout_url)
|
||||
|
||||
# Give browser time to redirect and close
|
||||
time.sleep(2.0)
|
||||
|
||||
# Clean up server
|
||||
receiver.stop()
|
||||
|
||||
# Delete local credentials
|
||||
delete_credentials()
|
||||
self.console.print("[green]Successfully logged out of TinyTorch![/green]")
|
||||
return 0
|
||||
except Exception as e:
|
||||
self.console.print(f"[red]Error during logout: {e}[/red]")
|
||||
return 1
|
||||
@@ -16,6 +16,7 @@ from typing import Dict, Optional
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||
from rich.prompt import Confirm
|
||||
|
||||
from ..base import BaseCommand
|
||||
from ..view import ViewCommand
|
||||
@@ -24,6 +25,8 @@ from ..export import ExportCommand
|
||||
from .reset import ModuleResetCommand
|
||||
from .test import ModuleTestCommand
|
||||
from ...core.exceptions import ModuleNotFoundError
|
||||
from ...core import auth
|
||||
from ...core.submission import SubmissionHandler
|
||||
|
||||
class ModuleWorkflowCommand(BaseCommand):
|
||||
"""Enhanced module command with natural workflow."""
|
||||
@@ -521,9 +524,28 @@ class ModuleWorkflowCommand(BaseCommand):
|
||||
# Step 5: Check for milestone unlocks
|
||||
if success:
|
||||
self._check_milestone_unlocks(module_name)
|
||||
self._trigger_submission()
|
||||
|
||||
|
||||
return 0 if success else 1
|
||||
|
||||
def _trigger_submission(self):
|
||||
"""Asks the user to submit their progress if they are logged in."""
|
||||
self.console.print() # Add a blank line for spacing
|
||||
|
||||
if auth.is_logged_in():
|
||||
should_submit = Confirm.ask(
|
||||
"[bold yellow]Would you like to sync your progress with the TinyTorch website?[/bold yellow]",
|
||||
default=True
|
||||
)
|
||||
if should_submit:
|
||||
handler = SubmissionHandler(self.config, self.console)
|
||||
total_modules = len(self.get_module_mapping())
|
||||
handler.sync_progress(total_modules=total_modules)
|
||||
else:
|
||||
self.console.print("[dim]💡 Run 'tito login' to enable automatic progress syncing![/dim]")
|
||||
|
||||
|
||||
def run_module_tests(self, module_name: str, verbose: bool = True) -> int:
|
||||
"""
|
||||
Run comprehensive tests for a module:
|
||||
@@ -1157,4 +1179,4 @@ class ModuleWorkflowCommand(BaseCommand):
|
||||
pass
|
||||
except Exception as e:
|
||||
# Don't fail the workflow if milestone checking fails
|
||||
self.console.print(f"[dim]Note: Could not check milestone unlocks: {e}[/dim]")
|
||||
self.console.print(f"[dim]Note: Could not check milestone unlocks: {e}[/dim]")
|
||||
320
tito/core/auth.py
Normal file
320
tito/core/auth.py
Normal file
@@ -0,0 +1,320 @@
|
||||
"""Simple secure JSON credentials storage system for TinyTorch CLI."""
|
||||
from __future__ import annotations
|
||||
|
||||
import http.server
|
||||
import threading
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import socket
|
||||
import webbrowser
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
|
||||
# --- Configuration Constants ---
|
||||
API_BASE_URL = "https://tinytorch.netlify.app"
|
||||
|
||||
# API Endpoints
|
||||
ENDPOINTS = {
|
||||
"login": f"{API_BASE_URL}/api/auth/login",
|
||||
"leaderboard": f"{API_BASE_URL}/api/leaderboard",
|
||||
"submissions": f"{API_BASE_URL}/api/submissions",
|
||||
"cli_login": f"{API_BASE_URL}/cli-login",
|
||||
}
|
||||
|
||||
# Defaults
|
||||
LOCAL_SERVER_HOST = "127.0.0.1"
|
||||
AUTH_START_PORT = 54321
|
||||
AUTH_PORT_HUNT_RANGE = 100
|
||||
AUTH_CALLBACK_PATH = "/callback"
|
||||
CREDENTIALS_FILE_NAME = "credentials.json"
|
||||
|
||||
# Determine credentials directory (Standard Python way)
|
||||
CREDENTIALS_DIR = os.getenv("TINOTORCH_CREDENTIALS_DIR", str(Path.home() / ".tinytorch"))
|
||||
|
||||
|
||||
# --- Storage Logic ---
|
||||
|
||||
def _credentials_dir() -> Path:
|
||||
return Path(os.path.expanduser(CREDENTIALS_DIR))
|
||||
|
||||
def _credentials_path() -> Path:
|
||||
return _credentials_dir() / CREDENTIALS_FILE_NAME
|
||||
|
||||
def _ensure_dir() -> None:
|
||||
d = _credentials_dir()
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
os.chmod(d, 0o700)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def save_credentials(data: Dict[str, str]) -> None:
|
||||
"""Persist credentials to disk safely and atomically."""
|
||||
_ensure_dir()
|
||||
p = _credentials_path()
|
||||
tmp = p.with_suffix(".tmp")
|
||||
with tmp.open("w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
os.replace(str(tmp), str(p))
|
||||
try:
|
||||
os.chmod(p, 0o600)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def load_credentials() -> Optional[Dict[str, str]]:
|
||||
p = _credentials_path()
|
||||
if not p.exists():
|
||||
return None
|
||||
try:
|
||||
with p.open("r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
def delete_credentials() -> None:
|
||||
p = _credentials_path()
|
||||
try:
|
||||
p.unlink()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# --- Public Auth Helpers ---
|
||||
|
||||
def get_token() -> Optional[str]:
|
||||
"""Retrieve the access token if it exists."""
|
||||
creds = load_credentials()
|
||||
if creds:
|
||||
return creds.get("access_token")
|
||||
return None
|
||||
|
||||
def is_logged_in() -> bool:
|
||||
"""Check if the user has valid credentials stored."""
|
||||
return get_token() is not None
|
||||
|
||||
def get_user_email() -> Optional[str]:
|
||||
"""Retrieve the user's email if it exists."""
|
||||
creds = load_credentials()
|
||||
if creds:
|
||||
return creds.get("user_email")
|
||||
return None
|
||||
|
||||
|
||||
def get_refresh_token() -> Optional[str]:
|
||||
"""Retrieve the refresh token if it exists."""
|
||||
creds = load_credentials()
|
||||
if creds:
|
||||
return creds.get("refresh_token")
|
||||
return None
|
||||
|
||||
def refresh_token(console: "Console") -> Optional[str]:
|
||||
"""Refresh the access token. If refresh fails, clear credentials to force re-login."""
|
||||
refresh_token_val = get_refresh_token()
|
||||
if not refresh_token_val:
|
||||
return None
|
||||
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
import json
|
||||
|
||||
url = f"{API_BASE_URL}/api/auth/refresh"
|
||||
data = {"refreshToken": refresh_token_val}
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
req = urllib.request.Request(
|
||||
url,
|
||||
data=json.dumps(data).encode('utf-8'),
|
||||
headers=headers,
|
||||
method="POST"
|
||||
)
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(req) as response:
|
||||
if response.status == 200:
|
||||
new_session = json.loads(response.read().decode('utf-8'))
|
||||
|
||||
# Handle nested session structures (adjust based on your actual API response)
|
||||
# Some APIs return { session: { ... } }, others return { access_token: ... } direct
|
||||
session_data = new_session.get('session', new_session)
|
||||
|
||||
if 'access_token' in session_data:
|
||||
new_access_token = session_data['access_token']
|
||||
# IMPORTANT: Always grab the new refresh token if the server rotates it
|
||||
new_refresh_token = session_data.get('refresh_token', refresh_token_val)
|
||||
|
||||
creds = load_credentials() or {}
|
||||
creds.update({
|
||||
"access_token": new_access_token,
|
||||
"refresh_token": new_refresh_token,
|
||||
})
|
||||
save_credentials(creds)
|
||||
return new_access_token
|
||||
else:
|
||||
console.print("[red]Token refresh response is missing session data.[/red]")
|
||||
return None
|
||||
else:
|
||||
console.print(f"[red]Token refresh failed with status: {response.status}[/red]")
|
||||
return None
|
||||
|
||||
except urllib.error.HTTPError as e:
|
||||
# --- CRITICAL FIX HERE ---
|
||||
# If we get a 400 (Bad Request) or 401 (Unauthorized), the refresh token is dead.
|
||||
# We must delete the credentials so the user is forced to log in again.
|
||||
if e.code in [400, 401, 403]:
|
||||
console.print("[yellow]Session expired. Please log in again.[/yellow]")
|
||||
delete_credentials() # This deletes the JSON file
|
||||
return None
|
||||
|
||||
console.print(f"[red]Token refresh failed (HTTP {e.code}): {e.reason}[/red]")
|
||||
try:
|
||||
error_body = e.read().decode('utf-8')
|
||||
error_json = json.loads(error_body)
|
||||
console.print(f" [dim red]Error details: {error_json.get('error', 'No description provided.')}[/dim red]")
|
||||
except (json.JSONDecodeError, Exception):
|
||||
pass
|
||||
return None
|
||||
|
||||
except urllib.error.URLError as e:
|
||||
console.print(f"[red]Token refresh failed (Network error): {e.reason}[/red]")
|
||||
return None
|
||||
|
||||
# --- Auth Server Logic ---
|
||||
|
||||
class CallbackHandler(http.server.BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
parsed_path = urlparse(self.path)
|
||||
|
||||
if parsed_path.path == "/logout":
|
||||
self.send_response(302)
|
||||
self.send_header('Location', f"{API_BASE_URL}/logout")
|
||||
self.end_headers()
|
||||
return
|
||||
|
||||
if parsed_path.path != AUTH_CALLBACK_PATH:
|
||||
self.send_error(404, "Not Found")
|
||||
return
|
||||
|
||||
query_params = parse_qs(parsed_path.query)
|
||||
|
||||
if 'access_token' in query_params and 'refresh_token' in query_params:
|
||||
self.server.auth_data = {
|
||||
'access_token': query_params['access_token'][0],
|
||||
'refresh_token': query_params['refresh_token'][0],
|
||||
'user_email': query_params.get('email', [''])[0]
|
||||
}
|
||||
|
||||
self.send_response(200)
|
||||
self.send_header('Content-type', 'text/html')
|
||||
self.end_headers()
|
||||
|
||||
html_content = "<html><body><h1>🔥 Tinytorch <h1> <h2>Login Successful</h2><p>You can close this window and return to the CLI.</p><p><a href='https://tinytorch.netlify.app/dashboard'>Go to TinyTorch Dashboard</a></p><script>window.close()</script></body></html>"
|
||||
|
||||
self.wfile.write(html_content.encode('utf-8'))
|
||||
self.wfile.flush()
|
||||
|
||||
# Persist immediately
|
||||
try:
|
||||
save_credentials(self.server.auth_data)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
self.send_error(400, "Missing tokens in callback URL")
|
||||
|
||||
def log_message(self, format, *args):
|
||||
pass
|
||||
|
||||
class LocalAuthServer(http.server.HTTPServer):
|
||||
def __init__(self, server_address, RequestHandlerClass):
|
||||
super().__init__(server_address, RequestHandlerClass)
|
||||
self.auth_data: Optional[Dict[str, str]] = None
|
||||
|
||||
class AuthReceiver:
|
||||
def __init__(self, start_port: int = None):
|
||||
self.start_port = start_port if start_port is not None else AUTH_START_PORT
|
||||
self.server: Optional[LocalAuthServer] = None
|
||||
self.thread: Optional[threading.Thread] = None
|
||||
self.port: int = 0
|
||||
|
||||
def start(self) -> int:
|
||||
port = self.start_port
|
||||
max_port = self.start_port + AUTH_PORT_HUNT_RANGE
|
||||
|
||||
while True:
|
||||
try:
|
||||
self.server = LocalAuthServer((LOCAL_SERVER_HOST, port), CallbackHandler)
|
||||
self.port = self.server.server_address[1]
|
||||
break
|
||||
except OSError:
|
||||
port += 1
|
||||
if port > max_port:
|
||||
raise Exception("Could not find an open port for authentication.")
|
||||
|
||||
def serve_with_error_handling():
|
||||
try:
|
||||
self.server.serve_forever()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.thread = threading.Thread(target=serve_with_error_handling, daemon=True)
|
||||
self.thread.start()
|
||||
time.sleep(0.2)
|
||||
|
||||
# Check if server is ready
|
||||
max_wait = 2.0
|
||||
waited = 0.0
|
||||
server_ready = False
|
||||
|
||||
while waited < max_wait:
|
||||
try:
|
||||
if not self.thread.is_alive():
|
||||
break
|
||||
|
||||
test_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
test_socket.settimeout(0.2)
|
||||
result = test_socket.connect_ex((LOCAL_SERVER_HOST, self.port))
|
||||
test_socket.close()
|
||||
if result == 0:
|
||||
server_ready = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(0.1)
|
||||
waited += 0.1
|
||||
|
||||
if not server_ready:
|
||||
self.stop()
|
||||
raise Exception(f"Server failed to start on port {self.port}")
|
||||
|
||||
return self.port
|
||||
|
||||
def wait_for_tokens(self, timeout: int = 120) -> Optional[Dict[str, str]]:
|
||||
start_time = time.time()
|
||||
try:
|
||||
while getattr(self.server, "auth_data", None) is None:
|
||||
if time.time() - start_time > timeout:
|
||||
return None
|
||||
time.sleep(0.25)
|
||||
|
||||
try:
|
||||
save_credentials(self.server.auth_data)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time.sleep(1.0)
|
||||
|
||||
return self.server.auth_data
|
||||
finally:
|
||||
self.stop()
|
||||
|
||||
def stop(self):
|
||||
if self.server:
|
||||
try:
|
||||
self.server.shutdown()
|
||||
self.server.server_close()
|
||||
except Exception:
|
||||
pass
|
||||
if self.thread and self.thread.is_alive():
|
||||
self.thread.join(timeout=1)
|
||||
231
tito/core/submission.py
Normal file
231
tito/core/submission.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""
|
||||
Handles data aggregation and submission to the Supabase Edge Function.
|
||||
|
||||
This version is refactored into a class-based handler that integrates
|
||||
with the TinyTorch CLI's config and console objects, using only standard libraries.
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from rich import box
|
||||
|
||||
# Local import for auth handler
|
||||
from . import auth
|
||||
from .config import CLIConfig
|
||||
|
||||
class SubmissionError(Exception):
|
||||
"""Custom exception for submission-related errors."""
|
||||
pass
|
||||
|
||||
class SubmissionHandler:
|
||||
"""
|
||||
Handles assembling progress data and submitting it to a remote server.
|
||||
"""
|
||||
|
||||
def __init__(self, config: CLIConfig, console: Console):
|
||||
"""
|
||||
Initialize the handler with CLI config and console.
|
||||
|
||||
Args:
|
||||
config: The CLI configuration object.
|
||||
console: The rich console for output.
|
||||
"""
|
||||
self.config = config
|
||||
self.console = console
|
||||
self.auth_handler = auth # Using the auth module directly for now
|
||||
|
||||
# TODO: In the future, the API endpoint could be made configurable via CLIConfig
|
||||
self.edge_function_url = "https://zrvmjrxhokwwmjacyhpq.supabase.co/functions/v1/upload-progress"
|
||||
|
||||
# Derive paths from the project root in config
|
||||
self.tito_dir = self.config.project_root / ".tito"
|
||||
self.progress_file = self.config.project_root / "progress.json"
|
||||
self.milestones_file = self.tito_dir / "milestones.json"
|
||||
self.config_file = self.tito_dir / "config.json" # Though config is passed via CLIConfig
|
||||
|
||||
def _read_json_safe(self, path: Path) -> Dict[str, Any]:
|
||||
"""Helper to read JSON files safely."""
|
||||
if not path.exists():
|
||||
return {}
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
self.console.print(f"[yellow]Warning: Could not read {path}: {e}[/yellow]")
|
||||
return {}
|
||||
|
||||
def _format_milestones(self, local_data: Dict) -> list:
|
||||
"""Transforms local milestone storage format to the API array format."""
|
||||
unlocked = local_data.get("unlocked_milestones", [])
|
||||
completed = local_data.get("completed_milestones", [])
|
||||
unlock_dates = local_data.get("unlock_dates", {})
|
||||
completion_dates = local_data.get("completion_dates", {})
|
||||
|
||||
# You might want a lookup map for real names
|
||||
milestone_names = {
|
||||
"01": "1957: Perceptron",
|
||||
"02": "1969: XOR Problem",
|
||||
"03": "1986: MLP Revival",
|
||||
"04": "1998: CNN Revolution",
|
||||
"05": "2017: Transformer Era",
|
||||
"06": "2018: MLPerf Benchmarking",
|
||||
}
|
||||
|
||||
formatted = []
|
||||
for m_id in unlocked:
|
||||
formatted.append({
|
||||
"id": m_id,
|
||||
"name": milestone_names.get(m_id, f"Milestone {m_id}"),
|
||||
"unlocked_at": unlock_dates.get(m_id),
|
||||
"completed": m_id in completed,
|
||||
"completed_at": completion_dates.get(m_id)
|
||||
})
|
||||
return formatted
|
||||
|
||||
def assemble_payload(self, total_modules: int = 20) -> Dict[str, Any]:
|
||||
"""
|
||||
Reads distinct local files and assembles the Unified Payload.
|
||||
"""
|
||||
progress_data = self._read_json_safe(self.progress_file)
|
||||
milestone_data = self._read_json_safe(self.milestones_file)
|
||||
|
||||
completed_modules = progress_data.get("completed_modules", [])
|
||||
|
||||
payload = {
|
||||
# user_id will be derived from the auth token on the backend,
|
||||
# but we can send a placeholder if needed for schema validation.
|
||||
"user_id": self.auth_handler.get_user_email() or "anonymous", # Using get_user_email from auth module
|
||||
"timestamp": progress_data.get("last_updated", ""),
|
||||
"version": "1.0",
|
||||
"module_progress": {
|
||||
"total_modules": total_modules,
|
||||
"completed_count": len(completed_modules),
|
||||
"completed_modules": completed_modules,
|
||||
"completion_dates": progress_data.get("completion_dates", {}),
|
||||
"completion_percentage": (len(completed_modules) / total_modules) * 100 if total_modules > 0 else 0,
|
||||
},
|
||||
"milestone_progress": {
|
||||
"total_milestones": 6,
|
||||
"unlocked_count": milestone_data.get("total_unlocked", 0),
|
||||
"unlocked_milestones": self._format_milestones(milestone_data)
|
||||
},
|
||||
"statistics": {
|
||||
"current_streak_days": progress_data.get("streak", 0)
|
||||
}
|
||||
}
|
||||
return payload
|
||||
|
||||
def sync_progress(self, total_modules: int = 20, is_retry: bool = False) -> bool:
|
||||
"""
|
||||
Main public function to assemble data and upload it.
|
||||
"""
|
||||
token = self.auth_handler.get_token()
|
||||
if not token:
|
||||
self.console.print("❌ [bold red]You are not logged in.[/bold red] Please run 'tito login' first.")
|
||||
return False
|
||||
|
||||
if not is_retry:
|
||||
self.console.print("📦 Assembling local progress...")
|
||||
|
||||
try:
|
||||
payload = self.assemble_payload(total_modules=total_modules)
|
||||
if not is_retry:
|
||||
self.console.print("Submitting payload:")
|
||||
table = Table(show_header=False, box=box.MINIMAL, padding=(0, 1))
|
||||
table.add_column("Field", style="dim")
|
||||
table.add_column("Value")
|
||||
|
||||
table.add_row("User ID", payload['user_id'])
|
||||
table.add_row("Timestamp", payload['timestamp'])
|
||||
table.add_row("Version", payload['version'])
|
||||
|
||||
table.add_row("")
|
||||
table.add_row("[bold]Module Progress[/bold]")
|
||||
table.add_row(" Total Modules", str(payload['module_progress']['total_modules']))
|
||||
table.add_row(" Completed", str(payload['module_progress']['completed_count']))
|
||||
table.add_row(" Completed Modules", ", ".join(payload['module_progress']['completed_modules']))
|
||||
table.add_row(" Completion %", f"{payload['module_progress']['completion_percentage']:.2f}%")
|
||||
|
||||
table.add_row("")
|
||||
table.add_row("[bold]Milestone Progress[/bold]")
|
||||
table.add_row(" Total Milestones", str(payload['milestone_progress']['total_milestones']))
|
||||
table.add_row(" Unlocked", str(payload['milestone_progress']['unlocked_count']))
|
||||
|
||||
table.add_row("")
|
||||
table.add_row("[bold]Statistics[/bold]")
|
||||
table.add_row(" Current Streak", str(payload['statistics']['current_streak_days']))
|
||||
|
||||
self.console.print(table)
|
||||
except Exception as e:
|
||||
self.console.print(f"❌ [red]Error assembling payload: {e}[/red]")
|
||||
return False
|
||||
|
||||
if not is_retry:
|
||||
self.console.print("🚀 Syncing with TinyTorch Cloud...")
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
req = urllib.request.Request(
|
||||
self.edge_function_url,
|
||||
data=json.dumps(payload).encode('utf-8'),
|
||||
headers=headers,
|
||||
method="POST"
|
||||
)
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=15) as response: # Added timeout
|
||||
if 200 <= response.status < 300:
|
||||
resp_body = json.loads(response.read().decode('utf-8'))
|
||||
self.console.print("✅ [bold green]Sync Successful![/bold green]")
|
||||
self.console.print(f" Modules Synced: {resp_body.get('synced_modules', 'N/A')}")
|
||||
return True
|
||||
else:
|
||||
self.console.print(f"⚠️ Server returned status: {response.status}")
|
||||
# Try to read error message from response body
|
||||
try:
|
||||
error_resp = json.loads(response.read().decode('utf-8'))
|
||||
self.console.print(f" [dim red]Error details: {error_resp.get('error', 'No message provided.')}[/dim red]")
|
||||
except json.JSONDecodeError:
|
||||
self.console.print(f" [dim red]Error details: {response.read().decode('utf-8')[:200]}...[/dim red]") # Truncate long body
|
||||
return False
|
||||
|
||||
except urllib.error.HTTPError as e:
|
||||
if e.code == 401 and not is_retry:
|
||||
self.console.print("🔑 Token expired. Attempting to refresh...")
|
||||
new_token = self.auth_handler.refresh_token(self.console)
|
||||
if new_token:
|
||||
self.console.print("✅ Token refreshed successfully. Retrying submission...")
|
||||
return self.sync_progress(total_modules=total_modules, is_retry=True)
|
||||
else:
|
||||
self.console.print("❌ [bold red]Token refresh failed.[/bold red]")
|
||||
self.console.print(" Run 'tito login --force' to refresh.")
|
||||
return False
|
||||
elif e.code == 401 and is_retry:
|
||||
self.console.print("❌ [bold red]Unauthorized.[/bold red] Your session may have expired.")
|
||||
self.console.print(" Run 'tito login --force' to refresh.")
|
||||
return False
|
||||
else:
|
||||
self.console.print(f"❌ [red]Upload failed (HTTP {e.code}): {e.reason}[/red]")
|
||||
try: # Attempt to read error body if available
|
||||
error_body = e.read().decode('utf-8')
|
||||
error_json = json.loads(error_body)
|
||||
self.console.print(f" [dim red]Error details: {error_json.get('error', 'No message provided.')}[/dim red]")
|
||||
except (json.JSONDecodeError, Exception):
|
||||
self.console.print(f" [dim red]Error details: {error_body[:200]}...[/dim red]")
|
||||
return False
|
||||
except urllib.error.URLError as e:
|
||||
self.console.print(f"❌ [red]Network error:[/red] Could not connect to the server.")
|
||||
self.console.print(f" [dim]{e.reason}[/dim]")
|
||||
return False
|
||||
except TimeoutError:
|
||||
self.console.print("❌ [red]Network error:[/red] Connection timed out.")
|
||||
return False
|
||||
@@ -38,6 +38,7 @@ from .commands.milestone import MilestoneCommand
|
||||
from .commands.setup import SetupCommand
|
||||
from .commands.benchmark import BenchmarkCommand
|
||||
from .commands.community import CommunityCommand
|
||||
from .commands.login import LoginCommand, LogoutCommand
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@@ -79,6 +80,9 @@ class TinyTorchCLI:
|
||||
'test': TestCommand,
|
||||
'grade': GradeCommand,
|
||||
'logo': LogoCommand,
|
||||
# Authentication commands
|
||||
'login': LoginCommand,
|
||||
'logout': LogoutCommand,
|
||||
}
|
||||
|
||||
def create_parser(self) -> argparse.ArgumentParser:
|
||||
|
||||
Reference in New Issue
Block a user