diff --git a/docs/api/tree.md b/docs/api/tree.md index b2b3107..6abb2e5 100644 --- a/docs/api/tree.md +++ b/docs/api/tree.md @@ -65,7 +65,7 @@ Returns a flat list of file and folder objects in HuggingFace-compatible format. **Field Descriptions:** - `type`: Object type (`file` or `directory`) -- `path`: Relative path from the specified prefix +- `path`: Relative path from the repository root - `size`: File size in bytes (for directories, sum of all contents) - `oid`: Object identifier (SHA256 for LFS files, SHA1 for regular files, tree hash for directories) - `lastModified`: ISO 8601 timestamp of last modification diff --git a/src/kohakuhub/api/files.py b/src/kohakuhub/api/files.py index f817f61..3e08ac7 100644 --- a/src/kohakuhub/api/files.py +++ b/src/kohakuhub/api/files.py @@ -28,7 +28,7 @@ from kohakuhub.auth.permissions import ( check_repo_read_permission, check_repo_write_permission, ) -from kohakuhub.utils.lakefs import get_lakefs_client, lakefs_repo_name +from kohakuhub.utils.lakefs import get_lakefs_client, lakefs_repo_name, resolve_revision from kohakuhub.utils.s3 import generate_download_presigned_url, parse_s3_uri from kohakuhub.api.fallback import with_repo_fallback from kohakuhub.api.xet import XET_ENABLE @@ -41,7 +41,6 @@ from kohakuhub.api.repo.utils.hf import ( hf_repo_not_found, hf_revision_not_found, hf_server_error, - is_lakefs_not_found_error, ) logger = get_logger("FILE") @@ -308,26 +307,13 @@ async def get_revision( lakefs_repo = lakefs_repo_name(repo_type.value, repo_id) client = get_lakefs_client() - # Get branch information + # Resolve revision (supports both branch names and commit hashes) try: - branch = await client.get_branch(repository=lakefs_repo, branch=revision) + commit_id, commit_info = await resolve_revision(client, lakefs_repo, revision) + except ValueError: + return hf_revision_not_found(repo_id, revision) except Exception as e: - if is_lakefs_not_found_error(e): - return hf_revision_not_found(repo_id, revision) - return hf_server_error(f"Failed to get branch: {str(e)}") - - commit_id = branch["commit_id"] - commit_info = None - - # Get commit details if available - if commit_id: - try: - commit_info = await client.get_commit( - repository=lakefs_repo, commit_id=commit_id - ) - except Exception as e: - # Log but don't fail if commit info unavailable - logger.warning(f"Could not get commit info: {e}") + return hf_server_error(f"Failed to resolve revision: {str(e)}") # Format last modified date last_modified = None diff --git a/src/kohakuhub/api/repo/routers/tree.py b/src/kohakuhub/api/repo/routers/tree.py index 9033d3f..55d8cc1 100644 --- a/src/kohakuhub/api/repo/routers/tree.py +++ b/src/kohakuhub/api/repo/routers/tree.py @@ -130,13 +130,12 @@ async def calculate_folder_stats( return folder_size, folder_latest_mtime -async def convert_file_object(obj, repository: Repository, prefix_len: int) -> dict: +async def convert_file_object(obj, repository: Repository) -> dict: """Convert LakeFS file object to HuggingFace format. Args: obj: LakeFS object dict repository: Repository object (FK) - prefix_len: Length of path prefix to remove Returns: HuggingFace formatted file object @@ -144,8 +143,8 @@ async def convert_file_object(obj, repository: Repository, prefix_len: int) -> d # Use repo-specific LFS settings is_lfs = should_use_lfs(repository, obj["path"], obj["size_bytes"]) - # Remove prefix from path to get relative path - relative_path = obj["path"][prefix_len:] if prefix_len else obj["path"] + # Use full path relative to repository root (HuggingFace spec) + file_path = obj["path"] # Get correct checksum from database using repository FK file_record = get_file(repository, obj["path"]) @@ -158,7 +157,7 @@ async def convert_file_object(obj, repository: Repository, prefix_len: int) -> d "type": "file", "oid": checksum, # Git blob SHA1 for non-LFS, SHA256 for LFS "size": obj["size_bytes"], - "path": relative_path, + "path": file_path, } # Add last modified info if available @@ -179,7 +178,7 @@ async def convert_file_object(obj, repository: Repository, prefix_len: int) -> d async def convert_directory_object( - obj, lakefs_repo: str, revision: str, prefix_len: int + obj, lakefs_repo: str, revision: str ) -> dict: """Convert LakeFS directory object to HuggingFace format. @@ -187,13 +186,12 @@ async def convert_directory_object( obj: LakeFS common_prefix object dict lakefs_repo: LakeFS repository name revision: Branch or commit - prefix_len: Length of path prefix to remove Returns: HuggingFace formatted directory object """ - # Remove prefix from path to get relative path - relative_path = obj["path"][prefix_len:] if prefix_len else obj["path"] + # Use full path relative to repository root (HuggingFace spec) + dir_path = obj["path"] # Calculate folder stats folder_size, folder_latest_mtime = await calculate_folder_stats( @@ -204,7 +202,7 @@ async def convert_directory_object( "type": "directory", "oid": obj.get("checksum", ""), "size": folder_size, - "path": relative_path.rstrip("/"), # Remove trailing slash + "path": dir_path.rstrip("/"), # Remove trailing slash } # Add last modified info @@ -290,19 +288,18 @@ async def list_repo_tree( # Convert LakeFS objects to HuggingFace format result_list = [] - prefix_len = len(prefix) for obj in all_results: match obj["path_type"]: case "object": # File object - pass Repository FK instead of repo_id - file_obj = await convert_file_object(obj, repo_row, prefix_len) + file_obj = await convert_file_object(obj, repo_row) result_list.append(file_obj) case "common_prefix": # Directory object dir_obj = await convert_directory_object( - obj, lakefs_repo, revision, prefix_len + obj, lakefs_repo, revision ) result_list.append(dir_obj) diff --git a/src/kohakuhub/utils/lakefs.py b/src/kohakuhub/utils/lakefs.py index 8d834d9..d8ca0aa 100644 --- a/src/kohakuhub/utils/lakefs.py +++ b/src/kohakuhub/utils/lakefs.py @@ -17,6 +17,58 @@ def get_lakefs_client() -> LakeFSRestClient: return get_lakefs_rest_client() +async def resolve_revision( + client: LakeFSRestClient, lakefs_repo: str, revision: str +) -> tuple[str, dict | None]: + """Resolve a revision (branch name or commit hash) to commit ID and info. + + HuggingFace datasets library and other clients may use either branch names + (e.g., "main") or commit hashes as revision identifiers. This function + handles both cases by first trying to resolve as a branch, then as a commit. + + Args: + client: LakeFS REST client instance + lakefs_repo: LakeFS repository name + revision: Branch name or commit hash + + Returns: + Tuple of (commit_id, commit_info dict or None) + + Raises: + ValueError: If revision cannot be resolved as either branch or commit + """ + # Try resolving as a branch first + try: + branch = await client.get_branch(repository=lakefs_repo, branch=revision) + commit_id = branch["commit_id"] + # Get commit details + try: + commit_info = await client.get_commit( + repository=lakefs_repo, commit_id=commit_id + ) + except Exception: + commit_info = None + return commit_id, commit_info + except Exception as branch_error: + # Check if it's a "not found" error (branch doesn't exist) + error_str = str(branch_error).lower() + if "404" not in error_str and "not found" not in error_str: + # Some other error, re-raise + raise branch_error + + # Branch not found, try resolving as a commit hash + try: + commit_info = await client.get_commit( + repository=lakefs_repo, commit_id=revision + ) + return commit_info["id"], commit_info + except Exception as commit_error: + # Neither branch nor commit found + raise ValueError( + f"Revision '{revision}' not found as branch or commit" + ) from commit_error + + def _base36_encode(num: int) -> str: """Encode integer to base36 using numpy (C-optimized).