diff --git a/src/kohakuhub/api/files.py b/src/kohakuhub/api/files.py index f817f61..3e08ac7 100644 --- a/src/kohakuhub/api/files.py +++ b/src/kohakuhub/api/files.py @@ -28,7 +28,7 @@ from kohakuhub.auth.permissions import ( check_repo_read_permission, check_repo_write_permission, ) -from kohakuhub.utils.lakefs import get_lakefs_client, lakefs_repo_name +from kohakuhub.utils.lakefs import get_lakefs_client, lakefs_repo_name, resolve_revision from kohakuhub.utils.s3 import generate_download_presigned_url, parse_s3_uri from kohakuhub.api.fallback import with_repo_fallback from kohakuhub.api.xet import XET_ENABLE @@ -41,7 +41,6 @@ from kohakuhub.api.repo.utils.hf import ( hf_repo_not_found, hf_revision_not_found, hf_server_error, - is_lakefs_not_found_error, ) logger = get_logger("FILE") @@ -308,26 +307,13 @@ async def get_revision( lakefs_repo = lakefs_repo_name(repo_type.value, repo_id) client = get_lakefs_client() - # Get branch information + # Resolve revision (supports both branch names and commit hashes) try: - branch = await client.get_branch(repository=lakefs_repo, branch=revision) + commit_id, commit_info = await resolve_revision(client, lakefs_repo, revision) + except ValueError: + return hf_revision_not_found(repo_id, revision) except Exception as e: - if is_lakefs_not_found_error(e): - return hf_revision_not_found(repo_id, revision) - return hf_server_error(f"Failed to get branch: {str(e)}") - - commit_id = branch["commit_id"] - commit_info = None - - # Get commit details if available - if commit_id: - try: - commit_info = await client.get_commit( - repository=lakefs_repo, commit_id=commit_id - ) - except Exception as e: - # Log but don't fail if commit info unavailable - logger.warning(f"Could not get commit info: {e}") + return hf_server_error(f"Failed to resolve revision: {str(e)}") # Format last modified date last_modified = None diff --git a/src/kohakuhub/utils/lakefs.py b/src/kohakuhub/utils/lakefs.py index 8d834d9..d8ca0aa 100644 --- a/src/kohakuhub/utils/lakefs.py +++ b/src/kohakuhub/utils/lakefs.py @@ -17,6 +17,58 @@ def get_lakefs_client() -> LakeFSRestClient: return get_lakefs_rest_client() +async def resolve_revision( + client: LakeFSRestClient, lakefs_repo: str, revision: str +) -> tuple[str, dict | None]: + """Resolve a revision (branch name or commit hash) to commit ID and info. + + HuggingFace datasets library and other clients may use either branch names + (e.g., "main") or commit hashes as revision identifiers. This function + handles both cases by first trying to resolve as a branch, then as a commit. + + Args: + client: LakeFS REST client instance + lakefs_repo: LakeFS repository name + revision: Branch name or commit hash + + Returns: + Tuple of (commit_id, commit_info dict or None) + + Raises: + ValueError: If revision cannot be resolved as either branch or commit + """ + # Try resolving as a branch first + try: + branch = await client.get_branch(repository=lakefs_repo, branch=revision) + commit_id = branch["commit_id"] + # Get commit details + try: + commit_info = await client.get_commit( + repository=lakefs_repo, commit_id=commit_id + ) + except Exception: + commit_info = None + return commit_id, commit_info + except Exception as branch_error: + # Check if it's a "not found" error (branch doesn't exist) + error_str = str(branch_error).lower() + if "404" not in error_str and "not found" not in error_str: + # Some other error, re-raise + raise branch_error + + # Branch not found, try resolving as a commit hash + try: + commit_info = await client.get_commit( + repository=lakefs_repo, commit_id=revision + ) + return commit_info["id"], commit_info + except Exception as commit_error: + # Neither branch nor commit found + raise ValueError( + f"Revision '{revision}' not found as branch or commit" + ) from commit_error + + def _base36_encode(num: int) -> str: """Encode integer to base36 using numpy (C-optimized).