mirror of
https://github.com/KohakuBlueleaf/KohakuHub.git
synced 2026-05-07 03:57:42 -05:00
Fix huggingface_hub repo settings compatibility
This commit is contained in:
@@ -296,6 +296,7 @@ class UpdateRepoSettingsPayload(BaseModel):
|
||||
"""Payload for repository settings update."""
|
||||
|
||||
private: Optional[bool] = None
|
||||
visibility: Optional[str] = None
|
||||
gated: Optional[str] = None # "auto", "manual", or False/None
|
||||
lfs_threshold_bytes: Optional[int] = None # NULL = use server default
|
||||
lfs_keep_versions: Optional[int] = None # NULL = use server default
|
||||
@@ -374,9 +375,26 @@ async def update_repo_settings(
|
||||
# Store as JSON string
|
||||
update_fields["lfs_suffix_rules"] = json.dumps(payload.lfs_suffix_rules)
|
||||
|
||||
if payload.private is not None:
|
||||
requested_private = payload.private
|
||||
if requested_private is None and payload.visibility is not None:
|
||||
if payload.visibility == "private":
|
||||
requested_private = True
|
||||
elif payload.visibility == "public":
|
||||
requested_private = False
|
||||
else:
|
||||
raise HTTPException(
|
||||
400,
|
||||
detail={
|
||||
"error": (
|
||||
"Unsupported repository visibility. "
|
||||
"Only 'public' and 'private' are supported."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
if requested_private is not None:
|
||||
# Check if visibility is actually changing
|
||||
if repo_row.private != payload.private:
|
||||
if repo_row.private != requested_private:
|
||||
# Calculate repository storage
|
||||
logger.info(
|
||||
f"Checking quota for visibility change: {repo_id} from "
|
||||
@@ -395,7 +413,7 @@ async def update_repo_settings(
|
||||
allowed, error_msg = check_quota(
|
||||
namespace=namespace,
|
||||
additional_bytes=repo_size,
|
||||
is_private=payload.private,
|
||||
is_private=requested_private,
|
||||
is_org=is_org,
|
||||
)
|
||||
|
||||
@@ -412,7 +430,7 @@ async def update_repo_settings(
|
||||
)
|
||||
|
||||
# Update repository visibility
|
||||
update_fields["private"] = payload.private
|
||||
update_fields["private"] = requested_private
|
||||
|
||||
# Apply all updates if there are any
|
||||
if update_fields:
|
||||
|
||||
@@ -24,6 +24,20 @@ async def _create_hf_token(client, name: str) -> str:
|
||||
return response.json()["token"]
|
||||
|
||||
|
||||
def _set_repo_private(api: HfApi, repo_id: str, private: bool) -> None:
|
||||
update_settings = getattr(api, "update_repo_settings", None)
|
||||
if callable(update_settings):
|
||||
update_settings(repo_id, private=private)
|
||||
return
|
||||
|
||||
update_visibility = getattr(api, "update_repo_visibility", None)
|
||||
if callable(update_visibility):
|
||||
update_visibility(repo_id, private=private)
|
||||
return
|
||||
|
||||
pytest.skip("huggingface_hub does not expose repository visibility updates")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def member_hf_api_token(member_client):
|
||||
return await _create_hf_token(member_client, "hf-api-member")
|
||||
@@ -358,10 +372,11 @@ async def test_hf_api_likes_visibility_move_delete_and_list_liked_repos(
|
||||
reliked = await asyncio.to_thread(lambda: api.list_liked_repos("owner"))
|
||||
assert "owner/demo-model" in reliked.models
|
||||
|
||||
update_visibility = getattr(api, "update_repo_visibility", None)
|
||||
if callable(update_visibility):
|
||||
if callable(getattr(api, "update_repo_settings", None)) or callable(
|
||||
getattr(api, "update_repo_visibility", None)
|
||||
):
|
||||
await asyncio.to_thread(
|
||||
lambda: update_visibility("owner/hf-lifecycle-compat", private=True)
|
||||
lambda: _set_repo_private(api, "owner/hf-lifecycle-compat", True)
|
||||
)
|
||||
private_info = await asyncio.to_thread(
|
||||
lambda: api.repo_info("owner/hf-lifecycle-compat")
|
||||
|
||||
@@ -42,6 +42,18 @@ async def test_update_repo_lfs_settings_and_read_effective_values(owner_client):
|
||||
assert ".gguf" in payload["lfs_suffix_rules_effective"]
|
||||
|
||||
|
||||
async def test_update_repo_visibility_accepts_huggingface_settings_payload(owner_client):
|
||||
update_response = await owner_client.put(
|
||||
"/api/models/owner/demo-model/settings",
|
||||
json={"visibility": "private"},
|
||||
)
|
||||
assert update_response.status_code == 200
|
||||
|
||||
info_response = await owner_client.get("/api/models/owner/demo-model")
|
||||
assert info_response.status_code == 200
|
||||
assert info_response.json()["private"] is True
|
||||
|
||||
|
||||
async def test_namespace_type_social_links_and_repo_setting_validation(owner_client):
|
||||
update_response = await owner_client.put(
|
||||
"/api/users/owner/settings",
|
||||
|
||||
@@ -304,6 +304,16 @@ async def test_repo_settings_and_lfs_settings_cover_validation_quota_and_default
|
||||
)
|
||||
assert bad_suffix.value.status_code == 400
|
||||
|
||||
with pytest.raises(HTTPException) as bad_visibility:
|
||||
await settings_api.update_repo_settings(
|
||||
"model",
|
||||
"alice",
|
||||
"demo",
|
||||
settings_api.UpdateRepoSettingsPayload(visibility="protected"),
|
||||
user=SimpleNamespace(username="alice"),
|
||||
)
|
||||
assert bad_visibility.value.status_code == 400
|
||||
|
||||
monkeypatch.setattr(
|
||||
settings_api,
|
||||
"calculate_repository_storage",
|
||||
@@ -359,6 +369,17 @@ async def test_repo_settings_and_lfs_settings_cover_validation_quota_and_default
|
||||
"private": True,
|
||||
}
|
||||
|
||||
repo_row.private = False
|
||||
updated_visibility = await settings_api.update_repo_settings(
|
||||
"model",
|
||||
"alice",
|
||||
"demo",
|
||||
settings_api.UpdateRepoSettingsPayload(visibility="public"),
|
||||
user=SimpleNamespace(username="alice"),
|
||||
)
|
||||
assert updated_visibility["success"] is True
|
||||
assert update_calls[-1] == {"private": False}
|
||||
|
||||
monkeypatch.setattr(settings_api, "get_repository", lambda repo_type, namespace, name: None)
|
||||
lfs_not_found = await settings_api.get_repo_lfs_settings(
|
||||
"model",
|
||||
|
||||
Reference in New Issue
Block a user