mirror of
https://github.com/KohakuBlueleaf/KohakuHub.git
synced 2026-03-09 07:12:07 -05:00
fix bad usage of async function
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,4 +1,5 @@
|
||||
*_old.*
|
||||
*_refactor.*
|
||||
CLAUDE.md
|
||||
.claude/
|
||||
example.md
|
||||
|
||||
@@ -14,6 +14,7 @@ from kohakuhub.api.utils.hf import (
|
||||
from kohakuhub.api.utils.lakefs import get_lakefs_client, lakefs_repo_name
|
||||
from kohakuhub.auth.dependencies import get_current_user
|
||||
from kohakuhub.auth.permissions import check_repo_delete_permission
|
||||
from kohakuhub.db_async import get_repository
|
||||
from kohakuhub.db import Repository, User
|
||||
from kohakuhub.logger import get_logger
|
||||
|
||||
@@ -30,7 +31,7 @@ class CreateBranchPayload(BaseModel):
|
||||
|
||||
|
||||
@router.post("/{repo_type}s/{namespace}/{name}/branch")
|
||||
def create_branch(
|
||||
async def create_branch(
|
||||
repo_type: str,
|
||||
namespace: str,
|
||||
name: str,
|
||||
@@ -90,7 +91,7 @@ def create_branch(
|
||||
|
||||
|
||||
@router.delete("/{repo_type}s/{namespace}/{name}/branch/{branch}")
|
||||
def delete_branch(
|
||||
async def delete_branch(
|
||||
repo_type: str,
|
||||
namespace: str,
|
||||
name: str,
|
||||
@@ -112,9 +113,7 @@ def delete_branch(
|
||||
repo_id = f"{namespace}/{name}"
|
||||
|
||||
# Check if repository exists
|
||||
repo_row = Repository.get_or_none(
|
||||
(Repository.full_id == repo_id) & (Repository.repo_type == repo_type)
|
||||
)
|
||||
repo_row = await get_repository(repo_type, namespace, name)
|
||||
|
||||
if not repo_row:
|
||||
return hf_repo_not_found(repo_id, repo_type)
|
||||
@@ -150,7 +149,7 @@ class CreateTagPayload(BaseModel):
|
||||
|
||||
|
||||
@router.post("/{repo_type}s/{namespace}/{name}/tag")
|
||||
def create_tag(
|
||||
async def create_tag(
|
||||
repo_type: str,
|
||||
namespace: str,
|
||||
name: str,
|
||||
@@ -172,9 +171,7 @@ def create_tag(
|
||||
repo_id = f"{namespace}/{name}"
|
||||
|
||||
# Check if repository exists
|
||||
repo_row = Repository.get_or_none(
|
||||
(Repository.full_id == repo_id) & (Repository.repo_type == repo_type)
|
||||
)
|
||||
repo_row = await get_repository(repo_type, namespace, name)
|
||||
|
||||
if not repo_row:
|
||||
return hf_repo_not_found(repo_id, repo_type)
|
||||
@@ -210,7 +207,7 @@ def create_tag(
|
||||
|
||||
|
||||
@router.delete("/{repo_type}s/{namespace}/{name}/tag/{tag}")
|
||||
def delete_tag(
|
||||
async def delete_tag(
|
||||
repo_type: str,
|
||||
namespace: str,
|
||||
name: str,
|
||||
@@ -232,9 +229,7 @@ def delete_tag(
|
||||
repo_id = f"{namespace}/{name}"
|
||||
|
||||
# Check if repository exists
|
||||
repo_row = Repository.get_or_none(
|
||||
(Repository.full_id == repo_id) & (Repository.repo_type == repo_type)
|
||||
)
|
||||
repo_row = await get_repository(repo_type, namespace, name)
|
||||
|
||||
if not repo_row:
|
||||
return hf_repo_not_found(repo_id, repo_type)
|
||||
|
||||
@@ -6,9 +6,10 @@ from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from kohakuhub.api.utils.hf import hf_repo_not_found, hf_server_error
|
||||
from kohakuhub.api.utils.lakefs import get_lakefs_client, lakefs_repo_name
|
||||
from kohakuhub.async_utils import get_async_lakefs_client, run_in_executor
|
||||
from kohakuhub.async_utils import get_async_lakefs_client, run_in_lakefs_executor
|
||||
from kohakuhub.auth.dependencies import get_optional_user
|
||||
from kohakuhub.auth.permissions import check_repo_read_permission
|
||||
from kohakuhub.db_async import get_repository
|
||||
from kohakuhub.db import Repository, User
|
||||
from kohakuhub.logger import get_logger
|
||||
|
||||
@@ -43,9 +44,7 @@ async def list_commits(
|
||||
repo_id = f"{namespace}/{name}"
|
||||
|
||||
# Check if repository exists
|
||||
repo_row = Repository.get_or_none(
|
||||
(Repository.full_id == repo_id) & (Repository.repo_type == repo_type)
|
||||
)
|
||||
repo_row = await get_repository(repo_type, namespace, name)
|
||||
|
||||
if not repo_row:
|
||||
return hf_repo_not_found(repo_id, repo_type)
|
||||
@@ -70,7 +69,7 @@ async def list_commits(
|
||||
if after:
|
||||
kwargs["after"] = after
|
||||
|
||||
log_result = await run_in_executor(
|
||||
log_result = await run_in_lakefs_executor(
|
||||
client.refs.log_commits,
|
||||
lakefs_repo, # repository (positional)
|
||||
branch, # ref (positional)
|
||||
|
||||
@@ -25,6 +25,7 @@ from kohakuhub.auth.permissions import (
|
||||
check_repo_write_permission,
|
||||
)
|
||||
from kohakuhub.config import cfg
|
||||
from kohakuhub.db_async import execute_db_query, get_file, get_repository
|
||||
from kohakuhub.db import File, Repository, User
|
||||
from kohakuhub.logger import get_logger
|
||||
|
||||
@@ -73,9 +74,7 @@ async def preupload(
|
||||
"""
|
||||
repo_id = f"{namespace}/{name}"
|
||||
# Verify repository exists
|
||||
repo_row = Repository.get_or_none(
|
||||
(Repository.full_id == repo_id) & (Repository.repo_type == repo_type.value)
|
||||
)
|
||||
repo_row = await get_repository(repo_type.value, namespace, name)
|
||||
if not repo_row:
|
||||
raise HTTPException(404, detail={"error": "Repository not found"})
|
||||
|
||||
@@ -116,9 +115,7 @@ async def preupload(
|
||||
# Check for existing file with same content
|
||||
if sha256:
|
||||
# If sha256 provided, use it for comparison (most reliable)
|
||||
existing = File.get_or_none(
|
||||
(File.repo_full_id == repo_id) & (File.path_in_repo == path)
|
||||
)
|
||||
existing = await get_file(repo_id, path)
|
||||
if existing and existing.sha256 == sha256 and existing.size == size:
|
||||
should_ignore = True
|
||||
elif sample and upload_mode == "regular":
|
||||
@@ -195,9 +192,7 @@ async def get_revision(
|
||||
"""
|
||||
repo_id = f"{namespace}/{name}"
|
||||
# Check if repository exists in database first
|
||||
repo_row = Repository.get_or_none(
|
||||
(Repository.repo_type == repo_type.value) & (Repository.full_id == repo_id)
|
||||
)
|
||||
repo_row = await get_repository(repo_type.value, namespace, name)
|
||||
|
||||
if not repo_row:
|
||||
return hf_repo_not_found(repo_id, repo_type.value)
|
||||
@@ -276,9 +271,12 @@ async def _get_file_metadata(
|
||||
repo_id = f"{namespace}/{name}"
|
||||
|
||||
# Check repository exists and read permission
|
||||
repo_row = Repository.get_or_none(
|
||||
(Repository.full_id == repo_id) & (Repository.repo_type == repo_type)
|
||||
)
|
||||
def _get_repo():
|
||||
return Repository.get_or_none(
|
||||
(Repository.full_id == repo_id) & (Repository.repo_type == repo_type)
|
||||
)
|
||||
|
||||
repo_row = await execute_db_query(_get_repo)
|
||||
if repo_row:
|
||||
check_repo_read_permission(repo_row, user)
|
||||
|
||||
@@ -311,7 +309,7 @@ async def _get_file_metadata(
|
||||
bucket, key = parse_s3_uri(physical_address)
|
||||
|
||||
# Generate presigned download URL
|
||||
presigned_url = generate_download_presigned_url(
|
||||
presigned_url = await generate_download_presigned_url(
|
||||
bucket=bucket,
|
||||
key=key,
|
||||
expires_in=3600, # 1 hour
|
||||
@@ -323,9 +321,7 @@ async def _get_file_metadata(
|
||||
|
||||
# Get correct checksum from database
|
||||
# sha256 column stores: git blob SHA1 for non-LFS, SHA256 for LFS
|
||||
file_record = File.get_or_none(
|
||||
(File.repo_full_id == repo_id) & (File.path_in_repo == path)
|
||||
)
|
||||
file_record = await get_file(repo_id, path)
|
||||
|
||||
# HuggingFace expects plain SHA256 hex (64 characters, unquoted)
|
||||
# For non-LFS: use git blob SHA1, for LFS: use SHA256
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Git LFS Batch API implementation.
|
||||
"""Git LFS Batch API implementation - Refactored version.
|
||||
|
||||
This module implements the Git LFS Batch API specification for handling
|
||||
large file uploads (>10MB). It provides presigned S3 URLs for direct uploads.
|
||||
@@ -6,7 +6,7 @@ large file uploads (>10MB). It provides presigned S3 URLs for direct uploads.
|
||||
|
||||
import base64
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
@@ -18,13 +18,14 @@ from kohakuhub.api.utils.s3 import (
|
||||
get_object_metadata,
|
||||
object_exists,
|
||||
)
|
||||
from kohakuhub.auth.dependencies import get_current_user, get_optional_user
|
||||
from kohakuhub.auth.dependencies import get_optional_user
|
||||
from kohakuhub.auth.permissions import (
|
||||
check_repo_read_permission,
|
||||
check_repo_write_permission,
|
||||
)
|
||||
from kohakuhub.config import cfg
|
||||
from kohakuhub.db import File, Repository, User
|
||||
from kohakuhub.db_async import execute_db_query, get_file_by_sha256
|
||||
from kohakuhub.db import Repository, User
|
||||
from kohakuhub.logger import get_logger
|
||||
|
||||
logger = get_logger("LFS")
|
||||
@@ -42,17 +43,9 @@ class LFSBatchRequest(BaseModel):
|
||||
"""LFS batch API request."""
|
||||
|
||||
operation: str # "upload" or "download"
|
||||
transfers: Optional[List[str]] = ["basic"]
|
||||
objects: List[LFSObject]
|
||||
hash_algo: Optional[str] = "sha256"
|
||||
|
||||
|
||||
class LFSAction(BaseModel):
|
||||
"""LFS upload/download action."""
|
||||
|
||||
href: str
|
||||
header: Optional[dict] = None
|
||||
expires_at: Optional[str] = None
|
||||
transfers: list[str] | None = ["basic"]
|
||||
objects: list[LFSObject]
|
||||
hash_algo: str | None = "sha256"
|
||||
|
||||
|
||||
class LFSError(BaseModel):
|
||||
@@ -67,26 +60,172 @@ class LFSObjectResponse(BaseModel):
|
||||
|
||||
oid: str
|
||||
size: int
|
||||
authenticated: Optional[bool] = True
|
||||
actions: Optional[dict] = None # Contains "upload", "verify", "download"
|
||||
error: Optional[LFSError] = None # Must use LFSError model
|
||||
authenticated: bool | None = True
|
||||
actions: dict | None = None # Contains "upload", "verify", "download"
|
||||
error: LFSError | None = None # Must use LFSError model
|
||||
|
||||
|
||||
class LFSBatchResponse(BaseModel):
|
||||
"""LFS batch API response."""
|
||||
|
||||
transfer: str = "basic"
|
||||
objects: List[LFSObjectResponse]
|
||||
objects: list[LFSObjectResponse]
|
||||
hash_algo: str = "sha256"
|
||||
|
||||
|
||||
def get_lfs_key(oid: str) -> str:
|
||||
"""Generate S3 key for LFS object.
|
||||
|
||||
Args:
|
||||
oid: SHA256 hash
|
||||
|
||||
Returns:
|
||||
S3 key with balanced directory structure
|
||||
"""
|
||||
return f"lfs/{oid[:2]}/{oid[2:4]}/{oid}"
|
||||
|
||||
|
||||
async def process_upload_object(oid: str, size: int, repo_id: str) -> LFSObjectResponse:
|
||||
"""Process single LFS object for upload operation.
|
||||
|
||||
Args:
|
||||
oid: Object ID (SHA256)
|
||||
size: File size in bytes
|
||||
repo_id: Repository ID
|
||||
|
||||
Returns:
|
||||
LFS object response with upload actions or error
|
||||
"""
|
||||
lfs_key = get_lfs_key(oid)
|
||||
|
||||
# Check if object already exists (deduplication)
|
||||
existing = await get_file_by_sha256(oid)
|
||||
|
||||
if existing and existing.size == size:
|
||||
# Object exists, no upload needed
|
||||
return LFSObjectResponse(
|
||||
oid=oid,
|
||||
size=size,
|
||||
authenticated=True,
|
||||
# No actions = already exists
|
||||
)
|
||||
|
||||
# Check if multipart upload required (>5GB)
|
||||
multipart_threshold = 5 * 1024 * 1024 * 1024 # 5GB
|
||||
|
||||
if size > multipart_threshold:
|
||||
return LFSObjectResponse(
|
||||
oid=oid,
|
||||
size=size,
|
||||
error=LFSError(
|
||||
code=501, # Not Implemented
|
||||
message="Multipart upload not yet implemented for files >5GB",
|
||||
),
|
||||
)
|
||||
|
||||
# Single PUT upload
|
||||
try:
|
||||
# Convert SHA256 hex to base64 for S3 checksum verification
|
||||
checksum_sha256 = base64.b64encode(bytes.fromhex(oid)).decode("utf-8")
|
||||
|
||||
upload_info = await generate_upload_presigned_url(
|
||||
bucket=cfg.s3.bucket,
|
||||
key=lfs_key,
|
||||
expires_in=3600, # 1 hour
|
||||
content_type="application/octet-stream",
|
||||
checksum_sha256=checksum_sha256,
|
||||
)
|
||||
|
||||
return LFSObjectResponse(
|
||||
oid=oid,
|
||||
size=size,
|
||||
authenticated=True,
|
||||
actions={
|
||||
"upload": {
|
||||
"href": upload_info["url"],
|
||||
"expires_at": upload_info["expires_at"],
|
||||
"header": upload_info.get("headers", {}),
|
||||
},
|
||||
"verify": {
|
||||
"href": f"{cfg.app.base_url}/api/{repo_id}.git/info/lfs/verify",
|
||||
"expires_at": upload_info["expires_at"],
|
||||
},
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
return LFSObjectResponse(
|
||||
oid=oid,
|
||||
size=size,
|
||||
error=LFSError(
|
||||
code=500,
|
||||
message=f"Failed to generate upload URL: {str(e)}",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def process_download_object(oid: str, size: int) -> LFSObjectResponse:
|
||||
"""Process single LFS object for download operation.
|
||||
|
||||
Args:
|
||||
oid: Object ID (SHA256)
|
||||
size: File size in bytes
|
||||
|
||||
Returns:
|
||||
LFS object response with download actions or error
|
||||
"""
|
||||
lfs_key = get_lfs_key(oid)
|
||||
|
||||
# Check if object exists
|
||||
existing = await get_file_by_sha256(oid)
|
||||
|
||||
if not existing:
|
||||
return LFSObjectResponse(
|
||||
oid=oid,
|
||||
size=size,
|
||||
error=LFSError(code=404, message="Object not found"),
|
||||
)
|
||||
|
||||
# Object exists, provide download URL
|
||||
try:
|
||||
download_url = await generate_download_presigned_url(
|
||||
bucket=cfg.s3.bucket,
|
||||
key=lfs_key,
|
||||
expires_in=3600,
|
||||
)
|
||||
|
||||
expires_at = (
|
||||
datetime.now(timezone.utc) + timedelta(seconds=3600)
|
||||
).strftime("%Y-%m-%dT%H:%M:%S.%fZ")
|
||||
|
||||
return LFSObjectResponse(
|
||||
oid=oid,
|
||||
size=size,
|
||||
authenticated=True,
|
||||
actions={
|
||||
"download": {
|
||||
"href": download_url,
|
||||
"expires_at": expires_at,
|
||||
},
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
return LFSObjectResponse(
|
||||
oid=oid,
|
||||
size=size,
|
||||
error=LFSError(
|
||||
code=500,
|
||||
message=f"Failed to generate download URL: {str(e)}",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{repo_type}s/{namespace}/{name}.git/info/lfs/objects/batch")
|
||||
@router.post("/{namespace}/{name}.git/info/lfs/objects/batch")
|
||||
async def lfs_batch(
|
||||
namespace: str,
|
||||
name: str,
|
||||
request: Request,
|
||||
repo_type: Optional[str] = "model",
|
||||
repo_type: str | None = "model",
|
||||
user: User | None = Depends(get_optional_user),
|
||||
):
|
||||
"""Git LFS Batch API endpoint.
|
||||
@@ -98,6 +237,7 @@ async def lfs_batch(
|
||||
namespace: Repository namespace
|
||||
name: Repository name
|
||||
request: FastAPI request with LFS batch payload
|
||||
repo_type: Repository type (default: model)
|
||||
user: Current authenticated user (optional for downloads)
|
||||
|
||||
Returns:
|
||||
@@ -116,178 +256,56 @@ async def lfs_batch(
|
||||
raise HTTPException(400, detail={"error": f"Invalid LFS batch request: {e}"})
|
||||
|
||||
# Check repository exists and permissions
|
||||
# Note: We need to infer repo_type from context or use a default
|
||||
# For LFS, we'll check all repo types
|
||||
repo_row = Repository.get_or_none(
|
||||
(Repository.full_id == repo_id) & (Repository.repo_type == repo_type)
|
||||
)
|
||||
def _get_repo():
|
||||
return Repository.get_or_none(
|
||||
(Repository.full_id == repo_id) & (Repository.repo_type == repo_type)
|
||||
)
|
||||
|
||||
repo_row = await execute_db_query(_get_repo)
|
||||
|
||||
if repo_row:
|
||||
operation = batch_req.operation
|
||||
if operation == "upload":
|
||||
# Upload requires authentication and write permission
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
401, detail={"error": "Authentication required for upload"}
|
||||
)
|
||||
check_repo_write_permission(repo_row, user)
|
||||
elif operation == "download":
|
||||
# Download requires read permission (may be public)
|
||||
check_repo_read_permission(repo_row, user)
|
||||
|
||||
match operation:
|
||||
case "upload":
|
||||
# Upload requires authentication and write permission
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
401, detail={"error": "Authentication required for upload"}
|
||||
)
|
||||
check_repo_write_permission(repo_row, user)
|
||||
|
||||
case "download":
|
||||
# Download requires read permission (may be public)
|
||||
check_repo_read_permission(repo_row, user)
|
||||
|
||||
if cfg.app.debug_log_payloads:
|
||||
logger.debug("==== LFS Batch Request ====")
|
||||
logger.debug(body)
|
||||
|
||||
operation = batch_req.operation
|
||||
# Process objects using match-case
|
||||
objects_response = []
|
||||
|
||||
# LFS files stored in: s3://bucket/lfs/{oid[:2]}/{oid[2:4]}/{oid}
|
||||
# This provides a balanced directory structure
|
||||
|
||||
for obj in batch_req.objects:
|
||||
oid = obj.oid
|
||||
size = obj.size
|
||||
|
||||
# S3 key for LFS object (content-addressable)
|
||||
lfs_key = f"lfs/{oid[:2]}/{oid[2:4]}/{oid}"
|
||||
match batch_req.operation:
|
||||
case "upload":
|
||||
response_obj = await process_upload_object(oid, size, repo_id)
|
||||
objects_response.append(response_obj)
|
||||
|
||||
# Check if object already exists (deduplication)
|
||||
existing = File.get_or_none(File.sha256 == oid)
|
||||
|
||||
if operation == "upload":
|
||||
if existing and existing.size == size:
|
||||
# Object exists, no upload needed
|
||||
objects_response.append(
|
||||
LFSObjectResponse(
|
||||
oid=oid,
|
||||
size=size,
|
||||
authenticated=True,
|
||||
# No actions = already exists
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Object needs upload
|
||||
# Check if multipart upload required (>5GB)
|
||||
multipart_threshold = 5 * 1024 * 1024 * 1024 # 5GB
|
||||
|
||||
if size > multipart_threshold:
|
||||
# Return error for multipart uploads (not yet implemented)
|
||||
objects_response.append(
|
||||
LFSObjectResponse(
|
||||
oid=oid,
|
||||
size=size,
|
||||
error=LFSError(
|
||||
code=501, # Not Implemented
|
||||
message="Multipart upload not yet implemented for files >5GB",
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Single PUT upload
|
||||
try:
|
||||
# Convert SHA256 hex to base64 for S3 checksum verification
|
||||
checksum_sha256 = base64.b64encode(bytes.fromhex(oid)).decode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
upload_info = generate_upload_presigned_url(
|
||||
bucket=cfg.s3.bucket,
|
||||
key=lfs_key,
|
||||
expires_in=3600, # 1 hour
|
||||
content_type="application/octet-stream",
|
||||
checksum_sha256=checksum_sha256,
|
||||
)
|
||||
|
||||
objects_response.append(
|
||||
LFSObjectResponse(
|
||||
oid=oid,
|
||||
size=size,
|
||||
authenticated=True,
|
||||
actions={
|
||||
"upload": {
|
||||
"href": upload_info["url"],
|
||||
"expires_at": upload_info["expires_at"],
|
||||
"header": upload_info.get("headers", {}),
|
||||
},
|
||||
"verify": {
|
||||
"href": f"{cfg.app.base_url}/api/{repo_id}.git/info/lfs/verify",
|
||||
"expires_at": upload_info["expires_at"],
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
# Error generating presigned URL
|
||||
objects_response.append(
|
||||
LFSObjectResponse(
|
||||
oid=oid,
|
||||
size=size,
|
||||
error=LFSError(
|
||||
code=500,
|
||||
message=f"Failed to generate upload URL: {str(e)}",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
elif operation == "download":
|
||||
if not existing:
|
||||
# Object not found - use proper error format
|
||||
objects_response.append(
|
||||
LFSObjectResponse(
|
||||
oid=oid,
|
||||
size=size,
|
||||
error=LFSError(code=404, message="Object not found"),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Object exists, provide download URL
|
||||
try:
|
||||
download_url = generate_download_presigned_url(
|
||||
bucket=cfg.s3.bucket,
|
||||
key=lfs_key,
|
||||
expires_in=3600,
|
||||
)
|
||||
|
||||
expires_at = (
|
||||
datetime.now(timezone.utc) + timedelta(seconds=3600)
|
||||
).strftime("%Y-%m-%dT%H:%M:%S.%fZ")
|
||||
|
||||
objects_response.append(
|
||||
LFSObjectResponse(
|
||||
oid=oid,
|
||||
size=size,
|
||||
authenticated=True,
|
||||
actions={
|
||||
"download": {
|
||||
"href": download_url,
|
||||
"expires_at": expires_at,
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
# Error generating download URL
|
||||
objects_response.append(
|
||||
LFSObjectResponse(
|
||||
oid=oid,
|
||||
size=size,
|
||||
error=LFSError(
|
||||
code=500,
|
||||
message=f"Failed to generate download URL: {str(e)}",
|
||||
),
|
||||
)
|
||||
)
|
||||
case "download":
|
||||
response_obj = await process_download_object(oid, size)
|
||||
objects_response.append(response_obj)
|
||||
|
||||
# Return response with exclude_none to omit null fields
|
||||
# This ensures "error": null is not included in the JSON
|
||||
response = LFSBatchResponse(
|
||||
transfer="basic",
|
||||
objects=objects_response,
|
||||
hash_algo="sha256",
|
||||
)
|
||||
|
||||
# Use JSONResponse to ensure proper serialization with exclude_none
|
||||
return JSONResponse(
|
||||
content=response.model_dump(exclude_none=True),
|
||||
media_type="application/vnd.git-lfs+json",
|
||||
@@ -301,7 +319,8 @@ async def lfs_verify(namespace: str, name: str, request: Request):
|
||||
Called by client after successful upload to confirm the file.
|
||||
|
||||
Args:
|
||||
repo_id: Repository ID
|
||||
namespace: Repository namespace
|
||||
name: Repository name
|
||||
request: FastAPI request with verification data
|
||||
|
||||
Returns:
|
||||
@@ -319,15 +338,15 @@ async def lfs_verify(namespace: str, name: str, request: Request):
|
||||
if not oid:
|
||||
raise HTTPException(400, detail={"error": "Missing OID"})
|
||||
|
||||
lfs_key = f"lfs/{oid[:2]}/{oid[2:4]}/{oid}"
|
||||
lfs_key = get_lfs_key(oid)
|
||||
|
||||
if not object_exists(cfg.s3.bucket, lfs_key):
|
||||
if not await object_exists(cfg.s3.bucket, lfs_key):
|
||||
raise HTTPException(404, detail={"error": "Object not found in storage"})
|
||||
|
||||
# Optionally verify size
|
||||
if size:
|
||||
try:
|
||||
metadata = get_object_metadata(cfg.s3.bucket, lfs_key)
|
||||
metadata = await get_object_metadata(cfg.s3.bucket, lfs_key)
|
||||
if metadata["size"] != size:
|
||||
raise HTTPException(400, detail={"error": "Size mismatch"})
|
||||
except Exception:
|
||||
|
||||
@@ -20,6 +20,7 @@ from kohakuhub.auth.permissions import (
|
||||
check_repo_delete_permission,
|
||||
)
|
||||
from kohakuhub.config import cfg
|
||||
from kohakuhub.db_async import execute_db_query, get_repository
|
||||
from kohakuhub.db import File, Repository, StagingUpload, User, init_db
|
||||
from kohakuhub.logger import get_logger
|
||||
|
||||
@@ -41,7 +42,7 @@ class CreateRepoPayload(BaseModel):
|
||||
|
||||
|
||||
@router.post("/repos/create")
|
||||
def create_repo(payload: CreateRepoPayload, user: User = Depends(get_current_user)):
|
||||
async def create_repo(payload: CreateRepoPayload, user: User = Depends(get_current_user)):
|
||||
"""Create a new repository.
|
||||
|
||||
Args:
|
||||
@@ -62,9 +63,8 @@ def create_repo(payload: CreateRepoPayload, user: User = Depends(get_current_use
|
||||
full_id = f"{namespace}/{payload.name}"
|
||||
lakefs_repo = lakefs_repo_name(payload.type, full_id)
|
||||
|
||||
if Repository.get_or_none(
|
||||
(Repository.full_id == full_id) & (Repository.repo_type == payload.type)
|
||||
):
|
||||
existing_repo = await get_repository(payload.type, namespace, payload.name)
|
||||
if existing_repo:
|
||||
return hf_error_response(
|
||||
400,
|
||||
HFErrorCode.REPO_EXISTS,
|
||||
@@ -88,13 +88,16 @@ def create_repo(payload: CreateRepoPayload, user: User = Depends(get_current_use
|
||||
return hf_server_error(f"LakeFS repository creation failed: {str(e)}")
|
||||
|
||||
# Store in database for listing/metadata
|
||||
Repository.get_or_create(
|
||||
repo_type=payload.type,
|
||||
namespace=namespace,
|
||||
name=payload.name,
|
||||
full_id=full_id,
|
||||
defaults={"private": payload.private},
|
||||
)
|
||||
def _create_repo():
|
||||
Repository.get_or_create(
|
||||
repo_type=payload.type,
|
||||
namespace=namespace,
|
||||
name=payload.name,
|
||||
full_id=full_id,
|
||||
defaults={"private": payload.private, "owner_id": user.id},
|
||||
)
|
||||
|
||||
await execute_db_query(_create_repo)
|
||||
|
||||
return {
|
||||
"url": f"{cfg.app.base_url}/{payload.type}s/{full_id}",
|
||||
@@ -133,9 +136,7 @@ async def delete_repo(
|
||||
lakefs_repo = lakefs_repo_name(repo_type, full_id)
|
||||
|
||||
# 1. Check if repository exists in database
|
||||
repo_row = Repository.get_or_none(
|
||||
(Repository.full_id == full_id) & (Repository.repo_type == repo_type)
|
||||
)
|
||||
repo_row = await get_repository(repo_type, namespace, payload.name)
|
||||
|
||||
if not repo_row:
|
||||
# NOTE: HuggingFace client expects 400 for delete repo not found
|
||||
@@ -160,14 +161,20 @@ async def delete_repo(
|
||||
logger.info(f"LakeFS repository {lakefs_repo} not found/already deleted (OK)")
|
||||
|
||||
# 3. Delete related metadata from database (manual cascade)
|
||||
def _delete_db_records():
|
||||
try:
|
||||
# Delete related file records first
|
||||
File.delete().where(File.repo_full_id == full_id).execute()
|
||||
StagingUpload.delete().where(StagingUpload.repo_full_id == full_id).execute()
|
||||
repo_row.delete_instance()
|
||||
logger.success(f"Successfully deleted database records for: {full_id}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Database deletion failed for {full_id}", e)
|
||||
raise
|
||||
|
||||
try:
|
||||
# Delete related file records first
|
||||
File.delete().where(File.repo_full_id == full_id).execute()
|
||||
StagingUpload.delete().where(StagingUpload.repo_full_id == full_id).execute()
|
||||
repo_row.delete_instance()
|
||||
logger.success(f"Successfully deleted database records for: {full_id}")
|
||||
await execute_db_query(_delete_db_records)
|
||||
except Exception as e:
|
||||
logger.exception(f"Database deletion failed for {full_id}", e)
|
||||
return hf_server_error(f"Database deletion failed for {full_id}: {str(e)}")
|
||||
|
||||
# 4. Return success response (200 OK with a simple message)
|
||||
@@ -184,7 +191,7 @@ class MoveRepoPayload(BaseModel):
|
||||
|
||||
|
||||
@router.post("/repos/move")
|
||||
def move_repo(
|
||||
async def move_repo(
|
||||
payload: MoveRepoPayload,
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
@@ -204,9 +211,12 @@ def move_repo(
|
||||
repo_type = payload.type
|
||||
|
||||
# Check if source repository exists
|
||||
repo_row = Repository.get_or_none(
|
||||
(Repository.full_id == from_id) & (Repository.repo_type == repo_type)
|
||||
)
|
||||
from_parts = from_id.split("/", 1)
|
||||
if len(from_parts) != 2:
|
||||
return hf_error_response(400, HFErrorCode.INVALID_REPO_ID, "Invalid source repository ID")
|
||||
|
||||
from_namespace, from_name = from_parts
|
||||
repo_row = await get_repository(repo_type, from_namespace, from_name)
|
||||
|
||||
if not repo_row:
|
||||
return hf_repo_not_found(from_id, repo_type)
|
||||
@@ -215,9 +225,12 @@ def move_repo(
|
||||
check_repo_delete_permission(repo_row, user)
|
||||
|
||||
# Check if destination already exists
|
||||
existing = Repository.get_or_none(
|
||||
(Repository.full_id == to_id) & (Repository.repo_type == repo_type)
|
||||
)
|
||||
to_parts = to_id.split("/", 1)
|
||||
if len(to_parts) != 2:
|
||||
return hf_error_response(400, HFErrorCode.INVALID_REPO_ID, "Invalid destination repository ID")
|
||||
|
||||
to_namespace, to_name = to_parts
|
||||
existing = await get_repository(repo_type, to_namespace, to_name)
|
||||
if existing:
|
||||
return hf_error_response(
|
||||
400,
|
||||
@@ -225,34 +238,27 @@ def move_repo(
|
||||
f"Repository {to_id} already exists",
|
||||
)
|
||||
|
||||
# Parse destination namespace and name
|
||||
if "/" not in to_id:
|
||||
return hf_error_response(
|
||||
400,
|
||||
HFErrorCode.INVALID_REPO_ID,
|
||||
"Invalid repository ID format (must be namespace/name)",
|
||||
)
|
||||
|
||||
to_namespace, to_name = to_id.split("/", 1)
|
||||
|
||||
# Check if user has permission to use destination namespace
|
||||
check_namespace_permission(to_namespace, user)
|
||||
|
||||
# Update database records
|
||||
# Update repository record
|
||||
Repository.update(
|
||||
namespace=to_namespace,
|
||||
name=to_name,
|
||||
full_id=to_id,
|
||||
).where(Repository.id == repo_row.id).execute()
|
||||
def _update_db_records():
|
||||
# Update repository record
|
||||
Repository.update(
|
||||
namespace=to_namespace,
|
||||
name=to_name,
|
||||
full_id=to_id,
|
||||
).where(Repository.id == repo_row.id).execute()
|
||||
|
||||
# Update related file records
|
||||
File.update(repo_full_id=to_id).where(File.repo_full_id == from_id).execute()
|
||||
# Update related file records
|
||||
File.update(repo_full_id=to_id).where(File.repo_full_id == from_id).execute()
|
||||
|
||||
# Update staging uploads
|
||||
StagingUpload.update(repo_full_id=to_id).where(
|
||||
StagingUpload.repo_full_id == from_id
|
||||
).execute()
|
||||
# Update staging uploads
|
||||
StagingUpload.update(repo_full_id=to_id).where(
|
||||
StagingUpload.repo_full_id == from_id
|
||||
).execute()
|
||||
|
||||
await execute_db_query(_update_db_records)
|
||||
|
||||
# Note: LakeFS repository rename not implemented yet
|
||||
# Would require creating new LakeFS repo and migrating data
|
||||
|
||||
@@ -15,6 +15,7 @@ from kohakuhub.async_utils import get_async_lakefs_client
|
||||
from kohakuhub.auth.dependencies import get_optional_user
|
||||
from kohakuhub.auth.permissions import check_repo_read_permission
|
||||
from kohakuhub.config import cfg
|
||||
from kohakuhub.db_async import execute_db_query, get_organization, get_repository, get_user_by_username
|
||||
from kohakuhub.db import Organization, Repository, User, UserOrganization
|
||||
from kohakuhub.logger import get_logger
|
||||
|
||||
@@ -71,9 +72,7 @@ async def get_repo_info(
|
||||
)
|
||||
|
||||
# Check if repository exists in database
|
||||
repo_row = Repository.get_or_none(
|
||||
(Repository.full_id == repo_id) & (Repository.repo_type == repo_type)
|
||||
)
|
||||
repo_row = await get_repository(repo_type, namespace, repo_name)
|
||||
|
||||
if not repo_row:
|
||||
return hf_repo_not_found(repo_id, repo_type)
|
||||
@@ -217,16 +216,19 @@ async def list_repos(
|
||||
)
|
||||
|
||||
# Query database
|
||||
q = Repository.select().where(Repository.repo_type == rt)
|
||||
def _query_repos():
|
||||
q = Repository.select().where(Repository.repo_type == rt)
|
||||
|
||||
# Filter by author if specified
|
||||
if author:
|
||||
q = q.where(Repository.namespace == author)
|
||||
# Filter by author if specified
|
||||
if author:
|
||||
q = q.where(Repository.namespace == author)
|
||||
|
||||
# Apply privacy filtering
|
||||
q = _filter_repos_by_privacy(q, user, author)
|
||||
# Apply privacy filtering
|
||||
q = _filter_repos_by_privacy(q, user, author)
|
||||
|
||||
rows = q.limit(limit)
|
||||
return list(q.limit(limit))
|
||||
|
||||
rows = await execute_db_query(_query_repos)
|
||||
|
||||
# Format response with lastModified from LakeFS
|
||||
client = get_lakefs_client()
|
||||
@@ -297,10 +299,10 @@ async def list_user_repos(
|
||||
Dict with models, datasets, and spaces lists
|
||||
"""
|
||||
# Check if the username exists
|
||||
target_user = User.get_or_none(User.username == username)
|
||||
target_user = await get_user_by_username(username)
|
||||
if not target_user:
|
||||
# Could also be an organization
|
||||
target_org = Organization.get_or_none(Organization.name == username)
|
||||
target_org = await get_organization(username)
|
||||
if not target_org:
|
||||
return hf_error_response(
|
||||
404,
|
||||
@@ -315,14 +317,17 @@ async def list_user_repos(
|
||||
}
|
||||
|
||||
for repo_type in ["model", "dataset", "space"]:
|
||||
q = Repository.select().where(
|
||||
(Repository.repo_type == repo_type) & (Repository.namespace == username)
|
||||
)
|
||||
def _query_repos():
|
||||
q = Repository.select().where(
|
||||
(Repository.repo_type == repo_type) & (Repository.namespace == username)
|
||||
)
|
||||
|
||||
# Apply privacy filtering
|
||||
q = _filter_repos_by_privacy(q, user, username)
|
||||
# Apply privacy filtering
|
||||
q = _filter_repos_by_privacy(q, user, username)
|
||||
|
||||
rows = q.limit(limit)
|
||||
return list(q.limit(limit))
|
||||
|
||||
rows = await execute_db_query(_query_repos)
|
||||
|
||||
key = repo_type + "s"
|
||||
repos_list = []
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Repository tree listing and path information endpoints."""
|
||||
"""Repository tree listing and path information endpoints - Refactored version."""
|
||||
|
||||
from typing import List, Literal
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, Depends, Form
|
||||
|
||||
@@ -16,6 +17,7 @@ from kohakuhub.async_utils import get_async_lakefs_client
|
||||
from kohakuhub.auth.dependencies import get_optional_user
|
||||
from kohakuhub.auth.permissions import check_repo_read_permission
|
||||
from kohakuhub.config import cfg
|
||||
from kohakuhub.db_async import execute_db_query, get_file
|
||||
from kohakuhub.db import File, Repository, User
|
||||
from kohakuhub.logger import get_logger
|
||||
|
||||
@@ -25,6 +27,189 @@ router = APIRouter()
|
||||
RepoType = Literal["model", "dataset", "space"]
|
||||
|
||||
|
||||
async def fetch_lakefs_objects(
|
||||
lakefs_repo: str, revision: str, prefix: str, recursive: bool
|
||||
) -> list:
|
||||
"""Fetch all objects from LakeFS with pagination.
|
||||
|
||||
Args:
|
||||
lakefs_repo: LakeFS repository name
|
||||
revision: Branch or commit
|
||||
prefix: Path prefix
|
||||
recursive: Whether to list recursively
|
||||
|
||||
Returns:
|
||||
List of all LakeFS objects
|
||||
|
||||
Raises:
|
||||
Exception: If listing fails
|
||||
"""
|
||||
async_client = get_async_lakefs_client()
|
||||
|
||||
all_results = []
|
||||
after = ""
|
||||
has_more = True
|
||||
|
||||
while has_more:
|
||||
result = await async_client.list_objects(
|
||||
repository=lakefs_repo,
|
||||
ref=revision,
|
||||
prefix=prefix,
|
||||
delimiter="" if recursive else "/",
|
||||
amount=1000, # Max per request
|
||||
after=after,
|
||||
)
|
||||
|
||||
all_results.extend(result.results)
|
||||
|
||||
# Check pagination
|
||||
if result.pagination and result.pagination.has_more:
|
||||
after = result.pagination.next_offset
|
||||
has_more = True
|
||||
else:
|
||||
has_more = False
|
||||
|
||||
return all_results
|
||||
|
||||
|
||||
async def calculate_folder_stats(
|
||||
lakefs_repo: str, revision: str, folder_path: str
|
||||
) -> tuple[int, float | None]:
|
||||
"""Calculate folder size and latest modification time.
|
||||
|
||||
Args:
|
||||
lakefs_repo: LakeFS repository name
|
||||
revision: Branch or commit
|
||||
folder_path: Full folder path
|
||||
|
||||
Returns:
|
||||
Tuple of (total_size, latest_mtime)
|
||||
"""
|
||||
folder_size = 0
|
||||
folder_latest_mtime = None
|
||||
|
||||
try:
|
||||
async_client = get_async_lakefs_client()
|
||||
folder_contents = await async_client.list_objects(
|
||||
repository=lakefs_repo,
|
||||
ref=revision,
|
||||
prefix=folder_path,
|
||||
delimiter="", # No delimiter = recursive
|
||||
amount=1000,
|
||||
)
|
||||
|
||||
# Calculate total size and find latest modification
|
||||
for child_obj in folder_contents.results:
|
||||
if child_obj.path_type == "object":
|
||||
folder_size += child_obj.size_bytes or 0
|
||||
if hasattr(child_obj, "mtime") and child_obj.mtime:
|
||||
if (
|
||||
folder_latest_mtime is None
|
||||
or child_obj.mtime > folder_latest_mtime
|
||||
):
|
||||
folder_latest_mtime = child_obj.mtime
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not calculate stats for folder {folder_path}: {str(e)}")
|
||||
|
||||
return folder_size, folder_latest_mtime
|
||||
|
||||
|
||||
async def convert_file_object(
|
||||
obj, repo_id: str, prefix_len: int
|
||||
) -> dict:
|
||||
"""Convert LakeFS file object to HuggingFace format.
|
||||
|
||||
Args:
|
||||
obj: LakeFS object
|
||||
repo_id: Repository ID
|
||||
prefix_len: Length of path prefix to remove
|
||||
|
||||
Returns:
|
||||
HuggingFace formatted file object
|
||||
"""
|
||||
is_lfs = obj.size_bytes > cfg.app.lfs_threshold_bytes
|
||||
|
||||
# Remove prefix from path to get relative path
|
||||
relative_path = obj.path[prefix_len:] if prefix_len else obj.path
|
||||
|
||||
# Get correct checksum from database
|
||||
file_record = await get_file(repo_id, obj.path)
|
||||
|
||||
checksum = (
|
||||
file_record.sha256
|
||||
if file_record and file_record.sha256
|
||||
else obj.checksum
|
||||
)
|
||||
|
||||
file_obj = {
|
||||
"type": "file",
|
||||
"oid": checksum, # Git blob SHA1 for non-LFS, SHA256 for LFS
|
||||
"size": obj.size_bytes,
|
||||
"path": relative_path,
|
||||
}
|
||||
|
||||
# Add last modified info if available
|
||||
if hasattr(obj, "mtime") and obj.mtime:
|
||||
file_obj["lastModified"] = datetime.fromtimestamp(obj.mtime).strftime(
|
||||
"%Y-%m-%dT%H:%M:%S.%fZ"
|
||||
)
|
||||
|
||||
# Add LFS metadata if it's an LFS file
|
||||
if is_lfs:
|
||||
file_obj["lfs"] = {
|
||||
"oid": checksum, # SHA256 for LFS files
|
||||
"size": obj.size_bytes,
|
||||
"pointerSize": 134, # Standard Git LFS pointer size
|
||||
}
|
||||
|
||||
return file_obj
|
||||
|
||||
|
||||
async def convert_directory_object(
|
||||
obj, lakefs_repo: str, revision: str, prefix_len: int
|
||||
) -> dict:
|
||||
"""Convert LakeFS directory object to HuggingFace format.
|
||||
|
||||
Args:
|
||||
obj: LakeFS common_prefix object
|
||||
lakefs_repo: LakeFS repository name
|
||||
revision: Branch or commit
|
||||
prefix_len: Length of path prefix to remove
|
||||
|
||||
Returns:
|
||||
HuggingFace formatted directory object
|
||||
"""
|
||||
# Remove prefix from path to get relative path
|
||||
relative_path = obj.path[prefix_len:] if prefix_len else obj.path
|
||||
|
||||
# Calculate folder stats
|
||||
folder_size, folder_latest_mtime = await calculate_folder_stats(
|
||||
lakefs_repo, revision, obj.path
|
||||
)
|
||||
|
||||
dir_obj = {
|
||||
"type": "directory",
|
||||
"oid": (
|
||||
obj.checksum if hasattr(obj, "checksum") and obj.checksum else ""
|
||||
),
|
||||
"size": folder_size,
|
||||
"path": relative_path.rstrip("/"), # Remove trailing slash
|
||||
}
|
||||
|
||||
# Add last modified info
|
||||
if folder_latest_mtime:
|
||||
dir_obj["lastModified"] = datetime.fromtimestamp(
|
||||
folder_latest_mtime
|
||||
).strftime("%Y-%m-%dT%H:%M:%S.%fZ")
|
||||
elif hasattr(obj, "mtime") and obj.mtime:
|
||||
dir_obj["lastModified"] = datetime.fromtimestamp(obj.mtime).strftime(
|
||||
"%Y-%m-%dT%H:%M:%S.%fZ"
|
||||
)
|
||||
|
||||
return dir_obj
|
||||
|
||||
|
||||
@router.get("/{repo_type}s/{namespace}/{repo_name}/tree/{revision}{path:path}")
|
||||
async def list_repo_tree(
|
||||
repo_type: RepoType,
|
||||
@@ -40,18 +225,6 @@ async def list_repo_tree(
|
||||
|
||||
Returns a flat list of files and folders in HuggingFace format.
|
||||
|
||||
Response format matches HuggingFace API:
|
||||
[
|
||||
{
|
||||
"type": "file", # or "directory"
|
||||
"oid": "sha256_hash",
|
||||
"size": 1234,
|
||||
"path": "relative/path/to/file.txt",
|
||||
"lfs": {"oid": "...", "size": ..., "pointerSize": ...} # if LFS file
|
||||
},
|
||||
...
|
||||
]
|
||||
|
||||
Args:
|
||||
repo_type: Type of repository
|
||||
namespace: Repository namespace
|
||||
@@ -69,9 +242,12 @@ async def list_repo_tree(
|
||||
repo_id = f"{namespace}/{repo_name}"
|
||||
|
||||
# Check if repository exists
|
||||
repo_row = Repository.get_or_none(
|
||||
(Repository.full_id == repo_id) & (Repository.repo_type == repo_type)
|
||||
)
|
||||
def _get_repo():
|
||||
return Repository.get_or_none(
|
||||
(Repository.full_id == repo_id) & (Repository.repo_type == repo_type)
|
||||
)
|
||||
|
||||
repo_row = await execute_db_query(_get_repo)
|
||||
|
||||
if not repo_row:
|
||||
return hf_repo_not_found(repo_id, repo_type)
|
||||
@@ -86,34 +262,11 @@ async def list_repo_tree(
|
||||
if prefix and not prefix.endswith("/"):
|
||||
prefix += "/"
|
||||
|
||||
# Fetch all objects from LakeFS
|
||||
try:
|
||||
# List objects from LakeFS with pagination support
|
||||
async_client = get_async_lakefs_client()
|
||||
|
||||
# Collect all results with pagination
|
||||
all_results = []
|
||||
after = ""
|
||||
has_more = True
|
||||
|
||||
while has_more:
|
||||
result = await async_client.list_objects(
|
||||
repository=lakefs_repo,
|
||||
ref=revision,
|
||||
prefix=prefix,
|
||||
delimiter="" if recursive else "/",
|
||||
amount=1000, # Max per request
|
||||
after=after,
|
||||
)
|
||||
|
||||
all_results.extend(result.results)
|
||||
|
||||
# Check pagination
|
||||
if result.pagination and result.pagination.has_more:
|
||||
after = result.pagination.next_offset
|
||||
has_more = True
|
||||
else:
|
||||
has_more = False
|
||||
|
||||
all_results = await fetch_lakefs_objects(
|
||||
lakefs_repo, revision, prefix, recursive
|
||||
)
|
||||
except Exception as e:
|
||||
# Check for specific error types
|
||||
if is_lakefs_not_found_error(e):
|
||||
@@ -127,115 +280,23 @@ async def list_repo_tree(
|
||||
logger.exception(f"Failed to list objects for {repo_id}", e)
|
||||
return hf_server_error(f"Failed to list objects: {str(e)}")
|
||||
|
||||
# Convert LakeFS objects to HuggingFace format (flat list)
|
||||
# Convert LakeFS objects to HuggingFace format
|
||||
result_list = []
|
||||
prefix_len = len(prefix)
|
||||
|
||||
for obj in all_results:
|
||||
if obj.path_type == "object":
|
||||
# File object
|
||||
is_lfs = obj.size_bytes > cfg.app.lfs_threshold_bytes
|
||||
match obj.path_type:
|
||||
case "object":
|
||||
# File object
|
||||
file_obj = await convert_file_object(obj, repo_id, prefix_len)
|
||||
result_list.append(file_obj)
|
||||
|
||||
# Remove prefix from path to get relative path
|
||||
relative_path = obj.path[prefix_len:] if prefix else obj.path
|
||||
|
||||
# Get correct checksum from database
|
||||
# sha256 column stores: git blob SHA1 for non-LFS, SHA256 for LFS
|
||||
file_record = File.get_or_none(
|
||||
(File.repo_full_id == repo_id) & (File.path_in_repo == obj.path)
|
||||
)
|
||||
|
||||
checksum = (
|
||||
file_record.sha256
|
||||
if file_record and file_record.sha256
|
||||
else obj.checksum
|
||||
)
|
||||
|
||||
file_obj = {
|
||||
"type": "file",
|
||||
"oid": checksum, # Git blob SHA1 for non-LFS, SHA256 for LFS
|
||||
"size": obj.size_bytes,
|
||||
"path": relative_path,
|
||||
}
|
||||
|
||||
# Add last modified info if available from LakeFS
|
||||
if hasattr(obj, "mtime") and obj.mtime:
|
||||
from datetime import datetime
|
||||
|
||||
file_obj["lastModified"] = datetime.fromtimestamp(obj.mtime).strftime(
|
||||
"%Y-%m-%dT%H:%M:%S.%fZ"
|
||||
case "common_prefix":
|
||||
# Directory object
|
||||
dir_obj = await convert_directory_object(
|
||||
obj, lakefs_repo, revision, prefix_len
|
||||
)
|
||||
|
||||
# Add LFS metadata if it's an LFS file
|
||||
if is_lfs:
|
||||
file_obj["lfs"] = {
|
||||
"oid": checksum, # SHA256 for LFS files
|
||||
"size": obj.size_bytes,
|
||||
"pointerSize": 134, # Standard Git LFS pointer size
|
||||
}
|
||||
|
||||
result_list.append(file_obj)
|
||||
|
||||
elif obj.path_type == "common_prefix":
|
||||
# Directory object
|
||||
# Remove prefix from path to get relative path
|
||||
relative_path = obj.path[prefix_len:] if prefix else obj.path
|
||||
|
||||
# Calculate folder stats by listing its contents recursively
|
||||
folder_size = 0
|
||||
folder_latest_mtime = None
|
||||
|
||||
try:
|
||||
# List all objects in this folder recursively
|
||||
async_client = get_async_lakefs_client()
|
||||
folder_contents = await async_client.list_objects(
|
||||
repository=lakefs_repo,
|
||||
ref=revision,
|
||||
prefix=obj.path, # Use full path as prefix
|
||||
delimiter="", # No delimiter = recursive
|
||||
amount=1000,
|
||||
)
|
||||
|
||||
# Calculate total size and find latest modification
|
||||
for child_obj in folder_contents.results:
|
||||
if child_obj.path_type == "object":
|
||||
folder_size += child_obj.size_bytes or 0
|
||||
if hasattr(child_obj, "mtime") and child_obj.mtime:
|
||||
if (
|
||||
folder_latest_mtime is None
|
||||
or child_obj.mtime > folder_latest_mtime
|
||||
):
|
||||
folder_latest_mtime = child_obj.mtime
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Could not calculate stats for folder {obj.path}: {str(e)}"
|
||||
)
|
||||
|
||||
dir_obj = {
|
||||
"type": "directory",
|
||||
"oid": (
|
||||
obj.checksum if hasattr(obj, "checksum") and obj.checksum else ""
|
||||
),
|
||||
"size": folder_size,
|
||||
"path": relative_path.rstrip("/"), # Remove trailing slash
|
||||
}
|
||||
|
||||
# Add last modified info
|
||||
if folder_latest_mtime:
|
||||
from datetime import datetime
|
||||
|
||||
dir_obj["lastModified"] = datetime.fromtimestamp(
|
||||
folder_latest_mtime
|
||||
).strftime("%Y-%m-%dT%H:%M:%S.%fZ")
|
||||
elif hasattr(obj, "mtime") and obj.mtime:
|
||||
from datetime import datetime
|
||||
|
||||
dir_obj["lastModified"] = datetime.fromtimestamp(obj.mtime).strftime(
|
||||
"%Y-%m-%dT%H:%M:%S.%fZ"
|
||||
)
|
||||
|
||||
result_list.append(dir_obj)
|
||||
result_list.append(dir_obj)
|
||||
|
||||
return result_list
|
||||
|
||||
@@ -246,18 +307,13 @@ async def get_paths_info(
|
||||
namespace: str,
|
||||
repo_name: str,
|
||||
revision: str,
|
||||
paths: List[str] = Form(...),
|
||||
paths: list[str] = Form(...),
|
||||
expand: bool = Form(False),
|
||||
user: User | None = Depends(get_optional_user),
|
||||
):
|
||||
"""Get information about specific paths in a repository.
|
||||
|
||||
This endpoint matches HuggingFace Hub API format:
|
||||
POST /api/{repo_type}s/{namespace}/{repo_name}/paths-info/{revision}
|
||||
|
||||
Form data:
|
||||
paths: List of paths to query
|
||||
expand: Whether to include extended metadata
|
||||
This endpoint matches HuggingFace Hub API format.
|
||||
|
||||
Args:
|
||||
repo_type: Type of repository (model/dataset/space)
|
||||
@@ -275,9 +331,12 @@ async def get_paths_info(
|
||||
repo_id = f"{namespace}/{repo_name}"
|
||||
|
||||
# Check if repository exists
|
||||
repo_row = Repository.get_or_none(
|
||||
(Repository.full_id == repo_id) & (Repository.repo_type == repo_type)
|
||||
)
|
||||
def _get_repo():
|
||||
return Repository.get_or_none(
|
||||
(Repository.full_id == repo_id) & (Repository.repo_type == repo_type)
|
||||
)
|
||||
|
||||
repo_row = await execute_db_query(_get_repo)
|
||||
|
||||
if not repo_row:
|
||||
return hf_repo_not_found(repo_id, repo_type)
|
||||
@@ -307,10 +366,7 @@ async def get_paths_info(
|
||||
is_lfs = obj_stats.size_bytes > cfg.app.lfs_threshold_bytes
|
||||
|
||||
# Get correct checksum from database
|
||||
# sha256 column stores: git blob SHA1 for non-LFS, SHA256 for LFS
|
||||
file_record = File.get_or_none(
|
||||
(File.repo_full_id == repo_id) & (File.path_in_repo == clean_path)
|
||||
)
|
||||
file_record = await get_file(repo_id, clean_path)
|
||||
|
||||
checksum = (
|
||||
file_record.sha256
|
||||
|
||||
@@ -8,6 +8,7 @@ from pydantic import BaseModel, EmailStr
|
||||
from kohakuhub.api.utils.hf import hf_repo_not_found
|
||||
from kohakuhub.auth.dependencies import get_current_user
|
||||
from kohakuhub.auth.permissions import check_repo_delete_permission
|
||||
from kohakuhub.db_async import execute_db_query, get_organization, get_repository, get_user_organization
|
||||
from kohakuhub.db import Organization, Repository, User, UserOrganization
|
||||
from kohakuhub.logger import get_logger
|
||||
|
||||
@@ -28,7 +29,7 @@ class UpdateUserSettingsRequest(BaseModel):
|
||||
|
||||
|
||||
@router.put("/users/{username}/settings")
|
||||
def update_user_settings(
|
||||
async def update_user_settings(
|
||||
username: str,
|
||||
req: UpdateUserSettingsRequest,
|
||||
user: User = Depends(get_current_user),
|
||||
@@ -50,13 +51,19 @@ def update_user_settings(
|
||||
# Update fields if provided
|
||||
if req.email is not None:
|
||||
# Check if email is already taken by another user
|
||||
existing = User.get_or_none((User.email == req.email) & (User.id != user.id))
|
||||
def _check_email():
|
||||
return User.get_or_none((User.email == req.email) & (User.id != user.id))
|
||||
|
||||
existing = await execute_db_query(_check_email)
|
||||
if existing:
|
||||
raise HTTPException(400, detail="Email already in use")
|
||||
|
||||
User.update(email=req.email, email_verified=False).where(
|
||||
User.id == user.id
|
||||
).execute()
|
||||
def _update_email():
|
||||
User.update(email=req.email, email_verified=False).where(
|
||||
User.id == user.id
|
||||
).execute()
|
||||
|
||||
await execute_db_query(_update_email)
|
||||
# TODO: Send new verification email
|
||||
|
||||
return {"success": True, "message": "User settings updated successfully"}
|
||||
@@ -72,7 +79,7 @@ class UpdateOrganizationSettingsRequest(BaseModel):
|
||||
|
||||
|
||||
@router.put("/organizations/{org_name}/settings")
|
||||
def update_organization_settings(
|
||||
async def update_organization_settings(
|
||||
org_name: str,
|
||||
req: UpdateOrganizationSettingsRequest,
|
||||
user: User = Depends(get_current_user),
|
||||
@@ -87,14 +94,12 @@ def update_organization_settings(
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
org = Organization.get_or_none(Organization.name == org_name)
|
||||
org = await get_organization(org_name)
|
||||
if not org:
|
||||
raise HTTPException(404, detail="Organization not found")
|
||||
|
||||
# Check if user is admin of the organization
|
||||
user_org = UserOrganization.get_or_none(
|
||||
(UserOrganization.user == user.id) & (UserOrganization.organization == org.id)
|
||||
)
|
||||
user_org = await get_user_organization(user.id, org.id)
|
||||
if not user_org or user_org.role not in ["admin", "super-admin"]:
|
||||
raise HTTPException(
|
||||
403, detail="Not authorized to update organization settings"
|
||||
@@ -102,9 +107,12 @@ def update_organization_settings(
|
||||
|
||||
# Update fields if provided
|
||||
if req.description is not None:
|
||||
Organization.update(description=req.description).where(
|
||||
Organization.id == org.id
|
||||
).execute()
|
||||
def _update_org():
|
||||
Organization.update(description=req.description).where(
|
||||
Organization.id == org.id
|
||||
).execute()
|
||||
|
||||
await execute_db_query(_update_org)
|
||||
|
||||
return {"success": True, "message": "Organization settings updated successfully"}
|
||||
|
||||
@@ -122,7 +130,7 @@ class UpdateRepoSettingsPayload(BaseModel):
|
||||
|
||||
|
||||
@router.put("/{repo_type}s/{namespace}/{name}/settings")
|
||||
def update_repo_settings(
|
||||
async def update_repo_settings(
|
||||
repo_type: str,
|
||||
namespace: str,
|
||||
name: str,
|
||||
@@ -146,9 +154,7 @@ def update_repo_settings(
|
||||
repo_id = f"{namespace}/{name}"
|
||||
|
||||
# Check if repository exists
|
||||
repo_row = Repository.get_or_none(
|
||||
(Repository.full_id == repo_id) & (Repository.repo_type == repo_type)
|
||||
)
|
||||
repo_row = await get_repository(repo_type, namespace, name)
|
||||
|
||||
if not repo_row:
|
||||
return hf_repo_not_found(repo_id, repo_type)
|
||||
@@ -158,9 +164,12 @@ def update_repo_settings(
|
||||
|
||||
# Update fields if provided
|
||||
if payload.private is not None:
|
||||
Repository.update(private=payload.private).where(
|
||||
Repository.id == repo_row.id
|
||||
).execute()
|
||||
def _update_private():
|
||||
Repository.update(private=payload.private).where(
|
||||
Repository.id == repo_row.id
|
||||
).execute()
|
||||
|
||||
await execute_db_query(_update_private)
|
||||
|
||||
# Note: gated functionality not yet implemented in database schema
|
||||
# Would require adding a 'gated' field to Repository model
|
||||
|
||||
@@ -6,6 +6,24 @@ from fastapi import APIRouter, HTTPException, Response, Depends
|
||||
from pydantic import BaseModel, EmailStr
|
||||
|
||||
from ..config import cfg
|
||||
from ..db_async import (
|
||||
create_email_verification,
|
||||
create_session,
|
||||
create_token,
|
||||
create_user,
|
||||
delete_email_verification,
|
||||
delete_session,
|
||||
delete_token,
|
||||
execute_db_query,
|
||||
get_email_verification,
|
||||
get_session,
|
||||
get_token_by_hash,
|
||||
get_user_by_email,
|
||||
get_user_by_id,
|
||||
get_user_by_username,
|
||||
list_user_tokens,
|
||||
update_user,
|
||||
)
|
||||
from ..db import User, EmailVerification, Session, Token
|
||||
from ..logger import get_logger
|
||||
|
||||
@@ -41,22 +59,22 @@ class CreateTokenRequest(BaseModel):
|
||||
|
||||
|
||||
@router.post("/register")
|
||||
def register(req: RegisterRequest):
|
||||
async def register(req: RegisterRequest):
|
||||
"""Register new user."""
|
||||
|
||||
logger.info(f"Registration attempt for username: {req.username}")
|
||||
|
||||
# Check if username or email already exists
|
||||
if User.get_or_none(User.username == req.username):
|
||||
if await get_user_by_username(req.username):
|
||||
logger.warning(f"Registration failed: username '{req.username}' already exists")
|
||||
raise HTTPException(400, detail="Username already exists")
|
||||
|
||||
if User.get_or_none(User.email == req.email):
|
||||
if await get_user_by_email(req.email):
|
||||
logger.warning(f"Registration failed: email '{req.email}' already exists")
|
||||
raise HTTPException(400, detail="Email already exists")
|
||||
|
||||
# Create user
|
||||
user = User.create(
|
||||
user = await create_user(
|
||||
username=req.username,
|
||||
email=req.email,
|
||||
password_hash=hash_password(req.password),
|
||||
@@ -68,8 +86,8 @@ def register(req: RegisterRequest):
|
||||
# Send verification email if required
|
||||
if cfg.auth.require_email_verification:
|
||||
token = generate_token()
|
||||
EmailVerification.create(
|
||||
user=user.id, token=token, expires_at=get_expiry_time(24)
|
||||
await create_email_verification(
|
||||
user_id=user.id, token=token, expires_at=get_expiry_time(24)
|
||||
)
|
||||
|
||||
if not send_verification_email(req.email, req.username, token):
|
||||
@@ -93,18 +111,15 @@ def register(req: RegisterRequest):
|
||||
|
||||
|
||||
@router.get("/verify-email")
|
||||
def verify_email(token: str, response: Response):
|
||||
async def verify_email(token: str, response: Response):
|
||||
"""Verify email with token and automatically log in user."""
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
logger.info(f"Email verification attempt with token: {token[:8]}...")
|
||||
|
||||
verification = EmailVerification.get_or_none(
|
||||
(EmailVerification.token == token)
|
||||
& (EmailVerification.expires_at > datetime.now(timezone.utc))
|
||||
)
|
||||
verification = await get_email_verification(token)
|
||||
|
||||
if not verification:
|
||||
if not verification or verification.expires_at <= datetime.now(timezone.utc):
|
||||
logger.warning(f"Invalid or expired verification token: {token[:8]}...")
|
||||
# Redirect to login with error message
|
||||
return RedirectResponse(
|
||||
@@ -113,16 +128,16 @@ def verify_email(token: str, response: Response):
|
||||
)
|
||||
|
||||
# Get user
|
||||
user = User.get_or_none(User.id == verification.user)
|
||||
user = await get_user_by_id(verification.user)
|
||||
if not user:
|
||||
logger.error(f"User not found for verification token: {token[:8]}...")
|
||||
return RedirectResponse(url="/?error=user_not_found", status_code=302)
|
||||
|
||||
# Update user email verification status
|
||||
User.update(email_verified=True).where(User.id == verification.user).execute()
|
||||
await update_user(user, email_verified=True)
|
||||
|
||||
# Delete verification token
|
||||
EmailVerification.delete().where(EmailVerification.id == verification.id).execute()
|
||||
await delete_email_verification(verification)
|
||||
|
||||
logger.success(f"Email verified for user: {user.username}")
|
||||
|
||||
@@ -130,7 +145,7 @@ def verify_email(token: str, response: Response):
|
||||
session_id = generate_token()
|
||||
session_secret = generate_session_secret()
|
||||
|
||||
Session.create(
|
||||
await create_session(
|
||||
session_id=session_id,
|
||||
user_id=user.id,
|
||||
secret=session_secret,
|
||||
@@ -155,12 +170,12 @@ def verify_email(token: str, response: Response):
|
||||
|
||||
|
||||
@router.post("/login")
|
||||
def login(req: LoginRequest, response: Response):
|
||||
async def login(req: LoginRequest, response: Response):
|
||||
"""Login and create session."""
|
||||
|
||||
logger.info(f"Login attempt for user: {req.username}")
|
||||
|
||||
user = User.get_or_none(User.username == req.username)
|
||||
user = await get_user_by_username(req.username)
|
||||
|
||||
if not user or not verify_password(req.password, user.password_hash):
|
||||
logger.warning(f"Failed login attempt for: {req.username}")
|
||||
@@ -178,7 +193,7 @@ def login(req: LoginRequest, response: Response):
|
||||
session_id = generate_token()
|
||||
session_secret = generate_session_secret()
|
||||
|
||||
Session.create(
|
||||
await create_session(
|
||||
session_id=session_id,
|
||||
user_id=user.id,
|
||||
secret=session_secret,
|
||||
@@ -205,13 +220,16 @@ def login(req: LoginRequest, response: Response):
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
def logout(response: Response, user: User = Depends(get_current_user)):
|
||||
async def logout(response: Response, user: User = Depends(get_current_user)):
|
||||
"""Logout and destroy session."""
|
||||
|
||||
logger.info(f"Logout request for user: {user.username}")
|
||||
|
||||
# Delete all user sessions
|
||||
deleted_count = Session.delete().where(Session.user_id == user.id).execute()
|
||||
def _delete_sessions():
|
||||
return Session.delete().where(Session.user_id == user.id).execute()
|
||||
|
||||
deleted_count = await execute_db_query(_delete_sessions)
|
||||
|
||||
# Clear cookie
|
||||
response.delete_cookie(key="session_id")
|
||||
@@ -237,10 +255,10 @@ def get_me(user: User = Depends(get_current_user)):
|
||||
|
||||
|
||||
@router.get("/tokens")
|
||||
def list_tokens(user: User = Depends(get_current_user)):
|
||||
async def list_tokens(user: User = Depends(get_current_user)):
|
||||
"""List user's API tokens."""
|
||||
|
||||
tokens = Token.select().where(Token.user_id == user.id)
|
||||
tokens = await list_user_tokens(user.id)
|
||||
|
||||
return {
|
||||
"tokens": [
|
||||
@@ -256,7 +274,7 @@ def list_tokens(user: User = Depends(get_current_user)):
|
||||
|
||||
|
||||
@router.post("/tokens/create")
|
||||
def create_token(req: CreateTokenRequest, user: User = Depends(get_current_user)):
|
||||
async def create_token_endpoint(req: CreateTokenRequest, user: User = Depends(get_current_user)):
|
||||
"""Create new API token."""
|
||||
|
||||
# Generate token
|
||||
@@ -264,10 +282,13 @@ def create_token(req: CreateTokenRequest, user: User = Depends(get_current_user)
|
||||
token_hash_val = hash_token(token_str)
|
||||
|
||||
# Save to database
|
||||
token = Token.create(user_id=user.id, token_hash=token_hash_val, name=req.name)
|
||||
token = await create_token(user_id=user.id, token_hash=token_hash_val, name=req.name)
|
||||
|
||||
# Get session secret for encryption (if in web session)
|
||||
session = Session.get_or_none(Session.user_id == user.id)
|
||||
def _get_session():
|
||||
return Session.get_or_none(Session.user_id == user.id)
|
||||
|
||||
session = await execute_db_query(_get_session)
|
||||
session_secret = session.secret if session else None
|
||||
|
||||
return {
|
||||
@@ -280,14 +301,17 @@ def create_token(req: CreateTokenRequest, user: User = Depends(get_current_user)
|
||||
|
||||
|
||||
@router.delete("/tokens/{token_id}")
|
||||
def revoke_token(token_id: int, user: User = Depends(get_current_user)):
|
||||
async def revoke_token(token_id: int, user: User = Depends(get_current_user)):
|
||||
"""Revoke an API token."""
|
||||
|
||||
token = Token.get_or_none((Token.id == token_id) & (Token.user_id == user.id))
|
||||
def _get_token():
|
||||
return Token.get_or_none((Token.id == token_id) & (Token.user_id == user.id))
|
||||
|
||||
token = await execute_db_query(_get_token)
|
||||
|
||||
if not token:
|
||||
raise HTTPException(404, detail="Token not found")
|
||||
|
||||
token.delete_instance()
|
||||
await delete_token(token.id)
|
||||
|
||||
return {"success": True, "message": "Token revoked successfully"}
|
||||
|
||||
@@ -3,11 +3,16 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from kohakuhub.db_async import (
|
||||
execute_db_query,
|
||||
get_organization,
|
||||
get_user_by_username,
|
||||
get_user_organization,
|
||||
list_organization_members as list_org_members_async,
|
||||
)
|
||||
from kohakuhub.db import User, UserOrganization
|
||||
from kohakuhub.auth.dependencies import get_current_user
|
||||
from kohakuhub.logger import get_logger
|
||||
|
||||
logger = get_logger("ORG")
|
||||
from kohakuhub.org.utils import (
|
||||
create_organization as create_org_util,
|
||||
get_organization_details as get_org_details_util,
|
||||
@@ -17,6 +22,8 @@ from kohakuhub.org.utils import (
|
||||
update_member_role as update_member_role_util,
|
||||
)
|
||||
|
||||
logger = get_logger("ORG")
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@@ -26,7 +33,7 @@ class CreateOrganizationPayload(BaseModel):
|
||||
|
||||
|
||||
@router.post("/create")
|
||||
def create_organization(
|
||||
async def create_organization(
|
||||
payload: CreateOrganizationPayload, user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Create a new organization."""
|
||||
@@ -35,7 +42,7 @@ def create_organization(
|
||||
|
||||
|
||||
@router.get("/{org_name}")
|
||||
def get_organization(org_name: str):
|
||||
async def get_organization(org_name: str):
|
||||
"""Get organization details."""
|
||||
org = get_org_details_util(org_name)
|
||||
if not org:
|
||||
@@ -53,7 +60,7 @@ class AddMemberPayload(BaseModel):
|
||||
|
||||
|
||||
@router.post("/{org_name}/members")
|
||||
def add_member(
|
||||
async def add_member(
|
||||
org_name: str,
|
||||
payload: AddMemberPayload,
|
||||
current_user: User = Depends(get_current_user),
|
||||
@@ -64,10 +71,7 @@ def add_member(
|
||||
raise HTTPException(404, detail="Organization not found")
|
||||
|
||||
# Check if the current user is an admin of the organization
|
||||
user_org = UserOrganization.get_or_none(
|
||||
(UserOrganization.user == current_user.id)
|
||||
& (UserOrganization.organization == org.id)
|
||||
)
|
||||
user_org = await get_user_organization(current_user.id, org.id)
|
||||
if not user_org or user_org.role not in ["admin", "super-admin"]:
|
||||
raise HTTPException(403, detail="Not authorized to add members")
|
||||
|
||||
@@ -76,7 +80,7 @@ def add_member(
|
||||
|
||||
|
||||
@router.delete("/{org_name}/members/{username}")
|
||||
def remove_member(
|
||||
async def remove_member(
|
||||
org_name: str,
|
||||
username: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
@@ -87,10 +91,7 @@ def remove_member(
|
||||
raise HTTPException(404, detail="Organization not found")
|
||||
|
||||
# Check if the current user is an admin of the organization
|
||||
user_org = UserOrganization.get_or_none(
|
||||
(UserOrganization.user == current_user.id)
|
||||
& (UserOrganization.organization == org.id)
|
||||
)
|
||||
user_org = await get_user_organization(current_user.id, org.id)
|
||||
if not user_org or user_org.role not in ["admin", "super-admin"]:
|
||||
raise HTTPException(403, detail="Not authorized to remove members")
|
||||
|
||||
@@ -99,9 +100,9 @@ def remove_member(
|
||||
|
||||
|
||||
@router.get("/users/{username}/orgs")
|
||||
def list_user_organizations(username: str):
|
||||
async def list_user_organizations(username: str):
|
||||
"""List organizations a user belongs to."""
|
||||
user = User.get_or_none(User.username == username)
|
||||
user = await get_user_by_username(username)
|
||||
if not user:
|
||||
raise HTTPException(404, detail="User not found")
|
||||
|
||||
@@ -123,7 +124,7 @@ class UpdateMemberRolePayload(BaseModel):
|
||||
|
||||
|
||||
@router.put("/{org_name}/members/{username}")
|
||||
def update_member_role(
|
||||
async def update_member_role(
|
||||
org_name: str,
|
||||
username: str,
|
||||
payload: UpdateMemberRolePayload,
|
||||
@@ -135,10 +136,7 @@ def update_member_role(
|
||||
raise HTTPException(404, detail="Organization not found")
|
||||
|
||||
# Check if the current user is an admin of the organization
|
||||
user_org = UserOrganization.get_or_none(
|
||||
(UserOrganization.user == current_user.id)
|
||||
& (UserOrganization.organization == org.id)
|
||||
)
|
||||
user_org = await get_user_organization(current_user.id, org.id)
|
||||
if not user_org or user_org.role not in ["admin", "super-admin"]:
|
||||
raise HTTPException(403, detail="Not authorized to update member roles")
|
||||
|
||||
@@ -147,20 +145,14 @@ def update_member_role(
|
||||
|
||||
|
||||
@router.get("/{org_name}/members")
|
||||
def list_organization_members(org_name: str):
|
||||
async def list_organization_members(org_name: str):
|
||||
"""List organization members."""
|
||||
from ..db import Organization
|
||||
|
||||
org = Organization.get_or_none(Organization.name == org_name)
|
||||
org = await get_organization(org_name)
|
||||
if not org:
|
||||
raise HTTPException(404, detail="Organization not found")
|
||||
|
||||
# Get all members
|
||||
members = (
|
||||
UserOrganization.select()
|
||||
.join(User)
|
||||
.where(UserOrganization.organization == org.id)
|
||||
)
|
||||
members = await list_org_members_async(org.id)
|
||||
|
||||
return {
|
||||
"members": [
|
||||
|
||||
Reference in New Issue
Block a user