mirror of
https://github.com/KohakuBlueleaf/KohakuHub.git
synced 2026-03-11 17:34:08 -05:00
Fix lakefs and s3 problem
This commit is contained in:
@@ -48,6 +48,7 @@ services:
|
||||
- KOHAKU_HUB_S3_ACCESS_KEY=minioadmin
|
||||
- KOHAKU_HUB_S3_SECRET_KEY=minioadmin
|
||||
- KOHAKU_HUB_S3_BUCKET=hub-storage
|
||||
- KOHAKU_HUB_S3_SIGNATURE_VERSION=s3v2 # s3v2 for MinIO, s3v4 for R2/AWS S3
|
||||
|
||||
## ===== LakeFS Configuration =====
|
||||
- KOHAKU_HUB_LAKEFS_ENDPOINT=http://lakefs:28000
|
||||
|
||||
@@ -272,7 +272,8 @@ def generate_hub_api_service(config: dict) -> str:
|
||||
- KOHAKU_HUB_S3_ACCESS_KEY={config['s3_access_key']}
|
||||
- KOHAKU_HUB_S3_SECRET_KEY={config['s3_secret_key']}
|
||||
- KOHAKU_HUB_S3_BUCKET=hub-storage
|
||||
|
||||
- KOHAKU_HUB_S3_SIGNATURE_VERSION={config.get('s3_signature_version', 's3v2')} # s3v2 for MinIO, s3v4 for R2/AWS S3
|
||||
{s3_region_env}
|
||||
## ===== LakeFS Configuration =====
|
||||
- KOHAKU_HUB_LAKEFS_ENDPOINT=http://lakefs:28000
|
||||
- KOHAKU_HUB_LAKEFS_REPO_NAMESPACE=hf
|
||||
@@ -416,11 +417,15 @@ def load_config_file(config_path: Path) -> dict:
|
||||
"secret_key", fallback=generate_secret(48)
|
||||
) # 64 chars
|
||||
config["s3_region"] = s3.get("region", fallback="")
|
||||
config["s3_signature_version"] = s3.get(
|
||||
"signature_version", fallback="s3v2" if config["s3_builtin"] else "s3v4"
|
||||
) # s3v2 for MinIO, s3v4 for R2/AWS S3
|
||||
else:
|
||||
config["s3_builtin"] = True
|
||||
config["s3_endpoint"] = "http://minio:9000"
|
||||
config["s3_access_key"] = generate_secret(24) # 32 chars
|
||||
config["s3_secret_key"] = generate_secret(48) # 64 chars
|
||||
config["s3_signature_version"] = "s3v2" # Default for MinIO
|
||||
|
||||
# Security section
|
||||
if parser.has_section("security"):
|
||||
@@ -483,11 +488,13 @@ builtin = true
|
||||
# access_key = your-access-key
|
||||
# secret_key = your-secret-key
|
||||
# region = us-east-1
|
||||
# signature_version = s3v4 # s3v2 for MinIO, s3v4 for R2/AWS S3
|
||||
|
||||
# If builtin = true, MinIO credentials are auto-generated (recommended)
|
||||
# You can override by uncommenting and setting custom values:
|
||||
# access_key = your-custom-access-key
|
||||
# secret_key = your-custom-secret-key
|
||||
# signature_version = s3v2
|
||||
|
||||
[security]
|
||||
# Session and admin secrets (auto-generated if not specified)
|
||||
@@ -634,12 +641,22 @@ def interactive_config() -> dict:
|
||||
config["s3_secret_key"] = ask_string("MinIO secret key")
|
||||
|
||||
config["s3_endpoint"] = "http://minio:9000"
|
||||
config["s3_signature_version"] = "s3v2" # MinIO uses s3v2
|
||||
else:
|
||||
config["s3_endpoint"] = ask_string("S3 endpoint URL")
|
||||
config["s3_access_key"] = ask_string("S3 access key")
|
||||
config["s3_secret_key"] = ask_string("S3 secret key")
|
||||
config["s3_region"] = ask_string("S3 region", default="us-east-1")
|
||||
|
||||
# Ask about signature version for external S3
|
||||
print()
|
||||
print("Signature version:")
|
||||
print(" - s3v2: MinIO (legacy)")
|
||||
print(" - s3v4: Cloudflare R2, AWS S3 (recommended)")
|
||||
config["s3_signature_version"] = ask_string(
|
||||
"S3 signature version (s3v2 or s3v4)", default="s3v4"
|
||||
)
|
||||
|
||||
print()
|
||||
|
||||
# Security Configuration
|
||||
|
||||
@@ -90,24 +90,37 @@ async def calculate_folder_stats(
|
||||
|
||||
try:
|
||||
client = get_lakefs_client()
|
||||
folder_contents = await client.list_objects(
|
||||
repository=lakefs_repo,
|
||||
ref=revision,
|
||||
prefix=folder_path,
|
||||
delimiter="", # No delimiter = recursive
|
||||
amount=1000,
|
||||
)
|
||||
|
||||
# Calculate total size and find latest modification
|
||||
for child_obj in folder_contents["results"]:
|
||||
if child_obj["path_type"] == "object":
|
||||
folder_size += child_obj.get("size_bytes") or 0
|
||||
if child_obj.get("mtime"):
|
||||
if (
|
||||
folder_latest_mtime is None
|
||||
or child_obj["mtime"] > folder_latest_mtime
|
||||
):
|
||||
folder_latest_mtime = child_obj["mtime"]
|
||||
# Paginate through all objects in folder
|
||||
after = ""
|
||||
has_more = True
|
||||
|
||||
while has_more:
|
||||
folder_contents = await client.list_objects(
|
||||
repository=lakefs_repo,
|
||||
ref=revision,
|
||||
prefix=folder_path,
|
||||
delimiter="", # No delimiter = recursive
|
||||
amount=1000,
|
||||
after=after,
|
||||
)
|
||||
|
||||
# Calculate total size and find latest modification
|
||||
for child_obj in folder_contents["results"]:
|
||||
if child_obj["path_type"] == "object":
|
||||
folder_size += child_obj.get("size_bytes") or 0
|
||||
if child_obj.get("mtime"):
|
||||
if (
|
||||
folder_latest_mtime is None
|
||||
or child_obj["mtime"] > folder_latest_mtime
|
||||
):
|
||||
folder_latest_mtime = child_obj["mtime"]
|
||||
|
||||
# Check pagination
|
||||
pagination = folder_contents.get("pagination", {})
|
||||
has_more = pagination.get("has_more", False)
|
||||
if has_more:
|
||||
after = pagination.get("next_offset", "")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not calculate stats for folder {folder_path}: {str(e)}")
|
||||
|
||||
@@ -352,24 +352,39 @@ async def check_commit_range_recoverability(
|
||||
|
||||
# Get all commits from target to HEAD
|
||||
try:
|
||||
commits_result = await client.log_commits(
|
||||
repository=lakefs_repo,
|
||||
ref=current_branch,
|
||||
amount=1000, # Should be enough for most cases
|
||||
)
|
||||
|
||||
commit_list = commits_result.get("results", [])
|
||||
|
||||
# Find target commit in the list
|
||||
# Paginate through commits to find target
|
||||
commit_list = []
|
||||
after = ""
|
||||
has_more = True
|
||||
target_index = None
|
||||
for i, commit in enumerate(commit_list):
|
||||
if commit["id"] == target_commit:
|
||||
target_index = i
|
||||
break
|
||||
|
||||
while has_more and target_index is None:
|
||||
commits_result = await client.log_commits(
|
||||
repository=lakefs_repo,
|
||||
ref=current_branch,
|
||||
amount=1000, # LakeFS maximum
|
||||
after=after,
|
||||
)
|
||||
|
||||
results = commits_result.get("results", [])
|
||||
current_batch_start = len(commit_list)
|
||||
commit_list.extend(results)
|
||||
|
||||
# Find target commit in current batch
|
||||
for i, commit in enumerate(results):
|
||||
if commit["id"] == target_commit:
|
||||
target_index = current_batch_start + i
|
||||
break
|
||||
|
||||
# Check pagination
|
||||
pagination = commits_result.get("pagination", {})
|
||||
has_more = pagination.get("has_more", False)
|
||||
if has_more:
|
||||
after = pagination.get("next_offset", "")
|
||||
|
||||
if target_index is None:
|
||||
logger.warning(
|
||||
f"Target commit {target_commit[:8]} not found in branch history"
|
||||
f"Target commit {target_commit[:8]} not found in branch history (checked {len(commit_list)} commits)"
|
||||
)
|
||||
return False, [], []
|
||||
|
||||
@@ -457,13 +472,29 @@ async def sync_file_table_with_commit(
|
||||
commit_id = branch_info["commit_id"]
|
||||
|
||||
# Get ALL objects at the commit (use commit ID to avoid staging issues)
|
||||
list_result = await client.list_objects(
|
||||
repository=lakefs_repo,
|
||||
ref=commit_id, # Use commit ID, not branch name!
|
||||
amount=10000, # Large enough for most repos
|
||||
)
|
||||
# LakeFS has max amount=1000, so we need to paginate
|
||||
all_objects = []
|
||||
after = ""
|
||||
has_more = True
|
||||
|
||||
all_objects = list_result.get("results", [])
|
||||
while has_more:
|
||||
list_result = await client.list_objects(
|
||||
repository=lakefs_repo,
|
||||
ref=commit_id, # Use commit ID, not branch name!
|
||||
amount=1000, # LakeFS maximum
|
||||
after=after,
|
||||
)
|
||||
|
||||
results = list_result.get("results", [])
|
||||
all_objects.extend(results)
|
||||
|
||||
# Check pagination
|
||||
pagination = list_result.get("pagination", {})
|
||||
has_more = pagination.get("has_more", False)
|
||||
if has_more:
|
||||
after = pagination.get("next_offset", "")
|
||||
|
||||
logger.info(f"Paginated through LakeFS, got {len(all_objects)} total objects")
|
||||
logger.info(
|
||||
f"Syncing {len(all_objects)} file(s) from ref {ref} (commit {commit_id[:8]})"
|
||||
)
|
||||
|
||||
@@ -18,6 +18,7 @@ class S3Config(BaseModel):
|
||||
bucket: str = "test-bucket"
|
||||
region: str = "us-east-1"
|
||||
force_path_style: bool = True
|
||||
signature_version: str = "s3v4" # s3v4 (R2, AWS S3) or s3v2 (MinIO)
|
||||
|
||||
|
||||
class LakeFSConfig(BaseModel):
|
||||
|
||||
@@ -13,10 +13,16 @@ logger = get_logger("S3")
|
||||
|
||||
|
||||
def get_s3_client():
|
||||
"""Create configured S3 client with SigV4 signing.
|
||||
"""Create configured S3 client with configurable signature version.
|
||||
|
||||
Signature versions:
|
||||
- s3v4: AWS S3, Cloudflare R2 (default, more secure)
|
||||
- s3v2: MinIO (legacy, required for some MinIO setups)
|
||||
|
||||
Set via KOHAKU_HUB_S3_SIGNATURE_VERSION environment variable.
|
||||
|
||||
Returns:
|
||||
Configured boto3 S3 client using Signature Version 4.
|
||||
Configured boto3 S3 client.
|
||||
"""
|
||||
# Build S3-specific config
|
||||
s3_config = {}
|
||||
@@ -29,13 +35,16 @@ def get_s3_client():
|
||||
if cfg.s3.endpoint and ("/" in cfg.s3.endpoint.split("//", 1)[1]):
|
||||
# Endpoint has path - treat it as bucket endpoint
|
||||
s3_config["use_accelerate_endpoint"] = False
|
||||
logger.info(
|
||||
logger.debug(
|
||||
"S3 endpoint contains path - using bucket_endpoint mode for R2 compatibility"
|
||||
)
|
||||
|
||||
# Always use Signature Version 4 (more secure than deprecated SigV2)
|
||||
# Use configured signature version (s3v4 or s3v2)
|
||||
sig_version = cfg.s3.signature_version
|
||||
logger.debug(f"Using S3 signature version: {sig_version}")
|
||||
|
||||
boto_config = BotoConfig(
|
||||
signature_version="s3v4",
|
||||
signature_version=sig_version,
|
||||
s3=s3_config,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user