mirror of
https://github.com/KohakuBlueleaf/KohakuHub.git
synced 2026-03-11 17:34:08 -05:00
Add Auth/User system
This commit is contained in:
@@ -16,13 +16,14 @@ dependencies = [
|
||||
"fastapi",
|
||||
"httpx",
|
||||
"uvicorn",
|
||||
"pydantic",
|
||||
"pydantic[email]",
|
||||
"toml",
|
||||
"peewee",
|
||||
"boto3",
|
||||
"lakefs-client",
|
||||
"psycopg2-binary",
|
||||
"pyyaml",
|
||||
"bcrypt",
|
||||
]
|
||||
urls = { "Homepage" = "https://kblueleaf.net/Kohaku-Hub" }
|
||||
|
||||
|
||||
87
scripts/test_auth.py
Normal file
87
scripts/test_auth.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Test authentication system."""
|
||||
|
||||
import requests
|
||||
|
||||
BASE_URL = "http://127.0.0.1:48888/api"
|
||||
|
||||
# 1. Register
|
||||
print("=== Testing Registration ===")
|
||||
resp = requests.post(
|
||||
f"{BASE_URL}/auth/register",
|
||||
json={
|
||||
"username": "testuser2",
|
||||
"email": "test2@example.com",
|
||||
"password": "testpass123",
|
||||
},
|
||||
)
|
||||
print(f"Status: {resp.status_code}")
|
||||
print(f"Response: {resp.json()}\n")
|
||||
|
||||
# 2. Login
|
||||
print("=== Testing Login ===")
|
||||
resp = requests.post(
|
||||
f"{BASE_URL}/auth/login", json={"username": "testuser2", "password": "testpass123"}
|
||||
)
|
||||
print(f"Status: {resp.status_code}")
|
||||
result = resp.json()
|
||||
print(f"Response: {result}\n")
|
||||
|
||||
# Save session cookie and secret
|
||||
session = requests.Session()
|
||||
session.cookies.update(resp.cookies)
|
||||
session_secret = result.get("session_secret")
|
||||
print(f"Session secret: {session_secret}\n")
|
||||
|
||||
# 3. Get current user
|
||||
print("=== Testing Get Current User ===")
|
||||
resp = session.get(f"{BASE_URL}/auth/me")
|
||||
print(f"Status: {resp.status_code}")
|
||||
print(f"Response: {resp.json()}\n")
|
||||
|
||||
# 4. Create API token
|
||||
print("=== Testing Create Token ===")
|
||||
resp = session.post(f"{BASE_URL}/auth/tokens/create", json={"name": "test-token"})
|
||||
print(f"Status: {resp.status_code}")
|
||||
result = resp.json()
|
||||
print(f"Response: {result}\n")
|
||||
|
||||
token = result["token"]
|
||||
print(f"Generated token: {token}\n")
|
||||
print(f"Session secret for encryption: {result['session_secret']}\n")
|
||||
|
||||
# 5-1. List tokens
|
||||
print("=== Testing List Tokens ===")
|
||||
resp = session.get(f"{BASE_URL}/auth/tokens")
|
||||
print(f"Status: {resp.status_code}")
|
||||
print(f"Response: {resp.json()}\n")
|
||||
|
||||
# 6. Test token-based auth
|
||||
print("=== Testing Token Auth ===")
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
resp = requests.get(f"{BASE_URL}/auth/me", headers=headers)
|
||||
print(f"Status: {resp.status_code}")
|
||||
print(f"Response: {resp.json()}\n")
|
||||
|
||||
# 5-2. List tokens
|
||||
print("=== Testing List Tokens ===")
|
||||
resp = session.get(f"{BASE_URL}/auth/tokens")
|
||||
print(f"Status: {resp.status_code}")
|
||||
print(f"Response: {resp.json()}\n")
|
||||
|
||||
# 7. Logout
|
||||
print("=== Testing Logout ===")
|
||||
resp = session.post(f"{BASE_URL}/auth/logout")
|
||||
print(f"Status: {resp.status_code}")
|
||||
print(f"Response: {resp.json()}\n")
|
||||
|
||||
# 8. Verify session cleared
|
||||
print("=== Testing Session Cleared ===")
|
||||
resp = session.get(f"{BASE_URL}/auth/me")
|
||||
print(f"Status: {resp.status_code}")
|
||||
print(f"Response: {resp.json()}\n")
|
||||
|
||||
# 9. Verify token still works
|
||||
print("=== Testing Token Still Works ===")
|
||||
resp = requests.get(f"{BASE_URL}/auth/me", headers=headers)
|
||||
print(f"Status: {resp.status_code}")
|
||||
print(f"Response: {resp.json()}\n")
|
||||
@@ -1,10 +1,13 @@
|
||||
"""Authentication and authorization for Kohaku Hub API.
|
||||
|
||||
TODO: Implement real authentication system.
|
||||
Currently returns mock user for development.
|
||||
Integrates with the new auth system in kohakuhub.auth module.
|
||||
"""
|
||||
|
||||
from ..db import User, db
|
||||
from ..db import db, User
|
||||
from ..auth.dependencies import (
|
||||
get_current_user as auth_get_current_user,
|
||||
get_optional_user,
|
||||
)
|
||||
|
||||
|
||||
def get_db():
|
||||
@@ -20,19 +23,6 @@ def get_db():
|
||||
def get_current_user():
|
||||
"""Get current authenticated user.
|
||||
|
||||
TODO: Implement real authentication:
|
||||
- Parse Authorization header (Bearer token)
|
||||
- Validate token against database or JWT
|
||||
- Return actual User object
|
||||
- Raise HTTPException(401) if invalid
|
||||
|
||||
Returns:
|
||||
Mock user object for development.
|
||||
Now delegates to the real auth system.
|
||||
"""
|
||||
|
||||
# Mock user for development
|
||||
class MockUser:
|
||||
username = "me"
|
||||
id = 1
|
||||
|
||||
return MockUser()
|
||||
return auth_get_current_user()
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
"""Utility API endpoints for Kohaku Hub."""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from pydantic import BaseModel
|
||||
|
||||
import yaml
|
||||
from ..db import User
|
||||
from .auth import get_optional_user
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -34,21 +37,30 @@ def validate_yaml(body: ValidateYamlPayload):
|
||||
|
||||
|
||||
@router.get("/whoami-v2")
|
||||
def whoami_v2():
|
||||
"""Get current user information.
|
||||
def whoami_v2(user: User = Depends(get_optional_user)):
|
||||
"""Get current user information (HuggingFace compatible).
|
||||
|
||||
TODO: Implement real user info retrieval from authentication system.
|
||||
|
||||
Returns:
|
||||
User information object
|
||||
Matches HuggingFace Hub /api/whoami-v2 endpoint format.
|
||||
Returns user info if authenticated, 401 if not.
|
||||
"""
|
||||
# Mock response matching HuggingFace Hub format
|
||||
if not user:
|
||||
raise HTTPException(401, detail="Invalid user token")
|
||||
|
||||
# Get user's organizations (stub for now - can be implemented later)
|
||||
orgs = []
|
||||
|
||||
return {
|
||||
"name": "me",
|
||||
"type": "user",
|
||||
"displayName": "me",
|
||||
"email": None,
|
||||
"orgs": [],
|
||||
"id": str(user.id),
|
||||
"name": user.username,
|
||||
"fullname": user.username,
|
||||
"email": user.email,
|
||||
"emailVerified": user.email_verified,
|
||||
"canPay": False,
|
||||
"isPro": False,
|
||||
"periodEnd": None,
|
||||
"orgs": orgs,
|
||||
"auth": {
|
||||
"type": "access_token",
|
||||
"accessToken": {"displayName": "Auto-generated token", "role": "write"},
|
||||
},
|
||||
}
|
||||
|
||||
6
src/kohakuhub/auth/__init__.py
Normal file
6
src/kohakuhub/auth/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Authentication module."""
|
||||
|
||||
from .routes import router
|
||||
from .dependencies import get_current_user, get_optional_user
|
||||
|
||||
__all__ = ["router", "get_current_user", "get_optional_user"]
|
||||
58
src/kohakuhub/auth/dependencies.py
Normal file
58
src/kohakuhub/auth/dependencies.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""FastAPI dependencies for authentication."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from fastapi import Cookie, Header, HTTPException
|
||||
|
||||
from ..db import User, Session, Token
|
||||
from .utils import hash_token
|
||||
|
||||
|
||||
def get_current_user(
|
||||
session_id: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> User:
|
||||
"""Get current authenticated user from session or token."""
|
||||
|
||||
# Try session-based auth first (web UI)
|
||||
if session_id:
|
||||
session = Session.get_or_none(
|
||||
(Session.session_id == session_id)
|
||||
& (Session.expires_at > datetime.now(timezone.utc))
|
||||
)
|
||||
if session:
|
||||
user = User.get_or_none(User.id == session.user_id)
|
||||
if user and user.is_active:
|
||||
return user
|
||||
|
||||
# Try token-based auth (API)
|
||||
if authorization:
|
||||
if not authorization.startswith("Bearer "):
|
||||
raise HTTPException(401, detail="Invalid authorization header")
|
||||
|
||||
token_str = authorization[7:] # Remove "Bearer "
|
||||
token_hash = hash_token(token_str)
|
||||
|
||||
token = Token.get_or_none(Token.token_hash == token_hash)
|
||||
if token:
|
||||
# Update last used
|
||||
Token.update(last_used=datetime.now(timezone.utc)).where(
|
||||
Token.id == token.id
|
||||
).execute()
|
||||
|
||||
user = User.get_or_none(User.id == token.user_id)
|
||||
if user and user.is_active:
|
||||
return user
|
||||
|
||||
raise HTTPException(401, detail="Not authenticated")
|
||||
|
||||
|
||||
def get_optional_user(
|
||||
session_id: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> Optional[User]:
|
||||
"""Get current user if authenticated, otherwise None."""
|
||||
try:
|
||||
return get_current_user(session_id, authorization)
|
||||
except HTTPException:
|
||||
return None
|
||||
53
src/kohakuhub/auth/email.py
Normal file
53
src/kohakuhub/auth/email.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""Email utilities for authentication."""
|
||||
|
||||
import smtplib
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
|
||||
from ..config import cfg
|
||||
|
||||
|
||||
def send_verification_email(to_email: str, username: str, token: str) -> bool:
|
||||
"""Send email verification email."""
|
||||
if not cfg.smtp.enabled:
|
||||
print(
|
||||
f"[EMAIL] SMTP disabled. Verification link: {cfg.app.base_url}/auth/verify?token={token}"
|
||||
)
|
||||
return True
|
||||
|
||||
subject = "Verify your Kohaku Hub account"
|
||||
verify_link = f"{cfg.app.base_url}/auth/verify?token={token}"
|
||||
|
||||
body = f"""
|
||||
Hello {username},
|
||||
|
||||
Please verify your email address by clicking the link below:
|
||||
|
||||
{verify_link}
|
||||
|
||||
This link will expire in 24 hours.
|
||||
|
||||
If you didn't create this account, please ignore this email.
|
||||
|
||||
Best regards,
|
||||
Kohaku Hub
|
||||
"""
|
||||
|
||||
try:
|
||||
msg = MIMEMultipart()
|
||||
msg["From"] = cfg.smtp.from_email
|
||||
msg["To"] = to_email
|
||||
msg["Subject"] = subject
|
||||
msg.attach(MIMEText(body, "plain"))
|
||||
|
||||
with smtplib.SMTP(cfg.smtp.host, cfg.smtp.port) as server:
|
||||
if cfg.smtp.use_tls:
|
||||
server.starttls()
|
||||
if cfg.smtp.username and cfg.smtp.password:
|
||||
server.login(cfg.smtp.username, cfg.smtp.password)
|
||||
server.send_message(msg)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"[EMAIL] Failed to send verification email: {e}")
|
||||
return False
|
||||
230
src/kohakuhub/auth/routes.py
Normal file
230
src/kohakuhub/auth/routes.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""Authentication API routes."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, HTTPException, Response, Depends
|
||||
from pydantic import BaseModel, EmailStr
|
||||
|
||||
from ..config import cfg
|
||||
from ..db import User, EmailVerification, Session, Token
|
||||
from .utils import (
|
||||
hash_password,
|
||||
verify_password,
|
||||
generate_token,
|
||||
hash_token,
|
||||
generate_session_secret,
|
||||
get_expiry_time,
|
||||
)
|
||||
from .email import send_verification_email
|
||||
from .dependencies import get_current_user, get_optional_user
|
||||
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
username: str
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class CreateTokenRequest(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
@router.post("/register")
|
||||
def register(req: RegisterRequest):
|
||||
"""Register new user."""
|
||||
|
||||
# Check if username or email already exists
|
||||
if User.get_or_none(User.username == req.username):
|
||||
raise HTTPException(400, detail="Username already exists")
|
||||
|
||||
if User.get_or_none(User.email == req.email):
|
||||
raise HTTPException(400, detail="Email already exists")
|
||||
|
||||
# Create user
|
||||
user = User.create(
|
||||
username=req.username,
|
||||
email=req.email,
|
||||
password_hash=hash_password(req.password),
|
||||
email_verified=not cfg.auth.require_email_verification,
|
||||
)
|
||||
|
||||
# Send verification email if required
|
||||
if cfg.auth.require_email_verification:
|
||||
token = generate_token()
|
||||
EmailVerification.create(
|
||||
user=user.id, token=token, expires_at=get_expiry_time(24)
|
||||
)
|
||||
|
||||
if not send_verification_email(req.email, req.username, token):
|
||||
return {
|
||||
"success": True,
|
||||
"message": "User created but failed to send verification email",
|
||||
"email_verified": False,
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "User created. Please check your email to verify your account.",
|
||||
"email_verified": False,
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "User created successfully",
|
||||
"email_verified": True,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/verify-email")
|
||||
def verify_email(token: str):
|
||||
"""Verify email with token."""
|
||||
|
||||
verification = EmailVerification.get_or_none(
|
||||
(EmailVerification.token == token)
|
||||
& (EmailVerification.expires_at > datetime.now(timezone.utc))
|
||||
)
|
||||
|
||||
if not verification:
|
||||
raise HTTPException(400, detail="Invalid or expired verification token")
|
||||
|
||||
# Update user
|
||||
User.update(email_verified=True).where(User.id == verification.user).execute()
|
||||
|
||||
# Delete verification token
|
||||
EmailVerification.delete().where(EmailVerification.id == verification.id).execute()
|
||||
|
||||
return {"success": True, "message": "Email verified successfully"}
|
||||
|
||||
|
||||
@router.post("/login")
|
||||
def login(req: LoginRequest, response: Response):
|
||||
"""Login and create session."""
|
||||
|
||||
user = User.get_or_none(User.username == req.username)
|
||||
|
||||
if not user or not verify_password(req.password, user.password_hash):
|
||||
raise HTTPException(401, detail="Invalid username or password")
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(403, detail="Account is disabled")
|
||||
|
||||
if cfg.auth.require_email_verification and not user.email_verified:
|
||||
raise HTTPException(403, detail="Please verify your email first")
|
||||
|
||||
# Create session
|
||||
session_id = generate_token()
|
||||
session_secret = generate_session_secret()
|
||||
|
||||
Session.create(
|
||||
session_id=session_id,
|
||||
user_id=user.id,
|
||||
secret=session_secret,
|
||||
expires_at=get_expiry_time(cfg.auth.session_expire_hours),
|
||||
)
|
||||
|
||||
# Set cookie
|
||||
response.set_cookie(
|
||||
key="session_id",
|
||||
value=session_id,
|
||||
httponly=True,
|
||||
max_age=cfg.auth.session_expire_hours * 3600,
|
||||
samesite="lax",
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Logged in successfully",
|
||||
"username": user.username,
|
||||
"session_secret": session_secret,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
def logout(response: Response, user: User = Depends(get_current_user)):
|
||||
"""Logout and destroy session."""
|
||||
|
||||
# Delete all user sessions
|
||||
Session.delete().where(Session.user_id == user.id).execute()
|
||||
|
||||
# Clear cookie
|
||||
response.delete_cookie(key="session_id")
|
||||
|
||||
return {"success": True, "message": "Logged out successfully"}
|
||||
|
||||
|
||||
@router.get("/me")
|
||||
def get_me(user: User = Depends(get_current_user)):
|
||||
"""Get current user info (internal endpoint)."""
|
||||
|
||||
return {
|
||||
"id": user.id,
|
||||
"username": user.username,
|
||||
"email": user.email,
|
||||
"email_verified": user.email_verified,
|
||||
"created_at": user.created_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/tokens")
|
||||
def list_tokens(user: User = Depends(get_current_user)):
|
||||
"""List user's API tokens."""
|
||||
|
||||
tokens = Token.select().where(Token.user_id == user.id)
|
||||
|
||||
return {
|
||||
"tokens": [
|
||||
{
|
||||
"id": t.id,
|
||||
"name": t.name,
|
||||
"last_used": t.last_used.isoformat() if t.last_used else None,
|
||||
"created_at": t.created_at.isoformat(),
|
||||
}
|
||||
for t in tokens
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@router.post("/tokens/create")
|
||||
def create_token(req: CreateTokenRequest, user: User = Depends(get_current_user)):
|
||||
"""Create new API token."""
|
||||
|
||||
# Generate token
|
||||
token_str = generate_token()
|
||||
token_hash_val = hash_token(token_str)
|
||||
|
||||
# Save to database
|
||||
token = Token.create(user_id=user.id, token_hash=token_hash_val, name=req.name)
|
||||
|
||||
# Get session secret for encryption (if in web session)
|
||||
session = Session.get_or_none(Session.user_id == user.id)
|
||||
session_secret = session.secret if session else None
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"token": token_str,
|
||||
"token_id": token.id,
|
||||
"session_secret": session_secret,
|
||||
"message": "Token created. Save it securely - you won't see it again!",
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/tokens/{token_id}")
|
||||
def revoke_token(token_id: int, user: User = Depends(get_current_user)):
|
||||
"""Revoke an API token."""
|
||||
|
||||
token = Token.get_or_none((Token.id == token_id) & (Token.user_id == user.id))
|
||||
|
||||
if not token:
|
||||
raise HTTPException(404, detail="Token not found")
|
||||
|
||||
token.delete_instance()
|
||||
|
||||
return {"success": True, "message": "Token revoked successfully"}
|
||||
39
src/kohakuhub/auth/utils.py
Normal file
39
src/kohakuhub/auth/utils.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""Authentication utilities."""
|
||||
|
||||
import secrets
|
||||
import hashlib
|
||||
import bcrypt
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash password with bcrypt."""
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
|
||||
def verify_password(password: str, password_hash: str) -> bool:
|
||||
"""Verify password against hash."""
|
||||
try:
|
||||
return bcrypt.checkpw(password.encode(), password_hash.encode())
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def generate_token() -> str:
|
||||
"""Generate random token (32 bytes = 64 hex chars)."""
|
||||
return secrets.token_hex(32)
|
||||
|
||||
|
||||
def hash_token(token: str) -> str:
|
||||
"""Hash token with SHA3-512."""
|
||||
return hashlib.sha3_512(token.encode()).hexdigest()
|
||||
|
||||
|
||||
def generate_session_secret() -> str:
|
||||
"""Generate session secret for token encryption."""
|
||||
return secrets.token_hex(16)
|
||||
|
||||
|
||||
def get_expiry_time(hours: int) -> datetime:
|
||||
"""Get expiry time from now."""
|
||||
return datetime.now(timezone.utc) + timedelta(hours=hours)
|
||||
@@ -23,10 +23,27 @@ class LakeFSConfig(BaseModel):
|
||||
repo_namespace: str = "hf"
|
||||
|
||||
|
||||
class SMTPConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
host: str = "localhost"
|
||||
port: int = 587
|
||||
username: str = ""
|
||||
password: str = ""
|
||||
from_email: str = "noreply@localhost"
|
||||
use_tls: bool = True
|
||||
|
||||
|
||||
class AuthConfig(BaseModel):
|
||||
require_email_verification: bool = False
|
||||
session_secret: str = "change-me-in-production"
|
||||
session_expire_hours: int = 168 # 7 days
|
||||
token_expire_days: int = 365
|
||||
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
base_url: str
|
||||
api_base: str = "/api"
|
||||
db_backend: str = "sqlite" # "sqlite" or "postgres"
|
||||
db_backend: str = "sqlite"
|
||||
database_url: str = "sqlite:///./hub.db"
|
||||
lfs_threshold_bytes: int = 10 * 1024 * 1024
|
||||
debug_log_payloads: bool = False
|
||||
@@ -35,6 +52,8 @@ class AppConfig(BaseModel):
|
||||
class Config(BaseModel):
|
||||
s3: S3Config
|
||||
lakefs: LakeFSConfig
|
||||
smtp: SMTPConfig = SMTPConfig()
|
||||
auth: AuthConfig = AuthConfig()
|
||||
app: AppConfig
|
||||
|
||||
|
||||
@@ -42,8 +61,6 @@ class Config(BaseModel):
|
||||
def load_config(path: str = None) -> Config:
|
||||
path = path or os.environ.get("HUB_CONFIG", None)
|
||||
if path is None:
|
||||
# use environment var (use .get with default value)
|
||||
# this is crucial for Docker Compose startup method
|
||||
s3_config = S3Config(
|
||||
public_endpoint=os.environ["KOHAKU_HUB_S3_PUBLIC_ENDPOINT"],
|
||||
endpoint=os.environ["KOHAKU_HUB_S3_ENDPOINT"],
|
||||
@@ -59,6 +76,34 @@ def load_config(path: str = None) -> Config:
|
||||
secret_key=os.environ["KOHAKU_HUB_LAKEFS_SECRET_KEY"],
|
||||
repo_namespace=os.environ.get("KOHAKU_HUB_LAKEFS_REPO_NAMESPACE", ""),
|
||||
)
|
||||
|
||||
smtp_config = SMTPConfig(
|
||||
enabled=os.environ.get("KOHAKU_HUB_SMTP_ENABLED", "false").lower()
|
||||
== "true",
|
||||
host=os.environ.get("KOHAKU_HUB_SMTP_HOST", "localhost"),
|
||||
port=int(os.environ.get("KOHAKU_HUB_SMTP_PORT", "587")),
|
||||
username=os.environ.get("KOHAKU_HUB_SMTP_USERNAME", ""),
|
||||
password=os.environ.get("KOHAKU_HUB_SMTP_PASSWORD", ""),
|
||||
from_email=os.environ.get("KOHAKU_HUB_SMTP_FROM", "noreply@localhost"),
|
||||
use_tls=os.environ.get("KOHAKU_HUB_SMTP_TLS", "true").lower() == "true",
|
||||
)
|
||||
|
||||
auth_config = AuthConfig(
|
||||
require_email_verification=os.environ.get(
|
||||
"KOHAKU_HUB_REQUIRE_EMAIL_VERIFICATION", "false"
|
||||
).lower()
|
||||
== "true",
|
||||
session_secret=os.environ.get(
|
||||
"KOHAKU_HUB_SESSION_SECRET", "change-me-in-production"
|
||||
),
|
||||
session_expire_hours=int(
|
||||
os.environ.get("KOHAKU_HUB_SESSION_EXPIRE_HOURS", "168")
|
||||
),
|
||||
token_expire_days=int(
|
||||
os.environ.get("KOHAKU_HUB_TOKEN_EXPIRE_DAYS", "365")
|
||||
),
|
||||
)
|
||||
|
||||
app_config = AppConfig(
|
||||
base_url=os.environ.get("KOHAKU_HUB_BASE_URL", "127.0.0.1:48888"),
|
||||
api_base=os.environ.get("KOHAKU_HUB_API_BASE", "/api"),
|
||||
@@ -69,7 +114,13 @@ def load_config(path: str = None) -> Config:
|
||||
),
|
||||
)
|
||||
|
||||
return Config(s3=s3_config, lakefs=lakefs_config, app=app_config)
|
||||
return Config(
|
||||
s3=s3_config,
|
||||
lakefs=lakefs_config,
|
||||
smtp=smtp_config,
|
||||
auth=auth_config,
|
||||
app=app_config,
|
||||
)
|
||||
else:
|
||||
with open(path, "rb") as f:
|
||||
raw = tomllib.load(f)
|
||||
|
||||
@@ -11,6 +11,7 @@ from peewee import (
|
||||
Model,
|
||||
SqliteDatabase,
|
||||
PostgresqlDatabase,
|
||||
TextField,
|
||||
)
|
||||
from .config import cfg
|
||||
|
||||
@@ -21,7 +22,6 @@ def _sqlite_path(url: str) -> str:
|
||||
|
||||
# Choose DB backend
|
||||
if cfg.app.db_backend == "postgres":
|
||||
# Example: postgresql://user:pass@host:5432/dbname
|
||||
url = cfg.app.database_url.replace("postgresql://", "")
|
||||
user_pass, host_db = url.split("@")
|
||||
user, password = user_pass.split(":")
|
||||
@@ -52,6 +52,37 @@ class BaseModel(Model):
|
||||
class User(BaseModel):
|
||||
id = AutoField()
|
||||
username = CharField(unique=True, index=True)
|
||||
email = CharField(unique=True, index=True)
|
||||
password_hash = CharField()
|
||||
email_verified = BooleanField(default=False)
|
||||
is_active = BooleanField(default=True)
|
||||
created_at = DateTimeField(default=partial(datetime.now, tz=timezone.utc))
|
||||
|
||||
|
||||
class EmailVerification(BaseModel):
|
||||
id = AutoField()
|
||||
user = IntegerField(index=True)
|
||||
token = CharField(unique=True, index=True)
|
||||
expires_at = DateTimeField()
|
||||
created_at = DateTimeField(default=partial(datetime.now, tz=timezone.utc))
|
||||
|
||||
|
||||
class Session(BaseModel):
|
||||
id = AutoField()
|
||||
session_id = CharField(unique=True, index=True)
|
||||
user_id = IntegerField(index=True)
|
||||
secret = CharField()
|
||||
expires_at = DateTimeField()
|
||||
created_at = DateTimeField(default=partial(datetime.now, tz=timezone.utc))
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
id = AutoField()
|
||||
user_id = IntegerField(index=True)
|
||||
token_hash = CharField(unique=True, index=True)
|
||||
name = CharField()
|
||||
last_used = DateTimeField(null=True)
|
||||
created_at = DateTimeField(default=partial(datetime.now, tz=timezone.utc))
|
||||
|
||||
|
||||
class Repository(BaseModel):
|
||||
@@ -61,6 +92,7 @@ class Repository(BaseModel):
|
||||
name = CharField(index=True)
|
||||
full_id = CharField(unique=True, index=True)
|
||||
private = BooleanField(default=False)
|
||||
owner_id = IntegerField(index=True, default=1)
|
||||
created_at = DateTimeField(default=partial(datetime.now, tz=timezone.utc))
|
||||
|
||||
class Meta:
|
||||
@@ -97,4 +129,7 @@ class StagingUpload(BaseModel):
|
||||
|
||||
def init_db():
|
||||
db.connect(reuse_if_open=True)
|
||||
db.create_tables([User, Repository, File, StagingUpload], safe=True)
|
||||
db.create_tables(
|
||||
[User, EmailVerification, Session, Token, Repository, File, StagingUpload],
|
||||
safe=True,
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from .api import basic, file, lfs, utils
|
||||
from .auth import router as auth_router
|
||||
from .config import cfg
|
||||
from .db import Repository
|
||||
from .api.file import resolve_file
|
||||
@@ -13,7 +14,6 @@ from .api.s3_utils import init_storage
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Load the ML model
|
||||
init_storage()
|
||||
yield
|
||||
|
||||
@@ -26,7 +26,6 @@ app = FastAPI(
|
||||
)
|
||||
|
||||
|
||||
# CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
@@ -35,32 +34,19 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Mount API routers with configured prefix
|
||||
app.include_router(auth_router, prefix=cfg.app.api_base)
|
||||
app.include_router(basic.router, prefix=cfg.app.api_base, tags=["repositories"])
|
||||
app.include_router(file.router, prefix=cfg.app.api_base, tags=["files"])
|
||||
app.include_router(lfs.router, tags=["lfs"])
|
||||
app.include_router(utils.router, prefix=cfg.app.api_base, tags=["utils"])
|
||||
|
||||
|
||||
# Public download endpoint (no /api prefix, matches HuggingFace URL pattern)
|
||||
@app.get("/{namespace}/{name}/resolve/{revision}/{path:path}")
|
||||
@app.head("/{namespace}/{name}/resolve/{revision}/{path:path}")
|
||||
async def public_resolve(
|
||||
namespace: str, name: str, revision: str, path: str, request: Request
|
||||
):
|
||||
"""Public download endpoint without /api prefix.
|
||||
|
||||
Matches HuggingFace Hub URL pattern for direct file downloads.
|
||||
Defaults to model repository type.
|
||||
|
||||
Args:
|
||||
repo_id: Repository ID (e.g., "org/repo")
|
||||
revision: Branch name or commit hash
|
||||
path: File path within repository
|
||||
|
||||
Returns:
|
||||
File download response or redirect
|
||||
"""
|
||||
"""Public download endpoint without /api prefix."""
|
||||
|
||||
repo = Repository.get_or_none(name=name, namespace=namespace)
|
||||
if not repo:
|
||||
@@ -85,6 +71,7 @@ def root():
|
||||
"description": "HuggingFace-compatible hub with LakeFS and S3 storage",
|
||||
"endpoints": {
|
||||
"api": cfg.app.api_base,
|
||||
"auth": f"{cfg.app.api_base}/auth",
|
||||
"docs": "/docs",
|
||||
"redoc": "/redoc",
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user