mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-04-29 00:37:32 -05:00
386 lines
13 KiB
Python
386 lines
13 KiB
Python
"""Simple secure JSON credentials storage system for TinyTorch CLI."""
|
|
from __future__ import annotations
|
|
|
|
import http.server
|
|
import threading
|
|
import json
|
|
import os
|
|
import time
|
|
import socket
|
|
import webbrowser
|
|
from pathlib import Path
|
|
from typing import Optional, Dict
|
|
from urllib.parse import urlparse, parse_qs
|
|
|
|
# --- Configuration Constants ---
|
|
API_BASE_URL = "https://tinytorch.netlify.app"
|
|
|
|
# API Endpoints
|
|
ENDPOINTS = {
|
|
"login": f"{API_BASE_URL}/api/auth/login",
|
|
"leaderboard": f"{API_BASE_URL}/api/leaderboard",
|
|
"submissions": f"{API_BASE_URL}/api/submissions",
|
|
"cli_login": f"{API_BASE_URL}/cli-login",
|
|
}
|
|
|
|
# Defaults
|
|
LOCAL_SERVER_HOST = "127.0.0.1"
|
|
AUTH_START_PORT = 54321
|
|
AUTH_PORT_HUNT_RANGE = 100
|
|
AUTH_CALLBACK_PATH = "/callback"
|
|
CREDENTIALS_FILE_NAME = "credentials.json"
|
|
|
|
# Determine credentials directory (Standard Python way)
|
|
CREDENTIALS_DIR = os.getenv("TINOTORCH_CREDENTIALS_DIR", str(Path.home() / ".tinytorch"))
|
|
|
|
|
|
# --- Storage Logic ---
|
|
|
|
def _credentials_dir() -> Path:
|
|
return Path(os.path.expanduser(CREDENTIALS_DIR))
|
|
|
|
def _credentials_path() -> Path:
|
|
return _credentials_dir() / CREDENTIALS_FILE_NAME
|
|
|
|
def _ensure_dir() -> None:
|
|
d = _credentials_dir()
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
try:
|
|
os.chmod(d, 0o700)
|
|
except OSError:
|
|
pass
|
|
|
|
def save_credentials(data: Dict[str, str]) -> None:
|
|
"""Persist credentials to disk safely and atomically."""
|
|
from tito.core.console import get_console
|
|
_ensure_dir()
|
|
p = _credentials_path()
|
|
tmp = p.with_suffix(".tmp")
|
|
with tmp.open("w", encoding="utf-8") as f:
|
|
json.dump(data, f, indent=2)
|
|
f.flush()
|
|
os.fsync(f.fileno())
|
|
os.replace(str(tmp), str(p))
|
|
try:
|
|
os.chmod(p, 0o600)
|
|
except OSError:
|
|
pass
|
|
|
|
# File update notification removed for cleaner UX
|
|
# console = get_console()
|
|
# relative_path = p.relative_to(Path.home())
|
|
# console.print(f"[dim]📝 Updated: ~/{relative_path}[/dim]")
|
|
|
|
def load_credentials() -> Optional[Dict[str, str]]:
|
|
p = _credentials_path()
|
|
if not p.exists():
|
|
return None
|
|
try:
|
|
with p.open("r", encoding="utf-8") as f:
|
|
return json.load(f)
|
|
except (OSError, json.JSONDecodeError):
|
|
return None
|
|
|
|
def delete_credentials() -> None:
|
|
p = _credentials_path()
|
|
try:
|
|
p.unlink()
|
|
except OSError:
|
|
pass
|
|
|
|
# --- Public Auth Helpers ---
|
|
|
|
def get_token() -> Optional[str]:
|
|
"""Retrieve the access token if it exists."""
|
|
creds = load_credentials()
|
|
if creds:
|
|
return creds.get("access_token")
|
|
return None
|
|
|
|
def is_logged_in() -> bool:
|
|
"""Check if the user has valid credentials stored."""
|
|
return get_token() is not None
|
|
|
|
def get_user_email() -> Optional[str]:
|
|
"""Retrieve the user's email if it exists."""
|
|
creds = load_credentials()
|
|
if creds:
|
|
return creds.get("user_email")
|
|
return None
|
|
|
|
|
|
def get_refresh_token() -> Optional[str]:
|
|
"""Retrieve the refresh token if it exists."""
|
|
creds = load_credentials()
|
|
if creds:
|
|
return creds.get("refresh_token")
|
|
return None
|
|
|
|
def refresh_token(console: "Console") -> Optional[str]:
|
|
"""Refresh the access token. If refresh fails, clear credentials to force re-login."""
|
|
refresh_token_val = get_refresh_token()
|
|
if not refresh_token_val:
|
|
return None
|
|
|
|
import urllib.request
|
|
import urllib.error
|
|
import json
|
|
|
|
url = f"{API_BASE_URL}/api/auth/refresh"
|
|
data = {"refreshToken": refresh_token_val}
|
|
headers = {"Content-Type": "application/json"}
|
|
|
|
req = urllib.request.Request(
|
|
url,
|
|
data=json.dumps(data).encode('utf-8'),
|
|
headers=headers,
|
|
method="POST"
|
|
)
|
|
|
|
try:
|
|
with urllib.request.urlopen(req) as response:
|
|
if response.status == 200:
|
|
new_session = json.loads(response.read().decode('utf-8'))
|
|
|
|
# Handle nested session structures (adjust based on your actual API response)
|
|
# Some APIs return { session: { ... } }, others return { access_token: ... } direct
|
|
session_data = new_session.get('session', new_session)
|
|
|
|
if 'access_token' in session_data:
|
|
new_access_token = session_data['access_token']
|
|
# IMPORTANT: Always grab the new refresh token if the server rotates it
|
|
new_refresh_token = session_data.get('refresh_token', refresh_token_val)
|
|
|
|
creds = load_credentials() or {}
|
|
creds.update({
|
|
"access_token": new_access_token,
|
|
"refresh_token": new_refresh_token,
|
|
})
|
|
save_credentials(creds)
|
|
return new_access_token
|
|
else:
|
|
console.print("[red]Token refresh response is missing session data.[/red]")
|
|
return None
|
|
else:
|
|
console.print(f"[red]Token refresh failed with status: {response.status}[/red]")
|
|
return None
|
|
|
|
except urllib.error.HTTPError as e:
|
|
# --- CRITICAL FIX HERE ---
|
|
# If we get a 400 (Bad Request) or 401 (Unauthorized), the refresh token is dead.
|
|
# We must delete the credentials so the user is forced to log in again.
|
|
if e.code in [400, 401, 403]:
|
|
console.print("[yellow]Session expired. Please log in again.[/yellow]")
|
|
delete_credentials() # This deletes the JSON file
|
|
return None
|
|
|
|
console.print(f"[red]Token refresh failed (HTTP {e.code}): {e.reason}[/red]")
|
|
try:
|
|
error_body = e.read().decode('utf-8')
|
|
error_json = json.loads(error_body)
|
|
console.print(f" [dim red]Error details: {error_json.get('error', 'No description provided.')}[/dim red]")
|
|
except (json.JSONDecodeError, Exception):
|
|
pass
|
|
return None
|
|
|
|
except urllib.error.URLError as e:
|
|
console.print(f"[red]Token refresh failed (Network error): {e.reason}[/red]")
|
|
return None
|
|
|
|
# --- Auth Server Logic ---
|
|
|
|
class CallbackHandler(http.server.BaseHTTPRequestHandler):
|
|
def do_GET(self):
|
|
parsed_path = urlparse(self.path)
|
|
|
|
if parsed_path.path == "/logout":
|
|
self.send_response(302)
|
|
self.send_header('Location', f"{API_BASE_URL}/logout")
|
|
self.end_headers()
|
|
return
|
|
|
|
if parsed_path.path != AUTH_CALLBACK_PATH:
|
|
self.send_error(404, "Not Found")
|
|
return
|
|
|
|
query_params = parse_qs(parsed_path.query)
|
|
|
|
if 'access_token' in query_params and 'refresh_token' in query_params:
|
|
self.server.auth_data = {
|
|
'access_token': query_params['access_token'][0],
|
|
'refresh_token': query_params['refresh_token'][0],
|
|
'user_email': query_params.get('email', [''])[0]
|
|
}
|
|
|
|
self.send_response(200)
|
|
self.send_header('Content-type', 'text/html; charset=utf-8')
|
|
self.end_headers()
|
|
|
|
html_content = """
|
|
<!DOCTYPE html>
|
|
<html lang="en">
|
|
<head>
|
|
<meta charset="UTF-8">
|
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
<title>TinyTorch CLI Login</title>
|
|
<style>
|
|
body {
|
|
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
|
|
display: flex;
|
|
justify-content: center;
|
|
align-items: center;
|
|
height: 100vh;
|
|
margin: 0;
|
|
background-color: #f8f9fa;
|
|
color: #333;
|
|
}
|
|
.card {
|
|
background: white;
|
|
padding: 2rem;
|
|
border-radius: 12px;
|
|
box-shadow: 0 4px 6px rgba(0,0,0,0.1);
|
|
text-align: center;
|
|
max-width: 400px;
|
|
}
|
|
h1 { color: #FF5733; margin-bottom: 0.5rem; }
|
|
h2 { margin-top: 0; margin-bottom: 1.5rem; font-weight: 600; }
|
|
p { color: #666; line-height: 1.5; margin-bottom: 1.5rem; }
|
|
.button {
|
|
display: inline-block;
|
|
background: #FF5733;
|
|
color: white;
|
|
text-decoration: none;
|
|
padding: 10px 20px;
|
|
border-radius: 6px;
|
|
font-weight: 500;
|
|
transition: background 0.2s;
|
|
}
|
|
.button:hover { background: #E64A29; }
|
|
.success-icon { font-size: 48px; margin-bottom: 1rem; display: block; }
|
|
</style>
|
|
</head>
|
|
<body>
|
|
<div class="card">
|
|
<span class="success-icon">✅</span>
|
|
<h1>Tiny🔥Torch</h1>
|
|
<h2>Login Successful</h2>
|
|
<p>You have successfully authenticated with the CLI. You can now close this window and return to your terminal.</p>
|
|
<a href="https://tinytorch.netlify.app/dashboard" class="button">Go to Dashboard</a>
|
|
</div>
|
|
<script>
|
|
// Optional: Attempt to close window automatically after delay
|
|
setTimeout(() => {
|
|
// window.close(); // Most browsers block this unless script opened window
|
|
}, 3000);
|
|
</script>
|
|
</body>
|
|
</html>
|
|
"""
|
|
|
|
self.wfile.write(html_content.encode('utf-8'))
|
|
self.wfile.flush()
|
|
|
|
# Persist immediately
|
|
try:
|
|
save_credentials(self.server.auth_data)
|
|
except Exception:
|
|
pass
|
|
else:
|
|
self.send_error(400, "Missing tokens in callback URL")
|
|
|
|
def log_message(self, format, *args):
|
|
pass
|
|
|
|
class LocalAuthServer(http.server.HTTPServer):
|
|
def __init__(self, server_address, RequestHandlerClass):
|
|
super().__init__(server_address, RequestHandlerClass)
|
|
self.auth_data: Optional[Dict[str, str]] = None
|
|
|
|
class AuthReceiver:
|
|
def __init__(self, start_port: int = None):
|
|
self.start_port = start_port if start_port is not None else AUTH_START_PORT
|
|
self.server: Optional[LocalAuthServer] = None
|
|
self.thread: Optional[threading.Thread] = None
|
|
self.port: int = 0
|
|
|
|
def start(self) -> int:
|
|
port = self.start_port
|
|
max_port = self.start_port + AUTH_PORT_HUNT_RANGE
|
|
|
|
while True:
|
|
try:
|
|
self.server = LocalAuthServer((LOCAL_SERVER_HOST, port), CallbackHandler)
|
|
self.port = self.server.server_address[1]
|
|
break
|
|
except OSError:
|
|
port += 1
|
|
if port > max_port:
|
|
raise Exception("Could not find an open port for authentication.")
|
|
|
|
def serve_with_error_handling():
|
|
try:
|
|
self.server.serve_forever()
|
|
except Exception:
|
|
pass
|
|
|
|
self.thread = threading.Thread(target=serve_with_error_handling, daemon=True)
|
|
self.thread.start()
|
|
time.sleep(0.2)
|
|
|
|
# Check if server is ready
|
|
max_wait = 2.0
|
|
waited = 0.0
|
|
server_ready = False
|
|
|
|
while waited < max_wait:
|
|
try:
|
|
if not self.thread.is_alive():
|
|
break
|
|
|
|
test_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
test_socket.settimeout(0.2)
|
|
result = test_socket.connect_ex((LOCAL_SERVER_HOST, self.port))
|
|
test_socket.close()
|
|
if result == 0:
|
|
server_ready = True
|
|
break
|
|
except Exception:
|
|
pass
|
|
time.sleep(0.1)
|
|
waited += 0.1
|
|
|
|
if not server_ready:
|
|
self.stop()
|
|
raise Exception(f"Server failed to start on port {self.port}")
|
|
|
|
return self.port
|
|
|
|
def wait_for_tokens(self, timeout: int = 120) -> Optional[Dict[str, str]]:
|
|
start_time = time.time()
|
|
try:
|
|
while getattr(self.server, "auth_data", None) is None:
|
|
if time.time() - start_time > timeout:
|
|
return None
|
|
time.sleep(0.25)
|
|
|
|
try:
|
|
save_credentials(self.server.auth_data)
|
|
except Exception:
|
|
pass
|
|
|
|
time.sleep(1.0)
|
|
|
|
return self.server.auth_data
|
|
finally:
|
|
self.stop()
|
|
|
|
def stop(self):
|
|
if self.server:
|
|
try:
|
|
self.server.shutdown()
|
|
self.server.server_close()
|
|
except Exception:
|
|
pass
|
|
if self.thread and self.thread.is_alive():
|
|
self.thread.join(timeout=1)
|