feat: Add login component and associated tests

This commit is contained in:
kai
2025-11-25 15:28:42 -05:00
parent 8e55d72aaa
commit 55f984a8f0
3 changed files with 563 additions and 0 deletions

86
tito/commands/login.py Normal file
View File

@@ -0,0 +1,86 @@
# tito/commands/login.py
import webbrowser
import time
from argparse import ArgumentParser, Namespace
from tito.commands.base import BaseCommand
from tito.core.auth import AuthReceiver, save_credentials, delete_credentials, ENDPOINTS, is_logged_in
class LoginCommand(BaseCommand):
@property
def name(self) -> str:
return "login"
@property
def description(self) -> str:
return "Log in to TinyTorch via web browser"
def add_arguments(self, parser: ArgumentParser) -> None:
parser.add_argument("--force", action="store_true", help="Force re-login")
def run(self, args: Namespace) -> int:
# Adapted logic from api.py
if args.force:
delete_credentials()
self.console.print("Cleared existing credentials.")
# Check if already logged in (unless force was used)
if is_logged_in():
self.console.print("[green]Already logged in to TinyTorch![/green]")
return 0
receiver = AuthReceiver()
try:
port = receiver.start()
target_url = f"{ENDPOINTS['cli_login']}?redirect_port={port}"
self.console.print(f"Opening browser to: [blue]{target_url}[/blue]")
self.console.print("Waiting for authentication...")
webbrowser.open(target_url)
tokens = receiver.wait_for_tokens()
if tokens:
save_credentials(tokens)
self.console.print(f"[green]Success! Logged in as {tokens['user_email']}[/green]")
return 0
else:
self.console.print("[red]Login timed out.[/red]")
return 1
except Exception as e:
self.console.print(f"[red]Error: {e}[/red]")
return 1
class LogoutCommand(BaseCommand):
@property
def name(self) -> str:
return "logout"
@property
def description(self) -> str:
return "Log out of TinyTorch by clearing stored credentials"
def add_arguments(self, parser: ArgumentParser) -> None:
pass # No arguments needed
def run(self, args: Namespace) -> int:
try:
# Start local server for logout redirect
receiver = AuthReceiver()
port = receiver.start()
# Open browser to local logout endpoint
logout_url = f"http://127.0.0.1:{port}/logout"
self.console.print(f"Opening browser to complete logout...")
webbrowser.open(logout_url)
# Give browser time to redirect and close
time.sleep(2.0)
# Clean up server
receiver.stop()
# Delete local credentials
delete_credentials()
self.console.print("[green]Successfully logged out of TinyTorch![/green]")
return 0
except Exception as e:
self.console.print(f"[red]Error during logout: {e}[/red]")
return 1

236
tito/core/auth.py Normal file
View File

@@ -0,0 +1,236 @@
"""Simple secure JSON credentials storage system for TinyTorch CLI."""
from __future__ import annotations
import http.server
import threading
import json
import os
import time
import socket
import webbrowser
from pathlib import Path
from typing import Optional, Dict
from urllib.parse import urlparse, parse_qs
# --- Configuration Constants ---
API_BASE_URL = "https://tinytorch.netlify.app"
# API Endpoints
ENDPOINTS = {
"login": f"{API_BASE_URL}/api/auth/login",
"leaderboard": f"{API_BASE_URL}/api/leaderboard",
"submissions": f"{API_BASE_URL}/api/submissions",
"cli_login": f"{API_BASE_URL}/cli-login",
}
# Defaults
LOCAL_SERVER_HOST = "127.0.0.1"
AUTH_START_PORT = 54321
AUTH_PORT_HUNT_RANGE = 100
AUTH_CALLBACK_PATH = "/callback"
CREDENTIALS_FILE_NAME = "credentials.json"
# Determine credentials directory (Standard Python way)
CREDENTIALS_DIR = os.getenv("TINOTORCH_CREDENTIALS_DIR", str(Path.home() / ".tinytorch"))
# --- Storage Logic ---
def _credentials_dir() -> Path:
return Path(os.path.expanduser(CREDENTIALS_DIR))
def _credentials_path() -> Path:
return _credentials_dir() / CREDENTIALS_FILE_NAME
def _ensure_dir() -> None:
d = _credentials_dir()
d.mkdir(parents=True, exist_ok=True)
try:
os.chmod(d, 0o700)
except OSError:
pass
def save_credentials(data: Dict[str, str]) -> None:
"""Persist credentials to disk safely and atomically."""
_ensure_dir()
p = _credentials_path()
tmp = p.with_suffix(".tmp")
with tmp.open("w", encoding="utf-8") as f:
json.dump(data, f, indent=2)
f.flush()
os.fsync(f.fileno())
os.replace(str(tmp), str(p))
try:
os.chmod(p, 0o600)
except OSError:
pass
def load_credentials() -> Optional[Dict[str, str]]:
p = _credentials_path()
if not p.exists():
return None
try:
with p.open("r", encoding="utf-8") as f:
return json.load(f)
except (OSError, json.JSONDecodeError):
return None
def delete_credentials() -> None:
p = _credentials_path()
try:
p.unlink()
except OSError:
pass
# --- Public Auth Helpers ---
def get_token() -> Optional[str]:
"""Retrieve the access token if it exists."""
creds = load_credentials()
if creds:
return creds.get("access_token")
return None
def is_logged_in() -> bool:
"""Check if the user has valid credentials stored."""
return get_token() is not None
# --- Auth Server Logic ---
class CallbackHandler(http.server.BaseHTTPRequestHandler):
def do_GET(self):
parsed_path = urlparse(self.path)
if parsed_path.path == "/logout":
self.send_response(302)
self.send_header('Location', f"{API_BASE_URL}/logout")
self.end_headers()
return
if parsed_path.path != AUTH_CALLBACK_PATH:
self.send_error(404, "Not Found")
return
query_params = parse_qs(parsed_path.query)
if 'access_token' in query_params and 'refresh_token' in query_params:
self.server.auth_data = {
'access_token': query_params['access_token'][0],
'refresh_token': query_params['refresh_token'][0],
'user_email': query_params.get('email', [''])[0]
}
self.send_response(200)
self.send_header('Content-type', 'text/html')
self.end_headers()
html_content = "<html><body><h1>🔥 Tinytorch <h1> <h2>Login Successful</h2><p>You can close this window and return to the CLI.</p><p><a href='https://tinytorch.netlify.app/dashboard'>Go to TinyTorch Dashboard</a></p><script>window.close()</script></body></html>"
self.wfile.write(html_content.encode('utf-8'))
self.wfile.flush()
# Persist immediately
try:
save_credentials(self.server.auth_data)
except Exception:
pass
else:
self.send_error(400, "Missing tokens in callback URL")
def log_message(self, format, *args):
pass
class LocalAuthServer(http.server.HTTPServer):
def __init__(self, server_address, RequestHandlerClass):
super().__init__(server_address, RequestHandlerClass)
self.auth_data: Optional[Dict[str, str]] = None
class AuthReceiver:
def __init__(self, start_port: int = None):
self.start_port = start_port if start_port is not None else AUTH_START_PORT
self.server: Optional[LocalAuthServer] = None
self.thread: Optional[threading.Thread] = None
self.port: int = 0
def start(self) -> int:
port = self.start_port
max_port = self.start_port + AUTH_PORT_HUNT_RANGE
while True:
try:
self.server = LocalAuthServer((LOCAL_SERVER_HOST, port), CallbackHandler)
self.port = self.server.server_address[1]
break
except OSError:
port += 1
if port > max_port:
raise Exception("Could not find an open port for authentication.")
def serve_with_error_handling():
try:
self.server.serve_forever()
except Exception:
pass
self.thread = threading.Thread(target=serve_with_error_handling, daemon=True)
self.thread.start()
time.sleep(0.2)
# Check if server is ready
max_wait = 2.0
waited = 0.0
server_ready = False
while waited < max_wait:
try:
if not self.thread.is_alive():
break
test_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
test_socket.settimeout(0.2)
result = test_socket.connect_ex((LOCAL_SERVER_HOST, self.port))
test_socket.close()
if result == 0:
server_ready = True
break
except Exception:
pass
time.sleep(0.1)
waited += 0.1
if not server_ready:
self.stop()
raise Exception(f"Server failed to start on port {self.port}")
return self.port
def wait_for_tokens(self, timeout: int = 120) -> Optional[Dict[str, str]]:
start_time = time.time()
try:
while getattr(self.server, "auth_data", None) is None:
if time.time() - start_time > timeout:
return None
time.sleep(0.25)
try:
save_credentials(self.server.auth_data)
except Exception:
pass
time.sleep(1.0)
return self.server.auth_data
finally:
self.stop()
def stop(self):
if self.server:
try:
self.server.shutdown()
self.server.server_close()
except Exception:
pass
if self.thread and self.thread.is_alive():
self.thread.join(timeout=1)

241
tito/tests/test_auth.py Normal file
View File

@@ -0,0 +1,241 @@
"""Tests for the authentication module (tito.core.auth)."""
import json
import os
import tempfile
from pathlib import Path
from unittest.mock import Mock, patch, MagicMock
import pytest
from tito.core.auth import (
save_credentials,
load_credentials,
delete_credentials,
get_token,
is_logged_in,
AuthReceiver,
CallbackHandler,
LocalAuthServer,
)
class TestCredentialsStorage:
"""Test credential storage functions."""
def setup_method(self):
"""Set up temporary directory for testing."""
self.temp_dir = tempfile.mkdtemp()
# Mock the credentials directory to use our temp dir
with patch('tito.core.auth.CREDENTIALS_DIR', self.temp_dir):
self.test_data = {"access_token": "test_token", "refresh_token": "refresh", "user_email": "test@example.com"}
def test_save_and_load_credentials(self):
"""Test saving and loading credentials."""
with patch('tito.core.auth.CREDENTIALS_DIR', self.temp_dir):
save_credentials(self.test_data)
loaded = load_credentials()
assert loaded == self.test_data
def test_load_nonexistent_credentials(self):
"""Test loading when no credentials exist."""
with patch('tito.core.auth.CREDENTIALS_DIR', self.temp_dir):
loaded = load_credentials()
assert loaded is None
def test_delete_credentials(self):
"""Test deleting credentials."""
with patch('tito.core.auth.CREDENTIALS_DIR', self.temp_dir):
save_credentials(self.test_data)
delete_credentials()
loaded = load_credentials()
assert loaded is None
def test_get_token(self):
"""Test getting access token."""
with patch('tito.core.auth.CREDENTIALS_DIR', self.temp_dir):
save_credentials(self.test_data)
token = get_token()
assert token == "test_token"
def test_get_token_no_credentials(self):
"""Test getting token when no credentials exist."""
with patch('tito.core.auth.CREDENTIALS_DIR', self.temp_dir):
token = get_token()
assert token is None
def test_is_logged_in(self):
"""Test checking login status."""
with patch('tito.core.auth.CREDENTIALS_DIR', self.temp_dir):
assert not is_logged_in()
save_credentials(self.test_data)
assert is_logged_in()
class TestAuthReceiver:
"""Test the AuthReceiver class."""
@patch('tito.core.auth.LocalAuthServer')
@patch('socket.socket')
@patch('time.sleep')
@patch('threading.Thread')
def test_start_server(self, mock_thread_class, mock_sleep, mock_socket_class, mock_server_class):
"""Test starting the auth server."""
# Mock the server instance
mock_server = Mock()
mock_server.server_address = ('127.0.0.1', 54321)
mock_server_class.return_value = mock_server
# Mock socket for port checking
mock_socket = Mock()
mock_socket.connect_ex.return_value = 0 # Success
mock_socket_class.return_value = mock_socket
# Mock thread
mock_thread = Mock()
mock_thread.is_alive.return_value = True
mock_thread_class.return_value = mock_thread
receiver = AuthReceiver()
port = receiver.start()
assert port == 54321
mock_server_class.assert_called_once()
receiver.stop()
@patch('webbrowser.open')
@patch('time.sleep')
def test_wait_for_tokens_timeout(self, mock_sleep, mock_open):
"""Test waiting for tokens with timeout."""
receiver = AuthReceiver()
# Mock server without auth_data
receiver.server = Mock()
receiver.server.auth_data = None
tokens = receiver.wait_for_tokens(timeout=0.1)
assert tokens is None
@patch('webbrowser.open')
@patch('time.sleep')
def test_wait_for_tokens_success(self, mock_sleep, mock_open):
"""Test successful token reception."""
receiver = AuthReceiver()
test_tokens = {"access_token": "token", "refresh_token": "refresh", "user_email": "user@example.com"}
# Mock server with auth_data
receiver.server = Mock()
receiver.server.auth_data = test_tokens
with patch('tito.core.auth.save_credentials') as mock_save:
tokens = receiver.wait_for_tokens(timeout=1)
assert tokens == test_tokens
mock_save.assert_called_with(test_tokens)
class TestCallbackHandler:
"""Test the CallbackHandler class."""
def test_do_get_callback_success(self):
"""Test successful callback handling."""
# Create handler directly without server initialization
handler = CallbackHandler.__new__(CallbackHandler)
# Mock the required attributes
handler.path = "/callback?access_token=test&refresh_token=refresh&email=user@example.com"
handler.send_response = Mock()
handler.send_header = Mock()
handler.end_headers = Mock()
handler.wfile = Mock()
handler.server = Mock()
with patch('tito.core.auth.save_credentials') as mock_save:
handler.do_GET()
assert handler.server.auth_data == {
"access_token": "test",
"refresh_token": "refresh",
"user_email": "user@example.com"
}
mock_save.assert_called_once()
def test_do_get_invalid_path(self):
"""Test handling of invalid callback path."""
handler = CallbackHandler.__new__(CallbackHandler)
handler.path = "/invalid"
handler.send_error = Mock()
handler.server = Mock()
handler.do_GET()
handler.send_error.assert_called_with(404, "Not Found")
# Integration test example
def test_full_login_flow():
"""Integration test for the full login flow (mocked)."""
# This would test the entire flow from AuthReceiver to storage
# In a real scenario, you'd mock the HTTP server and browser
pass
class TestLoginCommand:
"""Test the LoginCommand behavior."""
@patch('tito.core.auth.is_logged_in')
@patch('tito.commands.base.get_console')
def test_already_logged_in(self, mock_get_console, mock_is_logged_in):
"""Test that login command exits early if already logged in."""
from tito.commands.login import LoginCommand
mock_is_logged_in.return_value = True
# Mock console
mock_console = Mock()
mock_get_console.return_value = mock_console
# Mock config
mock_config = Mock()
command = LoginCommand(mock_config)
# Create mock args
args = Mock()
args.force = False
result = command.run(args)
assert result == 0
mock_console.print.assert_called_with("[green]Already logged in to TinyTorch![/green]")
class TestLogoutCommand:
"""Test the LogoutCommand behavior."""
@patch('tito.commands.login.AuthReceiver')
@patch('webbrowser.open')
@patch('time.sleep')
@patch('tito.commands.login.delete_credentials')
@patch('tito.commands.base.get_console')
def test_logout_with_browser(self, mock_get_console, mock_delete, mock_sleep, mock_open, mock_receiver_class):
"""Test that logout command opens browser and deletes credentials."""
from tito.commands.login import LogoutCommand
# Mock console
mock_console = Mock()
mock_get_console.return_value = mock_console
# Mock receiver
mock_receiver = Mock()
mock_receiver.start.return_value = 54321
mock_receiver_class.return_value = mock_receiver
# Mock config
mock_config = Mock()
command = LogoutCommand(mock_config)
# Create mock args
args = Mock()
result = command.run(args)
assert result == 0
mock_receiver.start.assert_called_once()
mock_open.assert_called_once_with("http://127.0.0.1:54321/logout")
mock_delete.assert_called_once()
mock_console.print.assert_called_with("[green]Successfully logged out of TinyTorch![/green]")