diff --git a/src/kohakuhub/api/settings.py b/src/kohakuhub/api/settings.py index c453275..733c19f 100644 --- a/src/kohakuhub/api/settings.py +++ b/src/kohakuhub/api/settings.py @@ -296,6 +296,7 @@ class UpdateRepoSettingsPayload(BaseModel): """Payload for repository settings update.""" private: Optional[bool] = None + visibility: Optional[str] = None gated: Optional[str] = None # "auto", "manual", or False/None lfs_threshold_bytes: Optional[int] = None # NULL = use server default lfs_keep_versions: Optional[int] = None # NULL = use server default @@ -374,9 +375,26 @@ async def update_repo_settings( # Store as JSON string update_fields["lfs_suffix_rules"] = json.dumps(payload.lfs_suffix_rules) - if payload.private is not None: + requested_private = payload.private + if requested_private is None and payload.visibility is not None: + if payload.visibility == "private": + requested_private = True + elif payload.visibility == "public": + requested_private = False + else: + raise HTTPException( + 400, + detail={ + "error": ( + "Unsupported repository visibility. " + "Only 'public' and 'private' are supported." + ) + }, + ) + + if requested_private is not None: # Check if visibility is actually changing - if repo_row.private != payload.private: + if repo_row.private != requested_private: # Calculate repository storage logger.info( f"Checking quota for visibility change: {repo_id} from " @@ -395,7 +413,7 @@ async def update_repo_settings( allowed, error_msg = check_quota( namespace=namespace, additional_bytes=repo_size, - is_private=payload.private, + is_private=requested_private, is_org=is_org, ) @@ -412,7 +430,7 @@ async def update_repo_settings( ) # Update repository visibility - update_fields["private"] = payload.private + update_fields["private"] = requested_private # Apply all updates if there are any if update_fields: diff --git a/test/kohakuhub/api/test_huggingface_hub_compat.py b/test/kohakuhub/api/test_huggingface_hub_compat.py index f3e8b98..837e3bb 100644 --- a/test/kohakuhub/api/test_huggingface_hub_compat.py +++ b/test/kohakuhub/api/test_huggingface_hub_compat.py @@ -24,6 +24,20 @@ async def _create_hf_token(client, name: str) -> str: return response.json()["token"] +def _set_repo_private(api: HfApi, repo_id: str, private: bool) -> None: + update_settings = getattr(api, "update_repo_settings", None) + if callable(update_settings): + update_settings(repo_id, private=private) + return + + update_visibility = getattr(api, "update_repo_visibility", None) + if callable(update_visibility): + update_visibility(repo_id, private=private) + return + + pytest.skip("huggingface_hub does not expose repository visibility updates") + + @pytest.fixture async def member_hf_api_token(member_client): return await _create_hf_token(member_client, "hf-api-member") @@ -358,10 +372,11 @@ async def test_hf_api_likes_visibility_move_delete_and_list_liked_repos( reliked = await asyncio.to_thread(lambda: api.list_liked_repos("owner")) assert "owner/demo-model" in reliked.models - update_visibility = getattr(api, "update_repo_visibility", None) - if callable(update_visibility): + if callable(getattr(api, "update_repo_settings", None)) or callable( + getattr(api, "update_repo_visibility", None) + ): await asyncio.to_thread( - lambda: update_visibility("owner/hf-lifecycle-compat", private=True) + lambda: _set_repo_private(api, "owner/hf-lifecycle-compat", True) ) private_info = await asyncio.to_thread( lambda: api.repo_info("owner/hf-lifecycle-compat") diff --git a/test/kohakuhub/api/test_settings.py b/test/kohakuhub/api/test_settings.py index 61816ed..1b2893b 100644 --- a/test/kohakuhub/api/test_settings.py +++ b/test/kohakuhub/api/test_settings.py @@ -42,6 +42,18 @@ async def test_update_repo_lfs_settings_and_read_effective_values(owner_client): assert ".gguf" in payload["lfs_suffix_rules_effective"] +async def test_update_repo_visibility_accepts_huggingface_settings_payload(owner_client): + update_response = await owner_client.put( + "/api/models/owner/demo-model/settings", + json={"visibility": "private"}, + ) + assert update_response.status_code == 200 + + info_response = await owner_client.get("/api/models/owner/demo-model") + assert info_response.status_code == 200 + assert info_response.json()["private"] is True + + async def test_namespace_type_social_links_and_repo_setting_validation(owner_client): update_response = await owner_client.put( "/api/users/owner/settings", diff --git a/test/kohakuhub/api/test_settings_unit.py b/test/kohakuhub/api/test_settings_unit.py index 70b57df..7327f4a 100644 --- a/test/kohakuhub/api/test_settings_unit.py +++ b/test/kohakuhub/api/test_settings_unit.py @@ -304,6 +304,16 @@ async def test_repo_settings_and_lfs_settings_cover_validation_quota_and_default ) assert bad_suffix.value.status_code == 400 + with pytest.raises(HTTPException) as bad_visibility: + await settings_api.update_repo_settings( + "model", + "alice", + "demo", + settings_api.UpdateRepoSettingsPayload(visibility="protected"), + user=SimpleNamespace(username="alice"), + ) + assert bad_visibility.value.status_code == 400 + monkeypatch.setattr( settings_api, "calculate_repository_storage", @@ -359,6 +369,17 @@ async def test_repo_settings_and_lfs_settings_cover_validation_quota_and_default "private": True, } + repo_row.private = False + updated_visibility = await settings_api.update_repo_settings( + "model", + "alice", + "demo", + settings_api.UpdateRepoSettingsPayload(visibility="public"), + user=SimpleNamespace(username="alice"), + ) + assert updated_visibility["success"] is True + assert update_calls[-1] == {"private": False} + monkeypatch.setattr(settings_api, "get_repository", lambda repo_type, namespace, name: None) lfs_not_found = await settings_api.get_repo_lfs_settings( "model",