Add Auth/User system

This commit is contained in:
Kohaku-Blueleaf
2025-10-02 17:40:56 +08:00
parent e6c02b455e
commit 5a9dfeb6dc
12 changed files with 605 additions and 56 deletions

View File

@@ -16,13 +16,14 @@ dependencies = [
"fastapi",
"httpx",
"uvicorn",
"pydantic",
"pydantic[email]",
"toml",
"peewee",
"boto3",
"lakefs-client",
"psycopg2-binary",
"pyyaml",
"bcrypt",
]
urls = { "Homepage" = "https://kblueleaf.net/Kohaku-Hub" }

87
scripts/test_auth.py Normal file
View File

@@ -0,0 +1,87 @@
"""Test authentication system."""
import requests
BASE_URL = "http://127.0.0.1:48888/api"
# 1. Register
print("=== Testing Registration ===")
resp = requests.post(
f"{BASE_URL}/auth/register",
json={
"username": "testuser2",
"email": "test2@example.com",
"password": "testpass123",
},
)
print(f"Status: {resp.status_code}")
print(f"Response: {resp.json()}\n")
# 2. Login
print("=== Testing Login ===")
resp = requests.post(
f"{BASE_URL}/auth/login", json={"username": "testuser2", "password": "testpass123"}
)
print(f"Status: {resp.status_code}")
result = resp.json()
print(f"Response: {result}\n")
# Save session cookie and secret
session = requests.Session()
session.cookies.update(resp.cookies)
session_secret = result.get("session_secret")
print(f"Session secret: {session_secret}\n")
# 3. Get current user
print("=== Testing Get Current User ===")
resp = session.get(f"{BASE_URL}/auth/me")
print(f"Status: {resp.status_code}")
print(f"Response: {resp.json()}\n")
# 4. Create API token
print("=== Testing Create Token ===")
resp = session.post(f"{BASE_URL}/auth/tokens/create", json={"name": "test-token"})
print(f"Status: {resp.status_code}")
result = resp.json()
print(f"Response: {result}\n")
token = result["token"]
print(f"Generated token: {token}\n")
print(f"Session secret for encryption: {result['session_secret']}\n")
# 5-1. List tokens
print("=== Testing List Tokens ===")
resp = session.get(f"{BASE_URL}/auth/tokens")
print(f"Status: {resp.status_code}")
print(f"Response: {resp.json()}\n")
# 6. Test token-based auth
print("=== Testing Token Auth ===")
headers = {"Authorization": f"Bearer {token}"}
resp = requests.get(f"{BASE_URL}/auth/me", headers=headers)
print(f"Status: {resp.status_code}")
print(f"Response: {resp.json()}\n")
# 5-2. List tokens
print("=== Testing List Tokens ===")
resp = session.get(f"{BASE_URL}/auth/tokens")
print(f"Status: {resp.status_code}")
print(f"Response: {resp.json()}\n")
# 7. Logout
print("=== Testing Logout ===")
resp = session.post(f"{BASE_URL}/auth/logout")
print(f"Status: {resp.status_code}")
print(f"Response: {resp.json()}\n")
# 8. Verify session cleared
print("=== Testing Session Cleared ===")
resp = session.get(f"{BASE_URL}/auth/me")
print(f"Status: {resp.status_code}")
print(f"Response: {resp.json()}\n")
# 9. Verify token still works
print("=== Testing Token Still Works ===")
resp = requests.get(f"{BASE_URL}/auth/me", headers=headers)
print(f"Status: {resp.status_code}")
print(f"Response: {resp.json()}\n")

View File

@@ -1,10 +1,13 @@
"""Authentication and authorization for Kohaku Hub API.
TODO: Implement real authentication system.
Currently returns mock user for development.
Integrates with the new auth system in kohakuhub.auth module.
"""
from ..db import User, db
from ..db import db, User
from ..auth.dependencies import (
get_current_user as auth_get_current_user,
get_optional_user,
)
def get_db():
@@ -20,19 +23,6 @@ def get_db():
def get_current_user():
"""Get current authenticated user.
TODO: Implement real authentication:
- Parse Authorization header (Bearer token)
- Validate token against database or JWT
- Return actual User object
- Raise HTTPException(401) if invalid
Returns:
Mock user object for development.
Now delegates to the real auth system.
"""
# Mock user for development
class MockUser:
username = "me"
id = 1
return MockUser()
return auth_get_current_user()

View File

@@ -1,9 +1,12 @@
"""Utility API endpoints for Kohaku Hub."""
from fastapi import APIRouter, HTTPException, Request
import yaml
from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel
import yaml
from ..db import User
from .auth import get_optional_user
router = APIRouter()
@@ -34,21 +37,30 @@ def validate_yaml(body: ValidateYamlPayload):
@router.get("/whoami-v2")
def whoami_v2():
"""Get current user information.
def whoami_v2(user: User = Depends(get_optional_user)):
"""Get current user information (HuggingFace compatible).
TODO: Implement real user info retrieval from authentication system.
Returns:
User information object
Matches HuggingFace Hub /api/whoami-v2 endpoint format.
Returns user info if authenticated, 401 if not.
"""
# Mock response matching HuggingFace Hub format
if not user:
raise HTTPException(401, detail="Invalid user token")
# Get user's organizations (stub for now - can be implemented later)
orgs = []
return {
"name": "me",
"type": "user",
"displayName": "me",
"email": None,
"orgs": [],
"id": str(user.id),
"name": user.username,
"fullname": user.username,
"email": user.email,
"emailVerified": user.email_verified,
"canPay": False,
"isPro": False,
"periodEnd": None,
"orgs": orgs,
"auth": {
"type": "access_token",
"accessToken": {"displayName": "Auto-generated token", "role": "write"},
},
}

View File

@@ -0,0 +1,6 @@
"""Authentication module."""
from .routes import router
from .dependencies import get_current_user, get_optional_user
__all__ = ["router", "get_current_user", "get_optional_user"]

View File

@@ -0,0 +1,58 @@
"""FastAPI dependencies for authentication."""
from datetime import datetime, timezone
from typing import Optional
from fastapi import Cookie, Header, HTTPException
from ..db import User, Session, Token
from .utils import hash_token
def get_current_user(
session_id: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> User:
"""Get current authenticated user from session or token."""
# Try session-based auth first (web UI)
if session_id:
session = Session.get_or_none(
(Session.session_id == session_id)
& (Session.expires_at > datetime.now(timezone.utc))
)
if session:
user = User.get_or_none(User.id == session.user_id)
if user and user.is_active:
return user
# Try token-based auth (API)
if authorization:
if not authorization.startswith("Bearer "):
raise HTTPException(401, detail="Invalid authorization header")
token_str = authorization[7:] # Remove "Bearer "
token_hash = hash_token(token_str)
token = Token.get_or_none(Token.token_hash == token_hash)
if token:
# Update last used
Token.update(last_used=datetime.now(timezone.utc)).where(
Token.id == token.id
).execute()
user = User.get_or_none(User.id == token.user_id)
if user and user.is_active:
return user
raise HTTPException(401, detail="Not authenticated")
def get_optional_user(
session_id: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> Optional[User]:
"""Get current user if authenticated, otherwise None."""
try:
return get_current_user(session_id, authorization)
except HTTPException:
return None

View File

@@ -0,0 +1,53 @@
"""Email utilities for authentication."""
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from ..config import cfg
def send_verification_email(to_email: str, username: str, token: str) -> bool:
"""Send email verification email."""
if not cfg.smtp.enabled:
print(
f"[EMAIL] SMTP disabled. Verification link: {cfg.app.base_url}/auth/verify?token={token}"
)
return True
subject = "Verify your Kohaku Hub account"
verify_link = f"{cfg.app.base_url}/auth/verify?token={token}"
body = f"""
Hello {username},
Please verify your email address by clicking the link below:
{verify_link}
This link will expire in 24 hours.
If you didn't create this account, please ignore this email.
Best regards,
Kohaku Hub
"""
try:
msg = MIMEMultipart()
msg["From"] = cfg.smtp.from_email
msg["To"] = to_email
msg["Subject"] = subject
msg.attach(MIMEText(body, "plain"))
with smtplib.SMTP(cfg.smtp.host, cfg.smtp.port) as server:
if cfg.smtp.use_tls:
server.starttls()
if cfg.smtp.username and cfg.smtp.password:
server.login(cfg.smtp.username, cfg.smtp.password)
server.send_message(msg)
return True
except Exception as e:
print(f"[EMAIL] Failed to send verification email: {e}")
return False

View File

@@ -0,0 +1,230 @@
"""Authentication API routes."""
from datetime import datetime, timezone
from typing import Optional
from fastapi import APIRouter, HTTPException, Response, Depends
from pydantic import BaseModel, EmailStr
from ..config import cfg
from ..db import User, EmailVerification, Session, Token
from .utils import (
hash_password,
verify_password,
generate_token,
hash_token,
generate_session_secret,
get_expiry_time,
)
from .email import send_verification_email
from .dependencies import get_current_user, get_optional_user
router = APIRouter(prefix="/auth", tags=["auth"])
class RegisterRequest(BaseModel):
username: str
email: EmailStr
password: str
class LoginRequest(BaseModel):
username: str
password: str
class CreateTokenRequest(BaseModel):
name: str
@router.post("/register")
def register(req: RegisterRequest):
"""Register new user."""
# Check if username or email already exists
if User.get_or_none(User.username == req.username):
raise HTTPException(400, detail="Username already exists")
if User.get_or_none(User.email == req.email):
raise HTTPException(400, detail="Email already exists")
# Create user
user = User.create(
username=req.username,
email=req.email,
password_hash=hash_password(req.password),
email_verified=not cfg.auth.require_email_verification,
)
# Send verification email if required
if cfg.auth.require_email_verification:
token = generate_token()
EmailVerification.create(
user=user.id, token=token, expires_at=get_expiry_time(24)
)
if not send_verification_email(req.email, req.username, token):
return {
"success": True,
"message": "User created but failed to send verification email",
"email_verified": False,
}
return {
"success": True,
"message": "User created. Please check your email to verify your account.",
"email_verified": False,
}
return {
"success": True,
"message": "User created successfully",
"email_verified": True,
}
@router.get("/verify-email")
def verify_email(token: str):
"""Verify email with token."""
verification = EmailVerification.get_or_none(
(EmailVerification.token == token)
& (EmailVerification.expires_at > datetime.now(timezone.utc))
)
if not verification:
raise HTTPException(400, detail="Invalid or expired verification token")
# Update user
User.update(email_verified=True).where(User.id == verification.user).execute()
# Delete verification token
EmailVerification.delete().where(EmailVerification.id == verification.id).execute()
return {"success": True, "message": "Email verified successfully"}
@router.post("/login")
def login(req: LoginRequest, response: Response):
"""Login and create session."""
user = User.get_or_none(User.username == req.username)
if not user or not verify_password(req.password, user.password_hash):
raise HTTPException(401, detail="Invalid username or password")
if not user.is_active:
raise HTTPException(403, detail="Account is disabled")
if cfg.auth.require_email_verification and not user.email_verified:
raise HTTPException(403, detail="Please verify your email first")
# Create session
session_id = generate_token()
session_secret = generate_session_secret()
Session.create(
session_id=session_id,
user_id=user.id,
secret=session_secret,
expires_at=get_expiry_time(cfg.auth.session_expire_hours),
)
# Set cookie
response.set_cookie(
key="session_id",
value=session_id,
httponly=True,
max_age=cfg.auth.session_expire_hours * 3600,
samesite="lax",
)
return {
"success": True,
"message": "Logged in successfully",
"username": user.username,
"session_secret": session_secret,
}
@router.post("/logout")
def logout(response: Response, user: User = Depends(get_current_user)):
"""Logout and destroy session."""
# Delete all user sessions
Session.delete().where(Session.user_id == user.id).execute()
# Clear cookie
response.delete_cookie(key="session_id")
return {"success": True, "message": "Logged out successfully"}
@router.get("/me")
def get_me(user: User = Depends(get_current_user)):
"""Get current user info (internal endpoint)."""
return {
"id": user.id,
"username": user.username,
"email": user.email,
"email_verified": user.email_verified,
"created_at": user.created_at.isoformat(),
}
@router.get("/tokens")
def list_tokens(user: User = Depends(get_current_user)):
"""List user's API tokens."""
tokens = Token.select().where(Token.user_id == user.id)
return {
"tokens": [
{
"id": t.id,
"name": t.name,
"last_used": t.last_used.isoformat() if t.last_used else None,
"created_at": t.created_at.isoformat(),
}
for t in tokens
]
}
@router.post("/tokens/create")
def create_token(req: CreateTokenRequest, user: User = Depends(get_current_user)):
"""Create new API token."""
# Generate token
token_str = generate_token()
token_hash_val = hash_token(token_str)
# Save to database
token = Token.create(user_id=user.id, token_hash=token_hash_val, name=req.name)
# Get session secret for encryption (if in web session)
session = Session.get_or_none(Session.user_id == user.id)
session_secret = session.secret if session else None
return {
"success": True,
"token": token_str,
"token_id": token.id,
"session_secret": session_secret,
"message": "Token created. Save it securely - you won't see it again!",
}
@router.delete("/tokens/{token_id}")
def revoke_token(token_id: int, user: User = Depends(get_current_user)):
"""Revoke an API token."""
token = Token.get_or_none((Token.id == token_id) & (Token.user_id == user.id))
if not token:
raise HTTPException(404, detail="Token not found")
token.delete_instance()
return {"success": True, "message": "Token revoked successfully"}

View File

@@ -0,0 +1,39 @@
"""Authentication utilities."""
import secrets
import hashlib
import bcrypt
from datetime import datetime, timedelta, timezone
def hash_password(password: str) -> str:
"""Hash password with bcrypt."""
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
def verify_password(password: str, password_hash: str) -> bool:
"""Verify password against hash."""
try:
return bcrypt.checkpw(password.encode(), password_hash.encode())
except Exception:
return False
def generate_token() -> str:
"""Generate random token (32 bytes = 64 hex chars)."""
return secrets.token_hex(32)
def hash_token(token: str) -> str:
"""Hash token with SHA3-512."""
return hashlib.sha3_512(token.encode()).hexdigest()
def generate_session_secret() -> str:
"""Generate session secret for token encryption."""
return secrets.token_hex(16)
def get_expiry_time(hours: int) -> datetime:
"""Get expiry time from now."""
return datetime.now(timezone.utc) + timedelta(hours=hours)

View File

@@ -23,10 +23,27 @@ class LakeFSConfig(BaseModel):
repo_namespace: str = "hf"
class SMTPConfig(BaseModel):
enabled: bool = False
host: str = "localhost"
port: int = 587
username: str = ""
password: str = ""
from_email: str = "noreply@localhost"
use_tls: bool = True
class AuthConfig(BaseModel):
require_email_verification: bool = False
session_secret: str = "change-me-in-production"
session_expire_hours: int = 168 # 7 days
token_expire_days: int = 365
class AppConfig(BaseModel):
base_url: str
api_base: str = "/api"
db_backend: str = "sqlite" # "sqlite" or "postgres"
db_backend: str = "sqlite"
database_url: str = "sqlite:///./hub.db"
lfs_threshold_bytes: int = 10 * 1024 * 1024
debug_log_payloads: bool = False
@@ -35,6 +52,8 @@ class AppConfig(BaseModel):
class Config(BaseModel):
s3: S3Config
lakefs: LakeFSConfig
smtp: SMTPConfig = SMTPConfig()
auth: AuthConfig = AuthConfig()
app: AppConfig
@@ -42,8 +61,6 @@ class Config(BaseModel):
def load_config(path: str = None) -> Config:
path = path or os.environ.get("HUB_CONFIG", None)
if path is None:
# use environment var (use .get with default value)
# this is crucial for Docker Compose startup method
s3_config = S3Config(
public_endpoint=os.environ["KOHAKU_HUB_S3_PUBLIC_ENDPOINT"],
endpoint=os.environ["KOHAKU_HUB_S3_ENDPOINT"],
@@ -59,6 +76,34 @@ def load_config(path: str = None) -> Config:
secret_key=os.environ["KOHAKU_HUB_LAKEFS_SECRET_KEY"],
repo_namespace=os.environ.get("KOHAKU_HUB_LAKEFS_REPO_NAMESPACE", ""),
)
smtp_config = SMTPConfig(
enabled=os.environ.get("KOHAKU_HUB_SMTP_ENABLED", "false").lower()
== "true",
host=os.environ.get("KOHAKU_HUB_SMTP_HOST", "localhost"),
port=int(os.environ.get("KOHAKU_HUB_SMTP_PORT", "587")),
username=os.environ.get("KOHAKU_HUB_SMTP_USERNAME", ""),
password=os.environ.get("KOHAKU_HUB_SMTP_PASSWORD", ""),
from_email=os.environ.get("KOHAKU_HUB_SMTP_FROM", "noreply@localhost"),
use_tls=os.environ.get("KOHAKU_HUB_SMTP_TLS", "true").lower() == "true",
)
auth_config = AuthConfig(
require_email_verification=os.environ.get(
"KOHAKU_HUB_REQUIRE_EMAIL_VERIFICATION", "false"
).lower()
== "true",
session_secret=os.environ.get(
"KOHAKU_HUB_SESSION_SECRET", "change-me-in-production"
),
session_expire_hours=int(
os.environ.get("KOHAKU_HUB_SESSION_EXPIRE_HOURS", "168")
),
token_expire_days=int(
os.environ.get("KOHAKU_HUB_TOKEN_EXPIRE_DAYS", "365")
),
)
app_config = AppConfig(
base_url=os.environ.get("KOHAKU_HUB_BASE_URL", "127.0.0.1:48888"),
api_base=os.environ.get("KOHAKU_HUB_API_BASE", "/api"),
@@ -69,7 +114,13 @@ def load_config(path: str = None) -> Config:
),
)
return Config(s3=s3_config, lakefs=lakefs_config, app=app_config)
return Config(
s3=s3_config,
lakefs=lakefs_config,
smtp=smtp_config,
auth=auth_config,
app=app_config,
)
else:
with open(path, "rb") as f:
raw = tomllib.load(f)

View File

@@ -11,6 +11,7 @@ from peewee import (
Model,
SqliteDatabase,
PostgresqlDatabase,
TextField,
)
from .config import cfg
@@ -21,7 +22,6 @@ def _sqlite_path(url: str) -> str:
# Choose DB backend
if cfg.app.db_backend == "postgres":
# Example: postgresql://user:pass@host:5432/dbname
url = cfg.app.database_url.replace("postgresql://", "")
user_pass, host_db = url.split("@")
user, password = user_pass.split(":")
@@ -52,6 +52,37 @@ class BaseModel(Model):
class User(BaseModel):
id = AutoField()
username = CharField(unique=True, index=True)
email = CharField(unique=True, index=True)
password_hash = CharField()
email_verified = BooleanField(default=False)
is_active = BooleanField(default=True)
created_at = DateTimeField(default=partial(datetime.now, tz=timezone.utc))
class EmailVerification(BaseModel):
id = AutoField()
user = IntegerField(index=True)
token = CharField(unique=True, index=True)
expires_at = DateTimeField()
created_at = DateTimeField(default=partial(datetime.now, tz=timezone.utc))
class Session(BaseModel):
id = AutoField()
session_id = CharField(unique=True, index=True)
user_id = IntegerField(index=True)
secret = CharField()
expires_at = DateTimeField()
created_at = DateTimeField(default=partial(datetime.now, tz=timezone.utc))
class Token(BaseModel):
id = AutoField()
user_id = IntegerField(index=True)
token_hash = CharField(unique=True, index=True)
name = CharField()
last_used = DateTimeField(null=True)
created_at = DateTimeField(default=partial(datetime.now, tz=timezone.utc))
class Repository(BaseModel):
@@ -61,6 +92,7 @@ class Repository(BaseModel):
name = CharField(index=True)
full_id = CharField(unique=True, index=True)
private = BooleanField(default=False)
owner_id = IntegerField(index=True, default=1)
created_at = DateTimeField(default=partial(datetime.now, tz=timezone.utc))
class Meta:
@@ -97,4 +129,7 @@ class StagingUpload(BaseModel):
def init_db():
db.connect(reuse_if_open=True)
db.create_tables([User, Repository, File, StagingUpload], safe=True)
db.create_tables(
[User, EmailVerification, Session, Token, Repository, File, StagingUpload],
safe=True,
)

View File

@@ -5,6 +5,7 @@ from fastapi import FastAPI, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from .api import basic, file, lfs, utils
from .auth import router as auth_router
from .config import cfg
from .db import Repository
from .api.file import resolve_file
@@ -13,7 +14,6 @@ from .api.s3_utils import init_storage
@asynccontextmanager
async def lifespan(app: FastAPI):
# Load the ML model
init_storage()
yield
@@ -26,7 +26,6 @@ app = FastAPI(
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
@@ -35,32 +34,19 @@ app.add_middleware(
allow_headers=["*"],
)
# Mount API routers with configured prefix
app.include_router(auth_router, prefix=cfg.app.api_base)
app.include_router(basic.router, prefix=cfg.app.api_base, tags=["repositories"])
app.include_router(file.router, prefix=cfg.app.api_base, tags=["files"])
app.include_router(lfs.router, tags=["lfs"])
app.include_router(utils.router, prefix=cfg.app.api_base, tags=["utils"])
# Public download endpoint (no /api prefix, matches HuggingFace URL pattern)
@app.get("/{namespace}/{name}/resolve/{revision}/{path:path}")
@app.head("/{namespace}/{name}/resolve/{revision}/{path:path}")
async def public_resolve(
namespace: str, name: str, revision: str, path: str, request: Request
):
"""Public download endpoint without /api prefix.
Matches HuggingFace Hub URL pattern for direct file downloads.
Defaults to model repository type.
Args:
repo_id: Repository ID (e.g., "org/repo")
revision: Branch name or commit hash
path: File path within repository
Returns:
File download response or redirect
"""
"""Public download endpoint without /api prefix."""
repo = Repository.get_or_none(name=name, namespace=namespace)
if not repo:
@@ -85,6 +71,7 @@ def root():
"description": "HuggingFace-compatible hub with LakeFS and S3 storage",
"endpoints": {
"api": cfg.app.api_base,
"auth": f"{cfg.app.api_base}/auth",
"docs": "/docs",
"redoc": "/redoc",
},