mirror of
https://github.com/MLSysBook/TinyTorch.git
synced 2026-04-28 23:09:41 -05:00
feat: Add login component and associated tests
This commit is contained in:
86
tito/commands/login.py
Normal file
86
tito/commands/login.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# tito/commands/login.py
|
||||
import webbrowser
|
||||
import time
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from tito.commands.base import BaseCommand
|
||||
from tito.core.auth import AuthReceiver, save_credentials, delete_credentials, ENDPOINTS, is_logged_in
|
||||
|
||||
class LoginCommand(BaseCommand):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "login"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Log in to TinyTorch via web browser"
|
||||
|
||||
def add_arguments(self, parser: ArgumentParser) -> None:
|
||||
parser.add_argument("--force", action="store_true", help="Force re-login")
|
||||
|
||||
def run(self, args: Namespace) -> int:
|
||||
# Adapted logic from api.py
|
||||
if args.force:
|
||||
delete_credentials()
|
||||
self.console.print("Cleared existing credentials.")
|
||||
|
||||
# Check if already logged in (unless force was used)
|
||||
if is_logged_in():
|
||||
self.console.print("[green]Already logged in to TinyTorch![/green]")
|
||||
return 0
|
||||
|
||||
receiver = AuthReceiver()
|
||||
try:
|
||||
port = receiver.start()
|
||||
target_url = f"{ENDPOINTS['cli_login']}?redirect_port={port}"
|
||||
self.console.print(f"Opening browser to: [blue]{target_url}[/blue]")
|
||||
self.console.print("Waiting for authentication...")
|
||||
webbrowser.open(target_url)
|
||||
tokens = receiver.wait_for_tokens()
|
||||
if tokens:
|
||||
save_credentials(tokens)
|
||||
self.console.print(f"[green]Success! Logged in as {tokens['user_email']}[/green]")
|
||||
return 0
|
||||
else:
|
||||
self.console.print("[red]Login timed out.[/red]")
|
||||
return 1
|
||||
except Exception as e:
|
||||
self.console.print(f"[red]Error: {e}[/red]")
|
||||
return 1
|
||||
|
||||
|
||||
class LogoutCommand(BaseCommand):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "logout"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Log out of TinyTorch by clearing stored credentials"
|
||||
|
||||
def add_arguments(self, parser: ArgumentParser) -> None:
|
||||
pass # No arguments needed
|
||||
|
||||
def run(self, args: Namespace) -> int:
|
||||
try:
|
||||
# Start local server for logout redirect
|
||||
receiver = AuthReceiver()
|
||||
port = receiver.start()
|
||||
|
||||
# Open browser to local logout endpoint
|
||||
logout_url = f"http://127.0.0.1:{port}/logout"
|
||||
self.console.print(f"Opening browser to complete logout...")
|
||||
webbrowser.open(logout_url)
|
||||
|
||||
# Give browser time to redirect and close
|
||||
time.sleep(2.0)
|
||||
|
||||
# Clean up server
|
||||
receiver.stop()
|
||||
|
||||
# Delete local credentials
|
||||
delete_credentials()
|
||||
self.console.print("[green]Successfully logged out of TinyTorch![/green]")
|
||||
return 0
|
||||
except Exception as e:
|
||||
self.console.print(f"[red]Error during logout: {e}[/red]")
|
||||
return 1
|
||||
236
tito/core/auth.py
Normal file
236
tito/core/auth.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""Simple secure JSON credentials storage system for TinyTorch CLI."""
|
||||
from __future__ import annotations
|
||||
|
||||
import http.server
|
||||
import threading
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import socket
|
||||
import webbrowser
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
|
||||
# --- Configuration Constants ---
|
||||
API_BASE_URL = "https://tinytorch.netlify.app"
|
||||
|
||||
# API Endpoints
|
||||
ENDPOINTS = {
|
||||
"login": f"{API_BASE_URL}/api/auth/login",
|
||||
"leaderboard": f"{API_BASE_URL}/api/leaderboard",
|
||||
"submissions": f"{API_BASE_URL}/api/submissions",
|
||||
"cli_login": f"{API_BASE_URL}/cli-login",
|
||||
}
|
||||
|
||||
# Defaults
|
||||
LOCAL_SERVER_HOST = "127.0.0.1"
|
||||
AUTH_START_PORT = 54321
|
||||
AUTH_PORT_HUNT_RANGE = 100
|
||||
AUTH_CALLBACK_PATH = "/callback"
|
||||
CREDENTIALS_FILE_NAME = "credentials.json"
|
||||
|
||||
# Determine credentials directory (Standard Python way)
|
||||
CREDENTIALS_DIR = os.getenv("TINOTORCH_CREDENTIALS_DIR", str(Path.home() / ".tinytorch"))
|
||||
|
||||
|
||||
# --- Storage Logic ---
|
||||
|
||||
def _credentials_dir() -> Path:
|
||||
return Path(os.path.expanduser(CREDENTIALS_DIR))
|
||||
|
||||
def _credentials_path() -> Path:
|
||||
return _credentials_dir() / CREDENTIALS_FILE_NAME
|
||||
|
||||
def _ensure_dir() -> None:
|
||||
d = _credentials_dir()
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
os.chmod(d, 0o700)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def save_credentials(data: Dict[str, str]) -> None:
|
||||
"""Persist credentials to disk safely and atomically."""
|
||||
_ensure_dir()
|
||||
p = _credentials_path()
|
||||
tmp = p.with_suffix(".tmp")
|
||||
with tmp.open("w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
os.replace(str(tmp), str(p))
|
||||
try:
|
||||
os.chmod(p, 0o600)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def load_credentials() -> Optional[Dict[str, str]]:
|
||||
p = _credentials_path()
|
||||
if not p.exists():
|
||||
return None
|
||||
try:
|
||||
with p.open("r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
def delete_credentials() -> None:
|
||||
p = _credentials_path()
|
||||
try:
|
||||
p.unlink()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# --- Public Auth Helpers ---
|
||||
|
||||
def get_token() -> Optional[str]:
|
||||
"""Retrieve the access token if it exists."""
|
||||
creds = load_credentials()
|
||||
if creds:
|
||||
return creds.get("access_token")
|
||||
return None
|
||||
|
||||
def is_logged_in() -> bool:
|
||||
"""Check if the user has valid credentials stored."""
|
||||
return get_token() is not None
|
||||
|
||||
|
||||
|
||||
# --- Auth Server Logic ---
|
||||
|
||||
class CallbackHandler(http.server.BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
parsed_path = urlparse(self.path)
|
||||
|
||||
if parsed_path.path == "/logout":
|
||||
self.send_response(302)
|
||||
self.send_header('Location', f"{API_BASE_URL}/logout")
|
||||
self.end_headers()
|
||||
return
|
||||
|
||||
if parsed_path.path != AUTH_CALLBACK_PATH:
|
||||
self.send_error(404, "Not Found")
|
||||
return
|
||||
|
||||
query_params = parse_qs(parsed_path.query)
|
||||
|
||||
if 'access_token' in query_params and 'refresh_token' in query_params:
|
||||
self.server.auth_data = {
|
||||
'access_token': query_params['access_token'][0],
|
||||
'refresh_token': query_params['refresh_token'][0],
|
||||
'user_email': query_params.get('email', [''])[0]
|
||||
}
|
||||
|
||||
self.send_response(200)
|
||||
self.send_header('Content-type', 'text/html')
|
||||
self.end_headers()
|
||||
|
||||
html_content = "<html><body><h1>🔥 Tinytorch <h1> <h2>Login Successful</h2><p>You can close this window and return to the CLI.</p><p><a href='https://tinytorch.netlify.app/dashboard'>Go to TinyTorch Dashboard</a></p><script>window.close()</script></body></html>"
|
||||
|
||||
self.wfile.write(html_content.encode('utf-8'))
|
||||
self.wfile.flush()
|
||||
|
||||
# Persist immediately
|
||||
try:
|
||||
save_credentials(self.server.auth_data)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
self.send_error(400, "Missing tokens in callback URL")
|
||||
|
||||
def log_message(self, format, *args):
|
||||
pass
|
||||
|
||||
class LocalAuthServer(http.server.HTTPServer):
|
||||
def __init__(self, server_address, RequestHandlerClass):
|
||||
super().__init__(server_address, RequestHandlerClass)
|
||||
self.auth_data: Optional[Dict[str, str]] = None
|
||||
|
||||
class AuthReceiver:
|
||||
def __init__(self, start_port: int = None):
|
||||
self.start_port = start_port if start_port is not None else AUTH_START_PORT
|
||||
self.server: Optional[LocalAuthServer] = None
|
||||
self.thread: Optional[threading.Thread] = None
|
||||
self.port: int = 0
|
||||
|
||||
def start(self) -> int:
|
||||
port = self.start_port
|
||||
max_port = self.start_port + AUTH_PORT_HUNT_RANGE
|
||||
|
||||
while True:
|
||||
try:
|
||||
self.server = LocalAuthServer((LOCAL_SERVER_HOST, port), CallbackHandler)
|
||||
self.port = self.server.server_address[1]
|
||||
break
|
||||
except OSError:
|
||||
port += 1
|
||||
if port > max_port:
|
||||
raise Exception("Could not find an open port for authentication.")
|
||||
|
||||
def serve_with_error_handling():
|
||||
try:
|
||||
self.server.serve_forever()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.thread = threading.Thread(target=serve_with_error_handling, daemon=True)
|
||||
self.thread.start()
|
||||
time.sleep(0.2)
|
||||
|
||||
# Check if server is ready
|
||||
max_wait = 2.0
|
||||
waited = 0.0
|
||||
server_ready = False
|
||||
|
||||
while waited < max_wait:
|
||||
try:
|
||||
if not self.thread.is_alive():
|
||||
break
|
||||
|
||||
test_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
test_socket.settimeout(0.2)
|
||||
result = test_socket.connect_ex((LOCAL_SERVER_HOST, self.port))
|
||||
test_socket.close()
|
||||
if result == 0:
|
||||
server_ready = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(0.1)
|
||||
waited += 0.1
|
||||
|
||||
if not server_ready:
|
||||
self.stop()
|
||||
raise Exception(f"Server failed to start on port {self.port}")
|
||||
|
||||
return self.port
|
||||
|
||||
def wait_for_tokens(self, timeout: int = 120) -> Optional[Dict[str, str]]:
|
||||
start_time = time.time()
|
||||
try:
|
||||
while getattr(self.server, "auth_data", None) is None:
|
||||
if time.time() - start_time > timeout:
|
||||
return None
|
||||
time.sleep(0.25)
|
||||
|
||||
try:
|
||||
save_credentials(self.server.auth_data)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time.sleep(1.0)
|
||||
|
||||
return self.server.auth_data
|
||||
finally:
|
||||
self.stop()
|
||||
|
||||
def stop(self):
|
||||
if self.server:
|
||||
try:
|
||||
self.server.shutdown()
|
||||
self.server.server_close()
|
||||
except Exception:
|
||||
pass
|
||||
if self.thread and self.thread.is_alive():
|
||||
self.thread.join(timeout=1)
|
||||
241
tito/tests/test_auth.py
Normal file
241
tito/tests/test_auth.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""Tests for the authentication module (tito.core.auth)."""
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
import pytest
|
||||
|
||||
from tito.core.auth import (
|
||||
save_credentials,
|
||||
load_credentials,
|
||||
delete_credentials,
|
||||
get_token,
|
||||
is_logged_in,
|
||||
AuthReceiver,
|
||||
CallbackHandler,
|
||||
LocalAuthServer,
|
||||
)
|
||||
|
||||
|
||||
class TestCredentialsStorage:
|
||||
"""Test credential storage functions."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up temporary directory for testing."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
# Mock the credentials directory to use our temp dir
|
||||
with patch('tito.core.auth.CREDENTIALS_DIR', self.temp_dir):
|
||||
self.test_data = {"access_token": "test_token", "refresh_token": "refresh", "user_email": "test@example.com"}
|
||||
|
||||
def test_save_and_load_credentials(self):
|
||||
"""Test saving and loading credentials."""
|
||||
with patch('tito.core.auth.CREDENTIALS_DIR', self.temp_dir):
|
||||
save_credentials(self.test_data)
|
||||
loaded = load_credentials()
|
||||
assert loaded == self.test_data
|
||||
|
||||
def test_load_nonexistent_credentials(self):
|
||||
"""Test loading when no credentials exist."""
|
||||
with patch('tito.core.auth.CREDENTIALS_DIR', self.temp_dir):
|
||||
loaded = load_credentials()
|
||||
assert loaded is None
|
||||
|
||||
def test_delete_credentials(self):
|
||||
"""Test deleting credentials."""
|
||||
with patch('tito.core.auth.CREDENTIALS_DIR', self.temp_dir):
|
||||
save_credentials(self.test_data)
|
||||
delete_credentials()
|
||||
loaded = load_credentials()
|
||||
assert loaded is None
|
||||
|
||||
def test_get_token(self):
|
||||
"""Test getting access token."""
|
||||
with patch('tito.core.auth.CREDENTIALS_DIR', self.temp_dir):
|
||||
save_credentials(self.test_data)
|
||||
token = get_token()
|
||||
assert token == "test_token"
|
||||
|
||||
def test_get_token_no_credentials(self):
|
||||
"""Test getting token when no credentials exist."""
|
||||
with patch('tito.core.auth.CREDENTIALS_DIR', self.temp_dir):
|
||||
token = get_token()
|
||||
assert token is None
|
||||
|
||||
def test_is_logged_in(self):
|
||||
"""Test checking login status."""
|
||||
with patch('tito.core.auth.CREDENTIALS_DIR', self.temp_dir):
|
||||
assert not is_logged_in()
|
||||
save_credentials(self.test_data)
|
||||
assert is_logged_in()
|
||||
|
||||
|
||||
class TestAuthReceiver:
|
||||
"""Test the AuthReceiver class."""
|
||||
|
||||
@patch('tito.core.auth.LocalAuthServer')
|
||||
@patch('socket.socket')
|
||||
@patch('time.sleep')
|
||||
@patch('threading.Thread')
|
||||
def test_start_server(self, mock_thread_class, mock_sleep, mock_socket_class, mock_server_class):
|
||||
"""Test starting the auth server."""
|
||||
# Mock the server instance
|
||||
mock_server = Mock()
|
||||
mock_server.server_address = ('127.0.0.1', 54321)
|
||||
mock_server_class.return_value = mock_server
|
||||
|
||||
# Mock socket for port checking
|
||||
mock_socket = Mock()
|
||||
mock_socket.connect_ex.return_value = 0 # Success
|
||||
mock_socket_class.return_value = mock_socket
|
||||
|
||||
# Mock thread
|
||||
mock_thread = Mock()
|
||||
mock_thread.is_alive.return_value = True
|
||||
mock_thread_class.return_value = mock_thread
|
||||
|
||||
receiver = AuthReceiver()
|
||||
port = receiver.start()
|
||||
|
||||
assert port == 54321
|
||||
mock_server_class.assert_called_once()
|
||||
receiver.stop()
|
||||
|
||||
@patch('webbrowser.open')
|
||||
@patch('time.sleep')
|
||||
def test_wait_for_tokens_timeout(self, mock_sleep, mock_open):
|
||||
"""Test waiting for tokens with timeout."""
|
||||
receiver = AuthReceiver()
|
||||
# Mock server without auth_data
|
||||
receiver.server = Mock()
|
||||
receiver.server.auth_data = None
|
||||
|
||||
tokens = receiver.wait_for_tokens(timeout=0.1)
|
||||
assert tokens is None
|
||||
|
||||
@patch('webbrowser.open')
|
||||
@patch('time.sleep')
|
||||
def test_wait_for_tokens_success(self, mock_sleep, mock_open):
|
||||
"""Test successful token reception."""
|
||||
receiver = AuthReceiver()
|
||||
test_tokens = {"access_token": "token", "refresh_token": "refresh", "user_email": "user@example.com"}
|
||||
|
||||
# Mock server with auth_data
|
||||
receiver.server = Mock()
|
||||
receiver.server.auth_data = test_tokens
|
||||
|
||||
with patch('tito.core.auth.save_credentials') as mock_save:
|
||||
tokens = receiver.wait_for_tokens(timeout=1)
|
||||
assert tokens == test_tokens
|
||||
mock_save.assert_called_with(test_tokens)
|
||||
|
||||
|
||||
class TestCallbackHandler:
|
||||
"""Test the CallbackHandler class."""
|
||||
|
||||
def test_do_get_callback_success(self):
|
||||
"""Test successful callback handling."""
|
||||
# Create handler directly without server initialization
|
||||
handler = CallbackHandler.__new__(CallbackHandler)
|
||||
|
||||
# Mock the required attributes
|
||||
handler.path = "/callback?access_token=test&refresh_token=refresh&email=user@example.com"
|
||||
handler.send_response = Mock()
|
||||
handler.send_header = Mock()
|
||||
handler.end_headers = Mock()
|
||||
handler.wfile = Mock()
|
||||
handler.server = Mock()
|
||||
|
||||
with patch('tito.core.auth.save_credentials') as mock_save:
|
||||
handler.do_GET()
|
||||
assert handler.server.auth_data == {
|
||||
"access_token": "test",
|
||||
"refresh_token": "refresh",
|
||||
"user_email": "user@example.com"
|
||||
}
|
||||
mock_save.assert_called_once()
|
||||
|
||||
def test_do_get_invalid_path(self):
|
||||
"""Test handling of invalid callback path."""
|
||||
handler = CallbackHandler.__new__(CallbackHandler)
|
||||
|
||||
handler.path = "/invalid"
|
||||
handler.send_error = Mock()
|
||||
handler.server = Mock()
|
||||
|
||||
handler.do_GET()
|
||||
handler.send_error.assert_called_with(404, "Not Found")
|
||||
|
||||
|
||||
# Integration test example
|
||||
def test_full_login_flow():
|
||||
"""Integration test for the full login flow (mocked)."""
|
||||
# This would test the entire flow from AuthReceiver to storage
|
||||
# In a real scenario, you'd mock the HTTP server and browser
|
||||
pass
|
||||
|
||||
|
||||
class TestLoginCommand:
|
||||
"""Test the LoginCommand behavior."""
|
||||
|
||||
@patch('tito.core.auth.is_logged_in')
|
||||
@patch('tito.commands.base.get_console')
|
||||
def test_already_logged_in(self, mock_get_console, mock_is_logged_in):
|
||||
"""Test that login command exits early if already logged in."""
|
||||
from tito.commands.login import LoginCommand
|
||||
|
||||
mock_is_logged_in.return_value = True
|
||||
|
||||
# Mock console
|
||||
mock_console = Mock()
|
||||
mock_get_console.return_value = mock_console
|
||||
|
||||
# Mock config
|
||||
mock_config = Mock()
|
||||
command = LoginCommand(mock_config)
|
||||
|
||||
# Create mock args
|
||||
args = Mock()
|
||||
args.force = False
|
||||
|
||||
result = command.run(args)
|
||||
|
||||
assert result == 0
|
||||
mock_console.print.assert_called_with("[green]Already logged in to TinyTorch![/green]")
|
||||
|
||||
|
||||
class TestLogoutCommand:
|
||||
"""Test the LogoutCommand behavior."""
|
||||
|
||||
@patch('tito.commands.login.AuthReceiver')
|
||||
@patch('webbrowser.open')
|
||||
@patch('time.sleep')
|
||||
@patch('tito.commands.login.delete_credentials')
|
||||
@patch('tito.commands.base.get_console')
|
||||
def test_logout_with_browser(self, mock_get_console, mock_delete, mock_sleep, mock_open, mock_receiver_class):
|
||||
"""Test that logout command opens browser and deletes credentials."""
|
||||
from tito.commands.login import LogoutCommand
|
||||
|
||||
# Mock console
|
||||
mock_console = Mock()
|
||||
mock_get_console.return_value = mock_console
|
||||
|
||||
# Mock receiver
|
||||
mock_receiver = Mock()
|
||||
mock_receiver.start.return_value = 54321
|
||||
mock_receiver_class.return_value = mock_receiver
|
||||
|
||||
# Mock config
|
||||
mock_config = Mock()
|
||||
command = LogoutCommand(mock_config)
|
||||
|
||||
# Create mock args
|
||||
args = Mock()
|
||||
|
||||
result = command.run(args)
|
||||
|
||||
assert result == 0
|
||||
mock_receiver.start.assert_called_once()
|
||||
mock_open.assert_called_once_with("http://127.0.0.1:54321/logout")
|
||||
mock_delete.assert_called_once()
|
||||
mock_console.print.assert_called_with("[green]Successfully logged out of TinyTorch![/green]")
|
||||
Reference in New Issue
Block a user