Files
KohakuHub/tests/base.py
Kohaku-Blueleaf f667b4b7fe fix import format
2025-10-22 03:27:58 +08:00

312 lines
9.1 KiB
Python

"""Base test classes and utilities for KohakuHub API tests."""
import os
import shutil
import tempfile
from pathlib import Path
from typing import Any
import requests
from huggingface_hub import HfApi
from tests.config import config
class HTTPClient:
"""Custom HTTP client for direct API testing.
This client is used to test API endpoints directly without the HuggingFace
client abstraction, ensuring our API matches the intended schema.
"""
def __init__(self, endpoint: str = None, token: str = None):
"""Initialize HTTP client.
Args:
endpoint: API endpoint URL
token: API token for authentication
"""
self.endpoint = endpoint or config.endpoint
self.token = token
self.session = requests.Session()
self.session.headers.update({"Content-Type": "application/json"})
if self.token:
self.session.headers.update({"Authorization": f"Bearer {self.token}"})
def request(
self,
method: str,
path: str,
json: dict = None,
data: Any = None,
headers: dict = None,
params: dict = None,
**kwargs,
) -> requests.Response:
"""Make HTTP request.
Args:
method: HTTP method (GET, POST, DELETE, etc.)
path: API path (will be joined with endpoint)
json: JSON payload
data: Raw data payload
headers: Additional headers
params: Query parameters
**kwargs: Additional requests arguments
Returns:
Response object
"""
url = f"{self.endpoint.rstrip('/')}/{path.lstrip('/')}"
request_headers = dict(self.session.headers)
if headers:
request_headers.update(headers)
return self.session.request(
method=method,
url=url,
json=json,
data=data,
headers=request_headers,
params=params,
timeout=config.timeout,
**kwargs,
)
def get(self, path: str, **kwargs) -> requests.Response:
"""GET request."""
return self.request("GET", path, **kwargs)
def post(self, path: str, **kwargs) -> requests.Response:
"""POST request."""
return self.request("POST", path, **kwargs)
def put(self, path: str, **kwargs) -> requests.Response:
"""PUT request."""
return self.request("PUT", path, **kwargs)
def delete(self, path: str, **kwargs) -> requests.Response:
"""DELETE request."""
return self.request("DELETE", path, **kwargs)
def head(self, path: str, **kwargs) -> requests.Response:
"""HEAD request."""
return self.request("HEAD", path, **kwargs)
class HFClientWrapper:
"""Wrapper around HuggingFace Hub client for testing.
This wrapper ensures we're using the correct endpoint and provides
utilities for testing HF API compatibility.
"""
def __init__(self, endpoint: str = None, token: str = None):
"""Initialize HF client wrapper.
Args:
endpoint: API endpoint URL
token: API token for authentication
"""
# Remove trailing slash from endpoint to avoid double slash in URLs
self.endpoint = (endpoint or config.endpoint).rstrip("/")
self.token = token
# Set environment variable for HuggingFace client
os.environ["HF_ENDPOINT"] = self.endpoint
if self.token:
os.environ["HF_TOKEN"] = self.token
self.api = HfApi(endpoint=self.endpoint, token=self.token)
def create_repo(
self, repo_id: str, repo_type: str = "model", private: bool = False
):
"""Create repository."""
return self.api.create_repo(
repo_id=repo_id, repo_type=repo_type, private=private
)
def delete_repo(self, repo_id: str, repo_type: str = "model"):
"""Delete repository."""
return self.api.delete_repo(repo_id=repo_id, repo_type=repo_type)
def repo_info(self, repo_id: str, repo_type: str = "model", revision: str = None):
"""Get repository info."""
return self.api.repo_info(
repo_id=repo_id, repo_type=repo_type, revision=revision
)
def list_repo_files(
self, repo_id: str, repo_type: str = "model", revision: str = None
):
"""List repository files."""
return self.api.list_repo_files(
repo_id=repo_id, repo_type=repo_type, revision=revision
)
def upload_file(
self,
path_or_fileobj,
path_in_repo: str,
repo_id: str,
repo_type: str = "model",
commit_message: str = None,
):
"""Upload single file."""
return self.api.upload_file(
path_or_fileobj=path_or_fileobj,
path_in_repo=path_in_repo,
repo_id=repo_id,
repo_type=repo_type,
commit_message=commit_message or "Upload file",
)
def upload_folder(
self,
folder_path: str,
path_in_repo: str,
repo_id: str,
repo_type: str = "model",
commit_message: str = None,
):
"""Upload folder."""
return self.api.upload_folder(
folder_path=folder_path,
path_in_repo=path_in_repo,
repo_id=repo_id,
repo_type=repo_type,
commit_message=commit_message or "Upload folder",
)
def download_file(
self,
repo_id: str,
filename: str,
repo_type: str = "model",
revision: str = None,
) -> str:
"""Download file and return local path."""
return self.api.hf_hub_download(
repo_id=repo_id, filename=filename, repo_type=repo_type, revision=revision
)
def delete_file(
self,
path_in_repo: str,
repo_id: str,
repo_type: str = "model",
commit_message: str = None,
):
"""Delete file."""
return self.api.delete_file(
path_in_repo=path_in_repo,
repo_id=repo_id,
repo_type=repo_type,
commit_message=commit_message or "Delete file",
)
class BaseTestCase:
"""Base test case with common utilities."""
@classmethod
def setup_class(cls):
"""Setup test class."""
cls.config = config
cls.http_client = HTTPClient()
cls.temp_dir = tempfile.mkdtemp(prefix="kohakuhub_test_")
@classmethod
def teardown_class(cls):
"""Cleanup test class."""
if hasattr(cls, "temp_dir") and Path(cls.temp_dir).exists():
shutil.rmtree(cls.temp_dir)
def create_temp_file(self, name: str, content: bytes) -> str:
"""Create temporary file for testing.
Args:
name: File name
content: File content
Returns:
Path to created file
"""
path = Path(self.temp_dir) / name
path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(content)
return str(path)
def create_random_file(self, name: str, size_mb: float) -> str:
"""Create temporary file with random content.
Args:
name: File name
size_mb: File size in megabytes
Returns:
Path to created file
"""
size_bytes = int(size_mb * 1000 * 1000)
content = os.urandom(size_bytes)
return self.create_temp_file(name, content)
def assert_response_ok(
self, response: requests.Response, expected_status: int = 200
):
"""Assert HTTP response is successful.
Args:
response: HTTP response
expected_status: Expected status code
"""
assert (
response.status_code == expected_status
), f"Expected {expected_status}, got {response.status_code}: {response.text}"
def assert_response_error(
self,
response: requests.Response,
expected_status: int,
expected_error_code: str = None,
):
"""Assert HTTP response is an error.
Args:
response: HTTP response
expected_status: Expected status code
expected_error_code: Expected HF error code (e.g., "RepoNotFound")
"""
assert (
response.status_code == expected_status
), f"Expected {expected_status}, got {response.status_code}"
if expected_error_code:
error_code = response.headers.get("X-Error-Code")
assert (
error_code == expected_error_code
), f"Expected error code {expected_error_code}, got {error_code}"
def get_test_repo_id(self, name: str) -> str:
"""Generate test repository ID.
Args:
name: Repository name
Returns:
Full repository ID with test prefix
"""
return f"{config.username}/{config.repo_prefix}-{name}"
def get_test_org_repo_id(self, name: str) -> str:
"""Generate test organization repository ID.
Args:
name: Repository name
Returns:
Full repository ID with organization
"""
return f"{config.org_name}/{config.repo_prefix}-{name}"