mirror of
https://github.com/KohakuBlueleaf/KohakuHub.git
synced 2026-04-28 01:57:14 -05:00
better fallback mechanism
This commit is contained in:
@@ -103,8 +103,16 @@ async def create_fallback_source(
|
||||
name=source.name,
|
||||
source_type=source.source_type,
|
||||
enabled=source.enabled,
|
||||
created_at=source.created_at.isoformat(),
|
||||
updated_at=source.updated_at.isoformat(),
|
||||
created_at=(
|
||||
source.created_at.isoformat()
|
||||
if isinstance(source.created_at, datetime)
|
||||
else source.created_at
|
||||
),
|
||||
updated_at=(
|
||||
source.updated_at.isoformat()
|
||||
if isinstance(source.updated_at, datetime)
|
||||
else source.updated_at
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -149,8 +157,16 @@ async def list_fallback_sources(
|
||||
name=s.name,
|
||||
source_type=s.source_type,
|
||||
enabled=s.enabled,
|
||||
created_at=s.created_at.isoformat(),
|
||||
updated_at=s.updated_at.isoformat(),
|
||||
created_at=(
|
||||
s.created_at.isoformat()
|
||||
if isinstance(s.created_at, datetime)
|
||||
else s.created_at
|
||||
),
|
||||
updated_at=(
|
||||
s.updated_at.isoformat()
|
||||
if isinstance(s.updated_at, datetime)
|
||||
else s.updated_at
|
||||
),
|
||||
)
|
||||
for s in sources
|
||||
]
|
||||
@@ -183,8 +199,16 @@ async def get_fallback_source(source_id: int, _admin=Depends(verify_admin_token)
|
||||
name=source.name,
|
||||
source_type=source.source_type,
|
||||
enabled=source.enabled,
|
||||
created_at=source.created_at.isoformat(),
|
||||
updated_at=source.updated_at.isoformat(),
|
||||
created_at=(
|
||||
source.created_at.isoformat()
|
||||
if isinstance(source.created_at, datetime)
|
||||
else source.created_at
|
||||
),
|
||||
updated_at=(
|
||||
source.updated_at.isoformat()
|
||||
if isinstance(source.updated_at, datetime)
|
||||
else source.updated_at
|
||||
),
|
||||
)
|
||||
|
||||
except FallbackSource.DoesNotExist:
|
||||
@@ -253,8 +277,16 @@ async def update_fallback_source(
|
||||
name=source.name,
|
||||
source_type=source.source_type,
|
||||
enabled=source.enabled,
|
||||
created_at=source.created_at.isoformat(),
|
||||
updated_at=source.updated_at.isoformat(),
|
||||
created_at=(
|
||||
source.created_at.isoformat()
|
||||
if isinstance(source.created_at, datetime)
|
||||
else source.created_at
|
||||
),
|
||||
updated_at=(
|
||||
source.updated_at.isoformat()
|
||||
if isinstance(source.updated_at, datetime)
|
||||
else source.updated_at
|
||||
),
|
||||
)
|
||||
|
||||
except FallbackSource.DoesNotExist:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import io
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
|
||||
from fastapi.responses import Response
|
||||
from PIL import Image
|
||||
|
||||
@@ -17,6 +17,7 @@ from kohakuhub.db_operations import (
|
||||
)
|
||||
from kohakuhub.logger import get_logger
|
||||
from kohakuhub.auth.dependencies import get_current_user, get_optional_user
|
||||
from kohakuhub.api.fallback import with_user_fallback
|
||||
|
||||
logger = get_logger("AVATAR")
|
||||
|
||||
@@ -194,17 +195,23 @@ async def upload_user_avatar(
|
||||
|
||||
|
||||
@router.get("/users/{username}/avatar")
|
||||
@with_user_fallback("avatar")
|
||||
async def get_user_avatar(
|
||||
username: str, _user: User | None = Depends(get_optional_user)
|
||||
username: str,
|
||||
request: Request,
|
||||
fallback: bool = True,
|
||||
_user: User | None = Depends(get_optional_user),
|
||||
):
|
||||
"""Get user avatar image.
|
||||
|
||||
Args:
|
||||
username: Username
|
||||
request: FastAPI request object
|
||||
fallback: Enable fallback to external sources
|
||||
_user: Optional authenticated user (for logging)
|
||||
|
||||
Returns:
|
||||
JPEG image
|
||||
JPEG image (can be from local or fallback source)
|
||||
|
||||
Raises:
|
||||
HTTPException: If user not found or no avatar
|
||||
@@ -336,17 +343,23 @@ async def upload_org_avatar(
|
||||
|
||||
|
||||
@router.get("/organizations/{org_name}/avatar")
|
||||
@with_user_fallback("avatar")
|
||||
async def get_org_avatar(
|
||||
org_name: str, _user: User | None = Depends(get_optional_user)
|
||||
org_name: str,
|
||||
request: Request,
|
||||
fallback: bool = True,
|
||||
_user: User | None = Depends(get_optional_user),
|
||||
):
|
||||
"""Get organization avatar image.
|
||||
|
||||
Args:
|
||||
org_name: Organization name
|
||||
request: FastAPI request object
|
||||
fallback: Enable fallback to external sources
|
||||
_user: Optional authenticated user (for logging)
|
||||
|
||||
Returns:
|
||||
JPEG image
|
||||
JPEG image (can be from local or fallback source)
|
||||
|
||||
Raises:
|
||||
HTTPException: If organization not found or no avatar
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Decorators for adding fallback functionality to endpoints."""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from typing import Literal
|
||||
|
||||
@@ -12,8 +13,10 @@ from kohakuhub.logger import get_logger
|
||||
from kohakuhub.api.fallback.operations import (
|
||||
fetch_external_list,
|
||||
try_fallback_info,
|
||||
try_fallback_org_avatar,
|
||||
try_fallback_resolve,
|
||||
try_fallback_tree,
|
||||
try_fallback_user_avatar,
|
||||
try_fallback_user_profile,
|
||||
try_fallback_user_repos,
|
||||
)
|
||||
@@ -22,7 +25,7 @@ from kohakuhub.api.fallback.config import get_enabled_sources
|
||||
logger = get_logger("FALLBACK_DEC")
|
||||
|
||||
OperationType = Literal["resolve", "tree", "info", "revision", "paths_info"]
|
||||
UserOperationType = Literal["profile", "repos"]
|
||||
UserOperationType = Literal["profile", "repos", "avatar"]
|
||||
|
||||
|
||||
def with_repo_fallback(operation: OperationType):
|
||||
@@ -38,10 +41,25 @@ def with_repo_fallback(operation: OperationType):
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
# Get function signature to extract default values
|
||||
sig = inspect.signature(func)
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Extract fallback param from Request object (if available)
|
||||
fallback_enabled = None
|
||||
# Extract fallback param - priority: query param > kwargs > function default > True
|
||||
fallback_enabled = True # Default to True
|
||||
|
||||
# Check function signature default
|
||||
if "fallback" in sig.parameters:
|
||||
default = sig.parameters["fallback"].default
|
||||
if default != inspect.Parameter.empty:
|
||||
fallback_enabled = default
|
||||
|
||||
# Check kwargs (FastAPI injected value)
|
||||
if "fallback" in kwargs:
|
||||
fallback_enabled = kwargs["fallback"]
|
||||
|
||||
# Check query param (highest priority - overrides everything)
|
||||
request = kwargs.get("request")
|
||||
if request and hasattr(request, "query_params"):
|
||||
fallback_param = request.query_params.get("fallback")
|
||||
@@ -52,8 +70,8 @@ def with_repo_fallback(operation: OperationType):
|
||||
"no",
|
||||
)
|
||||
|
||||
# Check if fallback is enabled globally and not disabled by query param
|
||||
if not cfg.fallback.enabled or fallback_enabled is False:
|
||||
# Check if fallback is enabled globally and not disabled by param
|
||||
if not cfg.fallback.enabled or not fallback_enabled:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
# Extract repo info from kwargs
|
||||
@@ -189,16 +207,29 @@ def with_list_aggregation(repo_type: str):
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
# Get function signature to find 'fallback' parameter position
|
||||
sig = inspect.signature(func)
|
||||
param_names = list(sig.parameters.keys())
|
||||
fallback_index = (
|
||||
param_names.index("fallback") if "fallback" in param_names else -1
|
||||
)
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Extract fallback parameter (5th arg or kwargs)
|
||||
# Functions called as: _list_xxx_with_aggregation(author, limit, sort, user, fallback)
|
||||
fallback_enabled = kwargs.get("fallback", True)
|
||||
if fallback_enabled is None:
|
||||
fallback_enabled = True
|
||||
if len(args) > 4:
|
||||
args = list(args)
|
||||
fallback_enabled = args.pop()
|
||||
# Extract fallback parameter from args or kwargs
|
||||
fallback_enabled = True # Default to True
|
||||
|
||||
# Try kwargs first
|
||||
if "fallback" in kwargs:
|
||||
fallback_enabled = kwargs["fallback"]
|
||||
# Try positional args
|
||||
elif fallback_index >= 0 and len(args) > fallback_index:
|
||||
fallback_enabled = args[fallback_index]
|
||||
# Use default from signature
|
||||
elif fallback_index >= 0:
|
||||
default = sig.parameters["fallback"].default
|
||||
if default != inspect.Parameter.empty:
|
||||
fallback_enabled = default
|
||||
|
||||
logger.info(
|
||||
f"with_list_aggregation decorator params: fallback_enabled={fallback_enabled}"
|
||||
@@ -206,7 +237,7 @@ def with_list_aggregation(repo_type: str):
|
||||
|
||||
# Check if fallback is enabled globally and not disabled by param
|
||||
if not cfg.fallback.enabled or not fallback_enabled:
|
||||
# Call without fallback - need to remove fallback from args
|
||||
# Call without fallback
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
# Get local results
|
||||
@@ -308,10 +339,25 @@ def with_user_fallback(operation: UserOperationType):
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
# Get function signature to extract default values
|
||||
sig = inspect.signature(func)
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Extract fallback param from Request object (if available)
|
||||
fallback_enabled = None
|
||||
# Extract fallback param - priority: query param > kwargs > function default > True
|
||||
fallback_enabled = True # Default to True
|
||||
|
||||
# Check function signature default
|
||||
if "fallback" in sig.parameters:
|
||||
default = sig.parameters["fallback"].default
|
||||
if default != inspect.Parameter.empty:
|
||||
fallback_enabled = default
|
||||
|
||||
# Check kwargs (FastAPI injected value)
|
||||
if "fallback" in kwargs:
|
||||
fallback_enabled = kwargs["fallback"]
|
||||
|
||||
# Check query param (highest priority - overrides everything)
|
||||
request = kwargs.get("request")
|
||||
if request and hasattr(request, "query_params"):
|
||||
fallback_param = request.query_params.get("fallback")
|
||||
@@ -322,8 +368,8 @@ def with_user_fallback(operation: UserOperationType):
|
||||
"no",
|
||||
)
|
||||
|
||||
# Check if fallback is enabled globally and not disabled by query param
|
||||
if not cfg.fallback.enabled or fallback_enabled is False:
|
||||
# Check if fallback is enabled globally and not disabled by param
|
||||
if not cfg.fallback.enabled or not fallback_enabled:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
# Extract username/org_name from kwargs
|
||||
@@ -373,12 +419,29 @@ def with_user_fallback(operation: UserOperationType):
|
||||
case "repos":
|
||||
result = await try_fallback_user_repos(username)
|
||||
|
||||
case "avatar":
|
||||
# Check if it's org or user based on parameter name
|
||||
org_name = kwargs.get("org_name")
|
||||
if org_name:
|
||||
result = await try_fallback_org_avatar(org_name)
|
||||
else:
|
||||
result = await try_fallback_user_avatar(username)
|
||||
|
||||
case _:
|
||||
logger.warning(f"Unknown user fallback operation: {operation}")
|
||||
result = None
|
||||
|
||||
if result:
|
||||
logger.success(f"Fallback SUCCESS for user {operation}: {username}")
|
||||
# For avatar operation, wrap bytes in Response
|
||||
if operation == "avatar" and isinstance(result, bytes):
|
||||
return Response(
|
||||
content=result,
|
||||
media_type="image/jpeg",
|
||||
headers={
|
||||
"Cache-Control": "public, max-age=86400", # 24 hour cache
|
||||
},
|
||||
)
|
||||
return result
|
||||
else:
|
||||
# Not found in any source
|
||||
|
||||
@@ -452,6 +452,134 @@ async def try_fallback_user_profile(username: str) -> Optional[dict]:
|
||||
return None
|
||||
|
||||
|
||||
async def try_fallback_user_avatar(username: str) -> Optional[bytes]:
|
||||
"""Try to get user avatar from fallback sources.
|
||||
|
||||
For HuggingFace: Get avatar URL from overview, then download it
|
||||
For KohakuHub: Call /api/users/{username}/avatar directly
|
||||
|
||||
Args:
|
||||
username: Username to lookup
|
||||
|
||||
Returns:
|
||||
Avatar image bytes (JPEG) or None if not found
|
||||
"""
|
||||
sources = get_enabled_sources(namespace="") # Global sources only
|
||||
|
||||
if not sources:
|
||||
return None
|
||||
|
||||
for source in sources:
|
||||
try:
|
||||
client = FallbackClient(
|
||||
source_url=source["url"],
|
||||
source_type=source["source_type"],
|
||||
token=source.get("token"),
|
||||
)
|
||||
|
||||
match source["source_type"]:
|
||||
case "huggingface":
|
||||
# Get avatar URL from user overview
|
||||
user_path = f"/api/users/{username}/overview"
|
||||
user_response = await client.get(user_path, "model")
|
||||
|
||||
if 200 <= user_response.status_code < 400:
|
||||
hf_data = user_response.json()
|
||||
avatar_url = hf_data.get("avatarUrl")
|
||||
|
||||
if avatar_url:
|
||||
# Download avatar image
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as http_client:
|
||||
avatar_response = await http_client.get(avatar_url)
|
||||
if avatar_response.status_code == 200:
|
||||
logger.info(
|
||||
f"Fallback user avatar SUCCESS: {username} from {source['name']}"
|
||||
)
|
||||
return avatar_response.content
|
||||
|
||||
logger.debug(f"HF user avatar not found: {username}")
|
||||
continue
|
||||
|
||||
case "kohakuhub":
|
||||
# Other KohakuHub instances - call avatar endpoint directly
|
||||
avatar_path = f"/api/users/{username}/avatar"
|
||||
response = await client.get(avatar_path, "model")
|
||||
|
||||
if response.status_code == 200:
|
||||
logger.info(
|
||||
f"Fallback user avatar SUCCESS: {username} from {source['name']}"
|
||||
)
|
||||
return response.content
|
||||
|
||||
elif not should_retry_source(response):
|
||||
return None
|
||||
|
||||
case _:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Fallback user avatar failed for {source['name']}: {e}")
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def try_fallback_org_avatar(org_name: str) -> Optional[bytes]:
|
||||
"""Try to get organization avatar from fallback sources.
|
||||
|
||||
For KohakuHub: Call /api/organizations/{org_name}/avatar directly
|
||||
For HuggingFace: Organizations don't have avatars in the API
|
||||
|
||||
Args:
|
||||
org_name: Organization name to lookup
|
||||
|
||||
Returns:
|
||||
Avatar image bytes (JPEG) or None if not found
|
||||
"""
|
||||
sources = get_enabled_sources(namespace="") # Global sources only
|
||||
|
||||
if not sources:
|
||||
return None
|
||||
|
||||
for source in sources:
|
||||
try:
|
||||
client = FallbackClient(
|
||||
source_url=source["url"],
|
||||
source_type=source["source_type"],
|
||||
token=source.get("token"),
|
||||
)
|
||||
|
||||
match source["source_type"]:
|
||||
case "kohakuhub":
|
||||
# Other KohakuHub instances - call avatar endpoint directly
|
||||
avatar_path = f"/api/organizations/{org_name}/avatar"
|
||||
response = await client.get(avatar_path, "model")
|
||||
|
||||
if response.status_code == 200:
|
||||
logger.info(
|
||||
f"Fallback org avatar SUCCESS: {org_name} from {source['name']}"
|
||||
)
|
||||
return response.content
|
||||
|
||||
elif not should_retry_source(response):
|
||||
return None
|
||||
|
||||
case "huggingface":
|
||||
# HuggingFace doesn't provide org avatars via API
|
||||
continue
|
||||
|
||||
case _:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Fallback org avatar failed for {source['name']}: {e}")
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def try_fallback_user_repos(username: str) -> Optional[dict]:
|
||||
"""Try to get user repositories from fallback sources.
|
||||
|
||||
|
||||
@@ -276,7 +276,9 @@ async def get_revision(
|
||||
namespace: str,
|
||||
name: str,
|
||||
revision: str,
|
||||
request: Request,
|
||||
expand: Optional[str] = None,
|
||||
fallback: bool = True,
|
||||
user: User | None = Depends(get_optional_user),
|
||||
):
|
||||
"""Get revision information for a repository.
|
||||
@@ -461,6 +463,8 @@ async def resolve_file_head(
|
||||
name: str,
|
||||
revision: str,
|
||||
path: str,
|
||||
request: Request,
|
||||
fallback: bool = True,
|
||||
user: User | None = Depends(get_optional_user),
|
||||
):
|
||||
"""Get file metadata (HEAD request).
|
||||
@@ -487,6 +491,7 @@ async def resolve_file_get(
|
||||
revision: str,
|
||||
path: str,
|
||||
request: Request,
|
||||
fallback: bool = True,
|
||||
user: User | None = Depends(get_optional_user),
|
||||
):
|
||||
"""Download file (GET request).
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Organization related API endpoints."""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from kohakuhub.db import User, UserOrganization, db
|
||||
@@ -57,8 +57,12 @@ async def create_organization_endpoint(
|
||||
|
||||
@router.get("/{org_name}")
|
||||
@with_user_fallback("profile")
|
||||
async def get_organization_info(org_name: str):
|
||||
"""Get organization details."""
|
||||
async def get_organization_info(org_name: str, request: Request, fallback: bool = True):
|
||||
"""Get organization details.
|
||||
|
||||
Query params:
|
||||
fallback: Set to "false" to disable fallback to external sources (default: true)
|
||||
"""
|
||||
org = get_organization(org_name)
|
||||
if not org:
|
||||
raise HTTPException(404, detail=_ERR_ORG_NOT_FOUND)
|
||||
|
||||
@@ -48,6 +48,7 @@ async def get_repo_info(
|
||||
namespace: str,
|
||||
repo_name: str,
|
||||
request: Request,
|
||||
fallback: bool = True,
|
||||
user: User | None = Depends(get_optional_user),
|
||||
):
|
||||
"""Get repository information (without revision).
|
||||
@@ -460,8 +461,10 @@ async def list_repos(
|
||||
@with_user_fallback("repos")
|
||||
async def list_user_repos(
|
||||
username: str,
|
||||
request: Request,
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
sort: str = Query("recent", regex="^(recent|likes|downloads)$"),
|
||||
fallback: bool = True,
|
||||
user: User | None = Depends(get_optional_user),
|
||||
):
|
||||
"""List all repositories for a specific user/namespace.
|
||||
|
||||
@@ -4,7 +4,7 @@ import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, Depends, Form
|
||||
from fastapi import APIRouter, Depends, Form, Request
|
||||
|
||||
from kohakuhub.config import cfg
|
||||
from kohakuhub.db import File, Repository, User
|
||||
@@ -225,10 +225,12 @@ async def list_repo_tree(
|
||||
repo_type: RepoType,
|
||||
namespace: str,
|
||||
repo_name: str,
|
||||
request: Request,
|
||||
revision: str = "main",
|
||||
path: str = "",
|
||||
recursive: bool = False,
|
||||
expand: bool = False,
|
||||
fallback: bool = True,
|
||||
user: User | None = Depends(get_optional_user),
|
||||
):
|
||||
"""List repository file tree.
|
||||
@@ -313,8 +315,10 @@ async def get_paths_info(
|
||||
namespace: str,
|
||||
repo_name: str,
|
||||
revision: str,
|
||||
request: Request,
|
||||
paths: list[str] = Form(...),
|
||||
expand: bool = Form(False),
|
||||
fallback: bool = True,
|
||||
user: User | None = Depends(get_optional_user),
|
||||
):
|
||||
"""Get information about specific paths in a repository.
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from pydantic import BaseModel, EmailStr
|
||||
|
||||
from kohakuhub.db import User
|
||||
@@ -103,11 +103,13 @@ async def update_user_settings(
|
||||
|
||||
@router.get("/users/{username}/profile")
|
||||
@with_user_fallback("profile")
|
||||
async def get_user_profile(username: str):
|
||||
async def get_user_profile(username: str, request: Request, fallback: bool = True):
|
||||
"""Get user public profile information.
|
||||
|
||||
Args:
|
||||
username: Username to query
|
||||
request: FastAPI request object
|
||||
fallback: Enable fallback to external sources
|
||||
|
||||
Returns:
|
||||
Public profile data
|
||||
|
||||
@@ -106,7 +106,9 @@ async def public_resolve_head(
|
||||
name: str,
|
||||
revision: str,
|
||||
path: str,
|
||||
request: Request,
|
||||
type: str = "model",
|
||||
fallback: bool = True,
|
||||
user: User | None = Depends(get_optional_user),
|
||||
):
|
||||
"""Public HEAD endpoint without /api prefix - returns file metadata only."""
|
||||
@@ -137,6 +139,7 @@ async def public_resolve_get(
|
||||
path: str,
|
||||
request: Request,
|
||||
type: str = "model",
|
||||
fallback: bool = True,
|
||||
user: User | None = Depends(get_optional_user),
|
||||
):
|
||||
"""Public GET endpoint without /api prefix - redirects to S3 download."""
|
||||
|
||||
Reference in New Issue
Block a user