feat(lakefs): add resolve_revision function to handle branch names and commit hashes

This commit is contained in:
migo
2026-02-04 11:04:58 +08:00
parent f41728fa60
commit 59a8a1646d
2 changed files with 58 additions and 20 deletions

View File

@@ -28,7 +28,7 @@ from kohakuhub.auth.permissions import (
check_repo_read_permission,
check_repo_write_permission,
)
from kohakuhub.utils.lakefs import get_lakefs_client, lakefs_repo_name
from kohakuhub.utils.lakefs import get_lakefs_client, lakefs_repo_name, resolve_revision
from kohakuhub.utils.s3 import generate_download_presigned_url, parse_s3_uri
from kohakuhub.api.fallback import with_repo_fallback
from kohakuhub.api.xet import XET_ENABLE
@@ -41,7 +41,6 @@ from kohakuhub.api.repo.utils.hf import (
hf_repo_not_found,
hf_revision_not_found,
hf_server_error,
is_lakefs_not_found_error,
)
logger = get_logger("FILE")
@@ -308,26 +307,13 @@ async def get_revision(
lakefs_repo = lakefs_repo_name(repo_type.value, repo_id)
client = get_lakefs_client()
# Get branch information
# Resolve revision (supports both branch names and commit hashes)
try:
branch = await client.get_branch(repository=lakefs_repo, branch=revision)
commit_id, commit_info = await resolve_revision(client, lakefs_repo, revision)
except ValueError:
return hf_revision_not_found(repo_id, revision)
except Exception as e:
if is_lakefs_not_found_error(e):
return hf_revision_not_found(repo_id, revision)
return hf_server_error(f"Failed to get branch: {str(e)}")
commit_id = branch["commit_id"]
commit_info = None
# Get commit details if available
if commit_id:
try:
commit_info = await client.get_commit(
repository=lakefs_repo, commit_id=commit_id
)
except Exception as e:
# Log but don't fail if commit info unavailable
logger.warning(f"Could not get commit info: {e}")
return hf_server_error(f"Failed to resolve revision: {str(e)}")
# Format last modified date
last_modified = None

View File

@@ -17,6 +17,58 @@ def get_lakefs_client() -> LakeFSRestClient:
return get_lakefs_rest_client()
async def resolve_revision(
client: LakeFSRestClient, lakefs_repo: str, revision: str
) -> tuple[str, dict | None]:
"""Resolve a revision (branch name or commit hash) to commit ID and info.
HuggingFace datasets library and other clients may use either branch names
(e.g., "main") or commit hashes as revision identifiers. This function
handles both cases by first trying to resolve as a branch, then as a commit.
Args:
client: LakeFS REST client instance
lakefs_repo: LakeFS repository name
revision: Branch name or commit hash
Returns:
Tuple of (commit_id, commit_info dict or None)
Raises:
ValueError: If revision cannot be resolved as either branch or commit
"""
# Try resolving as a branch first
try:
branch = await client.get_branch(repository=lakefs_repo, branch=revision)
commit_id = branch["commit_id"]
# Get commit details
try:
commit_info = await client.get_commit(
repository=lakefs_repo, commit_id=commit_id
)
except Exception:
commit_info = None
return commit_id, commit_info
except Exception as branch_error:
# Check if it's a "not found" error (branch doesn't exist)
error_str = str(branch_error).lower()
if "404" not in error_str and "not found" not in error_str:
# Some other error, re-raise
raise branch_error
# Branch not found, try resolving as a commit hash
try:
commit_info = await client.get_commit(
repository=lakefs_repo, commit_id=revision
)
return commit_info["id"], commit_info
except Exception as commit_error:
# Neither branch nor commit found
raise ValueError(
f"Revision '{revision}' not found as branch or commit"
) from commit_error
def _base36_encode(num: int) -> str:
"""Encode integer to base36 using numpy (C-optimized).