Files
KohakuHub/tests/base.py
2025-10-22 02:43:31 +08:00

313 lines
9.1 KiB
Python

"""Base test classes and utilities for KohakuHub API tests."""
import os
import tempfile
from pathlib import Path
from typing import Any
import requests
from huggingface_hub import HfApi
from tests.config import config
class HTTPClient:
"""Custom HTTP client for direct API testing.
This client is used to test API endpoints directly without the HuggingFace
client abstraction, ensuring our API matches the intended schema.
"""
def __init__(self, endpoint: str = None, token: str = None):
"""Initialize HTTP client.
Args:
endpoint: API endpoint URL
token: API token for authentication
"""
self.endpoint = endpoint or config.endpoint
self.token = token
self.session = requests.Session()
self.session.headers.update({"Content-Type": "application/json"})
if self.token:
self.session.headers.update({"Authorization": f"Bearer {self.token}"})
def request(
self,
method: str,
path: str,
json: dict = None,
data: Any = None,
headers: dict = None,
params: dict = None,
**kwargs,
) -> requests.Response:
"""Make HTTP request.
Args:
method: HTTP method (GET, POST, DELETE, etc.)
path: API path (will be joined with endpoint)
json: JSON payload
data: Raw data payload
headers: Additional headers
params: Query parameters
**kwargs: Additional requests arguments
Returns:
Response object
"""
url = f"{self.endpoint.rstrip('/')}/{path.lstrip('/')}"
request_headers = dict(self.session.headers)
if headers:
request_headers.update(headers)
return self.session.request(
method=method,
url=url,
json=json,
data=data,
headers=request_headers,
params=params,
timeout=config.timeout,
**kwargs,
)
def get(self, path: str, **kwargs) -> requests.Response:
"""GET request."""
return self.request("GET", path, **kwargs)
def post(self, path: str, **kwargs) -> requests.Response:
"""POST request."""
return self.request("POST", path, **kwargs)
def put(self, path: str, **kwargs) -> requests.Response:
"""PUT request."""
return self.request("PUT", path, **kwargs)
def delete(self, path: str, **kwargs) -> requests.Response:
"""DELETE request."""
return self.request("DELETE", path, **kwargs)
def head(self, path: str, **kwargs) -> requests.Response:
"""HEAD request."""
return self.request("HEAD", path, **kwargs)
class HFClientWrapper:
"""Wrapper around HuggingFace Hub client for testing.
This wrapper ensures we're using the correct endpoint and provides
utilities for testing HF API compatibility.
"""
def __init__(self, endpoint: str = None, token: str = None):
"""Initialize HF client wrapper.
Args:
endpoint: API endpoint URL
token: API token for authentication
"""
# Remove trailing slash from endpoint to avoid double slash in URLs
self.endpoint = (endpoint or config.endpoint).rstrip("/")
self.token = token
# Set environment variable for HuggingFace client
os.environ["HF_ENDPOINT"] = self.endpoint
if self.token:
os.environ["HF_TOKEN"] = self.token
self.api = HfApi(endpoint=self.endpoint, token=self.token)
def create_repo(
self, repo_id: str, repo_type: str = "model", private: bool = False
):
"""Create repository."""
return self.api.create_repo(
repo_id=repo_id, repo_type=repo_type, private=private
)
def delete_repo(self, repo_id: str, repo_type: str = "model"):
"""Delete repository."""
return self.api.delete_repo(repo_id=repo_id, repo_type=repo_type)
def repo_info(self, repo_id: str, repo_type: str = "model", revision: str = None):
"""Get repository info."""
return self.api.repo_info(
repo_id=repo_id, repo_type=repo_type, revision=revision
)
def list_repo_files(
self, repo_id: str, repo_type: str = "model", revision: str = None
):
"""List repository files."""
return self.api.list_repo_files(
repo_id=repo_id, repo_type=repo_type, revision=revision
)
def upload_file(
self,
path_or_fileobj,
path_in_repo: str,
repo_id: str,
repo_type: str = "model",
commit_message: str = None,
):
"""Upload single file."""
return self.api.upload_file(
path_or_fileobj=path_or_fileobj,
path_in_repo=path_in_repo,
repo_id=repo_id,
repo_type=repo_type,
commit_message=commit_message or "Upload file",
)
def upload_folder(
self,
folder_path: str,
path_in_repo: str,
repo_id: str,
repo_type: str = "model",
commit_message: str = None,
):
"""Upload folder."""
return self.api.upload_folder(
folder_path=folder_path,
path_in_repo=path_in_repo,
repo_id=repo_id,
repo_type=repo_type,
commit_message=commit_message or "Upload folder",
)
def download_file(
self,
repo_id: str,
filename: str,
repo_type: str = "model",
revision: str = None,
) -> str:
"""Download file and return local path."""
return self.api.hf_hub_download(
repo_id=repo_id, filename=filename, repo_type=repo_type, revision=revision
)
def delete_file(
self,
path_in_repo: str,
repo_id: str,
repo_type: str = "model",
commit_message: str = None,
):
"""Delete file."""
return self.api.delete_file(
path_in_repo=path_in_repo,
repo_id=repo_id,
repo_type=repo_type,
commit_message=commit_message or "Delete file",
)
class BaseTestCase:
"""Base test case with common utilities."""
@classmethod
def setup_class(cls):
"""Setup test class."""
cls.config = config
cls.http_client = HTTPClient()
cls.temp_dir = tempfile.mkdtemp(prefix="kohakuhub_test_")
@classmethod
def teardown_class(cls):
"""Cleanup test class."""
import shutil
if hasattr(cls, "temp_dir") and Path(cls.temp_dir).exists():
shutil.rmtree(cls.temp_dir)
def create_temp_file(self, name: str, content: bytes) -> str:
"""Create temporary file for testing.
Args:
name: File name
content: File content
Returns:
Path to created file
"""
path = Path(self.temp_dir) / name
path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(content)
return str(path)
def create_random_file(self, name: str, size_mb: float) -> str:
"""Create temporary file with random content.
Args:
name: File name
size_mb: File size in megabytes
Returns:
Path to created file
"""
size_bytes = int(size_mb * 1000 * 1000)
content = os.urandom(size_bytes)
return self.create_temp_file(name, content)
def assert_response_ok(
self, response: requests.Response, expected_status: int = 200
):
"""Assert HTTP response is successful.
Args:
response: HTTP response
expected_status: Expected status code
"""
assert (
response.status_code == expected_status
), f"Expected {expected_status}, got {response.status_code}: {response.text}"
def assert_response_error(
self,
response: requests.Response,
expected_status: int,
expected_error_code: str = None,
):
"""Assert HTTP response is an error.
Args:
response: HTTP response
expected_status: Expected status code
expected_error_code: Expected HF error code (e.g., "RepoNotFound")
"""
assert (
response.status_code == expected_status
), f"Expected {expected_status}, got {response.status_code}"
if expected_error_code:
error_code = response.headers.get("X-Error-Code")
assert (
error_code == expected_error_code
), f"Expected error code {expected_error_code}, got {error_code}"
def get_test_repo_id(self, name: str) -> str:
"""Generate test repository ID.
Args:
name: Repository name
Returns:
Full repository ID with test prefix
"""
return f"{config.username}/{config.repo_prefix}-{name}"
def get_test_org_repo_id(self, name: str) -> str:
"""Generate test organization repository ID.
Args:
name: Repository name
Returns:
Full repository ID with organization
"""
return f"{config.org_name}/{config.repo_prefix}-{name}"