mirror of
https://github.com/KohakuBlueleaf/KohakuHub.git
synced 2026-04-30 09:28:35 -05:00
Fix lakefs and s3 problem
This commit is contained in:
@@ -48,6 +48,7 @@ services:
|
|||||||
- KOHAKU_HUB_S3_ACCESS_KEY=minioadmin
|
- KOHAKU_HUB_S3_ACCESS_KEY=minioadmin
|
||||||
- KOHAKU_HUB_S3_SECRET_KEY=minioadmin
|
- KOHAKU_HUB_S3_SECRET_KEY=minioadmin
|
||||||
- KOHAKU_HUB_S3_BUCKET=hub-storage
|
- KOHAKU_HUB_S3_BUCKET=hub-storage
|
||||||
|
- KOHAKU_HUB_S3_SIGNATURE_VERSION=s3v2 # s3v2 for MinIO, s3v4 for R2/AWS S3
|
||||||
|
|
||||||
## ===== LakeFS Configuration =====
|
## ===== LakeFS Configuration =====
|
||||||
- KOHAKU_HUB_LAKEFS_ENDPOINT=http://lakefs:28000
|
- KOHAKU_HUB_LAKEFS_ENDPOINT=http://lakefs:28000
|
||||||
|
|||||||
@@ -272,7 +272,8 @@ def generate_hub_api_service(config: dict) -> str:
|
|||||||
- KOHAKU_HUB_S3_ACCESS_KEY={config['s3_access_key']}
|
- KOHAKU_HUB_S3_ACCESS_KEY={config['s3_access_key']}
|
||||||
- KOHAKU_HUB_S3_SECRET_KEY={config['s3_secret_key']}
|
- KOHAKU_HUB_S3_SECRET_KEY={config['s3_secret_key']}
|
||||||
- KOHAKU_HUB_S3_BUCKET=hub-storage
|
- KOHAKU_HUB_S3_BUCKET=hub-storage
|
||||||
|
- KOHAKU_HUB_S3_SIGNATURE_VERSION={config.get('s3_signature_version', 's3v2')} # s3v2 for MinIO, s3v4 for R2/AWS S3
|
||||||
|
{s3_region_env}
|
||||||
## ===== LakeFS Configuration =====
|
## ===== LakeFS Configuration =====
|
||||||
- KOHAKU_HUB_LAKEFS_ENDPOINT=http://lakefs:28000
|
- KOHAKU_HUB_LAKEFS_ENDPOINT=http://lakefs:28000
|
||||||
- KOHAKU_HUB_LAKEFS_REPO_NAMESPACE=hf
|
- KOHAKU_HUB_LAKEFS_REPO_NAMESPACE=hf
|
||||||
@@ -416,11 +417,15 @@ def load_config_file(config_path: Path) -> dict:
|
|||||||
"secret_key", fallback=generate_secret(48)
|
"secret_key", fallback=generate_secret(48)
|
||||||
) # 64 chars
|
) # 64 chars
|
||||||
config["s3_region"] = s3.get("region", fallback="")
|
config["s3_region"] = s3.get("region", fallback="")
|
||||||
|
config["s3_signature_version"] = s3.get(
|
||||||
|
"signature_version", fallback="s3v2" if config["s3_builtin"] else "s3v4"
|
||||||
|
) # s3v2 for MinIO, s3v4 for R2/AWS S3
|
||||||
else:
|
else:
|
||||||
config["s3_builtin"] = True
|
config["s3_builtin"] = True
|
||||||
config["s3_endpoint"] = "http://minio:9000"
|
config["s3_endpoint"] = "http://minio:9000"
|
||||||
config["s3_access_key"] = generate_secret(24) # 32 chars
|
config["s3_access_key"] = generate_secret(24) # 32 chars
|
||||||
config["s3_secret_key"] = generate_secret(48) # 64 chars
|
config["s3_secret_key"] = generate_secret(48) # 64 chars
|
||||||
|
config["s3_signature_version"] = "s3v2" # Default for MinIO
|
||||||
|
|
||||||
# Security section
|
# Security section
|
||||||
if parser.has_section("security"):
|
if parser.has_section("security"):
|
||||||
@@ -483,11 +488,13 @@ builtin = true
|
|||||||
# access_key = your-access-key
|
# access_key = your-access-key
|
||||||
# secret_key = your-secret-key
|
# secret_key = your-secret-key
|
||||||
# region = us-east-1
|
# region = us-east-1
|
||||||
|
# signature_version = s3v4 # s3v2 for MinIO, s3v4 for R2/AWS S3
|
||||||
|
|
||||||
# If builtin = true, MinIO credentials are auto-generated (recommended)
|
# If builtin = true, MinIO credentials are auto-generated (recommended)
|
||||||
# You can override by uncommenting and setting custom values:
|
# You can override by uncommenting and setting custom values:
|
||||||
# access_key = your-custom-access-key
|
# access_key = your-custom-access-key
|
||||||
# secret_key = your-custom-secret-key
|
# secret_key = your-custom-secret-key
|
||||||
|
# signature_version = s3v2
|
||||||
|
|
||||||
[security]
|
[security]
|
||||||
# Session and admin secrets (auto-generated if not specified)
|
# Session and admin secrets (auto-generated if not specified)
|
||||||
@@ -634,12 +641,22 @@ def interactive_config() -> dict:
|
|||||||
config["s3_secret_key"] = ask_string("MinIO secret key")
|
config["s3_secret_key"] = ask_string("MinIO secret key")
|
||||||
|
|
||||||
config["s3_endpoint"] = "http://minio:9000"
|
config["s3_endpoint"] = "http://minio:9000"
|
||||||
|
config["s3_signature_version"] = "s3v2" # MinIO uses s3v2
|
||||||
else:
|
else:
|
||||||
config["s3_endpoint"] = ask_string("S3 endpoint URL")
|
config["s3_endpoint"] = ask_string("S3 endpoint URL")
|
||||||
config["s3_access_key"] = ask_string("S3 access key")
|
config["s3_access_key"] = ask_string("S3 access key")
|
||||||
config["s3_secret_key"] = ask_string("S3 secret key")
|
config["s3_secret_key"] = ask_string("S3 secret key")
|
||||||
config["s3_region"] = ask_string("S3 region", default="us-east-1")
|
config["s3_region"] = ask_string("S3 region", default="us-east-1")
|
||||||
|
|
||||||
|
# Ask about signature version for external S3
|
||||||
|
print()
|
||||||
|
print("Signature version:")
|
||||||
|
print(" - s3v2: MinIO (legacy)")
|
||||||
|
print(" - s3v4: Cloudflare R2, AWS S3 (recommended)")
|
||||||
|
config["s3_signature_version"] = ask_string(
|
||||||
|
"S3 signature version (s3v2 or s3v4)", default="s3v4"
|
||||||
|
)
|
||||||
|
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# Security Configuration
|
# Security Configuration
|
||||||
|
|||||||
@@ -90,24 +90,37 @@ async def calculate_folder_stats(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
client = get_lakefs_client()
|
client = get_lakefs_client()
|
||||||
folder_contents = await client.list_objects(
|
|
||||||
repository=lakefs_repo,
|
|
||||||
ref=revision,
|
|
||||||
prefix=folder_path,
|
|
||||||
delimiter="", # No delimiter = recursive
|
|
||||||
amount=1000,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate total size and find latest modification
|
# Paginate through all objects in folder
|
||||||
for child_obj in folder_contents["results"]:
|
after = ""
|
||||||
if child_obj["path_type"] == "object":
|
has_more = True
|
||||||
folder_size += child_obj.get("size_bytes") or 0
|
|
||||||
if child_obj.get("mtime"):
|
while has_more:
|
||||||
if (
|
folder_contents = await client.list_objects(
|
||||||
folder_latest_mtime is None
|
repository=lakefs_repo,
|
||||||
or child_obj["mtime"] > folder_latest_mtime
|
ref=revision,
|
||||||
):
|
prefix=folder_path,
|
||||||
folder_latest_mtime = child_obj["mtime"]
|
delimiter="", # No delimiter = recursive
|
||||||
|
amount=1000,
|
||||||
|
after=after,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate total size and find latest modification
|
||||||
|
for child_obj in folder_contents["results"]:
|
||||||
|
if child_obj["path_type"] == "object":
|
||||||
|
folder_size += child_obj.get("size_bytes") or 0
|
||||||
|
if child_obj.get("mtime"):
|
||||||
|
if (
|
||||||
|
folder_latest_mtime is None
|
||||||
|
or child_obj["mtime"] > folder_latest_mtime
|
||||||
|
):
|
||||||
|
folder_latest_mtime = child_obj["mtime"]
|
||||||
|
|
||||||
|
# Check pagination
|
||||||
|
pagination = folder_contents.get("pagination", {})
|
||||||
|
has_more = pagination.get("has_more", False)
|
||||||
|
if has_more:
|
||||||
|
after = pagination.get("next_offset", "")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Could not calculate stats for folder {folder_path}: {str(e)}")
|
logger.debug(f"Could not calculate stats for folder {folder_path}: {str(e)}")
|
||||||
|
|||||||
@@ -352,24 +352,39 @@ async def check_commit_range_recoverability(
|
|||||||
|
|
||||||
# Get all commits from target to HEAD
|
# Get all commits from target to HEAD
|
||||||
try:
|
try:
|
||||||
commits_result = await client.log_commits(
|
# Paginate through commits to find target
|
||||||
repository=lakefs_repo,
|
commit_list = []
|
||||||
ref=current_branch,
|
after = ""
|
||||||
amount=1000, # Should be enough for most cases
|
has_more = True
|
||||||
)
|
|
||||||
|
|
||||||
commit_list = commits_result.get("results", [])
|
|
||||||
|
|
||||||
# Find target commit in the list
|
|
||||||
target_index = None
|
target_index = None
|
||||||
for i, commit in enumerate(commit_list):
|
|
||||||
if commit["id"] == target_commit:
|
while has_more and target_index is None:
|
||||||
target_index = i
|
commits_result = await client.log_commits(
|
||||||
break
|
repository=lakefs_repo,
|
||||||
|
ref=current_branch,
|
||||||
|
amount=1000, # LakeFS maximum
|
||||||
|
after=after,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = commits_result.get("results", [])
|
||||||
|
current_batch_start = len(commit_list)
|
||||||
|
commit_list.extend(results)
|
||||||
|
|
||||||
|
# Find target commit in current batch
|
||||||
|
for i, commit in enumerate(results):
|
||||||
|
if commit["id"] == target_commit:
|
||||||
|
target_index = current_batch_start + i
|
||||||
|
break
|
||||||
|
|
||||||
|
# Check pagination
|
||||||
|
pagination = commits_result.get("pagination", {})
|
||||||
|
has_more = pagination.get("has_more", False)
|
||||||
|
if has_more:
|
||||||
|
after = pagination.get("next_offset", "")
|
||||||
|
|
||||||
if target_index is None:
|
if target_index is None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Target commit {target_commit[:8]} not found in branch history"
|
f"Target commit {target_commit[:8]} not found in branch history (checked {len(commit_list)} commits)"
|
||||||
)
|
)
|
||||||
return False, [], []
|
return False, [], []
|
||||||
|
|
||||||
@@ -457,13 +472,29 @@ async def sync_file_table_with_commit(
|
|||||||
commit_id = branch_info["commit_id"]
|
commit_id = branch_info["commit_id"]
|
||||||
|
|
||||||
# Get ALL objects at the commit (use commit ID to avoid staging issues)
|
# Get ALL objects at the commit (use commit ID to avoid staging issues)
|
||||||
list_result = await client.list_objects(
|
# LakeFS has max amount=1000, so we need to paginate
|
||||||
repository=lakefs_repo,
|
all_objects = []
|
||||||
ref=commit_id, # Use commit ID, not branch name!
|
after = ""
|
||||||
amount=10000, # Large enough for most repos
|
has_more = True
|
||||||
)
|
|
||||||
|
|
||||||
all_objects = list_result.get("results", [])
|
while has_more:
|
||||||
|
list_result = await client.list_objects(
|
||||||
|
repository=lakefs_repo,
|
||||||
|
ref=commit_id, # Use commit ID, not branch name!
|
||||||
|
amount=1000, # LakeFS maximum
|
||||||
|
after=after,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = list_result.get("results", [])
|
||||||
|
all_objects.extend(results)
|
||||||
|
|
||||||
|
# Check pagination
|
||||||
|
pagination = list_result.get("pagination", {})
|
||||||
|
has_more = pagination.get("has_more", False)
|
||||||
|
if has_more:
|
||||||
|
after = pagination.get("next_offset", "")
|
||||||
|
|
||||||
|
logger.info(f"Paginated through LakeFS, got {len(all_objects)} total objects")
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Syncing {len(all_objects)} file(s) from ref {ref} (commit {commit_id[:8]})"
|
f"Syncing {len(all_objects)} file(s) from ref {ref} (commit {commit_id[:8]})"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ class S3Config(BaseModel):
|
|||||||
bucket: str = "test-bucket"
|
bucket: str = "test-bucket"
|
||||||
region: str = "us-east-1"
|
region: str = "us-east-1"
|
||||||
force_path_style: bool = True
|
force_path_style: bool = True
|
||||||
|
signature_version: str = "s3v4" # s3v4 (R2, AWS S3) or s3v2 (MinIO)
|
||||||
|
|
||||||
|
|
||||||
class LakeFSConfig(BaseModel):
|
class LakeFSConfig(BaseModel):
|
||||||
|
|||||||
@@ -13,10 +13,16 @@ logger = get_logger("S3")
|
|||||||
|
|
||||||
|
|
||||||
def get_s3_client():
|
def get_s3_client():
|
||||||
"""Create configured S3 client with SigV4 signing.
|
"""Create configured S3 client with configurable signature version.
|
||||||
|
|
||||||
|
Signature versions:
|
||||||
|
- s3v4: AWS S3, Cloudflare R2 (default, more secure)
|
||||||
|
- s3v2: MinIO (legacy, required for some MinIO setups)
|
||||||
|
|
||||||
|
Set via KOHAKU_HUB_S3_SIGNATURE_VERSION environment variable.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Configured boto3 S3 client using Signature Version 4.
|
Configured boto3 S3 client.
|
||||||
"""
|
"""
|
||||||
# Build S3-specific config
|
# Build S3-specific config
|
||||||
s3_config = {}
|
s3_config = {}
|
||||||
@@ -29,13 +35,16 @@ def get_s3_client():
|
|||||||
if cfg.s3.endpoint and ("/" in cfg.s3.endpoint.split("//", 1)[1]):
|
if cfg.s3.endpoint and ("/" in cfg.s3.endpoint.split("//", 1)[1]):
|
||||||
# Endpoint has path - treat it as bucket endpoint
|
# Endpoint has path - treat it as bucket endpoint
|
||||||
s3_config["use_accelerate_endpoint"] = False
|
s3_config["use_accelerate_endpoint"] = False
|
||||||
logger.info(
|
logger.debug(
|
||||||
"S3 endpoint contains path - using bucket_endpoint mode for R2 compatibility"
|
"S3 endpoint contains path - using bucket_endpoint mode for R2 compatibility"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Always use Signature Version 4 (more secure than deprecated SigV2)
|
# Use configured signature version (s3v4 or s3v2)
|
||||||
|
sig_version = cfg.s3.signature_version
|
||||||
|
logger.debug(f"Using S3 signature version: {sig_version}")
|
||||||
|
|
||||||
boto_config = BotoConfig(
|
boto_config = BotoConfig(
|
||||||
signature_version="s3v4",
|
signature_version=sig_version,
|
||||||
s3=s3_config,
|
s3=s3_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user