fix bad usage of async function

This commit is contained in:
Kohaku-Blueleaf
2025-10-06 13:36:04 +08:00
parent fa5594cc57
commit d487adf7b3
11 changed files with 616 additions and 514 deletions

1
.gitignore vendored
View File

@@ -1,4 +1,5 @@
*_old.*
*_refactor.*
CLAUDE.md
.claude/
example.md

View File

@@ -14,6 +14,7 @@ from kohakuhub.api.utils.hf import (
from kohakuhub.api.utils.lakefs import get_lakefs_client, lakefs_repo_name
from kohakuhub.auth.dependencies import get_current_user
from kohakuhub.auth.permissions import check_repo_delete_permission
from kohakuhub.db_async import get_repository
from kohakuhub.db import Repository, User
from kohakuhub.logger import get_logger
@@ -30,7 +31,7 @@ class CreateBranchPayload(BaseModel):
@router.post("/{repo_type}s/{namespace}/{name}/branch")
def create_branch(
async def create_branch(
repo_type: str,
namespace: str,
name: str,
@@ -90,7 +91,7 @@ def create_branch(
@router.delete("/{repo_type}s/{namespace}/{name}/branch/{branch}")
def delete_branch(
async def delete_branch(
repo_type: str,
namespace: str,
name: str,
@@ -112,9 +113,7 @@ def delete_branch(
repo_id = f"{namespace}/{name}"
# Check if repository exists
repo_row = Repository.get_or_none(
(Repository.full_id == repo_id) & (Repository.repo_type == repo_type)
)
repo_row = await get_repository(repo_type, namespace, name)
if not repo_row:
return hf_repo_not_found(repo_id, repo_type)
@@ -150,7 +149,7 @@ class CreateTagPayload(BaseModel):
@router.post("/{repo_type}s/{namespace}/{name}/tag")
def create_tag(
async def create_tag(
repo_type: str,
namespace: str,
name: str,
@@ -172,9 +171,7 @@ def create_tag(
repo_id = f"{namespace}/{name}"
# Check if repository exists
repo_row = Repository.get_or_none(
(Repository.full_id == repo_id) & (Repository.repo_type == repo_type)
)
repo_row = await get_repository(repo_type, namespace, name)
if not repo_row:
return hf_repo_not_found(repo_id, repo_type)
@@ -210,7 +207,7 @@ def create_tag(
@router.delete("/{repo_type}s/{namespace}/{name}/tag/{tag}")
def delete_tag(
async def delete_tag(
repo_type: str,
namespace: str,
name: str,
@@ -232,9 +229,7 @@ def delete_tag(
repo_id = f"{namespace}/{name}"
# Check if repository exists
repo_row = Repository.get_or_none(
(Repository.full_id == repo_id) & (Repository.repo_type == repo_type)
)
repo_row = await get_repository(repo_type, namespace, name)
if not repo_row:
return hf_repo_not_found(repo_id, repo_type)

View File

@@ -6,9 +6,10 @@ from fastapi import APIRouter, Depends, Query
from kohakuhub.api.utils.hf import hf_repo_not_found, hf_server_error
from kohakuhub.api.utils.lakefs import get_lakefs_client, lakefs_repo_name
from kohakuhub.async_utils import get_async_lakefs_client, run_in_executor
from kohakuhub.async_utils import get_async_lakefs_client, run_in_lakefs_executor
from kohakuhub.auth.dependencies import get_optional_user
from kohakuhub.auth.permissions import check_repo_read_permission
from kohakuhub.db_async import get_repository
from kohakuhub.db import Repository, User
from kohakuhub.logger import get_logger
@@ -43,9 +44,7 @@ async def list_commits(
repo_id = f"{namespace}/{name}"
# Check if repository exists
repo_row = Repository.get_or_none(
(Repository.full_id == repo_id) & (Repository.repo_type == repo_type)
)
repo_row = await get_repository(repo_type, namespace, name)
if not repo_row:
return hf_repo_not_found(repo_id, repo_type)
@@ -70,7 +69,7 @@ async def list_commits(
if after:
kwargs["after"] = after
log_result = await run_in_executor(
log_result = await run_in_lakefs_executor(
client.refs.log_commits,
lakefs_repo, # repository (positional)
branch, # ref (positional)

View File

@@ -25,6 +25,7 @@ from kohakuhub.auth.permissions import (
check_repo_write_permission,
)
from kohakuhub.config import cfg
from kohakuhub.db_async import execute_db_query, get_file, get_repository
from kohakuhub.db import File, Repository, User
from kohakuhub.logger import get_logger
@@ -73,9 +74,7 @@ async def preupload(
"""
repo_id = f"{namespace}/{name}"
# Verify repository exists
repo_row = Repository.get_or_none(
(Repository.full_id == repo_id) & (Repository.repo_type == repo_type.value)
)
repo_row = await get_repository(repo_type.value, namespace, name)
if not repo_row:
raise HTTPException(404, detail={"error": "Repository not found"})
@@ -116,9 +115,7 @@ async def preupload(
# Check for existing file with same content
if sha256:
# If sha256 provided, use it for comparison (most reliable)
existing = File.get_or_none(
(File.repo_full_id == repo_id) & (File.path_in_repo == path)
)
existing = await get_file(repo_id, path)
if existing and existing.sha256 == sha256 and existing.size == size:
should_ignore = True
elif sample and upload_mode == "regular":
@@ -195,9 +192,7 @@ async def get_revision(
"""
repo_id = f"{namespace}/{name}"
# Check if repository exists in database first
repo_row = Repository.get_or_none(
(Repository.repo_type == repo_type.value) & (Repository.full_id == repo_id)
)
repo_row = await get_repository(repo_type.value, namespace, name)
if not repo_row:
return hf_repo_not_found(repo_id, repo_type.value)
@@ -276,9 +271,12 @@ async def _get_file_metadata(
repo_id = f"{namespace}/{name}"
# Check repository exists and read permission
repo_row = Repository.get_or_none(
(Repository.full_id == repo_id) & (Repository.repo_type == repo_type)
)
def _get_repo():
return Repository.get_or_none(
(Repository.full_id == repo_id) & (Repository.repo_type == repo_type)
)
repo_row = await execute_db_query(_get_repo)
if repo_row:
check_repo_read_permission(repo_row, user)
@@ -311,7 +309,7 @@ async def _get_file_metadata(
bucket, key = parse_s3_uri(physical_address)
# Generate presigned download URL
presigned_url = generate_download_presigned_url(
presigned_url = await generate_download_presigned_url(
bucket=bucket,
key=key,
expires_in=3600, # 1 hour
@@ -323,9 +321,7 @@ async def _get_file_metadata(
# Get correct checksum from database
# sha256 column stores: git blob SHA1 for non-LFS, SHA256 for LFS
file_record = File.get_or_none(
(File.repo_full_id == repo_id) & (File.path_in_repo == path)
)
file_record = await get_file(repo_id, path)
# HuggingFace expects plain SHA256 hex (64 characters, unquoted)
# For non-LFS: use git blob SHA1, for LFS: use SHA256

View File

@@ -1,4 +1,4 @@
"""Git LFS Batch API implementation.
"""Git LFS Batch API implementation - Refactored version.
This module implements the Git LFS Batch API specification for handling
large file uploads (>10MB). It provides presigned S3 URLs for direct uploads.
@@ -6,7 +6,7 @@ large file uploads (>10MB). It provides presigned S3 URLs for direct uploads.
import base64
from datetime import datetime, timedelta, timezone
from typing import List, Optional
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import JSONResponse
@@ -18,13 +18,14 @@ from kohakuhub.api.utils.s3 import (
get_object_metadata,
object_exists,
)
from kohakuhub.auth.dependencies import get_current_user, get_optional_user
from kohakuhub.auth.dependencies import get_optional_user
from kohakuhub.auth.permissions import (
check_repo_read_permission,
check_repo_write_permission,
)
from kohakuhub.config import cfg
from kohakuhub.db import File, Repository, User
from kohakuhub.db_async import execute_db_query, get_file_by_sha256
from kohakuhub.db import Repository, User
from kohakuhub.logger import get_logger
logger = get_logger("LFS")
@@ -42,17 +43,9 @@ class LFSBatchRequest(BaseModel):
"""LFS batch API request."""
operation: str # "upload" or "download"
transfers: Optional[List[str]] = ["basic"]
objects: List[LFSObject]
hash_algo: Optional[str] = "sha256"
class LFSAction(BaseModel):
"""LFS upload/download action."""
href: str
header: Optional[dict] = None
expires_at: Optional[str] = None
transfers: list[str] | None = ["basic"]
objects: list[LFSObject]
hash_algo: str | None = "sha256"
class LFSError(BaseModel):
@@ -67,26 +60,172 @@ class LFSObjectResponse(BaseModel):
oid: str
size: int
authenticated: Optional[bool] = True
actions: Optional[dict] = None # Contains "upload", "verify", "download"
error: Optional[LFSError] = None # Must use LFSError model
authenticated: bool | None = True
actions: dict | None = None # Contains "upload", "verify", "download"
error: LFSError | None = None # Must use LFSError model
class LFSBatchResponse(BaseModel):
"""LFS batch API response."""
transfer: str = "basic"
objects: List[LFSObjectResponse]
objects: list[LFSObjectResponse]
hash_algo: str = "sha256"
def get_lfs_key(oid: str) -> str:
"""Generate S3 key for LFS object.
Args:
oid: SHA256 hash
Returns:
S3 key with balanced directory structure
"""
return f"lfs/{oid[:2]}/{oid[2:4]}/{oid}"
async def process_upload_object(oid: str, size: int, repo_id: str) -> LFSObjectResponse:
"""Process single LFS object for upload operation.
Args:
oid: Object ID (SHA256)
size: File size in bytes
repo_id: Repository ID
Returns:
LFS object response with upload actions or error
"""
lfs_key = get_lfs_key(oid)
# Check if object already exists (deduplication)
existing = await get_file_by_sha256(oid)
if existing and existing.size == size:
# Object exists, no upload needed
return LFSObjectResponse(
oid=oid,
size=size,
authenticated=True,
# No actions = already exists
)
# Check if multipart upload required (>5GB)
multipart_threshold = 5 * 1024 * 1024 * 1024 # 5GB
if size > multipart_threshold:
return LFSObjectResponse(
oid=oid,
size=size,
error=LFSError(
code=501, # Not Implemented
message="Multipart upload not yet implemented for files >5GB",
),
)
# Single PUT upload
try:
# Convert SHA256 hex to base64 for S3 checksum verification
checksum_sha256 = base64.b64encode(bytes.fromhex(oid)).decode("utf-8")
upload_info = await generate_upload_presigned_url(
bucket=cfg.s3.bucket,
key=lfs_key,
expires_in=3600, # 1 hour
content_type="application/octet-stream",
checksum_sha256=checksum_sha256,
)
return LFSObjectResponse(
oid=oid,
size=size,
authenticated=True,
actions={
"upload": {
"href": upload_info["url"],
"expires_at": upload_info["expires_at"],
"header": upload_info.get("headers", {}),
},
"verify": {
"href": f"{cfg.app.base_url}/api/{repo_id}.git/info/lfs/verify",
"expires_at": upload_info["expires_at"],
},
},
)
except Exception as e:
return LFSObjectResponse(
oid=oid,
size=size,
error=LFSError(
code=500,
message=f"Failed to generate upload URL: {str(e)}",
),
)
async def process_download_object(oid: str, size: int) -> LFSObjectResponse:
"""Process single LFS object for download operation.
Args:
oid: Object ID (SHA256)
size: File size in bytes
Returns:
LFS object response with download actions or error
"""
lfs_key = get_lfs_key(oid)
# Check if object exists
existing = await get_file_by_sha256(oid)
if not existing:
return LFSObjectResponse(
oid=oid,
size=size,
error=LFSError(code=404, message="Object not found"),
)
# Object exists, provide download URL
try:
download_url = await generate_download_presigned_url(
bucket=cfg.s3.bucket,
key=lfs_key,
expires_in=3600,
)
expires_at = (
datetime.now(timezone.utc) + timedelta(seconds=3600)
).strftime("%Y-%m-%dT%H:%M:%S.%fZ")
return LFSObjectResponse(
oid=oid,
size=size,
authenticated=True,
actions={
"download": {
"href": download_url,
"expires_at": expires_at,
},
},
)
except Exception as e:
return LFSObjectResponse(
oid=oid,
size=size,
error=LFSError(
code=500,
message=f"Failed to generate download URL: {str(e)}",
),
)
@router.post("/{repo_type}s/{namespace}/{name}.git/info/lfs/objects/batch")
@router.post("/{namespace}/{name}.git/info/lfs/objects/batch")
async def lfs_batch(
namespace: str,
name: str,
request: Request,
repo_type: Optional[str] = "model",
repo_type: str | None = "model",
user: User | None = Depends(get_optional_user),
):
"""Git LFS Batch API endpoint.
@@ -98,6 +237,7 @@ async def lfs_batch(
namespace: Repository namespace
name: Repository name
request: FastAPI request with LFS batch payload
repo_type: Repository type (default: model)
user: Current authenticated user (optional for downloads)
Returns:
@@ -116,178 +256,56 @@ async def lfs_batch(
raise HTTPException(400, detail={"error": f"Invalid LFS batch request: {e}"})
# Check repository exists and permissions
# Note: We need to infer repo_type from context or use a default
# For LFS, we'll check all repo types
repo_row = Repository.get_or_none(
(Repository.full_id == repo_id) & (Repository.repo_type == repo_type)
)
def _get_repo():
return Repository.get_or_none(
(Repository.full_id == repo_id) & (Repository.repo_type == repo_type)
)
repo_row = await execute_db_query(_get_repo)
if repo_row:
operation = batch_req.operation
if operation == "upload":
# Upload requires authentication and write permission
if not user:
raise HTTPException(
401, detail={"error": "Authentication required for upload"}
)
check_repo_write_permission(repo_row, user)
elif operation == "download":
# Download requires read permission (may be public)
check_repo_read_permission(repo_row, user)
match operation:
case "upload":
# Upload requires authentication and write permission
if not user:
raise HTTPException(
401, detail={"error": "Authentication required for upload"}
)
check_repo_write_permission(repo_row, user)
case "download":
# Download requires read permission (may be public)
check_repo_read_permission(repo_row, user)
if cfg.app.debug_log_payloads:
logger.debug("==== LFS Batch Request ====")
logger.debug(body)
operation = batch_req.operation
# Process objects using match-case
objects_response = []
# LFS files stored in: s3://bucket/lfs/{oid[:2]}/{oid[2:4]}/{oid}
# This provides a balanced directory structure
for obj in batch_req.objects:
oid = obj.oid
size = obj.size
# S3 key for LFS object (content-addressable)
lfs_key = f"lfs/{oid[:2]}/{oid[2:4]}/{oid}"
match batch_req.operation:
case "upload":
response_obj = await process_upload_object(oid, size, repo_id)
objects_response.append(response_obj)
# Check if object already exists (deduplication)
existing = File.get_or_none(File.sha256 == oid)
if operation == "upload":
if existing and existing.size == size:
# Object exists, no upload needed
objects_response.append(
LFSObjectResponse(
oid=oid,
size=size,
authenticated=True,
# No actions = already exists
)
)
else:
# Object needs upload
# Check if multipart upload required (>5GB)
multipart_threshold = 5 * 1024 * 1024 * 1024 # 5GB
if size > multipart_threshold:
# Return error for multipart uploads (not yet implemented)
objects_response.append(
LFSObjectResponse(
oid=oid,
size=size,
error=LFSError(
code=501, # Not Implemented
message="Multipart upload not yet implemented for files >5GB",
),
)
)
else:
# Single PUT upload
try:
# Convert SHA256 hex to base64 for S3 checksum verification
checksum_sha256 = base64.b64encode(bytes.fromhex(oid)).decode(
"utf-8"
)
upload_info = generate_upload_presigned_url(
bucket=cfg.s3.bucket,
key=lfs_key,
expires_in=3600, # 1 hour
content_type="application/octet-stream",
checksum_sha256=checksum_sha256,
)
objects_response.append(
LFSObjectResponse(
oid=oid,
size=size,
authenticated=True,
actions={
"upload": {
"href": upload_info["url"],
"expires_at": upload_info["expires_at"],
"header": upload_info.get("headers", {}),
},
"verify": {
"href": f"{cfg.app.base_url}/api/{repo_id}.git/info/lfs/verify",
"expires_at": upload_info["expires_at"],
},
},
)
)
except Exception as e:
# Error generating presigned URL
objects_response.append(
LFSObjectResponse(
oid=oid,
size=size,
error=LFSError(
code=500,
message=f"Failed to generate upload URL: {str(e)}",
),
)
)
elif operation == "download":
if not existing:
# Object not found - use proper error format
objects_response.append(
LFSObjectResponse(
oid=oid,
size=size,
error=LFSError(code=404, message="Object not found"),
)
)
else:
# Object exists, provide download URL
try:
download_url = generate_download_presigned_url(
bucket=cfg.s3.bucket,
key=lfs_key,
expires_in=3600,
)
expires_at = (
datetime.now(timezone.utc) + timedelta(seconds=3600)
).strftime("%Y-%m-%dT%H:%M:%S.%fZ")
objects_response.append(
LFSObjectResponse(
oid=oid,
size=size,
authenticated=True,
actions={
"download": {
"href": download_url,
"expires_at": expires_at,
},
},
)
)
except Exception as e:
# Error generating download URL
objects_response.append(
LFSObjectResponse(
oid=oid,
size=size,
error=LFSError(
code=500,
message=f"Failed to generate download URL: {str(e)}",
),
)
)
case "download":
response_obj = await process_download_object(oid, size)
objects_response.append(response_obj)
# Return response with exclude_none to omit null fields
# This ensures "error": null is not included in the JSON
response = LFSBatchResponse(
transfer="basic",
objects=objects_response,
hash_algo="sha256",
)
# Use JSONResponse to ensure proper serialization with exclude_none
return JSONResponse(
content=response.model_dump(exclude_none=True),
media_type="application/vnd.git-lfs+json",
@@ -301,7 +319,8 @@ async def lfs_verify(namespace: str, name: str, request: Request):
Called by client after successful upload to confirm the file.
Args:
repo_id: Repository ID
namespace: Repository namespace
name: Repository name
request: FastAPI request with verification data
Returns:
@@ -319,15 +338,15 @@ async def lfs_verify(namespace: str, name: str, request: Request):
if not oid:
raise HTTPException(400, detail={"error": "Missing OID"})
lfs_key = f"lfs/{oid[:2]}/{oid[2:4]}/{oid}"
lfs_key = get_lfs_key(oid)
if not object_exists(cfg.s3.bucket, lfs_key):
if not await object_exists(cfg.s3.bucket, lfs_key):
raise HTTPException(404, detail={"error": "Object not found in storage"})
# Optionally verify size
if size:
try:
metadata = get_object_metadata(cfg.s3.bucket, lfs_key)
metadata = await get_object_metadata(cfg.s3.bucket, lfs_key)
if metadata["size"] != size:
raise HTTPException(400, detail={"error": "Size mismatch"})
except Exception:

View File

@@ -20,6 +20,7 @@ from kohakuhub.auth.permissions import (
check_repo_delete_permission,
)
from kohakuhub.config import cfg
from kohakuhub.db_async import execute_db_query, get_repository
from kohakuhub.db import File, Repository, StagingUpload, User, init_db
from kohakuhub.logger import get_logger
@@ -41,7 +42,7 @@ class CreateRepoPayload(BaseModel):
@router.post("/repos/create")
def create_repo(payload: CreateRepoPayload, user: User = Depends(get_current_user)):
async def create_repo(payload: CreateRepoPayload, user: User = Depends(get_current_user)):
"""Create a new repository.
Args:
@@ -62,9 +63,8 @@ def create_repo(payload: CreateRepoPayload, user: User = Depends(get_current_use
full_id = f"{namespace}/{payload.name}"
lakefs_repo = lakefs_repo_name(payload.type, full_id)
if Repository.get_or_none(
(Repository.full_id == full_id) & (Repository.repo_type == payload.type)
):
existing_repo = await get_repository(payload.type, namespace, payload.name)
if existing_repo:
return hf_error_response(
400,
HFErrorCode.REPO_EXISTS,
@@ -88,13 +88,16 @@ def create_repo(payload: CreateRepoPayload, user: User = Depends(get_current_use
return hf_server_error(f"LakeFS repository creation failed: {str(e)}")
# Store in database for listing/metadata
Repository.get_or_create(
repo_type=payload.type,
namespace=namespace,
name=payload.name,
full_id=full_id,
defaults={"private": payload.private},
)
def _create_repo():
Repository.get_or_create(
repo_type=payload.type,
namespace=namespace,
name=payload.name,
full_id=full_id,
defaults={"private": payload.private, "owner_id": user.id},
)
await execute_db_query(_create_repo)
return {
"url": f"{cfg.app.base_url}/{payload.type}s/{full_id}",
@@ -133,9 +136,7 @@ async def delete_repo(
lakefs_repo = lakefs_repo_name(repo_type, full_id)
# 1. Check if repository exists in database
repo_row = Repository.get_or_none(
(Repository.full_id == full_id) & (Repository.repo_type == repo_type)
)
repo_row = await get_repository(repo_type, namespace, payload.name)
if not repo_row:
# NOTE: HuggingFace client expects 400 for delete repo not found
@@ -160,14 +161,20 @@ async def delete_repo(
logger.info(f"LakeFS repository {lakefs_repo} not found/already deleted (OK)")
# 3. Delete related metadata from database (manual cascade)
def _delete_db_records():
try:
# Delete related file records first
File.delete().where(File.repo_full_id == full_id).execute()
StagingUpload.delete().where(StagingUpload.repo_full_id == full_id).execute()
repo_row.delete_instance()
logger.success(f"Successfully deleted database records for: {full_id}")
except Exception as e:
logger.exception(f"Database deletion failed for {full_id}", e)
raise
try:
# Delete related file records first
File.delete().where(File.repo_full_id == full_id).execute()
StagingUpload.delete().where(StagingUpload.repo_full_id == full_id).execute()
repo_row.delete_instance()
logger.success(f"Successfully deleted database records for: {full_id}")
await execute_db_query(_delete_db_records)
except Exception as e:
logger.exception(f"Database deletion failed for {full_id}", e)
return hf_server_error(f"Database deletion failed for {full_id}: {str(e)}")
# 4. Return success response (200 OK with a simple message)
@@ -184,7 +191,7 @@ class MoveRepoPayload(BaseModel):
@router.post("/repos/move")
def move_repo(
async def move_repo(
payload: MoveRepoPayload,
user: User = Depends(get_current_user),
):
@@ -204,9 +211,12 @@ def move_repo(
repo_type = payload.type
# Check if source repository exists
repo_row = Repository.get_or_none(
(Repository.full_id == from_id) & (Repository.repo_type == repo_type)
)
from_parts = from_id.split("/", 1)
if len(from_parts) != 2:
return hf_error_response(400, HFErrorCode.INVALID_REPO_ID, "Invalid source repository ID")
from_namespace, from_name = from_parts
repo_row = await get_repository(repo_type, from_namespace, from_name)
if not repo_row:
return hf_repo_not_found(from_id, repo_type)
@@ -215,9 +225,12 @@ def move_repo(
check_repo_delete_permission(repo_row, user)
# Check if destination already exists
existing = Repository.get_or_none(
(Repository.full_id == to_id) & (Repository.repo_type == repo_type)
)
to_parts = to_id.split("/", 1)
if len(to_parts) != 2:
return hf_error_response(400, HFErrorCode.INVALID_REPO_ID, "Invalid destination repository ID")
to_namespace, to_name = to_parts
existing = await get_repository(repo_type, to_namespace, to_name)
if existing:
return hf_error_response(
400,
@@ -225,34 +238,27 @@ def move_repo(
f"Repository {to_id} already exists",
)
# Parse destination namespace and name
if "/" not in to_id:
return hf_error_response(
400,
HFErrorCode.INVALID_REPO_ID,
"Invalid repository ID format (must be namespace/name)",
)
to_namespace, to_name = to_id.split("/", 1)
# Check if user has permission to use destination namespace
check_namespace_permission(to_namespace, user)
# Update database records
# Update repository record
Repository.update(
namespace=to_namespace,
name=to_name,
full_id=to_id,
).where(Repository.id == repo_row.id).execute()
def _update_db_records():
# Update repository record
Repository.update(
namespace=to_namespace,
name=to_name,
full_id=to_id,
).where(Repository.id == repo_row.id).execute()
# Update related file records
File.update(repo_full_id=to_id).where(File.repo_full_id == from_id).execute()
# Update related file records
File.update(repo_full_id=to_id).where(File.repo_full_id == from_id).execute()
# Update staging uploads
StagingUpload.update(repo_full_id=to_id).where(
StagingUpload.repo_full_id == from_id
).execute()
# Update staging uploads
StagingUpload.update(repo_full_id=to_id).where(
StagingUpload.repo_full_id == from_id
).execute()
await execute_db_query(_update_db_records)
# Note: LakeFS repository rename not implemented yet
# Would require creating new LakeFS repo and migrating data

View File

@@ -15,6 +15,7 @@ from kohakuhub.async_utils import get_async_lakefs_client
from kohakuhub.auth.dependencies import get_optional_user
from kohakuhub.auth.permissions import check_repo_read_permission
from kohakuhub.config import cfg
from kohakuhub.db_async import execute_db_query, get_organization, get_repository, get_user_by_username
from kohakuhub.db import Organization, Repository, User, UserOrganization
from kohakuhub.logger import get_logger
@@ -71,9 +72,7 @@ async def get_repo_info(
)
# Check if repository exists in database
repo_row = Repository.get_or_none(
(Repository.full_id == repo_id) & (Repository.repo_type == repo_type)
)
repo_row = await get_repository(repo_type, namespace, repo_name)
if not repo_row:
return hf_repo_not_found(repo_id, repo_type)
@@ -217,16 +216,19 @@ async def list_repos(
)
# Query database
q = Repository.select().where(Repository.repo_type == rt)
def _query_repos():
q = Repository.select().where(Repository.repo_type == rt)
# Filter by author if specified
if author:
q = q.where(Repository.namespace == author)
# Filter by author if specified
if author:
q = q.where(Repository.namespace == author)
# Apply privacy filtering
q = _filter_repos_by_privacy(q, user, author)
# Apply privacy filtering
q = _filter_repos_by_privacy(q, user, author)
rows = q.limit(limit)
return list(q.limit(limit))
rows = await execute_db_query(_query_repos)
# Format response with lastModified from LakeFS
client = get_lakefs_client()
@@ -297,10 +299,10 @@ async def list_user_repos(
Dict with models, datasets, and spaces lists
"""
# Check if the username exists
target_user = User.get_or_none(User.username == username)
target_user = await get_user_by_username(username)
if not target_user:
# Could also be an organization
target_org = Organization.get_or_none(Organization.name == username)
target_org = await get_organization(username)
if not target_org:
return hf_error_response(
404,
@@ -315,14 +317,17 @@ async def list_user_repos(
}
for repo_type in ["model", "dataset", "space"]:
q = Repository.select().where(
(Repository.repo_type == repo_type) & (Repository.namespace == username)
)
def _query_repos():
q = Repository.select().where(
(Repository.repo_type == repo_type) & (Repository.namespace == username)
)
# Apply privacy filtering
q = _filter_repos_by_privacy(q, user, username)
# Apply privacy filtering
q = _filter_repos_by_privacy(q, user, username)
rows = q.limit(limit)
return list(q.limit(limit))
rows = await execute_db_query(_query_repos)
key = repo_type + "s"
repos_list = []

View File

@@ -1,6 +1,7 @@
"""Repository tree listing and path information endpoints."""
"""Repository tree listing and path information endpoints - Refactored version."""
from typing import List, Literal
from datetime import datetime
from typing import Literal
from fastapi import APIRouter, Depends, Form
@@ -16,6 +17,7 @@ from kohakuhub.async_utils import get_async_lakefs_client
from kohakuhub.auth.dependencies import get_optional_user
from kohakuhub.auth.permissions import check_repo_read_permission
from kohakuhub.config import cfg
from kohakuhub.db_async import execute_db_query, get_file
from kohakuhub.db import File, Repository, User
from kohakuhub.logger import get_logger
@@ -25,6 +27,189 @@ router = APIRouter()
RepoType = Literal["model", "dataset", "space"]
async def fetch_lakefs_objects(
lakefs_repo: str, revision: str, prefix: str, recursive: bool
) -> list:
"""Fetch all objects from LakeFS with pagination.
Args:
lakefs_repo: LakeFS repository name
revision: Branch or commit
prefix: Path prefix
recursive: Whether to list recursively
Returns:
List of all LakeFS objects
Raises:
Exception: If listing fails
"""
async_client = get_async_lakefs_client()
all_results = []
after = ""
has_more = True
while has_more:
result = await async_client.list_objects(
repository=lakefs_repo,
ref=revision,
prefix=prefix,
delimiter="" if recursive else "/",
amount=1000, # Max per request
after=after,
)
all_results.extend(result.results)
# Check pagination
if result.pagination and result.pagination.has_more:
after = result.pagination.next_offset
has_more = True
else:
has_more = False
return all_results
async def calculate_folder_stats(
lakefs_repo: str, revision: str, folder_path: str
) -> tuple[int, float | None]:
"""Calculate folder size and latest modification time.
Args:
lakefs_repo: LakeFS repository name
revision: Branch or commit
folder_path: Full folder path
Returns:
Tuple of (total_size, latest_mtime)
"""
folder_size = 0
folder_latest_mtime = None
try:
async_client = get_async_lakefs_client()
folder_contents = await async_client.list_objects(
repository=lakefs_repo,
ref=revision,
prefix=folder_path,
delimiter="", # No delimiter = recursive
amount=1000,
)
# Calculate total size and find latest modification
for child_obj in folder_contents.results:
if child_obj.path_type == "object":
folder_size += child_obj.size_bytes or 0
if hasattr(child_obj, "mtime") and child_obj.mtime:
if (
folder_latest_mtime is None
or child_obj.mtime > folder_latest_mtime
):
folder_latest_mtime = child_obj.mtime
except Exception as e:
logger.debug(f"Could not calculate stats for folder {folder_path}: {str(e)}")
return folder_size, folder_latest_mtime
async def convert_file_object(
obj, repo_id: str, prefix_len: int
) -> dict:
"""Convert LakeFS file object to HuggingFace format.
Args:
obj: LakeFS object
repo_id: Repository ID
prefix_len: Length of path prefix to remove
Returns:
HuggingFace formatted file object
"""
is_lfs = obj.size_bytes > cfg.app.lfs_threshold_bytes
# Remove prefix from path to get relative path
relative_path = obj.path[prefix_len:] if prefix_len else obj.path
# Get correct checksum from database
file_record = await get_file(repo_id, obj.path)
checksum = (
file_record.sha256
if file_record and file_record.sha256
else obj.checksum
)
file_obj = {
"type": "file",
"oid": checksum, # Git blob SHA1 for non-LFS, SHA256 for LFS
"size": obj.size_bytes,
"path": relative_path,
}
# Add last modified info if available
if hasattr(obj, "mtime") and obj.mtime:
file_obj["lastModified"] = datetime.fromtimestamp(obj.mtime).strftime(
"%Y-%m-%dT%H:%M:%S.%fZ"
)
# Add LFS metadata if it's an LFS file
if is_lfs:
file_obj["lfs"] = {
"oid": checksum, # SHA256 for LFS files
"size": obj.size_bytes,
"pointerSize": 134, # Standard Git LFS pointer size
}
return file_obj
async def convert_directory_object(
obj, lakefs_repo: str, revision: str, prefix_len: int
) -> dict:
"""Convert LakeFS directory object to HuggingFace format.
Args:
obj: LakeFS common_prefix object
lakefs_repo: LakeFS repository name
revision: Branch or commit
prefix_len: Length of path prefix to remove
Returns:
HuggingFace formatted directory object
"""
# Remove prefix from path to get relative path
relative_path = obj.path[prefix_len:] if prefix_len else obj.path
# Calculate folder stats
folder_size, folder_latest_mtime = await calculate_folder_stats(
lakefs_repo, revision, obj.path
)
dir_obj = {
"type": "directory",
"oid": (
obj.checksum if hasattr(obj, "checksum") and obj.checksum else ""
),
"size": folder_size,
"path": relative_path.rstrip("/"), # Remove trailing slash
}
# Add last modified info
if folder_latest_mtime:
dir_obj["lastModified"] = datetime.fromtimestamp(
folder_latest_mtime
).strftime("%Y-%m-%dT%H:%M:%S.%fZ")
elif hasattr(obj, "mtime") and obj.mtime:
dir_obj["lastModified"] = datetime.fromtimestamp(obj.mtime).strftime(
"%Y-%m-%dT%H:%M:%S.%fZ"
)
return dir_obj
@router.get("/{repo_type}s/{namespace}/{repo_name}/tree/{revision}{path:path}")
async def list_repo_tree(
repo_type: RepoType,
@@ -40,18 +225,6 @@ async def list_repo_tree(
Returns a flat list of files and folders in HuggingFace format.
Response format matches HuggingFace API:
[
{
"type": "file", # or "directory"
"oid": "sha256_hash",
"size": 1234,
"path": "relative/path/to/file.txt",
"lfs": {"oid": "...", "size": ..., "pointerSize": ...} # if LFS file
},
...
]
Args:
repo_type: Type of repository
namespace: Repository namespace
@@ -69,9 +242,12 @@ async def list_repo_tree(
repo_id = f"{namespace}/{repo_name}"
# Check if repository exists
repo_row = Repository.get_or_none(
(Repository.full_id == repo_id) & (Repository.repo_type == repo_type)
)
def _get_repo():
return Repository.get_or_none(
(Repository.full_id == repo_id) & (Repository.repo_type == repo_type)
)
repo_row = await execute_db_query(_get_repo)
if not repo_row:
return hf_repo_not_found(repo_id, repo_type)
@@ -86,34 +262,11 @@ async def list_repo_tree(
if prefix and not prefix.endswith("/"):
prefix += "/"
# Fetch all objects from LakeFS
try:
# List objects from LakeFS with pagination support
async_client = get_async_lakefs_client()
# Collect all results with pagination
all_results = []
after = ""
has_more = True
while has_more:
result = await async_client.list_objects(
repository=lakefs_repo,
ref=revision,
prefix=prefix,
delimiter="" if recursive else "/",
amount=1000, # Max per request
after=after,
)
all_results.extend(result.results)
# Check pagination
if result.pagination and result.pagination.has_more:
after = result.pagination.next_offset
has_more = True
else:
has_more = False
all_results = await fetch_lakefs_objects(
lakefs_repo, revision, prefix, recursive
)
except Exception as e:
# Check for specific error types
if is_lakefs_not_found_error(e):
@@ -127,115 +280,23 @@ async def list_repo_tree(
logger.exception(f"Failed to list objects for {repo_id}", e)
return hf_server_error(f"Failed to list objects: {str(e)}")
# Convert LakeFS objects to HuggingFace format (flat list)
# Convert LakeFS objects to HuggingFace format
result_list = []
prefix_len = len(prefix)
for obj in all_results:
if obj.path_type == "object":
# File object
is_lfs = obj.size_bytes > cfg.app.lfs_threshold_bytes
match obj.path_type:
case "object":
# File object
file_obj = await convert_file_object(obj, repo_id, prefix_len)
result_list.append(file_obj)
# Remove prefix from path to get relative path
relative_path = obj.path[prefix_len:] if prefix else obj.path
# Get correct checksum from database
# sha256 column stores: git blob SHA1 for non-LFS, SHA256 for LFS
file_record = File.get_or_none(
(File.repo_full_id == repo_id) & (File.path_in_repo == obj.path)
)
checksum = (
file_record.sha256
if file_record and file_record.sha256
else obj.checksum
)
file_obj = {
"type": "file",
"oid": checksum, # Git blob SHA1 for non-LFS, SHA256 for LFS
"size": obj.size_bytes,
"path": relative_path,
}
# Add last modified info if available from LakeFS
if hasattr(obj, "mtime") and obj.mtime:
from datetime import datetime
file_obj["lastModified"] = datetime.fromtimestamp(obj.mtime).strftime(
"%Y-%m-%dT%H:%M:%S.%fZ"
case "common_prefix":
# Directory object
dir_obj = await convert_directory_object(
obj, lakefs_repo, revision, prefix_len
)
# Add LFS metadata if it's an LFS file
if is_lfs:
file_obj["lfs"] = {
"oid": checksum, # SHA256 for LFS files
"size": obj.size_bytes,
"pointerSize": 134, # Standard Git LFS pointer size
}
result_list.append(file_obj)
elif obj.path_type == "common_prefix":
# Directory object
# Remove prefix from path to get relative path
relative_path = obj.path[prefix_len:] if prefix else obj.path
# Calculate folder stats by listing its contents recursively
folder_size = 0
folder_latest_mtime = None
try:
# List all objects in this folder recursively
async_client = get_async_lakefs_client()
folder_contents = await async_client.list_objects(
repository=lakefs_repo,
ref=revision,
prefix=obj.path, # Use full path as prefix
delimiter="", # No delimiter = recursive
amount=1000,
)
# Calculate total size and find latest modification
for child_obj in folder_contents.results:
if child_obj.path_type == "object":
folder_size += child_obj.size_bytes or 0
if hasattr(child_obj, "mtime") and child_obj.mtime:
if (
folder_latest_mtime is None
or child_obj.mtime > folder_latest_mtime
):
folder_latest_mtime = child_obj.mtime
except Exception as e:
logger.debug(
f"Could not calculate stats for folder {obj.path}: {str(e)}"
)
dir_obj = {
"type": "directory",
"oid": (
obj.checksum if hasattr(obj, "checksum") and obj.checksum else ""
),
"size": folder_size,
"path": relative_path.rstrip("/"), # Remove trailing slash
}
# Add last modified info
if folder_latest_mtime:
from datetime import datetime
dir_obj["lastModified"] = datetime.fromtimestamp(
folder_latest_mtime
).strftime("%Y-%m-%dT%H:%M:%S.%fZ")
elif hasattr(obj, "mtime") and obj.mtime:
from datetime import datetime
dir_obj["lastModified"] = datetime.fromtimestamp(obj.mtime).strftime(
"%Y-%m-%dT%H:%M:%S.%fZ"
)
result_list.append(dir_obj)
result_list.append(dir_obj)
return result_list
@@ -246,18 +307,13 @@ async def get_paths_info(
namespace: str,
repo_name: str,
revision: str,
paths: List[str] = Form(...),
paths: list[str] = Form(...),
expand: bool = Form(False),
user: User | None = Depends(get_optional_user),
):
"""Get information about specific paths in a repository.
This endpoint matches HuggingFace Hub API format:
POST /api/{repo_type}s/{namespace}/{repo_name}/paths-info/{revision}
Form data:
paths: List of paths to query
expand: Whether to include extended metadata
This endpoint matches HuggingFace Hub API format.
Args:
repo_type: Type of repository (model/dataset/space)
@@ -275,9 +331,12 @@ async def get_paths_info(
repo_id = f"{namespace}/{repo_name}"
# Check if repository exists
repo_row = Repository.get_or_none(
(Repository.full_id == repo_id) & (Repository.repo_type == repo_type)
)
def _get_repo():
return Repository.get_or_none(
(Repository.full_id == repo_id) & (Repository.repo_type == repo_type)
)
repo_row = await execute_db_query(_get_repo)
if not repo_row:
return hf_repo_not_found(repo_id, repo_type)
@@ -307,10 +366,7 @@ async def get_paths_info(
is_lfs = obj_stats.size_bytes > cfg.app.lfs_threshold_bytes
# Get correct checksum from database
# sha256 column stores: git blob SHA1 for non-LFS, SHA256 for LFS
file_record = File.get_or_none(
(File.repo_full_id == repo_id) & (File.path_in_repo == clean_path)
)
file_record = await get_file(repo_id, clean_path)
checksum = (
file_record.sha256

View File

@@ -8,6 +8,7 @@ from pydantic import BaseModel, EmailStr
from kohakuhub.api.utils.hf import hf_repo_not_found
from kohakuhub.auth.dependencies import get_current_user
from kohakuhub.auth.permissions import check_repo_delete_permission
from kohakuhub.db_async import execute_db_query, get_organization, get_repository, get_user_organization
from kohakuhub.db import Organization, Repository, User, UserOrganization
from kohakuhub.logger import get_logger
@@ -28,7 +29,7 @@ class UpdateUserSettingsRequest(BaseModel):
@router.put("/users/{username}/settings")
def update_user_settings(
async def update_user_settings(
username: str,
req: UpdateUserSettingsRequest,
user: User = Depends(get_current_user),
@@ -50,13 +51,19 @@ def update_user_settings(
# Update fields if provided
if req.email is not None:
# Check if email is already taken by another user
existing = User.get_or_none((User.email == req.email) & (User.id != user.id))
def _check_email():
return User.get_or_none((User.email == req.email) & (User.id != user.id))
existing = await execute_db_query(_check_email)
if existing:
raise HTTPException(400, detail="Email already in use")
User.update(email=req.email, email_verified=False).where(
User.id == user.id
).execute()
def _update_email():
User.update(email=req.email, email_verified=False).where(
User.id == user.id
).execute()
await execute_db_query(_update_email)
# TODO: Send new verification email
return {"success": True, "message": "User settings updated successfully"}
@@ -72,7 +79,7 @@ class UpdateOrganizationSettingsRequest(BaseModel):
@router.put("/organizations/{org_name}/settings")
def update_organization_settings(
async def update_organization_settings(
org_name: str,
req: UpdateOrganizationSettingsRequest,
user: User = Depends(get_current_user),
@@ -87,14 +94,12 @@ def update_organization_settings(
Returns:
Success message
"""
org = Organization.get_or_none(Organization.name == org_name)
org = await get_organization(org_name)
if not org:
raise HTTPException(404, detail="Organization not found")
# Check if user is admin of the organization
user_org = UserOrganization.get_or_none(
(UserOrganization.user == user.id) & (UserOrganization.organization == org.id)
)
user_org = await get_user_organization(user.id, org.id)
if not user_org or user_org.role not in ["admin", "super-admin"]:
raise HTTPException(
403, detail="Not authorized to update organization settings"
@@ -102,9 +107,12 @@ def update_organization_settings(
# Update fields if provided
if req.description is not None:
Organization.update(description=req.description).where(
Organization.id == org.id
).execute()
def _update_org():
Organization.update(description=req.description).where(
Organization.id == org.id
).execute()
await execute_db_query(_update_org)
return {"success": True, "message": "Organization settings updated successfully"}
@@ -122,7 +130,7 @@ class UpdateRepoSettingsPayload(BaseModel):
@router.put("/{repo_type}s/{namespace}/{name}/settings")
def update_repo_settings(
async def update_repo_settings(
repo_type: str,
namespace: str,
name: str,
@@ -146,9 +154,7 @@ def update_repo_settings(
repo_id = f"{namespace}/{name}"
# Check if repository exists
repo_row = Repository.get_or_none(
(Repository.full_id == repo_id) & (Repository.repo_type == repo_type)
)
repo_row = await get_repository(repo_type, namespace, name)
if not repo_row:
return hf_repo_not_found(repo_id, repo_type)
@@ -158,9 +164,12 @@ def update_repo_settings(
# Update fields if provided
if payload.private is not None:
Repository.update(private=payload.private).where(
Repository.id == repo_row.id
).execute()
def _update_private():
Repository.update(private=payload.private).where(
Repository.id == repo_row.id
).execute()
await execute_db_query(_update_private)
# Note: gated functionality not yet implemented in database schema
# Would require adding a 'gated' field to Repository model

View File

@@ -6,6 +6,24 @@ from fastapi import APIRouter, HTTPException, Response, Depends
from pydantic import BaseModel, EmailStr
from ..config import cfg
from ..db_async import (
create_email_verification,
create_session,
create_token,
create_user,
delete_email_verification,
delete_session,
delete_token,
execute_db_query,
get_email_verification,
get_session,
get_token_by_hash,
get_user_by_email,
get_user_by_id,
get_user_by_username,
list_user_tokens,
update_user,
)
from ..db import User, EmailVerification, Session, Token
from ..logger import get_logger
@@ -41,22 +59,22 @@ class CreateTokenRequest(BaseModel):
@router.post("/register")
def register(req: RegisterRequest):
async def register(req: RegisterRequest):
"""Register new user."""
logger.info(f"Registration attempt for username: {req.username}")
# Check if username or email already exists
if User.get_or_none(User.username == req.username):
if await get_user_by_username(req.username):
logger.warning(f"Registration failed: username '{req.username}' already exists")
raise HTTPException(400, detail="Username already exists")
if User.get_or_none(User.email == req.email):
if await get_user_by_email(req.email):
logger.warning(f"Registration failed: email '{req.email}' already exists")
raise HTTPException(400, detail="Email already exists")
# Create user
user = User.create(
user = await create_user(
username=req.username,
email=req.email,
password_hash=hash_password(req.password),
@@ -68,8 +86,8 @@ def register(req: RegisterRequest):
# Send verification email if required
if cfg.auth.require_email_verification:
token = generate_token()
EmailVerification.create(
user=user.id, token=token, expires_at=get_expiry_time(24)
await create_email_verification(
user_id=user.id, token=token, expires_at=get_expiry_time(24)
)
if not send_verification_email(req.email, req.username, token):
@@ -93,18 +111,15 @@ def register(req: RegisterRequest):
@router.get("/verify-email")
def verify_email(token: str, response: Response):
async def verify_email(token: str, response: Response):
"""Verify email with token and automatically log in user."""
from fastapi.responses import RedirectResponse
logger.info(f"Email verification attempt with token: {token[:8]}...")
verification = EmailVerification.get_or_none(
(EmailVerification.token == token)
& (EmailVerification.expires_at > datetime.now(timezone.utc))
)
verification = await get_email_verification(token)
if not verification:
if not verification or verification.expires_at <= datetime.now(timezone.utc):
logger.warning(f"Invalid or expired verification token: {token[:8]}...")
# Redirect to login with error message
return RedirectResponse(
@@ -113,16 +128,16 @@ def verify_email(token: str, response: Response):
)
# Get user
user = User.get_or_none(User.id == verification.user)
user = await get_user_by_id(verification.user)
if not user:
logger.error(f"User not found for verification token: {token[:8]}...")
return RedirectResponse(url="/?error=user_not_found", status_code=302)
# Update user email verification status
User.update(email_verified=True).where(User.id == verification.user).execute()
await update_user(user, email_verified=True)
# Delete verification token
EmailVerification.delete().where(EmailVerification.id == verification.id).execute()
await delete_email_verification(verification)
logger.success(f"Email verified for user: {user.username}")
@@ -130,7 +145,7 @@ def verify_email(token: str, response: Response):
session_id = generate_token()
session_secret = generate_session_secret()
Session.create(
await create_session(
session_id=session_id,
user_id=user.id,
secret=session_secret,
@@ -155,12 +170,12 @@ def verify_email(token: str, response: Response):
@router.post("/login")
def login(req: LoginRequest, response: Response):
async def login(req: LoginRequest, response: Response):
"""Login and create session."""
logger.info(f"Login attempt for user: {req.username}")
user = User.get_or_none(User.username == req.username)
user = await get_user_by_username(req.username)
if not user or not verify_password(req.password, user.password_hash):
logger.warning(f"Failed login attempt for: {req.username}")
@@ -178,7 +193,7 @@ def login(req: LoginRequest, response: Response):
session_id = generate_token()
session_secret = generate_session_secret()
Session.create(
await create_session(
session_id=session_id,
user_id=user.id,
secret=session_secret,
@@ -205,13 +220,16 @@ def login(req: LoginRequest, response: Response):
@router.post("/logout")
def logout(response: Response, user: User = Depends(get_current_user)):
async def logout(response: Response, user: User = Depends(get_current_user)):
"""Logout and destroy session."""
logger.info(f"Logout request for user: {user.username}")
# Delete all user sessions
deleted_count = Session.delete().where(Session.user_id == user.id).execute()
def _delete_sessions():
return Session.delete().where(Session.user_id == user.id).execute()
deleted_count = await execute_db_query(_delete_sessions)
# Clear cookie
response.delete_cookie(key="session_id")
@@ -237,10 +255,10 @@ def get_me(user: User = Depends(get_current_user)):
@router.get("/tokens")
def list_tokens(user: User = Depends(get_current_user)):
async def list_tokens(user: User = Depends(get_current_user)):
"""List user's API tokens."""
tokens = Token.select().where(Token.user_id == user.id)
tokens = await list_user_tokens(user.id)
return {
"tokens": [
@@ -256,7 +274,7 @@ def list_tokens(user: User = Depends(get_current_user)):
@router.post("/tokens/create")
def create_token(req: CreateTokenRequest, user: User = Depends(get_current_user)):
async def create_token_endpoint(req: CreateTokenRequest, user: User = Depends(get_current_user)):
"""Create new API token."""
# Generate token
@@ -264,10 +282,13 @@ def create_token(req: CreateTokenRequest, user: User = Depends(get_current_user)
token_hash_val = hash_token(token_str)
# Save to database
token = Token.create(user_id=user.id, token_hash=token_hash_val, name=req.name)
token = await create_token(user_id=user.id, token_hash=token_hash_val, name=req.name)
# Get session secret for encryption (if in web session)
session = Session.get_or_none(Session.user_id == user.id)
def _get_session():
return Session.get_or_none(Session.user_id == user.id)
session = await execute_db_query(_get_session)
session_secret = session.secret if session else None
return {
@@ -280,14 +301,17 @@ def create_token(req: CreateTokenRequest, user: User = Depends(get_current_user)
@router.delete("/tokens/{token_id}")
def revoke_token(token_id: int, user: User = Depends(get_current_user)):
async def revoke_token(token_id: int, user: User = Depends(get_current_user)):
"""Revoke an API token."""
token = Token.get_or_none((Token.id == token_id) & (Token.user_id == user.id))
def _get_token():
return Token.get_or_none((Token.id == token_id) & (Token.user_id == user.id))
token = await execute_db_query(_get_token)
if not token:
raise HTTPException(404, detail="Token not found")
token.delete_instance()
await delete_token(token.id)
return {"success": True, "message": "Token revoked successfully"}

View File

@@ -3,11 +3,16 @@
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from kohakuhub.db_async import (
execute_db_query,
get_organization,
get_user_by_username,
get_user_organization,
list_organization_members as list_org_members_async,
)
from kohakuhub.db import User, UserOrganization
from kohakuhub.auth.dependencies import get_current_user
from kohakuhub.logger import get_logger
logger = get_logger("ORG")
from kohakuhub.org.utils import (
create_organization as create_org_util,
get_organization_details as get_org_details_util,
@@ -17,6 +22,8 @@ from kohakuhub.org.utils import (
update_member_role as update_member_role_util,
)
logger = get_logger("ORG")
router = APIRouter()
@@ -26,7 +33,7 @@ class CreateOrganizationPayload(BaseModel):
@router.post("/create")
def create_organization(
async def create_organization(
payload: CreateOrganizationPayload, user: User = Depends(get_current_user)
):
"""Create a new organization."""
@@ -35,7 +42,7 @@ def create_organization(
@router.get("/{org_name}")
def get_organization(org_name: str):
async def get_organization(org_name: str):
"""Get organization details."""
org = get_org_details_util(org_name)
if not org:
@@ -53,7 +60,7 @@ class AddMemberPayload(BaseModel):
@router.post("/{org_name}/members")
def add_member(
async def add_member(
org_name: str,
payload: AddMemberPayload,
current_user: User = Depends(get_current_user),
@@ -64,10 +71,7 @@ def add_member(
raise HTTPException(404, detail="Organization not found")
# Check if the current user is an admin of the organization
user_org = UserOrganization.get_or_none(
(UserOrganization.user == current_user.id)
& (UserOrganization.organization == org.id)
)
user_org = await get_user_organization(current_user.id, org.id)
if not user_org or user_org.role not in ["admin", "super-admin"]:
raise HTTPException(403, detail="Not authorized to add members")
@@ -76,7 +80,7 @@ def add_member(
@router.delete("/{org_name}/members/{username}")
def remove_member(
async def remove_member(
org_name: str,
username: str,
current_user: User = Depends(get_current_user),
@@ -87,10 +91,7 @@ def remove_member(
raise HTTPException(404, detail="Organization not found")
# Check if the current user is an admin of the organization
user_org = UserOrganization.get_or_none(
(UserOrganization.user == current_user.id)
& (UserOrganization.organization == org.id)
)
user_org = await get_user_organization(current_user.id, org.id)
if not user_org or user_org.role not in ["admin", "super-admin"]:
raise HTTPException(403, detail="Not authorized to remove members")
@@ -99,9 +100,9 @@ def remove_member(
@router.get("/users/{username}/orgs")
def list_user_organizations(username: str):
async def list_user_organizations(username: str):
"""List organizations a user belongs to."""
user = User.get_or_none(User.username == username)
user = await get_user_by_username(username)
if not user:
raise HTTPException(404, detail="User not found")
@@ -123,7 +124,7 @@ class UpdateMemberRolePayload(BaseModel):
@router.put("/{org_name}/members/{username}")
def update_member_role(
async def update_member_role(
org_name: str,
username: str,
payload: UpdateMemberRolePayload,
@@ -135,10 +136,7 @@ def update_member_role(
raise HTTPException(404, detail="Organization not found")
# Check if the current user is an admin of the organization
user_org = UserOrganization.get_or_none(
(UserOrganization.user == current_user.id)
& (UserOrganization.organization == org.id)
)
user_org = await get_user_organization(current_user.id, org.id)
if not user_org or user_org.role not in ["admin", "super-admin"]:
raise HTTPException(403, detail="Not authorized to update member roles")
@@ -147,20 +145,14 @@ def update_member_role(
@router.get("/{org_name}/members")
def list_organization_members(org_name: str):
async def list_organization_members(org_name: str):
"""List organization members."""
from ..db import Organization
org = Organization.get_or_none(Organization.name == org_name)
org = await get_organization(org_name)
if not org:
raise HTTPException(404, detail="Organization not found")
# Get all members
members = (
UserOrganization.select()
.join(User)
.where(UserOrganization.organization == org.id)
)
members = await list_org_members_async(org.id)
return {
"members": [