better fallback mechanism

This commit is contained in:
Kohaku-Blueleaf
2025-10-20 13:20:43 +08:00
parent 9ff2749f7c
commit 7c65ef4a5d
10 changed files with 294 additions and 37 deletions

View File

@@ -103,8 +103,16 @@ async def create_fallback_source(
name=source.name,
source_type=source.source_type,
enabled=source.enabled,
created_at=source.created_at.isoformat(),
updated_at=source.updated_at.isoformat(),
created_at=(
source.created_at.isoformat()
if isinstance(source.created_at, datetime)
else source.created_at
),
updated_at=(
source.updated_at.isoformat()
if isinstance(source.updated_at, datetime)
else source.updated_at
),
)
except Exception as e:
@@ -149,8 +157,16 @@ async def list_fallback_sources(
name=s.name,
source_type=s.source_type,
enabled=s.enabled,
created_at=s.created_at.isoformat(),
updated_at=s.updated_at.isoformat(),
created_at=(
s.created_at.isoformat()
if isinstance(s.created_at, datetime)
else s.created_at
),
updated_at=(
s.updated_at.isoformat()
if isinstance(s.updated_at, datetime)
else s.updated_at
),
)
for s in sources
]
@@ -183,8 +199,16 @@ async def get_fallback_source(source_id: int, _admin=Depends(verify_admin_token)
name=source.name,
source_type=source.source_type,
enabled=source.enabled,
created_at=source.created_at.isoformat(),
updated_at=source.updated_at.isoformat(),
created_at=(
source.created_at.isoformat()
if isinstance(source.created_at, datetime)
else source.created_at
),
updated_at=(
source.updated_at.isoformat()
if isinstance(source.updated_at, datetime)
else source.updated_at
),
)
except FallbackSource.DoesNotExist:
@@ -253,8 +277,16 @@ async def update_fallback_source(
name=source.name,
source_type=source.source_type,
enabled=source.enabled,
created_at=source.created_at.isoformat(),
updated_at=source.updated_at.isoformat(),
created_at=(
source.created_at.isoformat()
if isinstance(source.created_at, datetime)
else source.created_at
),
updated_at=(
source.updated_at.isoformat()
if isinstance(source.updated_at, datetime)
else source.updated_at
),
)
except FallbackSource.DoesNotExist:

View File

@@ -3,7 +3,7 @@
import io
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
from fastapi.responses import Response
from PIL import Image
@@ -17,6 +17,7 @@ from kohakuhub.db_operations import (
)
from kohakuhub.logger import get_logger
from kohakuhub.auth.dependencies import get_current_user, get_optional_user
from kohakuhub.api.fallback import with_user_fallback
logger = get_logger("AVATAR")
@@ -194,17 +195,23 @@ async def upload_user_avatar(
@router.get("/users/{username}/avatar")
@with_user_fallback("avatar")
async def get_user_avatar(
username: str, _user: User | None = Depends(get_optional_user)
username: str,
request: Request,
fallback: bool = True,
_user: User | None = Depends(get_optional_user),
):
"""Get user avatar image.
Args:
username: Username
request: FastAPI request object
fallback: Enable fallback to external sources
_user: Optional authenticated user (for logging)
Returns:
JPEG image
JPEG image (can be from local or fallback source)
Raises:
HTTPException: If user not found or no avatar
@@ -336,17 +343,23 @@ async def upload_org_avatar(
@router.get("/organizations/{org_name}/avatar")
@with_user_fallback("avatar")
async def get_org_avatar(
org_name: str, _user: User | None = Depends(get_optional_user)
org_name: str,
request: Request,
fallback: bool = True,
_user: User | None = Depends(get_optional_user),
):
"""Get organization avatar image.
Args:
org_name: Organization name
request: FastAPI request object
fallback: Enable fallback to external sources
_user: Optional authenticated user (for logging)
Returns:
JPEG image
JPEG image (can be from local or fallback source)
Raises:
HTTPException: If organization not found or no avatar

View File

@@ -1,6 +1,7 @@
"""Decorators for adding fallback functionality to endpoints."""
import asyncio
import inspect
from functools import wraps
from typing import Literal
@@ -12,8 +13,10 @@ from kohakuhub.logger import get_logger
from kohakuhub.api.fallback.operations import (
fetch_external_list,
try_fallback_info,
try_fallback_org_avatar,
try_fallback_resolve,
try_fallback_tree,
try_fallback_user_avatar,
try_fallback_user_profile,
try_fallback_user_repos,
)
@@ -22,7 +25,7 @@ from kohakuhub.api.fallback.config import get_enabled_sources
logger = get_logger("FALLBACK_DEC")
OperationType = Literal["resolve", "tree", "info", "revision", "paths_info"]
UserOperationType = Literal["profile", "repos"]
UserOperationType = Literal["profile", "repos", "avatar"]
def with_repo_fallback(operation: OperationType):
@@ -38,10 +41,25 @@ def with_repo_fallback(operation: OperationType):
"""
def decorator(func):
# Get function signature to extract default values
sig = inspect.signature(func)
@wraps(func)
async def wrapper(*args, **kwargs):
# Extract fallback param from Request object (if available)
fallback_enabled = None
# Extract fallback param - priority: query param > kwargs > function default > True
fallback_enabled = True # Default to True
# Check function signature default
if "fallback" in sig.parameters:
default = sig.parameters["fallback"].default
if default != inspect.Parameter.empty:
fallback_enabled = default
# Check kwargs (FastAPI injected value)
if "fallback" in kwargs:
fallback_enabled = kwargs["fallback"]
# Check query param (highest priority - overrides everything)
request = kwargs.get("request")
if request and hasattr(request, "query_params"):
fallback_param = request.query_params.get("fallback")
@@ -52,8 +70,8 @@ def with_repo_fallback(operation: OperationType):
"no",
)
# Check if fallback is enabled globally and not disabled by query param
if not cfg.fallback.enabled or fallback_enabled is False:
# Check if fallback is enabled globally and not disabled by param
if not cfg.fallback.enabled or not fallback_enabled:
return await func(*args, **kwargs)
# Extract repo info from kwargs
@@ -189,16 +207,29 @@ def with_list_aggregation(repo_type: str):
"""
def decorator(func):
# Get function signature to find 'fallback' parameter position
sig = inspect.signature(func)
param_names = list(sig.parameters.keys())
fallback_index = (
param_names.index("fallback") if "fallback" in param_names else -1
)
@wraps(func)
async def wrapper(*args, **kwargs):
# Extract fallback parameter (5th arg or kwargs)
# Functions called as: _list_xxx_with_aggregation(author, limit, sort, user, fallback)
fallback_enabled = kwargs.get("fallback", True)
if fallback_enabled is None:
fallback_enabled = True
if len(args) > 4:
args = list(args)
fallback_enabled = args.pop()
# Extract fallback parameter from args or kwargs
fallback_enabled = True # Default to True
# Try kwargs first
if "fallback" in kwargs:
fallback_enabled = kwargs["fallback"]
# Try positional args
elif fallback_index >= 0 and len(args) > fallback_index:
fallback_enabled = args[fallback_index]
# Use default from signature
elif fallback_index >= 0:
default = sig.parameters["fallback"].default
if default != inspect.Parameter.empty:
fallback_enabled = default
logger.info(
f"with_list_aggregation decorator params: fallback_enabled={fallback_enabled}"
@@ -206,7 +237,7 @@ def with_list_aggregation(repo_type: str):
# Check if fallback is enabled globally and not disabled by param
if not cfg.fallback.enabled or not fallback_enabled:
# Call without fallback - need to remove fallback from args
# Call without fallback
return await func(*args, **kwargs)
# Get local results
@@ -308,10 +339,25 @@ def with_user_fallback(operation: UserOperationType):
"""
def decorator(func):
# Get function signature to extract default values
sig = inspect.signature(func)
@wraps(func)
async def wrapper(*args, **kwargs):
# Extract fallback param from Request object (if available)
fallback_enabled = None
# Extract fallback param - priority: query param > kwargs > function default > True
fallback_enabled = True # Default to True
# Check function signature default
if "fallback" in sig.parameters:
default = sig.parameters["fallback"].default
if default != inspect.Parameter.empty:
fallback_enabled = default
# Check kwargs (FastAPI injected value)
if "fallback" in kwargs:
fallback_enabled = kwargs["fallback"]
# Check query param (highest priority - overrides everything)
request = kwargs.get("request")
if request and hasattr(request, "query_params"):
fallback_param = request.query_params.get("fallback")
@@ -322,8 +368,8 @@ def with_user_fallback(operation: UserOperationType):
"no",
)
# Check if fallback is enabled globally and not disabled by query param
if not cfg.fallback.enabled or fallback_enabled is False:
# Check if fallback is enabled globally and not disabled by param
if not cfg.fallback.enabled or not fallback_enabled:
return await func(*args, **kwargs)
# Extract username/org_name from kwargs
@@ -373,12 +419,29 @@ def with_user_fallback(operation: UserOperationType):
case "repos":
result = await try_fallback_user_repos(username)
case "avatar":
# Check if it's org or user based on parameter name
org_name = kwargs.get("org_name")
if org_name:
result = await try_fallback_org_avatar(org_name)
else:
result = await try_fallback_user_avatar(username)
case _:
logger.warning(f"Unknown user fallback operation: {operation}")
result = None
if result:
logger.success(f"Fallback SUCCESS for user {operation}: {username}")
# For avatar operation, wrap bytes in Response
if operation == "avatar" and isinstance(result, bytes):
return Response(
content=result,
media_type="image/jpeg",
headers={
"Cache-Control": "public, max-age=86400", # 24 hour cache
},
)
return result
else:
# Not found in any source

View File

@@ -452,6 +452,134 @@ async def try_fallback_user_profile(username: str) -> Optional[dict]:
return None
async def try_fallback_user_avatar(username: str) -> Optional[bytes]:
"""Try to get user avatar from fallback sources.
For HuggingFace: Get avatar URL from overview, then download it
For KohakuHub: Call /api/users/{username}/avatar directly
Args:
username: Username to lookup
Returns:
Avatar image bytes (JPEG) or None if not found
"""
sources = get_enabled_sources(namespace="") # Global sources only
if not sources:
return None
for source in sources:
try:
client = FallbackClient(
source_url=source["url"],
source_type=source["source_type"],
token=source.get("token"),
)
match source["source_type"]:
case "huggingface":
# Get avatar URL from user overview
user_path = f"/api/users/{username}/overview"
user_response = await client.get(user_path, "model")
if 200 <= user_response.status_code < 400:
hf_data = user_response.json()
avatar_url = hf_data.get("avatarUrl")
if avatar_url:
# Download avatar image
import httpx
async with httpx.AsyncClient(timeout=30.0) as http_client:
avatar_response = await http_client.get(avatar_url)
if avatar_response.status_code == 200:
logger.info(
f"Fallback user avatar SUCCESS: {username} from {source['name']}"
)
return avatar_response.content
logger.debug(f"HF user avatar not found: {username}")
continue
case "kohakuhub":
# Other KohakuHub instances - call avatar endpoint directly
avatar_path = f"/api/users/{username}/avatar"
response = await client.get(avatar_path, "model")
if response.status_code == 200:
logger.info(
f"Fallback user avatar SUCCESS: {username} from {source['name']}"
)
return response.content
elif not should_retry_source(response):
return None
case _:
continue
except Exception as e:
logger.warning(f"Fallback user avatar failed for {source['name']}: {e}")
continue
return None
async def try_fallback_org_avatar(org_name: str) -> Optional[bytes]:
"""Try to get organization avatar from fallback sources.
For KohakuHub: Call /api/organizations/{org_name}/avatar directly
For HuggingFace: Organizations don't have avatars in the API
Args:
org_name: Organization name to lookup
Returns:
Avatar image bytes (JPEG) or None if not found
"""
sources = get_enabled_sources(namespace="") # Global sources only
if not sources:
return None
for source in sources:
try:
client = FallbackClient(
source_url=source["url"],
source_type=source["source_type"],
token=source.get("token"),
)
match source["source_type"]:
case "kohakuhub":
# Other KohakuHub instances - call avatar endpoint directly
avatar_path = f"/api/organizations/{org_name}/avatar"
response = await client.get(avatar_path, "model")
if response.status_code == 200:
logger.info(
f"Fallback org avatar SUCCESS: {org_name} from {source['name']}"
)
return response.content
elif not should_retry_source(response):
return None
case "huggingface":
# HuggingFace doesn't provide org avatars via API
continue
case _:
continue
except Exception as e:
logger.warning(f"Fallback org avatar failed for {source['name']}: {e}")
continue
return None
async def try_fallback_user_repos(username: str) -> Optional[dict]:
"""Try to get user repositories from fallback sources.

View File

@@ -276,7 +276,9 @@ async def get_revision(
namespace: str,
name: str,
revision: str,
request: Request,
expand: Optional[str] = None,
fallback: bool = True,
user: User | None = Depends(get_optional_user),
):
"""Get revision information for a repository.
@@ -461,6 +463,8 @@ async def resolve_file_head(
name: str,
revision: str,
path: str,
request: Request,
fallback: bool = True,
user: User | None = Depends(get_optional_user),
):
"""Get file metadata (HEAD request).
@@ -487,6 +491,7 @@ async def resolve_file_get(
revision: str,
path: str,
request: Request,
fallback: bool = True,
user: User | None = Depends(get_optional_user),
):
"""Download file (GET request).

View File

@@ -1,6 +1,6 @@
"""Organization related API endpoints."""
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel
from kohakuhub.db import User, UserOrganization, db
@@ -57,8 +57,12 @@ async def create_organization_endpoint(
@router.get("/{org_name}")
@with_user_fallback("profile")
async def get_organization_info(org_name: str):
"""Get organization details."""
async def get_organization_info(org_name: str, request: Request, fallback: bool = True):
"""Get organization details.
Query params:
fallback: Set to "false" to disable fallback to external sources (default: true)
"""
org = get_organization(org_name)
if not org:
raise HTTPException(404, detail=_ERR_ORG_NOT_FOUND)

View File

@@ -48,6 +48,7 @@ async def get_repo_info(
namespace: str,
repo_name: str,
request: Request,
fallback: bool = True,
user: User | None = Depends(get_optional_user),
):
"""Get repository information (without revision).
@@ -460,8 +461,10 @@ async def list_repos(
@with_user_fallback("repos")
async def list_user_repos(
username: str,
request: Request,
limit: int = Query(100, ge=1, le=1000),
sort: str = Query("recent", regex="^(recent|likes|downloads)$"),
fallback: bool = True,
user: User | None = Depends(get_optional_user),
):
"""List all repositories for a specific user/namespace.

View File

@@ -4,7 +4,7 @@ import asyncio
from datetime import datetime
from typing import Literal
from fastapi import APIRouter, Depends, Form
from fastapi import APIRouter, Depends, Form, Request
from kohakuhub.config import cfg
from kohakuhub.db import File, Repository, User
@@ -225,10 +225,12 @@ async def list_repo_tree(
repo_type: RepoType,
namespace: str,
repo_name: str,
request: Request,
revision: str = "main",
path: str = "",
recursive: bool = False,
expand: bool = False,
fallback: bool = True,
user: User | None = Depends(get_optional_user),
):
"""List repository file tree.
@@ -313,8 +315,10 @@ async def get_paths_info(
namespace: str,
repo_name: str,
revision: str,
request: Request,
paths: list[str] = Form(...),
expand: bool = Form(False),
fallback: bool = True,
user: User | None = Depends(get_optional_user),
):
"""Get information about specific paths in a repository.

View File

@@ -3,7 +3,7 @@
import json
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel, EmailStr
from kohakuhub.db import User
@@ -103,11 +103,13 @@ async def update_user_settings(
@router.get("/users/{username}/profile")
@with_user_fallback("profile")
async def get_user_profile(username: str):
async def get_user_profile(username: str, request: Request, fallback: bool = True):
"""Get user public profile information.
Args:
username: Username to query
request: FastAPI request object
fallback: Enable fallback to external sources
Returns:
Public profile data

View File

@@ -106,7 +106,9 @@ async def public_resolve_head(
name: str,
revision: str,
path: str,
request: Request,
type: str = "model",
fallback: bool = True,
user: User | None = Depends(get_optional_user),
):
"""Public HEAD endpoint without /api prefix - returns file metadata only."""
@@ -137,6 +139,7 @@ async def public_resolve_get(
path: str,
request: Request,
type: str = "model",
fallback: bool = True,
user: User | None = Depends(get_optional_user),
):
"""Public GET endpoint without /api prefix - redirects to S3 download."""