mirror of
https://github.com/KohakuBlueleaf/KohakuHub.git
synced 2026-03-12 01:45:35 -05:00
continue cleanup
This commit is contained in:
@@ -653,7 +653,7 @@ onMounted(async () => {
|
||||
- Search functionality
|
||||
- Git push support
|
||||
|
||||
**KohakuBoard** (Standalone Sub-Project):
|
||||
**KohakuBoard** (Standalone Repository - https://github.com/KohakuBlueleaf/KohakuBoard):
|
||||
- Remote server mode with authentication (WIP)
|
||||
- Sync protocol for uploading local boards (WIP)
|
||||
- Frontend UI improvements (WIP)
|
||||
@@ -696,7 +696,8 @@ We're especially looking for help in:
|
||||
- Advanced repository features
|
||||
- Search functionality
|
||||
|
||||
### 📊 KohakuBoard (Sub-Project)
|
||||
### 📊 KohakuBoard (Standalone Repository)
|
||||
- See https://github.com/KohakuBlueleaf/KohakuBoard for contributing to KohakuBoard
|
||||
- Remote server authentication system
|
||||
- Sync protocol implementation
|
||||
- Frontend chart improvements
|
||||
|
||||
@@ -43,13 +43,15 @@ Self-hosted HuggingFace alternative with Git-like versioning for AI models and d
|
||||
- **Trending & Likes** - Repository popularity tracking
|
||||
- **Pure Python Git Server** - No native dependencies, memory-efficient
|
||||
|
||||
### KohakuBoard (Experiment Tracking) - Standalone Sub-Project
|
||||
### KohakuBoard (Experiment Tracking) - Standalone Repository
|
||||
|
||||
**Repository:** https://github.com/KohakuBlueleaf/KohakuBoard
|
||||
|
||||
- **Non-Blocking Logging** - Background writer process, zero training overhead
|
||||
- **Rich Data Types** - Scalars, images, videos, tables, histograms
|
||||
- **Hybrid Storage** - Lance (columnar) + SQLite (row-oriented) for optimal performance
|
||||
- **Local-First** - View experiments locally with `kobo open`, no server required
|
||||
- See [src/kohakuboard/README.md](./src/kohakuboard/README.md) for details
|
||||
- See the KohakuBoard repository for full documentation
|
||||
|
||||
## Quick Start
|
||||
|
||||
|
||||
411
tests/README.md
411
tests/README.md
@@ -1,411 +0,0 @@
|
||||
# KohakuHub API Tests
|
||||
|
||||
Comprehensive test suite for KohakuHub API, validating HuggingFace Hub compatibility and custom endpoints.
|
||||
|
||||
## Test Strategy
|
||||
|
||||
This test suite uses a **dual-client approach** to ensure API correctness:
|
||||
|
||||
1. **HuggingFace Hub Client** (`huggingface_hub`): Tests HF API compatibility
|
||||
2. **Custom HTTP Client** (`requests`): Tests custom endpoints and validates API schema
|
||||
|
||||
### Why Dual Testing?
|
||||
|
||||
- **HF Client tests**: Ensure compatibility with existing HF ecosystem
|
||||
- **HTTP Client tests**: Validate custom endpoints and catch reverse-engineering errors
|
||||
|
||||
If HF client fails but HTTP client succeeds → Our reverse-engineering of HF API is wrong
|
||||
If both fail → Our implementation is broken
|
||||
If both succeed → ✅ Perfect compatibility
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### 1. Deploy KohakuHub Server
|
||||
|
||||
Tests require a running KohakuHub instance (via docker-compose):
|
||||
|
||||
```bash
|
||||
# From project root
|
||||
cp docker-compose.example.yml docker-compose.yml
|
||||
# Edit docker-compose.yml with your configuration
|
||||
|
||||
# Build and start
|
||||
npm install --prefix ./src/kohaku-hub-ui
|
||||
npm run build --prefix ./src/kohaku-hub-ui
|
||||
docker-compose up -d --build
|
||||
|
||||
# Verify server is running
|
||||
curl http://localhost:28080/api/version
|
||||
```
|
||||
|
||||
**Important**: Tests connect to `http://localhost:28080` (nginx port) by default.
|
||||
|
||||
### 2. Install Test Dependencies
|
||||
|
||||
```bash
|
||||
# Install test requirements
|
||||
pip install pytest pytest-xdist requests huggingface_hub
|
||||
|
||||
# Or from project root
|
||||
pip install -e ".[test]"
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
Tests are configured via environment variables or defaults in `config.py`:
|
||||
|
||||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `TEST_ENDPOINT` | `http://localhost:28080` | KohakuHub API endpoint (use nginx port!) |
|
||||
| `TEST_USERNAME` | `testuser` | Test user username |
|
||||
| `TEST_EMAIL` | `test@example.com` | Test user email |
|
||||
| `TEST_PASSWORD` | `testpass123` | Test user password |
|
||||
| `TEST_ORG_NAME` | `testorg` | Test organization name |
|
||||
| `TEST_REPO_PREFIX` | `test` | Prefix for test repositories |
|
||||
| `TEST_TIMEOUT` | `30` | HTTP request timeout (seconds) |
|
||||
| `TEST_CLEANUP` | `true` | Cleanup resources after tests |
|
||||
|
||||
### Example: Test Against Custom Endpoint
|
||||
|
||||
```bash
|
||||
export TEST_ENDPOINT=http://my-server.com:28080
|
||||
export TEST_USERNAME=myuser
|
||||
export TEST_PASSWORD=mypass
|
||||
pytest tests/
|
||||
```
|
||||
|
||||
## Running Tests
|
||||
|
||||
### Run All Tests
|
||||
|
||||
```bash
|
||||
# From project root
|
||||
pytest tests/
|
||||
|
||||
# With verbose output
|
||||
pytest tests/ -v
|
||||
|
||||
# With coverage
|
||||
pytest tests/ --cov=kohakuhub --cov-report=html
|
||||
```
|
||||
|
||||
### Run Specific Test Files
|
||||
|
||||
```bash
|
||||
# Authentication tests only
|
||||
pytest tests/test_auth.py -v
|
||||
|
||||
# Repository CRUD tests
|
||||
pytest tests/test_repo_crud.py -v
|
||||
|
||||
# File operations
|
||||
pytest tests/test_file_ops.py -v
|
||||
|
||||
# LFS operations
|
||||
pytest tests/test_lfs.py -v
|
||||
```
|
||||
|
||||
### Run Specific Tests
|
||||
|
||||
```bash
|
||||
# Run single test
|
||||
pytest tests/test_auth.py::TestAuthentication::test_version_check -v
|
||||
|
||||
# Run tests matching pattern
|
||||
pytest tests/ -k "upload" -v
|
||||
|
||||
# Run tests with specific marker
|
||||
pytest tests/ -m lfs -v
|
||||
```
|
||||
|
||||
### Test Markers
|
||||
|
||||
Tests are marked for easier filtering:
|
||||
|
||||
```bash
|
||||
# Run only LFS tests
|
||||
pytest tests/ -m lfs
|
||||
|
||||
# Skip slow tests
|
||||
pytest tests/ -m "not slow"
|
||||
|
||||
# Run tests for specific repo type
|
||||
pytest tests/test_repo_crud.py::test_create_different_repo_types[model] -v
|
||||
```
|
||||
|
||||
Available markers:
|
||||
- `lfs` - Tests requiring LFS (large files >10MB)
|
||||
- `slow` - Slow running tests (files >50MB)
|
||||
- `repo_type(type)` - Tests for specific repository types
|
||||
|
||||
### Parallel Execution
|
||||
|
||||
```bash
|
||||
# Run tests in parallel (4 workers)
|
||||
pytest tests/ -n 4
|
||||
|
||||
# Auto-detect CPU cores
|
||||
pytest tests/ -n auto
|
||||
```
|
||||
|
||||
**Note**: Some tests may not be thread-safe. Use with caution.
|
||||
|
||||
## Test Structure
|
||||
|
||||
```
|
||||
tests/
|
||||
├── __init__.py # Package init
|
||||
├── conftest.py # Pytest fixtures and configuration
|
||||
├── config.py # Test configuration
|
||||
├── base.py # Base test classes and utilities
|
||||
├── test_auth.py # Authentication tests
|
||||
├── test_repo_crud.py # Repository CRUD operations
|
||||
├── test_file_ops.py # File upload/download/delete
|
||||
├── test_lfs.py # Large file storage tests
|
||||
└── README.md # This file
|
||||
```
|
||||
|
||||
## Test Categories
|
||||
|
||||
### 1. Authentication Tests (`test_auth.py`)
|
||||
|
||||
- User registration
|
||||
- Login/logout flow
|
||||
- Session management
|
||||
- API token creation/deletion
|
||||
- Token-based authentication
|
||||
- `whoami` endpoint
|
||||
|
||||
### 2. Repository CRUD Tests (`test_repo_crud.py`)
|
||||
|
||||
- Repository creation (model, dataset, space)
|
||||
- Repository deletion
|
||||
- Repository listing
|
||||
- Repository info retrieval
|
||||
- Repository move/rename
|
||||
- Private repository handling
|
||||
- Duplicate detection
|
||||
|
||||
### 3. File Operations Tests (`test_file_ops.py`)
|
||||
|
||||
- Small file upload/download (<10MB)
|
||||
- Folder upload
|
||||
- File deletion
|
||||
- File listing (tree API)
|
||||
- File metadata (HEAD request)
|
||||
- Content integrity verification
|
||||
- Commit messages
|
||||
|
||||
### 4. LFS Tests (`test_lfs.py`)
|
||||
|
||||
- Large file upload (>10MB)
|
||||
- LFS batch API
|
||||
- File size threshold (10MB boundary)
|
||||
- LFS deduplication
|
||||
- Mixed file sizes
|
||||
- Content integrity for large files
|
||||
- LFS metadata in tree API
|
||||
|
||||
## Fixtures
|
||||
|
||||
### Session-Scoped Fixtures
|
||||
|
||||
Created once per test session:
|
||||
|
||||
- `test_config`: Test configuration
|
||||
- `http_client`: Unauthenticated HTTP client
|
||||
- `api_token`: API token for test user
|
||||
- `authenticated_http_client`: HTTP client with auth
|
||||
- `hf_client`: HuggingFace Hub client wrapper
|
||||
- `test_org`: Test organization
|
||||
|
||||
### Function-Scoped Fixtures
|
||||
|
||||
Created per test function:
|
||||
|
||||
- `temp_repo`: Temporary repository (auto-cleanup)
|
||||
|
||||
### Example Usage
|
||||
|
||||
```python
|
||||
def test_something(hf_client, temp_repo):
|
||||
"""Test with HF client and temporary repository."""
|
||||
repo_id, repo_type = temp_repo
|
||||
|
||||
# Upload file
|
||||
hf_client.upload_file(
|
||||
path_or_fileobj="test.txt",
|
||||
path_in_repo="test.txt",
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
)
|
||||
|
||||
# Repository will be cleaned up automatically
|
||||
```
|
||||
|
||||
## Debugging
|
||||
|
||||
### Enable Verbose Logging
|
||||
|
||||
```bash
|
||||
# Pytest verbose mode
|
||||
pytest tests/ -vv
|
||||
|
||||
# Show print statements
|
||||
pytest tests/ -s
|
||||
|
||||
# Show local variables on failure
|
||||
pytest tests/ -l
|
||||
```
|
||||
|
||||
### Run Failed Tests Only
|
||||
|
||||
```bash
|
||||
# Run last failed tests
|
||||
pytest tests/ --lf
|
||||
|
||||
# Run failed first, then others
|
||||
pytest tests/ --ff
|
||||
```
|
||||
|
||||
### Stop on First Failure
|
||||
|
||||
```bash
|
||||
pytest tests/ -x
|
||||
```
|
||||
|
||||
### Interactive Debugging
|
||||
|
||||
```bash
|
||||
# Drop into debugger on failure
|
||||
pytest tests/ --pdb
|
||||
|
||||
# Drop into debugger on first failure
|
||||
pytest tests/ -x --pdb
|
||||
```
|
||||
|
||||
## Common Issues
|
||||
|
||||
### 1. Connection Refused
|
||||
|
||||
**Problem**: `ConnectionRefusedError` or `Connection refused to localhost:28080`
|
||||
|
||||
**Solution**: Ensure KohakuHub is running:
|
||||
```bash
|
||||
docker-compose ps
|
||||
curl http://localhost:28080/api/version
|
||||
```
|
||||
|
||||
### 2. Wrong Port (48888)
|
||||
|
||||
**Problem**: Tests connecting to backend port instead of nginx
|
||||
|
||||
**Solution**: Use `TEST_ENDPOINT=http://localhost:28080` (nginx port)
|
||||
|
||||
### 3. Authentication Errors
|
||||
|
||||
**Problem**: `401 Unauthorized` errors
|
||||
|
||||
**Solution**: Check test user credentials or recreate test user:
|
||||
```bash
|
||||
# Delete old user from database if needed
|
||||
docker-compose exec hub-api python -c "from kohakuhub.db import User; User.delete().where(User.username == 'testuser').execute()"
|
||||
```
|
||||
|
||||
### 4. File Permission Errors
|
||||
|
||||
**Problem**: Cannot create temporary files
|
||||
|
||||
**Solution**: Check disk space and permissions for temp directory
|
||||
|
||||
### 5. LFS Upload Failures
|
||||
|
||||
**Problem**: Large file uploads timing out
|
||||
|
||||
**Solution**:
|
||||
- Increase `TEST_TIMEOUT` environment variable
|
||||
- Check MinIO is running: `docker-compose ps minio`
|
||||
- Verify S3 credentials in `docker-compose.yml`
|
||||
|
||||
## Cleanup
|
||||
|
||||
Tests automatically cleanup resources if `TEST_CLEANUP=true` (default):
|
||||
|
||||
- Temporary repositories are deleted
|
||||
- Temporary files are removed
|
||||
- API tokens are revoked (optional)
|
||||
|
||||
### Manual Cleanup
|
||||
|
||||
If tests fail and leave resources:
|
||||
|
||||
```bash
|
||||
# List test repositories
|
||||
curl http://localhost:28080/api/models?author=testuser
|
||||
|
||||
# Delete manually via API
|
||||
curl -X DELETE http://localhost:28080/api/repos/delete \
|
||||
-H "Authorization: Bearer YOUR_TOKEN" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"type": "model", "name": "test-repo-name"}'
|
||||
|
||||
# Or use kohub-cli
|
||||
kohub-cli repo delete testuser/test-repo-name --type model
|
||||
```
|
||||
|
||||
## Continuous Integration
|
||||
|
||||
### GitHub Actions Example
|
||||
|
||||
```yaml
|
||||
name: API Tests
|
||||
|
||||
on: [push, pull_request]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install pytest requests huggingface_hub
|
||||
|
||||
- name: Start KohakuHub
|
||||
run: |
|
||||
cp docker-compose.example.yml docker-compose.yml
|
||||
npm install --prefix ./src/kohaku-hub-ui
|
||||
npm run build --prefix ./src/kohaku-hub-ui
|
||||
docker-compose up -d --build
|
||||
sleep 30 # Wait for services to start
|
||||
|
||||
- name: Run tests
|
||||
run: pytest tests/ -v
|
||||
|
||||
- name: Stop services
|
||||
run: docker-compose down
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
When adding new tests:
|
||||
|
||||
1. Follow existing patterns in test files
|
||||
2. Use both HF client and HTTP client where applicable
|
||||
3. Add appropriate markers (`@pytest.mark.lfs`, etc.)
|
||||
4. Ensure cleanup in teardown or use fixtures
|
||||
5. Document test purpose in docstring
|
||||
6. Update this README if adding new test categories
|
||||
|
||||
## Support
|
||||
|
||||
- **Issues**: https://github.com/KohakuBlueleaf/KohakuHub/issues
|
||||
- **Discord**: https://discord.gg/xWYrkyvJ2s
|
||||
- **Documentation**: `../docs/`
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Unit tests for KohakuHub API.
|
||||
|
||||
This test suite validates HuggingFace API compatibility and custom endpoints
|
||||
using both huggingface_hub client and direct HTTP requests.
|
||||
"""
|
||||
311
tests/base.py
311
tests/base.py
@@ -1,311 +0,0 @@
|
||||
"""Base test classes and utilities for KohakuHub API tests."""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from tests.config import config
|
||||
|
||||
|
||||
class HTTPClient:
|
||||
"""Custom HTTP client for direct API testing.
|
||||
|
||||
This client is used to test API endpoints directly without the HuggingFace
|
||||
client abstraction, ensuring our API matches the intended schema.
|
||||
"""
|
||||
|
||||
def __init__(self, endpoint: str = None, token: str = None):
|
||||
"""Initialize HTTP client.
|
||||
|
||||
Args:
|
||||
endpoint: API endpoint URL
|
||||
token: API token for authentication
|
||||
"""
|
||||
self.endpoint = endpoint or config.endpoint
|
||||
self.token = token
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update({"Content-Type": "application/json"})
|
||||
if self.token:
|
||||
self.session.headers.update({"Authorization": f"Bearer {self.token}"})
|
||||
|
||||
def request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
json: dict = None,
|
||||
data: Any = None,
|
||||
headers: dict = None,
|
||||
params: dict = None,
|
||||
**kwargs,
|
||||
) -> requests.Response:
|
||||
"""Make HTTP request.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, DELETE, etc.)
|
||||
path: API path (will be joined with endpoint)
|
||||
json: JSON payload
|
||||
data: Raw data payload
|
||||
headers: Additional headers
|
||||
params: Query parameters
|
||||
**kwargs: Additional requests arguments
|
||||
|
||||
Returns:
|
||||
Response object
|
||||
"""
|
||||
url = f"{self.endpoint.rstrip('/')}/{path.lstrip('/')}"
|
||||
request_headers = dict(self.session.headers)
|
||||
if headers:
|
||||
request_headers.update(headers)
|
||||
|
||||
return self.session.request(
|
||||
method=method,
|
||||
url=url,
|
||||
json=json,
|
||||
data=data,
|
||||
headers=request_headers,
|
||||
params=params,
|
||||
timeout=config.timeout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get(self, path: str, **kwargs) -> requests.Response:
|
||||
"""GET request."""
|
||||
return self.request("GET", path, **kwargs)
|
||||
|
||||
def post(self, path: str, **kwargs) -> requests.Response:
|
||||
"""POST request."""
|
||||
return self.request("POST", path, **kwargs)
|
||||
|
||||
def put(self, path: str, **kwargs) -> requests.Response:
|
||||
"""PUT request."""
|
||||
return self.request("PUT", path, **kwargs)
|
||||
|
||||
def delete(self, path: str, **kwargs) -> requests.Response:
|
||||
"""DELETE request."""
|
||||
return self.request("DELETE", path, **kwargs)
|
||||
|
||||
def head(self, path: str, **kwargs) -> requests.Response:
|
||||
"""HEAD request."""
|
||||
return self.request("HEAD", path, **kwargs)
|
||||
|
||||
|
||||
class HFClientWrapper:
|
||||
"""Wrapper around HuggingFace Hub client for testing.
|
||||
|
||||
This wrapper ensures we're using the correct endpoint and provides
|
||||
utilities for testing HF API compatibility.
|
||||
"""
|
||||
|
||||
def __init__(self, endpoint: str = None, token: str = None):
|
||||
"""Initialize HF client wrapper.
|
||||
|
||||
Args:
|
||||
endpoint: API endpoint URL
|
||||
token: API token for authentication
|
||||
"""
|
||||
# Remove trailing slash from endpoint to avoid double slash in URLs
|
||||
self.endpoint = (endpoint or config.endpoint).rstrip("/")
|
||||
self.token = token
|
||||
|
||||
# Set environment variable for HuggingFace client
|
||||
os.environ["HF_ENDPOINT"] = self.endpoint
|
||||
if self.token:
|
||||
os.environ["HF_TOKEN"] = self.token
|
||||
|
||||
self.api = HfApi(endpoint=self.endpoint, token=self.token)
|
||||
|
||||
def create_repo(
|
||||
self, repo_id: str, repo_type: str = "model", private: bool = False
|
||||
):
|
||||
"""Create repository."""
|
||||
return self.api.create_repo(
|
||||
repo_id=repo_id, repo_type=repo_type, private=private
|
||||
)
|
||||
|
||||
def delete_repo(self, repo_id: str, repo_type: str = "model"):
|
||||
"""Delete repository."""
|
||||
return self.api.delete_repo(repo_id=repo_id, repo_type=repo_type)
|
||||
|
||||
def repo_info(self, repo_id: str, repo_type: str = "model", revision: str = None):
|
||||
"""Get repository info."""
|
||||
return self.api.repo_info(
|
||||
repo_id=repo_id, repo_type=repo_type, revision=revision
|
||||
)
|
||||
|
||||
def list_repo_files(
|
||||
self, repo_id: str, repo_type: str = "model", revision: str = None
|
||||
):
|
||||
"""List repository files."""
|
||||
return self.api.list_repo_files(
|
||||
repo_id=repo_id, repo_type=repo_type, revision=revision
|
||||
)
|
||||
|
||||
def upload_file(
|
||||
self,
|
||||
path_or_fileobj,
|
||||
path_in_repo: str,
|
||||
repo_id: str,
|
||||
repo_type: str = "model",
|
||||
commit_message: str = None,
|
||||
):
|
||||
"""Upload single file."""
|
||||
return self.api.upload_file(
|
||||
path_or_fileobj=path_or_fileobj,
|
||||
path_in_repo=path_in_repo,
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
commit_message=commit_message or "Upload file",
|
||||
)
|
||||
|
||||
def upload_folder(
|
||||
self,
|
||||
folder_path: str,
|
||||
path_in_repo: str,
|
||||
repo_id: str,
|
||||
repo_type: str = "model",
|
||||
commit_message: str = None,
|
||||
):
|
||||
"""Upload folder."""
|
||||
return self.api.upload_folder(
|
||||
folder_path=folder_path,
|
||||
path_in_repo=path_in_repo,
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
commit_message=commit_message or "Upload folder",
|
||||
)
|
||||
|
||||
def download_file(
|
||||
self,
|
||||
repo_id: str,
|
||||
filename: str,
|
||||
repo_type: str = "model",
|
||||
revision: str = None,
|
||||
) -> str:
|
||||
"""Download file and return local path."""
|
||||
return self.api.hf_hub_download(
|
||||
repo_id=repo_id, filename=filename, repo_type=repo_type, revision=revision
|
||||
)
|
||||
|
||||
def delete_file(
|
||||
self,
|
||||
path_in_repo: str,
|
||||
repo_id: str,
|
||||
repo_type: str = "model",
|
||||
commit_message: str = None,
|
||||
):
|
||||
"""Delete file."""
|
||||
return self.api.delete_file(
|
||||
path_in_repo=path_in_repo,
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
commit_message=commit_message or "Delete file",
|
||||
)
|
||||
|
||||
|
||||
class BaseTestCase:
|
||||
"""Base test case with common utilities."""
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
"""Setup test class."""
|
||||
cls.config = config
|
||||
cls.http_client = HTTPClient()
|
||||
cls.temp_dir = tempfile.mkdtemp(prefix="kohakuhub_test_")
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
"""Cleanup test class."""
|
||||
if hasattr(cls, "temp_dir") and Path(cls.temp_dir).exists():
|
||||
shutil.rmtree(cls.temp_dir)
|
||||
|
||||
def create_temp_file(self, name: str, content: bytes) -> str:
|
||||
"""Create temporary file for testing.
|
||||
|
||||
Args:
|
||||
name: File name
|
||||
content: File content
|
||||
|
||||
Returns:
|
||||
Path to created file
|
||||
"""
|
||||
path = Path(self.temp_dir) / name
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_bytes(content)
|
||||
return str(path)
|
||||
|
||||
def create_random_file(self, name: str, size_mb: float) -> str:
|
||||
"""Create temporary file with random content.
|
||||
|
||||
Args:
|
||||
name: File name
|
||||
size_mb: File size in megabytes
|
||||
|
||||
Returns:
|
||||
Path to created file
|
||||
"""
|
||||
size_bytes = int(size_mb * 1000 * 1000)
|
||||
content = os.urandom(size_bytes)
|
||||
return self.create_temp_file(name, content)
|
||||
|
||||
def assert_response_ok(
|
||||
self, response: requests.Response, expected_status: int = 200
|
||||
):
|
||||
"""Assert HTTP response is successful.
|
||||
|
||||
Args:
|
||||
response: HTTP response
|
||||
expected_status: Expected status code
|
||||
"""
|
||||
assert (
|
||||
response.status_code == expected_status
|
||||
), f"Expected {expected_status}, got {response.status_code}: {response.text}"
|
||||
|
||||
def assert_response_error(
|
||||
self,
|
||||
response: requests.Response,
|
||||
expected_status: int,
|
||||
expected_error_code: str = None,
|
||||
):
|
||||
"""Assert HTTP response is an error.
|
||||
|
||||
Args:
|
||||
response: HTTP response
|
||||
expected_status: Expected status code
|
||||
expected_error_code: Expected HF error code (e.g., "RepoNotFound")
|
||||
"""
|
||||
assert (
|
||||
response.status_code == expected_status
|
||||
), f"Expected {expected_status}, got {response.status_code}"
|
||||
|
||||
if expected_error_code:
|
||||
error_code = response.headers.get("X-Error-Code")
|
||||
assert (
|
||||
error_code == expected_error_code
|
||||
), f"Expected error code {expected_error_code}, got {error_code}"
|
||||
|
||||
def get_test_repo_id(self, name: str) -> str:
|
||||
"""Generate test repository ID.
|
||||
|
||||
Args:
|
||||
name: Repository name
|
||||
|
||||
Returns:
|
||||
Full repository ID with test prefix
|
||||
"""
|
||||
return f"{config.username}/{config.repo_prefix}-{name}"
|
||||
|
||||
def get_test_org_repo_id(self, name: str) -> str:
|
||||
"""Generate test organization repository ID.
|
||||
|
||||
Args:
|
||||
name: Repository name
|
||||
|
||||
Returns:
|
||||
Full repository ID with organization
|
||||
"""
|
||||
return f"{config.org_name}/{config.repo_prefix}-{name}"
|
||||
@@ -1,48 +0,0 @@
|
||||
"""Test configuration for KohakuHub API tests.
|
||||
|
||||
Configuration can be set via environment variables:
|
||||
- TEST_ENDPOINT: API endpoint URL (default: http://localhost:28080)
|
||||
- TEST_USERNAME: Test user username (default: testuser)
|
||||
- TEST_EMAIL: Test user email (default: test@example.com)
|
||||
- TEST_PASSWORD: Test user password (default: testpass123)
|
||||
- TEST_ORG_NAME: Test organization name (default: testorg)
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestConfig:
|
||||
"""Test configuration."""
|
||||
|
||||
# API endpoint (should be the nginx port, not backend port)
|
||||
endpoint: str = os.getenv("TEST_ENDPOINT", "http://localhost:28080")
|
||||
|
||||
# Test user credentials (for session-scoped shared user)
|
||||
username: str = os.getenv("TEST_USERNAME", "testuser")
|
||||
email: str = os.getenv("TEST_EMAIL", "test@example.com")
|
||||
password: str = os.getenv("TEST_PASSWORD", "testpass123")
|
||||
|
||||
# Test organization (lowercase, matches ^[a-z0-9][a-z0-9-]{2,62}$)
|
||||
org_name: str = os.getenv("TEST_ORG_NAME", "testorg")
|
||||
|
||||
# Test repository prefix (to avoid conflicts, lowercase)
|
||||
repo_prefix: str = os.getenv("TEST_REPO_PREFIX", "tst")
|
||||
|
||||
# Timeout for HTTP requests
|
||||
timeout: int = int(os.getenv("TEST_TIMEOUT", "30"))
|
||||
|
||||
# Cleanup after tests
|
||||
cleanup_on_success: bool = os.getenv("TEST_CLEANUP", "true").lower() == "true"
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate configuration."""
|
||||
if not self.endpoint:
|
||||
raise ValueError("TEST_ENDPOINT must be set")
|
||||
if not self.username or not self.password:
|
||||
raise ValueError("TEST_USERNAME and TEST_PASSWORD must be set")
|
||||
|
||||
|
||||
# Global config instance
|
||||
config = TestConfig()
|
||||
@@ -1,247 +0,0 @@
|
||||
"""Pytest configuration and fixtures for KohakuHub API tests."""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import requests
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from tests.base import HFClientWrapper, HTTPClient
|
||||
from tests.config import config
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def test_config():
|
||||
"""Test configuration fixture."""
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def http_client():
|
||||
"""HTTP client fixture for direct API testing (unauthenticated).
|
||||
|
||||
Function-scoped to prevent session contamination between tests.
|
||||
"""
|
||||
return HTTPClient()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def api_token():
|
||||
"""Create and return API token for the test user.
|
||||
|
||||
This fixture:
|
||||
1. Registers a test user (if not exists)
|
||||
2. Logs in to get session
|
||||
3. Creates an API token
|
||||
4. Returns the token for use in tests
|
||||
"""
|
||||
from kohakuhub.auth.routes import RegisterRequest, LoginRequest, CreateTokenRequest
|
||||
|
||||
# Create a dedicated HTTP client for this session setup
|
||||
setup_client = HTTPClient()
|
||||
|
||||
# Try to register (will fail if user exists, which is fine)
|
||||
try:
|
||||
payload = RegisterRequest(
|
||||
username=config.username, email=config.email, password=config.password
|
||||
)
|
||||
resp = setup_client.post("/api/auth/register", json=payload.model_dump())
|
||||
if resp.status_code == 200:
|
||||
print(f"✓ Registered test user: {config.username}")
|
||||
except Exception as e:
|
||||
print(f"User registration skipped (may already exist): {e}")
|
||||
|
||||
# Login to get session
|
||||
login_payload = LoginRequest(username=config.username, password=config.password)
|
||||
resp = setup_client.post("/api/auth/login", json=login_payload.model_dump())
|
||||
assert resp.status_code == 200, f"Login failed: {resp.text}"
|
||||
|
||||
# Update session cookies
|
||||
setup_client.session.cookies.update(resp.cookies)
|
||||
|
||||
# Create API token
|
||||
import uuid
|
||||
|
||||
token_id = uuid.uuid4().hex[:6]
|
||||
token_payload = CreateTokenRequest(name=f"tok-{token_id}")
|
||||
resp = setup_client.post("/api/auth/tokens/create", json=token_payload.model_dump())
|
||||
assert resp.status_code == 200, f"Token creation failed: {resp.text}"
|
||||
|
||||
token_data = resp.json()
|
||||
token = token_data["token"]
|
||||
print(f"✓ Created API token for testing")
|
||||
|
||||
yield token
|
||||
|
||||
# Cleanup: revoke token (optional)
|
||||
if config.cleanup_on_success:
|
||||
try:
|
||||
token_id = token_data.get("token_id")
|
||||
if token_id:
|
||||
setup_client.delete(f"/api/auth/tokens/{token_id}")
|
||||
print(f"✓ Revoked test token")
|
||||
except Exception as e:
|
||||
print(f"Token cleanup failed: {e}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def authenticated_http_client(api_token):
|
||||
"""HTTP client with authentication."""
|
||||
return HTTPClient(token=api_token)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def hf_client(api_token):
|
||||
"""HuggingFace Hub client fixture."""
|
||||
return HFClientWrapper(token=api_token)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def random_user():
|
||||
"""Create a random user for testing.
|
||||
|
||||
Returns:
|
||||
Tuple of (username, token, hf_client_wrapper)
|
||||
|
||||
Each test gets a fresh user to avoid conflicts.
|
||||
"""
|
||||
import uuid
|
||||
from kohakuhub.auth.routes import RegisterRequest, LoginRequest, CreateTokenRequest
|
||||
|
||||
# Generate short random ID (6 chars, lowercase hex)
|
||||
unique_id = uuid.uuid4().hex[:6]
|
||||
username = f"user-{unique_id}" # Matches ^[a-z0-9][a-z0-9-]{2,62}$
|
||||
email = f"test-{unique_id}@example.com"
|
||||
password = "testpass123"
|
||||
|
||||
# Create HTTP client for setup
|
||||
setup_client = HTTPClient()
|
||||
|
||||
# Register user
|
||||
payload = RegisterRequest(username=username, email=email, password=password)
|
||||
resp = setup_client.post("/api/auth/register", json=payload.model_dump())
|
||||
assert resp.status_code == 200, f"Registration failed: {resp.text}"
|
||||
|
||||
# Login
|
||||
login_payload = LoginRequest(username=username, password=password)
|
||||
resp = setup_client.post("/api/auth/login", json=login_payload.model_dump())
|
||||
assert resp.status_code == 200, f"Login failed: {resp.text}"
|
||||
|
||||
# Update session cookies
|
||||
setup_client.session.cookies.update(resp.cookies)
|
||||
|
||||
# Create API token
|
||||
token_payload = CreateTokenRequest(name=f"tok-{unique_id}")
|
||||
resp = setup_client.post("/api/auth/tokens/create", json=token_payload.model_dump())
|
||||
assert resp.status_code == 200, f"Token creation failed: {resp.text}"
|
||||
|
||||
token = resp.json()["token"]
|
||||
|
||||
# Create HF client for this user
|
||||
user_hf_client = HFClientWrapper(token=token)
|
||||
|
||||
yield username, token, user_hf_client
|
||||
|
||||
# Cleanup happens automatically when user is deleted (if needed)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def temp_repo(random_user, request):
|
||||
"""Create temporary repository for testing.
|
||||
|
||||
Usage:
|
||||
def test_something(temp_repo):
|
||||
repo_id, repo_type, hf_client = temp_repo
|
||||
# repo will be cleaned up after test
|
||||
"""
|
||||
username, token, hf_client = random_user
|
||||
|
||||
# Get repo type from test marker or default to "model"
|
||||
marker = request.node.get_closest_marker("repo_type")
|
||||
repo_type = marker.args[0] if marker else "model"
|
||||
|
||||
# Generate unique repo name (short, lowercase)
|
||||
import uuid
|
||||
|
||||
unique_id = uuid.uuid4().hex[:6]
|
||||
repo_name = f"repo-{unique_id}" # Matches ^[a-z0-9][a-z0-9-]{2,62}$
|
||||
repo_id = f"{username}/{repo_name}"
|
||||
|
||||
# Create repository
|
||||
try:
|
||||
hf_client.create_repo(repo_id=repo_id, repo_type=repo_type, private=False)
|
||||
print(f"✓ Created test repo: {repo_id}")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Failed to create test repo: {e}")
|
||||
|
||||
yield repo_id, repo_type, hf_client
|
||||
|
||||
# Cleanup
|
||||
if config.cleanup_on_success:
|
||||
try:
|
||||
hf_client.delete_repo(repo_id=repo_id, repo_type=repo_type)
|
||||
print(f"✓ Cleaned up test repo: {repo_id}")
|
||||
except Exception as e:
|
||||
print(f"Repo cleanup failed: {e}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def test_org(authenticated_http_client):
|
||||
"""Create test organization.
|
||||
|
||||
This fixture creates an organization for testing org-related features.
|
||||
"""
|
||||
from kohakuhub.api.org.router import CreateOrganizationPayload
|
||||
|
||||
client = authenticated_http_client
|
||||
|
||||
# Try to create organization using actual model
|
||||
try:
|
||||
payload = CreateOrganizationPayload(
|
||||
name=config.org_name, description="Organization for testing"
|
||||
)
|
||||
resp = client.post("/org/create", json=payload.model_dump())
|
||||
if resp.status_code == 200:
|
||||
print(f"✓ Created test organization: {config.org_name}")
|
||||
elif resp.status_code == 400 and "exists" in resp.text.lower():
|
||||
print(f"✓ Test organization already exists: {config.org_name}")
|
||||
else:
|
||||
print(f"Warning: Org creation returned {resp.status_code}: {resp.text}")
|
||||
except Exception as e:
|
||||
print(f"Org creation skipped: {e}")
|
||||
|
||||
yield config.org_name
|
||||
|
||||
# Cleanup: delete organization (if API supports it)
|
||||
# Note: Current API may not support org deletion
|
||||
if config.cleanup_on_success:
|
||||
try:
|
||||
# Organization deletion endpoint may not exist yet
|
||||
# resp = client.delete(f"/org/{config.org_name}")
|
||||
# print(f"✓ Deleted test organization")
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"Org cleanup skipped: {e}")
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
"""Pytest configuration hook."""
|
||||
# Add custom markers
|
||||
config.addinivalue_line(
|
||||
"markers", "repo_type(type): mark test to use specific repo type"
|
||||
)
|
||||
config.addinivalue_line("markers", "slow: mark test as slow running")
|
||||
config.addinivalue_line("markers", "lfs: mark test as requiring LFS")
|
||||
|
||||
# Print test configuration
|
||||
print("\n" + "=" * 70)
|
||||
print("KohakuHub API Test Configuration")
|
||||
print("=" * 70)
|
||||
print(f"Endpoint: {test_config.endpoint}")
|
||||
print(f"Username: {test_config.username}")
|
||||
print(f"Org Name: {test_config.org_name}")
|
||||
print(f"Cleanup: {test_config.cleanup_on_success}")
|
||||
print("=" * 70 + "\n")
|
||||
|
||||
|
||||
# Get test_config for pytest_configure
|
||||
test_config = config
|
||||
@@ -1,14 +0,0 @@
|
||||
# Test requirements for KohakuHub API tests
|
||||
|
||||
# Core testing framework
|
||||
pytest>=7.0.0
|
||||
pytest-cov>=4.0.0
|
||||
pytest-xdist>=3.0.0
|
||||
pytest-timeout>=2.1.0
|
||||
|
||||
# HTTP client and HuggingFace Hub client
|
||||
requests>=2.31.0
|
||||
huggingface_hub>=0.20.0
|
||||
|
||||
# Additional utilities
|
||||
pytest-mock>=3.12.0
|
||||
@@ -1,200 +0,0 @@
|
||||
"""Authentication API tests.
|
||||
|
||||
Tests user registration, login, logout, token management, and whoami endpoints.
|
||||
Uses actual Pydantic models from the source code.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
from tests.base import HTTPClient
|
||||
from tests.config import config
|
||||
from kohakuhub.auth.routes import CreateTokenRequest, LoginRequest, RegisterRequest
|
||||
from kohakuhub.api.repo.routers.crud import CreateRepoPayload
|
||||
|
||||
|
||||
class TestAuthentication:
|
||||
"""Test authentication endpoints."""
|
||||
|
||||
def test_version_check(self, http_client):
|
||||
"""Test API version endpoint."""
|
||||
resp = http_client.get("/api/version")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "api" in data
|
||||
assert data["api"] == "kohakuhub"
|
||||
assert "version" in data
|
||||
|
||||
def test_register_login_logout_flow(self, http_client):
|
||||
"""Test complete user registration, login, and logout flow."""
|
||||
# Use unique username for this test
|
||||
unique_id = uuid.uuid4().hex[:6]
|
||||
test_username = f"user-{unique_id}" # Matches ^[a-z0-9][a-z0-9-]{2,62}$
|
||||
test_email = f"test-{unique_id}@example.com"
|
||||
test_password = "testpass123"
|
||||
|
||||
# 1. Register new user using actual model
|
||||
payload = RegisterRequest(
|
||||
username=test_username, email=test_email, password=test_password
|
||||
)
|
||||
|
||||
resp = http_client.post("/api/auth/register", json=payload.model_dump())
|
||||
assert resp.status_code == 200, f"Registration failed: {resp.text}"
|
||||
data = resp.json()
|
||||
assert data["success"] == True
|
||||
|
||||
# 2. Login using actual model
|
||||
login_payload = LoginRequest(username=test_username, password=test_password)
|
||||
|
||||
resp = http_client.post("/api/auth/login", json=login_payload.model_dump())
|
||||
assert resp.status_code == 200, f"Login failed: {resp.text}"
|
||||
data = resp.json()
|
||||
assert "username" in data
|
||||
assert data["username"] == test_username
|
||||
|
||||
# Save session
|
||||
session_cookies = resp.cookies
|
||||
|
||||
# 3. Get current user (with session)
|
||||
client_with_session = HTTPClient()
|
||||
client_with_session.session.cookies.update(session_cookies)
|
||||
|
||||
resp = client_with_session.get("/api/auth/me")
|
||||
assert resp.status_code == 200, f"Get user failed: {resp.text}"
|
||||
data = resp.json()
|
||||
assert data["username"] == test_username
|
||||
|
||||
# 4. Logout
|
||||
resp = client_with_session.post("/api/auth/logout")
|
||||
assert resp.status_code == 200
|
||||
|
||||
# 5. Verify session is cleared
|
||||
resp = client_with_session.get("/api/auth/me")
|
||||
assert resp.status_code == 401, "Session should be cleared after logout"
|
||||
|
||||
def test_token_creation_and_usage(self, authenticated_http_client):
|
||||
"""Test API token creation and usage."""
|
||||
# 1. Create token using actual model
|
||||
unique_id = uuid.uuid4().hex[:6]
|
||||
token_payload = CreateTokenRequest(name=f"token-{unique_id}")
|
||||
|
||||
resp = authenticated_http_client.post(
|
||||
"/api/auth/tokens/create", json=token_payload.model_dump()
|
||||
)
|
||||
assert resp.status_code == 200, f"Token creation failed: {resp.text}"
|
||||
data = resp.json()
|
||||
assert "token" in data
|
||||
assert "token_id" in data
|
||||
token = data["token"]
|
||||
token_id = data["token_id"]
|
||||
|
||||
# 2. Use token for authentication
|
||||
client_with_token = HTTPClient(token=token)
|
||||
resp = client_with_token.get("/api/auth/me")
|
||||
assert resp.status_code == 200, f"Token auth failed: {resp.text}"
|
||||
user_data = resp.json()
|
||||
assert user_data["username"] == config.username
|
||||
|
||||
# 3. List tokens
|
||||
resp = authenticated_http_client.get("/api/auth/tokens")
|
||||
assert resp.status_code == 200
|
||||
tokens_response = resp.json()
|
||||
assert "tokens" in tokens_response
|
||||
tokens = tokens_response["tokens"]
|
||||
assert isinstance(tokens, list)
|
||||
# Our token should be in the list
|
||||
token_names = [t["name"] for t in tokens]
|
||||
assert token_payload.name in token_names
|
||||
|
||||
# 4. Revoke token
|
||||
resp = authenticated_http_client.delete(f"/api/auth/tokens/{token_id}")
|
||||
assert resp.status_code == 200
|
||||
|
||||
# 5. Verify token is revoked
|
||||
resp = client_with_token.get("/api/auth/me")
|
||||
assert resp.status_code == 401, "Token should be revoked"
|
||||
|
||||
def test_whoami_endpoint(self, authenticated_http_client):
|
||||
"""Test whoami-v2 endpoint for detailed user info."""
|
||||
resp = authenticated_http_client.get("/api/whoami-v2")
|
||||
assert resp.status_code == 200, f"whoami-v2 failed: {resp.text}"
|
||||
data = resp.json()
|
||||
|
||||
# Validate response structure (based on actual implementation in misc.py)
|
||||
assert "type" in data
|
||||
assert data["type"] == "user"
|
||||
assert "name" in data
|
||||
assert data["name"] == config.username
|
||||
assert "email" in data
|
||||
assert "orgs" in data
|
||||
assert isinstance(data["orgs"], list)
|
||||
assert "emailVerified" in data
|
||||
assert "auth" in data
|
||||
|
||||
def test_invalid_credentials(self, http_client):
|
||||
"""Test login with invalid credentials."""
|
||||
payload = LoginRequest(username="nonexistent", password="wrongpass")
|
||||
|
||||
resp = http_client.post("/api/auth/login", json=payload.model_dump())
|
||||
assert resp.status_code == 401, "Should reject invalid credentials"
|
||||
|
||||
def test_unauthenticated_access(self, http_client):
|
||||
"""Test accessing protected endpoints without authentication."""
|
||||
# Try to access protected endpoint
|
||||
resp = http_client.get("/api/auth/me")
|
||||
assert resp.status_code == 401, "Should require authentication"
|
||||
|
||||
# Try to create repo without auth
|
||||
payload = CreateRepoPayload(type="model", name="test-repo")
|
||||
|
||||
resp = http_client.post("/api/repos/create", json=payload.model_dump())
|
||||
assert resp.status_code in [401, 403], "Should require authentication"
|
||||
|
||||
def test_duplicate_registration(self, http_client):
|
||||
"""Test that duplicate usernames are rejected."""
|
||||
unique_id = uuid.uuid4().hex[:6]
|
||||
test_username = f"dup-{unique_id}"
|
||||
test_email = f"dup-{unique_id}@example.com"
|
||||
|
||||
payload = RegisterRequest(
|
||||
username=test_username, email=test_email, password="testpass123"
|
||||
)
|
||||
|
||||
# First registration
|
||||
resp = http_client.post("/api/auth/register", json=payload.model_dump())
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Second registration with same username (different email)
|
||||
payload2 = RegisterRequest(
|
||||
username=test_username,
|
||||
email=f"different_{unique_id}@example.com",
|
||||
password="testpass123",
|
||||
)
|
||||
|
||||
resp = http_client.post("/api/auth/register", json=payload2.model_dump())
|
||||
assert resp.status_code == 400
|
||||
assert "username" in resp.text.lower() or "exist" in resp.text.lower()
|
||||
|
||||
def test_duplicate_email_registration(self, http_client):
|
||||
"""Test that duplicate emails are rejected."""
|
||||
unique_id = uuid.uuid4().hex[:6]
|
||||
test_username = f"email-{unique_id}"
|
||||
test_email = f"email-{unique_id}@example.com"
|
||||
|
||||
payload = RegisterRequest(
|
||||
username=test_username, email=test_email, password="testpass123"
|
||||
)
|
||||
|
||||
# First registration
|
||||
resp = http_client.post("/api/auth/register", json=payload.model_dump())
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Second registration with same email (different username)
|
||||
payload2 = RegisterRequest(
|
||||
username=f"different_{unique_id}",
|
||||
email=test_email,
|
||||
password="testpass123",
|
||||
)
|
||||
|
||||
resp = http_client.post("/api/auth/register", json=payload2.model_dump())
|
||||
assert resp.status_code == 400
|
||||
assert "email" in resp.text.lower() or "exist" in resp.text.lower()
|
||||
@@ -1,403 +0,0 @@
|
||||
"""Branch operation tests - revert, reset, and merge with LFS GC.
|
||||
|
||||
Tests the complete workflow of:
|
||||
- Revert operations (latest, non-conflicting)
|
||||
- Reset operations (creating new commits, LFS checks)
|
||||
- Branch creation and merging
|
||||
- LFS recoverability checks
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.base import HTTPClient
|
||||
|
||||
|
||||
class TestBranchRevert:
|
||||
"""Test revert functionality."""
|
||||
|
||||
def test_revert_latest_commit(self, temp_repo):
|
||||
"""Revert the latest commit (should succeed - no conflicts)."""
|
||||
repo_id, repo_type, hf_client = temp_repo
|
||||
|
||||
# Create files with random content
|
||||
test_content = os.urandom(100) # Random bytes
|
||||
lfs_content = os.urandom(2000000) # 2MB random LFS file
|
||||
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
(temp_dir / "revert-test.txt").write_bytes(test_content)
|
||||
(temp_dir / "revert-test-lfs.bin").write_bytes(lfs_content)
|
||||
|
||||
hf_client.upload_folder(
|
||||
folder_path=str(temp_dir),
|
||||
path_in_repo="", # Upload to root
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
commit_message="Add revert test files",
|
||||
)
|
||||
|
||||
# Verify files exist
|
||||
files = hf_client.list_repo_files(repo_id=repo_id, repo_type=repo_type)
|
||||
assert "revert-test.txt" in files
|
||||
assert "revert-test-lfs.bin" in files
|
||||
|
||||
# Get latest commit
|
||||
http_client = HTTPClient()
|
||||
resp = http_client.get(f"/api/{repo_type}s/{repo_id}/commits/main")
|
||||
assert resp.status_code == 200
|
||||
commits_data = resp.json()
|
||||
latest_commit = commits_data["commits"][0]["id"]
|
||||
|
||||
# Revert the commit (need authenticated client)
|
||||
http_auth = HTTPClient(token=hf_client.token)
|
||||
resp = http_auth.post(
|
||||
f"/api/{repo_type}s/{repo_id}/branch/main/revert",
|
||||
json={
|
||||
"ref": latest_commit,
|
||||
"parent_number": 1,
|
||||
"force": False,
|
||||
"allow_empty": False,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200, f"Revert failed: {resp.text}"
|
||||
|
||||
# Verify files were removed
|
||||
time.sleep(1)
|
||||
files_after = hf_client.list_repo_files(repo_id=repo_id, repo_type=repo_type)
|
||||
assert "revert-test.txt" not in files_after
|
||||
assert "revert-test-lfs.bin" not in files_after
|
||||
|
||||
# Cleanup
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
def test_revert_non_conflicting(self, temp_repo):
|
||||
"""Revert non-latest but non-conflicting commit with LFS."""
|
||||
repo_id, repo_type, hf_client = temp_repo
|
||||
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
|
||||
# Commit 1: Add file1 with LFS (random content)
|
||||
lfs1 = os.urandom(2000000) # Random 2MB
|
||||
(temp_dir / "file1.txt").write_bytes(os.urandom(100))
|
||||
(temp_dir / "file1-lfs.bin").write_bytes(lfs1)
|
||||
|
||||
hf_client.upload_folder(
|
||||
folder_path=str(temp_dir),
|
||||
path_in_repo="set1/",
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
commit_message="Add file1 with LFS",
|
||||
)
|
||||
|
||||
# Get commit 1 ID
|
||||
http_client = HTTPClient()
|
||||
resp = http_client.get(f"/api/{repo_type}s/{repo_id}/commits/main")
|
||||
commits_data = resp.json()
|
||||
commit1 = commits_data["commits"][0]["id"]
|
||||
|
||||
# Commit 2: Add file2 (different path, random content)
|
||||
(temp_dir / "file1.txt").unlink()
|
||||
(temp_dir / "file1-lfs.bin").unlink()
|
||||
|
||||
lfs2 = os.urandom(2000000) # Random 2MB
|
||||
(temp_dir / "file2.txt").write_bytes(os.urandom(100))
|
||||
(temp_dir / "file2-lfs.bin").write_bytes(lfs2)
|
||||
|
||||
hf_client.upload_folder(
|
||||
folder_path=str(temp_dir),
|
||||
path_in_repo="set2/",
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
commit_message="Add file2 with LFS",
|
||||
)
|
||||
|
||||
# Revert commit 1 (authenticated)
|
||||
http_auth = HTTPClient(token=hf_client.token)
|
||||
resp = http_auth.post(
|
||||
f"/api/{repo_type}s/{repo_id}/branch/main/revert",
|
||||
json={"ref": commit1, "parent_number": 1, "force": False},
|
||||
)
|
||||
assert resp.status_code == 200, f"Revert failed: {resp.text}"
|
||||
|
||||
# Verify file2 still exists, file1 removed
|
||||
time.sleep(1)
|
||||
files = hf_client.list_repo_files(repo_id=repo_id, repo_type=repo_type)
|
||||
assert "set2/file2.txt" in files
|
||||
assert "set2/file2-lfs.bin" in files
|
||||
assert "set1/file1.txt" not in files
|
||||
|
||||
# Cleanup
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
class TestBranchReset:
|
||||
"""Test reset functionality."""
|
||||
|
||||
def test_reset_creates_new_commit(self, temp_repo):
|
||||
"""Reset should create new commit, not delete history."""
|
||||
repo_id, repo_type, hf_client = temp_repo
|
||||
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
|
||||
# Create 3 commits with random content
|
||||
for i in range(1, 4):
|
||||
content = os.urandom(1000) # Random content each time
|
||||
(temp_dir / "test.txt").write_bytes(content)
|
||||
|
||||
hf_client.upload_folder(
|
||||
folder_path=str(temp_dir),
|
||||
path_in_repo="",
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
commit_message=f"Version {i}",
|
||||
)
|
||||
time.sleep(0.5)
|
||||
|
||||
# Get commits before reset
|
||||
http_client = HTTPClient()
|
||||
resp = http_client.get(f"/api/{repo_type}s/{repo_id}/commits/main")
|
||||
commits_data = resp.json()
|
||||
commits_before = commits_data["commits"]
|
||||
assert len(commits_before) >= 3
|
||||
|
||||
# Reset to version 1
|
||||
target_commit = commits_before[2]["id"] # Third from top (oldest)
|
||||
|
||||
http_auth = HTTPClient(token=hf_client.token)
|
||||
resp = http_auth.post(
|
||||
f"/api/{repo_type}s/{repo_id}/branch/main/reset",
|
||||
json={"ref": target_commit, "force": True},
|
||||
)
|
||||
assert resp.status_code == 200, f"Reset failed: {resp.text}"
|
||||
|
||||
# Verify new commit was created (history preserved)
|
||||
time.sleep(1)
|
||||
resp = http_client.get(f"/api/{repo_type}s/{repo_id}/commits/main")
|
||||
commits_after_data = resp.json()
|
||||
commits_after = commits_after_data["commits"]
|
||||
|
||||
# Should have MORE commits (original 3 + reset commit)
|
||||
assert len(commits_after) >= len(commits_before)
|
||||
|
||||
# Verify file exists after reset
|
||||
downloaded = hf_client.download_file(
|
||||
repo_id=repo_id, filename="test.txt", repo_type=repo_type
|
||||
)
|
||||
# Just verify file exists and can be downloaded
|
||||
assert Path(downloaded).exists()
|
||||
|
||||
# Cleanup
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
@pytest.mark.lfs
|
||||
def test_reset_with_lfs_files(self, temp_repo):
|
||||
"""Reset with LFS files (should preserve LFS objects)."""
|
||||
repo_id, repo_type, hf_client = temp_repo
|
||||
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
|
||||
# Create 3 versions with LFS (random content)
|
||||
for i in range(1, 4):
|
||||
content = os.urandom(2000000) # Random 2MB each time
|
||||
(temp_dir / "large.bin").write_bytes(content)
|
||||
|
||||
hf_client.upload_folder(
|
||||
folder_path=str(temp_dir),
|
||||
path_in_repo="",
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
commit_message=f"LFS Version {i}",
|
||||
)
|
||||
time.sleep(0.5)
|
||||
|
||||
# Get commits
|
||||
http_client = HTTPClient()
|
||||
resp = http_client.get(f"/api/{repo_type}s/{repo_id}/commits/main")
|
||||
commits_data = resp.json()
|
||||
commits = commits_data["commits"]
|
||||
target_commit = commits[2]["id"] # First version
|
||||
|
||||
# Reset to version 1
|
||||
http_auth = HTTPClient(token=hf_client.token)
|
||||
resp = http_auth.post(
|
||||
f"/api/{repo_type}s/{repo_id}/branch/main/reset",
|
||||
json={"ref": target_commit, "force": True},
|
||||
)
|
||||
assert resp.status_code == 200, f"Reset failed: {resp.text}"
|
||||
|
||||
# Verify LFS file exists after reset
|
||||
time.sleep(1)
|
||||
downloaded = hf_client.download_file(
|
||||
repo_id=repo_id, filename="large.bin", repo_type=repo_type
|
||||
)
|
||||
# Just verify file exists and can be downloaded
|
||||
assert Path(downloaded).exists()
|
||||
assert Path(downloaded).stat().st_size == 2000000 # 2MB
|
||||
|
||||
# Cleanup
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
class TestBranchMerge:
|
||||
"""Test merge functionality."""
|
||||
|
||||
def test_merge_branches(self, temp_repo):
|
||||
"""Merge dev branch into main."""
|
||||
repo_id, repo_type, hf_client = temp_repo
|
||||
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
|
||||
# Create initial file on main with random content
|
||||
(temp_dir / "main.txt").write_bytes(os.urandom(100))
|
||||
hf_client.upload_folder(
|
||||
folder_path=str(temp_dir),
|
||||
path_in_repo="",
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
commit_message="Add main file",
|
||||
)
|
||||
|
||||
# Create dev branch
|
||||
http_auth = HTTPClient(token=hf_client.token)
|
||||
resp = http_auth.post(
|
||||
f"/api/{repo_type}s/{repo_id}/branch",
|
||||
json={"branch": "dev", "revision": "main"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Upload different files to dev branch using direct API
|
||||
dev_txt_content = os.urandom(200) # Random content
|
||||
dev_lfs_content = os.urandom(2000000) # Random 2MB LFS
|
||||
|
||||
# Upload to dev branch using commit API
|
||||
ndjson_lines = []
|
||||
|
||||
# Header
|
||||
ndjson_lines.append(
|
||||
json.dumps(
|
||||
{
|
||||
"key": "header",
|
||||
"value": {"summary": "Add dev files", "description": ""},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Dev text file
|
||||
ndjson_lines.append(
|
||||
json.dumps(
|
||||
{
|
||||
"key": "file",
|
||||
"value": {
|
||||
"path": "dev.txt",
|
||||
"content": base64.b64encode(dev_txt_content).decode(),
|
||||
"encoding": "base64",
|
||||
},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Dev LFS file - upload to S3 first
|
||||
sha256 = hashlib.sha256(dev_lfs_content).hexdigest()
|
||||
|
||||
# Get LFS upload URL
|
||||
lfs_resp = http_auth.post(
|
||||
f"/{repo_type}s/{repo_id}.git/info/lfs/objects/batch",
|
||||
json={
|
||||
"operation": "upload",
|
||||
"objects": [{"oid": sha256, "size": len(dev_lfs_content)}],
|
||||
},
|
||||
)
|
||||
assert lfs_resp.status_code == 200
|
||||
|
||||
upload_url = lfs_resp.json()["objects"][0]["actions"]["upload"]["href"]
|
||||
|
||||
# Upload to S3
|
||||
s3_resp = requests.put(upload_url, data=dev_lfs_content)
|
||||
assert s3_resp.status_code in (200, 204)
|
||||
|
||||
# Add LFS file to commit
|
||||
ndjson_lines.append(
|
||||
json.dumps(
|
||||
{
|
||||
"key": "lfsFile",
|
||||
"value": {
|
||||
"path": "dev-lfs.bin",
|
||||
"oid": sha256,
|
||||
"size": len(dev_lfs_content),
|
||||
"algo": "sha256",
|
||||
},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Commit to dev branch
|
||||
commit_resp = http_auth.post(
|
||||
f"/api/{repo_type}s/{repo_id}/commit/dev",
|
||||
data="\n".join(ndjson_lines),
|
||||
headers={"Content-Type": "application/x-ndjson"},
|
||||
)
|
||||
assert commit_resp.status_code == 200
|
||||
|
||||
# Merge dev into main
|
||||
resp = http_auth.post(
|
||||
f"/api/{repo_type}s/{repo_id}/merge/dev/into/main",
|
||||
json={"strategy": "source-wins"},
|
||||
)
|
||||
assert resp.status_code == 200, f"Merge failed: {resp.text}"
|
||||
|
||||
# Verify merged files on main
|
||||
time.sleep(1)
|
||||
files = hf_client.list_repo_files(repo_id=repo_id, repo_type=repo_type)
|
||||
assert "main.txt" in files
|
||||
assert "dev.txt" in files
|
||||
assert "dev-lfs.bin" in files
|
||||
|
||||
# Cleanup
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
@pytest.mark.lfs
|
||||
class TestBranchOperationsWithLFS:
|
||||
"""Test branch operations with LFS garbage collection."""
|
||||
|
||||
def test_create_and_delete_branch(self, temp_repo):
|
||||
"""Create and delete branches."""
|
||||
repo_id, repo_type, hf_client = temp_repo
|
||||
|
||||
http_auth = HTTPClient(token=hf_client.token)
|
||||
|
||||
# Create branch
|
||||
resp = http_auth.post(
|
||||
f"/api/{repo_type}s/{repo_id}/branch",
|
||||
json={"branch": "feature", "revision": "main"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Delete branch
|
||||
resp = http_auth.delete(f"/api/{repo_type}s/{repo_id}/branch/feature")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_create_and_delete_tag(self, temp_repo):
|
||||
"""Create and delete tags."""
|
||||
repo_id, repo_type, hf_client = temp_repo
|
||||
|
||||
http_auth = HTTPClient(token=hf_client.token)
|
||||
|
||||
# Create tag
|
||||
resp = http_auth.post(
|
||||
f"/api/{repo_type}s/{repo_id}/tag",
|
||||
json={"tag": "v1.0", "revision": "main"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Delete tag
|
||||
resp = http_auth.delete(f"/api/{repo_type}s/{repo_id}/tag/v1.0")
|
||||
assert resp.status_code == 200
|
||||
@@ -1,333 +0,0 @@
|
||||
"""File operation tests.
|
||||
|
||||
Tests file upload, download, deletion, and listing.
|
||||
Covers both small files (inline) and large files (LFS).
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.base import HTTPClient
|
||||
|
||||
|
||||
class TestFileOperations:
|
||||
"""Test file upload, download, and deletion operations."""
|
||||
|
||||
def test_upload_small_file_hf_client(self, temp_repo):
|
||||
"""Test uploading small file (<10MB) using HuggingFace Hub client."""
|
||||
repo_id, repo_type, hf_client = temp_repo
|
||||
|
||||
# Create small test file in temp directory
|
||||
test_content = b"Hello, KohakuHub! This is a small test file."
|
||||
test_file = (
|
||||
Path(tempfile.gettempdir()) / f"test_small_{os.urandom(4).hex()}.txt"
|
||||
)
|
||||
test_file.write_bytes(test_content)
|
||||
|
||||
# Upload file
|
||||
hf_client.upload_file(
|
||||
path_or_fileobj=str(test_file),
|
||||
path_in_repo="test_small.txt",
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
commit_message="Add small test file",
|
||||
)
|
||||
|
||||
# Download and verify
|
||||
downloaded = hf_client.download_file(
|
||||
repo_id=repo_id, filename="test_small.txt", repo_type=repo_type
|
||||
)
|
||||
assert Path(downloaded).exists()
|
||||
content = Path(downloaded).read_bytes()
|
||||
assert content == test_content
|
||||
|
||||
# Cleanup temp file
|
||||
test_file.unlink(missing_ok=True)
|
||||
|
||||
def test_upload_folder_hf_client(self, temp_repo):
|
||||
"""Test uploading folder using HuggingFace Hub client."""
|
||||
repo_id, repo_type, hf_client = temp_repo
|
||||
|
||||
# Create temp folder with files
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
(temp_dir / "file1.txt").write_bytes(b"File 1 content")
|
||||
(temp_dir / "file2.txt").write_bytes(b"File 2 content")
|
||||
(temp_dir / "subdir").mkdir()
|
||||
(temp_dir / "subdir" / "file3.txt").write_bytes(b"File 3 content")
|
||||
|
||||
# Upload folder
|
||||
hf_client.upload_folder(
|
||||
folder_path=str(temp_dir),
|
||||
path_in_repo="uploaded_folder/",
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
commit_message="Upload folder",
|
||||
)
|
||||
|
||||
# Verify files exist
|
||||
files = hf_client.list_repo_files(repo_id=repo_id, repo_type=repo_type)
|
||||
assert "uploaded_folder/file1.txt" in files
|
||||
assert "uploaded_folder/file2.txt" in files
|
||||
assert "uploaded_folder/subdir/file3.txt" in files
|
||||
|
||||
# Cleanup
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
def test_download_file_hf_client(self, temp_repo):
|
||||
"""Test downloading file using HuggingFace Hub client."""
|
||||
repo_id, repo_type, hf_client = temp_repo
|
||||
|
||||
# Upload a file first
|
||||
test_content = b"Download test content"
|
||||
test_file = (
|
||||
Path(tempfile.gettempdir()) / f"test_download_{os.urandom(4).hex()}.txt"
|
||||
)
|
||||
test_file.write_bytes(test_content)
|
||||
|
||||
hf_client.upload_file(
|
||||
path_or_fileobj=str(test_file),
|
||||
path_in_repo="test_download.txt",
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
)
|
||||
|
||||
# Download file
|
||||
downloaded = hf_client.download_file(
|
||||
repo_id=repo_id, filename="test_download.txt", repo_type=repo_type
|
||||
)
|
||||
|
||||
# Verify content
|
||||
assert Path(downloaded).exists()
|
||||
content = Path(downloaded).read_bytes()
|
||||
assert content == test_content
|
||||
|
||||
# Cleanup
|
||||
test_file.unlink(missing_ok=True)
|
||||
|
||||
def test_delete_file_hf_client(self, temp_repo):
|
||||
"""Test deleting file using HuggingFace Hub client."""
|
||||
repo_id, repo_type, hf_client = temp_repo
|
||||
|
||||
# Upload a file first
|
||||
test_content = b"File to be deleted"
|
||||
test_file = (
|
||||
Path(tempfile.gettempdir()) / f"test_delete_{os.urandom(4).hex()}.txt"
|
||||
)
|
||||
test_file.write_bytes(test_content)
|
||||
|
||||
hf_client.upload_file(
|
||||
path_or_fileobj=str(test_file),
|
||||
path_in_repo="test_delete.txt",
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
)
|
||||
|
||||
# Verify file exists
|
||||
files = hf_client.list_repo_files(repo_id=repo_id, repo_type=repo_type)
|
||||
assert "test_delete.txt" in files
|
||||
|
||||
# Delete file
|
||||
hf_client.delete_file(
|
||||
path_in_repo="test_delete.txt",
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
commit_message="Delete test file",
|
||||
)
|
||||
|
||||
# Verify file is deleted
|
||||
files = hf_client.list_repo_files(repo_id=repo_id, repo_type=repo_type)
|
||||
assert "test_delete.txt" not in files
|
||||
|
||||
# Cleanup
|
||||
test_file.unlink(missing_ok=True)
|
||||
|
||||
def test_list_repo_files(self, temp_repo):
|
||||
"""Test listing repository files."""
|
||||
repo_id, repo_type, hf_client = temp_repo
|
||||
|
||||
# Upload multiple files
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
files_to_upload = {
|
||||
"file1.txt": b"Content 1",
|
||||
"file2.txt": b"Content 2",
|
||||
"subdir/file3.txt": b"Content 3",
|
||||
}
|
||||
|
||||
for file_path, content in files_to_upload.items():
|
||||
full_path = temp_dir / file_path
|
||||
full_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
full_path.write_bytes(content)
|
||||
|
||||
hf_client.upload_folder(
|
||||
folder_path=str(temp_dir),
|
||||
path_in_repo="",
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
)
|
||||
|
||||
# List files
|
||||
files = hf_client.list_repo_files(repo_id=repo_id, repo_type=repo_type)
|
||||
assert isinstance(files, list)
|
||||
assert "file1.txt" in files
|
||||
assert "file2.txt" in files
|
||||
assert "subdir/file3.txt" in files
|
||||
|
||||
# Cleanup
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
def test_file_metadata_head_request(self, random_user, temp_repo):
|
||||
"""Test getting file metadata via HEAD request."""
|
||||
username, token, _ = random_user
|
||||
repo_id, repo_type, hf_client = temp_repo
|
||||
|
||||
# Upload a file
|
||||
test_content = b"Metadata test content"
|
||||
test_file = (
|
||||
Path(tempfile.gettempdir()) / f"test_metadata_{os.urandom(4).hex()}.txt"
|
||||
)
|
||||
test_file.write_bytes(test_content)
|
||||
|
||||
hf_client.upload_file(
|
||||
path_or_fileobj=str(test_file),
|
||||
path_in_repo="test_metadata.txt",
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
)
|
||||
|
||||
# HEAD request to get metadata using repo owner's token
|
||||
user_http_client = HTTPClient(token=token)
|
||||
|
||||
namespace, repo_name = repo_id.split("/")
|
||||
resp = user_http_client.head(
|
||||
f"/{repo_type}s/{namespace}/{repo_name}/resolve/main/test_metadata.txt"
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Check headers
|
||||
assert "X-Repo-Commit" in resp.headers or "ETag" in resp.headers
|
||||
if "Content-Length" in resp.headers:
|
||||
assert int(resp.headers["Content-Length"]) == len(test_content)
|
||||
|
||||
# Cleanup
|
||||
test_file.unlink(missing_ok=True)
|
||||
|
||||
def test_upload_with_commit_message(self, temp_repo):
|
||||
"""Test uploading file with custom commit message."""
|
||||
repo_id, repo_type, hf_client = temp_repo
|
||||
|
||||
# Upload file with custom message
|
||||
test_content = b"Commit message test"
|
||||
test_file = (
|
||||
Path(tempfile.gettempdir()) / f"test_commit_{os.urandom(4).hex()}.txt"
|
||||
)
|
||||
test_file.write_bytes(test_content)
|
||||
|
||||
commit_message = "Custom commit message for testing"
|
||||
hf_client.upload_file(
|
||||
path_or_fileobj=str(test_file),
|
||||
path_in_repo="test_commit.txt",
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
commit_message=commit_message,
|
||||
)
|
||||
|
||||
# Note: Verifying commit message would require commit history API
|
||||
# Just verify file was uploaded
|
||||
files = hf_client.list_repo_files(repo_id=repo_id, repo_type=repo_type)
|
||||
assert "test_commit.txt" in files
|
||||
|
||||
# Cleanup
|
||||
test_file.unlink(missing_ok=True)
|
||||
|
||||
def test_file_content_integrity(self, temp_repo):
|
||||
"""Test that file content integrity is preserved through upload/download."""
|
||||
repo_id, repo_type, hf_client = temp_repo
|
||||
|
||||
# Create file with random content
|
||||
test_content = os.urandom(100 * 1000) # 100KB random data
|
||||
original_hash = hashlib.sha256(test_content).hexdigest()
|
||||
|
||||
test_file = (
|
||||
Path(tempfile.gettempdir()) / f"test_integrity_{os.urandom(4).hex()}.bin"
|
||||
)
|
||||
test_file.write_bytes(test_content)
|
||||
|
||||
# Upload
|
||||
hf_client.upload_file(
|
||||
path_or_fileobj=str(test_file),
|
||||
path_in_repo="test_integrity.bin",
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
)
|
||||
|
||||
# Download
|
||||
downloaded = hf_client.download_file(
|
||||
repo_id=repo_id, filename="test_integrity.bin", repo_type=repo_type
|
||||
)
|
||||
|
||||
# Verify integrity
|
||||
downloaded_content = Path(downloaded).read_bytes()
|
||||
downloaded_hash = hashlib.sha256(downloaded_content).hexdigest()
|
||||
assert (
|
||||
downloaded_hash == original_hash
|
||||
), "File content corrupted during upload/download"
|
||||
|
||||
# Cleanup
|
||||
test_file.unlink(missing_ok=True)
|
||||
|
||||
def test_nonexistent_file_download(self, temp_repo):
|
||||
"""Test downloading non-existent file."""
|
||||
repo_id, repo_type, hf_client = temp_repo
|
||||
|
||||
# Try to download non-existent file
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
hf_client.download_file(
|
||||
repo_id=repo_id, filename="nonexistent.txt", repo_type=repo_type
|
||||
)
|
||||
|
||||
# Should be an error (404 or file not found)
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert (
|
||||
"404" in error_msg or "not found" in error_msg or "cannot find" in error_msg
|
||||
)
|
||||
|
||||
def test_tree_endpoint(self, random_user, temp_repo):
|
||||
"""Test tree endpoint for listing files."""
|
||||
username, token, _ = random_user
|
||||
repo_id, repo_type, hf_client = temp_repo
|
||||
|
||||
# Upload some files
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
(temp_dir / "file1.txt").write_bytes(b"Content 1")
|
||||
(temp_dir / "dir1").mkdir()
|
||||
(temp_dir / "dir1" / "file2.txt").write_bytes(b"Content 2")
|
||||
|
||||
hf_client.upload_folder(
|
||||
folder_path=str(temp_dir),
|
||||
path_in_repo="",
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
)
|
||||
|
||||
# Query tree endpoint using repo owner's token
|
||||
user_http_client = HTTPClient(token=token)
|
||||
|
||||
namespace, repo_name = repo_id.split("/")
|
||||
resp = user_http_client.get(
|
||||
f"/api/{repo_type}s/{namespace}/{repo_name}/tree/main/"
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
tree_data = resp.json()
|
||||
assert isinstance(tree_data, list)
|
||||
|
||||
# Check files are in tree
|
||||
paths = [item["path"] for item in tree_data]
|
||||
assert "file1.txt" in paths
|
||||
|
||||
# Cleanup
|
||||
shutil.rmtree(temp_dir)
|
||||
@@ -1,338 +0,0 @@
|
||||
"""LFS (Large File Storage) operation tests.
|
||||
|
||||
Tests large file upload/download using Git LFS protocol.
|
||||
Files >10MB should use LFS, files <=10MB should use regular upload.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.base import HTTPClient
|
||||
|
||||
|
||||
class TestLFSOperations:
|
||||
"""Test LFS file operations for large files."""
|
||||
|
||||
@pytest.mark.lfs
|
||||
def test_upload_large_file_15mb(self, temp_repo):
|
||||
"""Test uploading 15MB file (should use LFS)."""
|
||||
repo_id, repo_type, hf_client = temp_repo
|
||||
|
||||
# Create 15MB file
|
||||
size_mb = 15
|
||||
test_content = os.urandom(size_mb * 1000 * 1000)
|
||||
original_hash = hashlib.sha256(test_content).hexdigest()
|
||||
|
||||
test_file = Path(tempfile.gettempdir()) / f"test_15mb_{os.urandom(4).hex()}.bin"
|
||||
test_file.write_bytes(test_content)
|
||||
|
||||
# Upload file (should trigger LFS)
|
||||
hf_client.upload_file(
|
||||
path_or_fileobj=str(test_file),
|
||||
path_in_repo="large/test_15mb.bin",
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
commit_message="Upload 15MB file via LFS",
|
||||
)
|
||||
|
||||
# Download and verify
|
||||
downloaded = hf_client.download_file(
|
||||
repo_id=repo_id, filename="large/test_15mb.bin", repo_type=repo_type
|
||||
)
|
||||
|
||||
# Verify content integrity
|
||||
downloaded_content = Path(downloaded).read_bytes()
|
||||
downloaded_hash = hashlib.sha256(downloaded_content).hexdigest()
|
||||
assert (
|
||||
downloaded_hash == original_hash
|
||||
), "Large file content corrupted during LFS upload/download"
|
||||
|
||||
# Cleanup
|
||||
test_file.unlink(missing_ok=True)
|
||||
|
||||
@pytest.mark.lfs
|
||||
@pytest.mark.slow
|
||||
def test_upload_large_file_50mb(self, temp_repo):
|
||||
"""Test uploading 50MB file (should use LFS)."""
|
||||
repo_id, repo_type, hf_client = temp_repo
|
||||
|
||||
# Create 50MB file
|
||||
size_mb = 50
|
||||
test_content = os.urandom(size_mb * 1000 * 1000)
|
||||
original_hash = hashlib.sha256(test_content).hexdigest()
|
||||
|
||||
test_file = Path(tempfile.gettempdir()) / f"test_50mb_{os.urandom(4).hex()}.bin"
|
||||
test_file.write_bytes(test_content)
|
||||
|
||||
# Upload file (should trigger LFS)
|
||||
hf_client.upload_file(
|
||||
path_or_fileobj=str(test_file),
|
||||
path_in_repo="large/test_50mb.bin",
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
commit_message="Upload 50MB file via LFS",
|
||||
)
|
||||
|
||||
# Download and verify
|
||||
downloaded = hf_client.download_file(
|
||||
repo_id=repo_id, filename="large/test_50mb.bin", repo_type=repo_type
|
||||
)
|
||||
|
||||
# Verify content integrity
|
||||
downloaded_content = Path(downloaded).read_bytes()
|
||||
downloaded_hash = hashlib.sha256(downloaded_content).hexdigest()
|
||||
assert (
|
||||
downloaded_hash == original_hash
|
||||
), "50MB file content corrupted during LFS upload/download"
|
||||
|
||||
# Cleanup
|
||||
test_file.unlink(missing_ok=True)
|
||||
|
||||
def test_small_file_uses_regular_upload(self, temp_repo):
|
||||
"""Test that small file (<10MB) uses regular upload, not LFS."""
|
||||
repo_id, repo_type, hf_client = temp_repo
|
||||
|
||||
# Create 5MB file (below LFS threshold)
|
||||
size_mb = 5
|
||||
test_content = os.urandom(size_mb * 1000 * 1000)
|
||||
|
||||
test_file = Path(tempfile.gettempdir()) / f"test_5mb_{os.urandom(4).hex()}.bin"
|
||||
test_file.write_bytes(test_content)
|
||||
|
||||
# Upload file (should NOT use LFS)
|
||||
hf_client.upload_file(
|
||||
path_or_fileobj=str(test_file),
|
||||
path_in_repo="small/test_5mb.bin",
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
commit_message="Upload 5MB file (regular)",
|
||||
)
|
||||
|
||||
# Verify file was uploaded
|
||||
files = hf_client.list_repo_files(repo_id=repo_id, repo_type=repo_type)
|
||||
assert "small/test_5mb.bin" in files
|
||||
|
||||
# Cleanup
|
||||
test_file.unlink(missing_ok=True)
|
||||
|
||||
@pytest.mark.lfs
|
||||
def test_lfs_deduplication(self, temp_repo):
|
||||
"""Test LFS deduplication - same file uploaded twice should use same storage."""
|
||||
repo_id, repo_type, hf_client = temp_repo
|
||||
|
||||
# Create 12MB file
|
||||
size_mb = 12
|
||||
test_content = os.urandom(size_mb * 1000 * 1000)
|
||||
test_file = (
|
||||
Path(tempfile.gettempdir()) / f"test_dedup_{os.urandom(4).hex()}.bin"
|
||||
)
|
||||
test_file.write_bytes(test_content)
|
||||
|
||||
# Upload file first time
|
||||
hf_client.upload_file(
|
||||
path_or_fileobj=str(test_file),
|
||||
path_in_repo="dedup/file1.bin",
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
commit_message="Upload file 1",
|
||||
)
|
||||
|
||||
# Upload same file with different path (should be deduplicated)
|
||||
hf_client.upload_file(
|
||||
path_or_fileobj=str(test_file),
|
||||
path_in_repo="dedup/file2.bin",
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
commit_message="Upload file 2 (same content)",
|
||||
)
|
||||
|
||||
# Both files should exist
|
||||
files = hf_client.list_repo_files(repo_id=repo_id, repo_type=repo_type)
|
||||
assert "dedup/file1.bin" in files
|
||||
assert "dedup/file2.bin" in files
|
||||
|
||||
# Download both and verify they're identical
|
||||
downloaded1 = hf_client.download_file(
|
||||
repo_id=repo_id, filename="dedup/file1.bin", repo_type=repo_type
|
||||
)
|
||||
downloaded2 = hf_client.download_file(
|
||||
repo_id=repo_id, filename="dedup/file2.bin", repo_type=repo_type
|
||||
)
|
||||
|
||||
content1 = Path(downloaded1).read_bytes()
|
||||
content2 = Path(downloaded2).read_bytes()
|
||||
assert content1 == content2 == test_content
|
||||
|
||||
# Cleanup
|
||||
test_file.unlink(missing_ok=True)
|
||||
|
||||
@pytest.mark.lfs
|
||||
def test_lfs_batch_api(self, random_user, temp_repo):
|
||||
"""Test LFS batch API endpoint directly."""
|
||||
username, token, _ = random_user
|
||||
repo_id, repo_type, hf_client = temp_repo
|
||||
|
||||
# Create HTTP client with the same user's token
|
||||
user_http_client = HTTPClient(token=token)
|
||||
|
||||
# Prepare LFS batch request
|
||||
fake_oid = "a" * 64 # SHA256 hex
|
||||
fake_size = 15 * 1000 * 1000 # 15MB
|
||||
|
||||
batch_request = {
|
||||
"operation": "upload",
|
||||
"transfers": ["basic"],
|
||||
"objects": [{"oid": fake_oid, "size": fake_size}],
|
||||
"hash_algo": "sha256",
|
||||
}
|
||||
|
||||
# Send LFS batch request using the repo owner's token
|
||||
resp = user_http_client.post(
|
||||
f"/{repo_id}.git/info/lfs/objects/batch", json=batch_request
|
||||
)
|
||||
assert resp.status_code == 200, f"LFS batch request failed: {resp.text}"
|
||||
|
||||
data = resp.json()
|
||||
assert "objects" in data
|
||||
assert len(data["objects"]) == 1
|
||||
|
||||
lfs_object = data["objects"][0]
|
||||
assert lfs_object["oid"] == fake_oid
|
||||
assert lfs_object["size"] == fake_size
|
||||
|
||||
# Check if upload action is provided
|
||||
# If file doesn't exist, should have "actions" with "upload"
|
||||
# If file exists (deduplicated), no "actions"
|
||||
if "actions" in lfs_object:
|
||||
assert "upload" in lfs_object["actions"]
|
||||
assert "href" in lfs_object["actions"]["upload"]
|
||||
|
||||
@pytest.mark.lfs
|
||||
def test_mixed_file_sizes_upload(self, temp_repo):
|
||||
"""Test uploading folder with mixed file sizes (some LFS, some regular)."""
|
||||
repo_id, repo_type, hf_client = temp_repo
|
||||
|
||||
# Create temp folder with mixed sizes
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
|
||||
# Small files (regular upload)
|
||||
(temp_dir / "small1.txt").write_bytes(b"Small file 1" * 100)
|
||||
(temp_dir / "small2.txt").write_bytes(b"Small file 2" * 100)
|
||||
|
||||
# Large file (LFS)
|
||||
(temp_dir / "large.bin").write_bytes(os.urandom(12 * 1000 * 1000)) # 12MB
|
||||
|
||||
# Upload folder
|
||||
hf_client.upload_folder(
|
||||
folder_path=str(temp_dir),
|
||||
path_in_repo="mixed/",
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
commit_message="Upload mixed file sizes",
|
||||
)
|
||||
|
||||
# Verify all files exist
|
||||
files = hf_client.list_repo_files(repo_id=repo_id, repo_type=repo_type)
|
||||
assert "mixed/small1.txt" in files
|
||||
assert "mixed/small2.txt" in files
|
||||
assert "mixed/large.bin" in files
|
||||
|
||||
# Download and verify large file
|
||||
downloaded = hf_client.download_file(
|
||||
repo_id=repo_id, filename="mixed/large.bin", repo_type=repo_type
|
||||
)
|
||||
assert Path(downloaded).stat().st_size == 12 * 1000 * 1000
|
||||
|
||||
# Cleanup
|
||||
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
@pytest.mark.lfs
|
||||
def test_lfs_file_metadata(self, random_user, temp_repo):
|
||||
"""Test that LFS files have proper metadata in tree API."""
|
||||
username, token, _ = random_user
|
||||
repo_id, repo_type, hf_client = temp_repo
|
||||
|
||||
# Upload LFS file
|
||||
test_content = os.urandom(15 * 1000 * 1000) # 15MB
|
||||
test_file = (
|
||||
Path(tempfile.gettempdir()) / f"test_lfs_meta_{os.urandom(4).hex()}.bin"
|
||||
)
|
||||
test_file.write_bytes(test_content)
|
||||
|
||||
hf_client.upload_file(
|
||||
path_or_fileobj=str(test_file),
|
||||
path_in_repo="test_lfs_meta.bin",
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
)
|
||||
|
||||
# Query tree with expand=true to get LFS metadata using repo owner's token
|
||||
user_http_client = HTTPClient(token=token)
|
||||
|
||||
namespace, repo_name = repo_id.split("/")
|
||||
resp = user_http_client.get(
|
||||
f"/api/{repo_type}s/{namespace}/{repo_name}/tree/main/",
|
||||
params={"expand": "true"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
tree_data = resp.json()
|
||||
|
||||
# Find our file in tree
|
||||
lfs_file = None
|
||||
for item in tree_data:
|
||||
if item.get("path") == "test_lfs_meta.bin":
|
||||
lfs_file = item
|
||||
break
|
||||
|
||||
assert lfs_file is not None, "LFS file not found in tree"
|
||||
|
||||
# Check if LFS metadata is present
|
||||
if "lfs" in lfs_file:
|
||||
assert "oid" in lfs_file["lfs"]
|
||||
assert "size" in lfs_file["lfs"]
|
||||
|
||||
# Cleanup
|
||||
test_file.unlink(missing_ok=True)
|
||||
|
||||
def test_boundary_file_size_10mb(self, temp_repo):
|
||||
"""Test uploading file exactly at 10MB boundary."""
|
||||
repo_id, repo_type, hf_client = temp_repo
|
||||
|
||||
# Create exactly 10MB file
|
||||
size_bytes = 10 * 1000 * 1000
|
||||
test_content = os.urandom(size_bytes)
|
||||
|
||||
test_file = (
|
||||
Path(tempfile.gettempdir()) / f"test_10mb_exact_{os.urandom(4).hex()}.bin"
|
||||
)
|
||||
test_file.write_bytes(test_content)
|
||||
|
||||
# Upload file
|
||||
hf_client.upload_file(
|
||||
path_or_fileobj=str(test_file),
|
||||
path_in_repo="boundary/test_10mb_exact.bin",
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
commit_message="Upload exactly 10MB file",
|
||||
)
|
||||
|
||||
# Verify file exists
|
||||
files = hf_client.list_repo_files(repo_id=repo_id, repo_type=repo_type)
|
||||
assert "boundary/test_10mb_exact.bin" in files
|
||||
|
||||
# Download and verify
|
||||
downloaded = hf_client.download_file(
|
||||
repo_id=repo_id,
|
||||
filename="boundary/test_10mb_exact.bin",
|
||||
repo_type=repo_type,
|
||||
)
|
||||
assert Path(downloaded).stat().st_size == size_bytes
|
||||
|
||||
# Cleanup
|
||||
test_file.unlink(missing_ok=True)
|
||||
@@ -1,169 +0,0 @@
|
||||
"""Organization management tests.
|
||||
|
||||
Tests organization creation, member management, and organization listing.
|
||||
Uses actual Pydantic models from source code.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
from kohakuhub.api.org.router import (
|
||||
AddMemberPayload,
|
||||
CreateOrganizationPayload,
|
||||
UpdateMemberRolePayload,
|
||||
)
|
||||
from kohakuhub.auth.routes import RegisterRequest
|
||||
from tests.config import config
|
||||
|
||||
|
||||
class TestOrganization:
|
||||
"""Test organization operations."""
|
||||
|
||||
def test_create_organization(self, authenticated_http_client):
|
||||
"""Test organization creation."""
|
||||
unique_id = uuid.uuid4().hex[:6]
|
||||
org_name = f"org-{unique_id}"
|
||||
|
||||
# Create organization using actual model
|
||||
payload = CreateOrganizationPayload(
|
||||
name=org_name, description="Test organization"
|
||||
)
|
||||
|
||||
resp = authenticated_http_client.post("/org/create", json=payload.model_dump())
|
||||
assert resp.status_code == 200, f"Create org failed: {resp.text}"
|
||||
data = resp.json()
|
||||
assert data["success"] == True
|
||||
assert data["name"] == org_name
|
||||
|
||||
def test_get_organization_info(self, authenticated_http_client, test_org):
|
||||
"""Test getting organization information."""
|
||||
resp = authenticated_http_client.get(f"/org/{test_org}")
|
||||
assert resp.status_code == 200, f"Get org info failed: {resp.text}"
|
||||
data = resp.json()
|
||||
assert "name" in data
|
||||
assert data["name"] == test_org
|
||||
|
||||
def test_list_user_organizations(self, authenticated_http_client):
|
||||
"""Test listing user's organizations."""
|
||||
resp = authenticated_http_client.get(f"/org/users/{config.username}/orgs")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "organizations" in data
|
||||
assert isinstance(data["organizations"], list)
|
||||
|
||||
def test_add_remove_member(self, authenticated_http_client, test_org):
|
||||
"""Test adding and removing organization members."""
|
||||
# Create a new user to add as member
|
||||
unique_id = uuid.uuid4().hex[:6]
|
||||
member_username = f"mem-{unique_id}"
|
||||
member_email = f"mem-{unique_id}@example.com"
|
||||
|
||||
register_payload = RegisterRequest(
|
||||
username=member_username, email=member_email, password="testpass123"
|
||||
)
|
||||
|
||||
resp = authenticated_http_client.post(
|
||||
"/api/auth/register", json=register_payload.model_dump()
|
||||
)
|
||||
assert resp.status_code == 200, f"Member registration failed: {resp.text}"
|
||||
|
||||
# Add member to organization using actual model
|
||||
add_payload = AddMemberPayload(username=member_username, role="member")
|
||||
|
||||
resp = authenticated_http_client.post(
|
||||
f"/org/{test_org}/members", json=add_payload.model_dump()
|
||||
)
|
||||
assert resp.status_code == 200, f"Add member failed: {resp.text}"
|
||||
|
||||
# Verify member was added
|
||||
resp = authenticated_http_client.get(f"/org/{test_org}/members")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "members" in data
|
||||
member_usernames = [m["user"] for m in data["members"]]
|
||||
assert member_username in member_usernames
|
||||
|
||||
# Remove member
|
||||
resp = authenticated_http_client.delete(
|
||||
f"/org/{test_org}/members/{member_username}"
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Verify member was removed
|
||||
resp = authenticated_http_client.get(f"/org/{test_org}/members")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
member_usernames = [m["user"] for m in data["members"]]
|
||||
assert member_username not in member_usernames
|
||||
|
||||
def test_update_member_role(self, authenticated_http_client, test_org):
|
||||
"""Test updating organization member role."""
|
||||
# Create a new user
|
||||
unique_id = uuid.uuid4().hex[:6]
|
||||
member_username = f"mem-{unique_id}"
|
||||
member_email = f"mem-{unique_id}@example.com"
|
||||
|
||||
register_payload = RegisterRequest(
|
||||
username=member_username, email=member_email, password="testpass123"
|
||||
)
|
||||
resp = authenticated_http_client.post(
|
||||
"/api/auth/register", json=register_payload.model_dump()
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Add member as 'member' role
|
||||
add_payload = AddMemberPayload(username=member_username, role="member")
|
||||
resp = authenticated_http_client.post(
|
||||
f"/org/{test_org}/members", json=add_payload.model_dump()
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Update member role to 'admin' using actual model
|
||||
update_payload = UpdateMemberRolePayload(role="admin")
|
||||
resp = authenticated_http_client.put(
|
||||
f"/org/{test_org}/members/{member_username}",
|
||||
json=update_payload.model_dump(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Verify role was updated
|
||||
resp = authenticated_http_client.get(f"/org/{test_org}/members")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
member_data = next(
|
||||
(m for m in data["members"] if m["user"] == member_username), None
|
||||
)
|
||||
assert member_data is not None
|
||||
assert member_data["role"] == "admin"
|
||||
|
||||
# Cleanup
|
||||
resp = authenticated_http_client.delete(
|
||||
f"/org/{test_org}/members/{member_username}"
|
||||
)
|
||||
|
||||
def test_list_organization_members(self, authenticated_http_client, test_org):
|
||||
"""Test listing organization members."""
|
||||
resp = authenticated_http_client.get(f"/org/{test_org}/members")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "members" in data
|
||||
assert isinstance(data["members"], list)
|
||||
# Creator should be in the list as super-admin
|
||||
usernames = [m["user"] for m in data["members"]]
|
||||
assert config.username in usernames
|
||||
|
||||
def test_duplicate_organization(self, authenticated_http_client, test_org):
|
||||
"""Test that creating duplicate organization fails."""
|
||||
# Try to create organization with same name
|
||||
payload = CreateOrganizationPayload(name=test_org, description="Duplicate org")
|
||||
|
||||
resp = authenticated_http_client.post("/org/create", json=payload.model_dump())
|
||||
assert resp.status_code == 400
|
||||
assert "exist" in resp.text.lower() or "already" in resp.text.lower()
|
||||
|
||||
def test_nonexistent_organization(self, authenticated_http_client):
|
||||
"""Test accessing non-existent organization."""
|
||||
unique_id = uuid.uuid4().hex[:6]
|
||||
fake_org = f"noorg-{unique_id}"
|
||||
|
||||
resp = authenticated_http_client.get(f"/org/{fake_org}")
|
||||
assert resp.status_code == 404
|
||||
@@ -1,265 +0,0 @@
|
||||
"""Repository CRUD operation tests.
|
||||
|
||||
Tests repository creation, deletion, listing, moving, and info retrieval.
|
||||
Uses actual Pydantic models from source code and validates HF API compatibility.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from huggingface_hub.utils import HfHubHTTPError
|
||||
|
||||
from kohakuhub.api.repo.routers.crud import (
|
||||
CreateRepoPayload,
|
||||
DeleteRepoPayload,
|
||||
MoveRepoPayload,
|
||||
)
|
||||
from tests.config import config
|
||||
|
||||
|
||||
class TestRepositoryCRUD:
|
||||
"""Test repository CRUD operations."""
|
||||
|
||||
def test_create_repo_hf_client(self, hf_client):
|
||||
"""Test repository creation using HuggingFace Hub client."""
|
||||
unique_id = uuid.uuid4().hex[:6]
|
||||
repo_id = f"{config.username}/hfc-{unique_id}" # hfc = hf-create
|
||||
|
||||
# Create repository
|
||||
result = hf_client.create_repo(
|
||||
repo_id=repo_id, repo_type="model", private=False
|
||||
)
|
||||
assert result is not None
|
||||
|
||||
# Verify repository exists
|
||||
info = hf_client.repo_info(repo_id=repo_id, repo_type="model")
|
||||
assert info is not None
|
||||
# Check if repo_id or id field contains our repo
|
||||
repo_field = getattr(info, "id", getattr(info, "repo_id", None))
|
||||
assert repo_id in str(repo_field)
|
||||
|
||||
# Cleanup
|
||||
hf_client.delete_repo(repo_id=repo_id, repo_type="model")
|
||||
|
||||
def test_create_repo_http_client(self, authenticated_http_client):
|
||||
"""Test repository creation using custom HTTP client with Pydantic model."""
|
||||
unique_id = uuid.uuid4().hex[:6]
|
||||
repo_name = f"htc-{unique_id}" # htc = http-create
|
||||
|
||||
# Create repository using actual Pydantic model
|
||||
payload = CreateRepoPayload(
|
||||
type="model",
|
||||
name=repo_name,
|
||||
organization=None, # Use user's own namespace
|
||||
private=False,
|
||||
)
|
||||
|
||||
resp = authenticated_http_client.post(
|
||||
"/api/repos/create", json=payload.model_dump()
|
||||
)
|
||||
assert resp.status_code == 200, f"Create repo failed: {resp.text}"
|
||||
data = resp.json()
|
||||
assert "url" in data or "repo_id" in data
|
||||
|
||||
# Verify via GET endpoint
|
||||
resp = authenticated_http_client.get(
|
||||
f"/api/models/{config.username}/{repo_name}"
|
||||
)
|
||||
assert resp.status_code == 200, f"Get repo info failed: {resp.text}"
|
||||
|
||||
# Cleanup using actual delete model
|
||||
delete_payload = DeleteRepoPayload(
|
||||
type="model", name=repo_name, organization=None
|
||||
)
|
||||
|
||||
resp = authenticated_http_client.delete(
|
||||
"/api/repos/delete", json=delete_payload.model_dump()
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_create_duplicate_repo(self, hf_client):
|
||||
"""Test that creating duplicate repository fails."""
|
||||
unique_id = uuid.uuid4().hex[:6]
|
||||
repo_id = f"{config.username}/dup-{unique_id}"
|
||||
|
||||
# Create first time
|
||||
hf_client.create_repo(repo_id=repo_id, repo_type="model", private=False)
|
||||
|
||||
# Try to create again
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
hf_client.create_repo(repo_id=repo_id, repo_type="model", private=False)
|
||||
|
||||
# Verify it's the right error
|
||||
assert "exist" in str(exc_info.value).lower() or "400" in str(exc_info.value)
|
||||
|
||||
# Cleanup
|
||||
hf_client.delete_repo(repo_id=repo_id, repo_type="model")
|
||||
|
||||
def test_create_private_repo(self, hf_client):
|
||||
"""Test creating private repository."""
|
||||
unique_id = uuid.uuid4().hex[:6]
|
||||
repo_id = f"{config.username}/prv-{unique_id}"
|
||||
|
||||
# Create private repository
|
||||
hf_client.create_repo(repo_id=repo_id, repo_type="model", private=True)
|
||||
|
||||
# Verify it exists
|
||||
info = hf_client.repo_info(repo_id=repo_id, repo_type="model")
|
||||
assert info is not None
|
||||
|
||||
# Check if private field exists (may vary by HF client version)
|
||||
if hasattr(info, "private"):
|
||||
assert info.private == True
|
||||
|
||||
# Cleanup
|
||||
hf_client.delete_repo(repo_id=repo_id, repo_type="model")
|
||||
|
||||
def test_delete_repo_hf_client(self, hf_client):
|
||||
"""Test repository deletion using HuggingFace Hub client."""
|
||||
unique_id = uuid.uuid4().hex[:6]
|
||||
repo_id = f"{config.username}/del-{unique_id}"
|
||||
|
||||
# Create repository
|
||||
hf_client.create_repo(repo_id=repo_id, repo_type="model", private=False)
|
||||
|
||||
# Delete repository
|
||||
hf_client.delete_repo(repo_id=repo_id, repo_type="model")
|
||||
|
||||
# Verify it's deleted
|
||||
with pytest.raises(HfHubHTTPError) as exc_info:
|
||||
hf_client.repo_info(repo_id=repo_id, repo_type="model")
|
||||
assert exc_info.value.response.status_code == 404
|
||||
|
||||
def test_delete_nonexistent_repo(self, hf_client):
|
||||
"""Test deleting non-existent repository."""
|
||||
unique_id = uuid.uuid4().hex[:6]
|
||||
repo_id = f"{config.username}/nonexistent-repo-{unique_id}"
|
||||
|
||||
# Try to delete non-existent repo
|
||||
with pytest.raises(HfHubHTTPError) as exc_info:
|
||||
hf_client.delete_repo(repo_id=repo_id, repo_type="model")
|
||||
assert exc_info.value.response.status_code == 404
|
||||
|
||||
def test_list_repos(self, authenticated_http_client, hf_client):
|
||||
"""Test listing repositories."""
|
||||
# Create test repos with unique names
|
||||
unique_id = uuid.uuid4().hex[:6]
|
||||
repo_ids = [
|
||||
f"{config.username}/lst-{unique_id}-1",
|
||||
f"{config.username}/lst-{unique_id}-2",
|
||||
]
|
||||
|
||||
for repo_id in repo_ids:
|
||||
hf_client.create_repo(repo_id=repo_id, repo_type="model", private=False)
|
||||
|
||||
# List repositories
|
||||
resp = authenticated_http_client.get(
|
||||
"/api/models", params={"author": config.username, "limit": 100}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
repos = resp.json()
|
||||
assert isinstance(repos, list)
|
||||
|
||||
# Verify our repos are in the list
|
||||
repo_ids_in_list = [repo.get("id") or repo.get("repo_id") for repo in repos]
|
||||
for repo_id in repo_ids:
|
||||
assert repo_id in repo_ids_in_list, f"{repo_id} not found in repo list"
|
||||
|
||||
# Cleanup
|
||||
for repo_id in repo_ids:
|
||||
hf_client.delete_repo(repo_id=repo_id, repo_type="model")
|
||||
|
||||
def test_get_repo_info(self, hf_client):
|
||||
"""Test getting repository information."""
|
||||
unique_id = uuid.uuid4().hex[:6]
|
||||
repo_id = f"{config.username}/inf-{unique_id}"
|
||||
|
||||
# Create repository
|
||||
hf_client.create_repo(repo_id=repo_id, repo_type="model", private=False)
|
||||
|
||||
# Get repository info
|
||||
info = hf_client.repo_info(repo_id=repo_id, repo_type="model")
|
||||
assert info is not None
|
||||
|
||||
# Check basic fields
|
||||
repo_field = getattr(info, "id", getattr(info, "repo_id", None))
|
||||
assert repo_id in str(repo_field)
|
||||
|
||||
# Cleanup
|
||||
hf_client.delete_repo(repo_id=repo_id, repo_type="model")
|
||||
|
||||
def test_move_repo(self, authenticated_http_client, hf_client):
|
||||
"""Test moving/renaming repository."""
|
||||
unique_id = uuid.uuid4().hex[:6]
|
||||
old_name = f"old-{unique_id}"
|
||||
new_name = f"new-{unique_id}"
|
||||
old_repo_id = f"{config.username}/{old_name}"
|
||||
new_repo_id = f"{config.username}/{new_name}"
|
||||
|
||||
# Create repository with old name
|
||||
hf_client.create_repo(repo_id=old_repo_id, repo_type="model", private=False)
|
||||
|
||||
# Move repository using actual model
|
||||
payload = MoveRepoPayload(
|
||||
fromRepo=old_repo_id, toRepo=new_repo_id, type="model"
|
||||
)
|
||||
|
||||
resp = authenticated_http_client.post(
|
||||
"/api/repos/move", json=payload.model_dump()
|
||||
)
|
||||
assert resp.status_code == 200, f"Move repo failed: {resp.text}"
|
||||
|
||||
# Verify new name exists
|
||||
resp = authenticated_http_client.get(
|
||||
f"/api/models/{config.username}/{new_name}"
|
||||
)
|
||||
assert resp.status_code == 200, "New repo name should exist"
|
||||
|
||||
# Verify old name doesn't exist (or redirects)
|
||||
resp = authenticated_http_client.get(
|
||||
f"/api/models/{config.username}/{old_name}", allow_redirects=False
|
||||
)
|
||||
# Should be 404 or 301/302 redirect
|
||||
assert resp.status_code in [301, 302, 404]
|
||||
|
||||
# Cleanup
|
||||
hf_client.delete_repo(repo_id=new_repo_id, repo_type="model")
|
||||
|
||||
@pytest.mark.parametrize("repo_type", ["model", "dataset"])
|
||||
def test_create_different_repo_types(self, hf_client, repo_type):
|
||||
"""Test creating different repository types."""
|
||||
unique_id = uuid.uuid4().hex[:6]
|
||||
repo_id = f"{config.username}/typ-{unique_id}" # Different types
|
||||
|
||||
# Create repository
|
||||
hf_client.create_repo(repo_id=repo_id, repo_type=repo_type, private=False)
|
||||
|
||||
# Verify it exists
|
||||
info = hf_client.repo_info(repo_id=repo_id, repo_type=repo_type)
|
||||
assert info is not None
|
||||
|
||||
# Cleanup
|
||||
hf_client.delete_repo(repo_id=repo_id, repo_type=repo_type)
|
||||
|
||||
def test_create_org_repo(self, authenticated_http_client, hf_client, test_org):
|
||||
"""Test creating repository under organization."""
|
||||
unique_id = uuid.uuid4().hex[:6]
|
||||
repo_name = f"org-{unique_id}"
|
||||
|
||||
# Create repository under organization using actual model
|
||||
payload = CreateRepoPayload(
|
||||
type="model", name=repo_name, organization=test_org, private=False
|
||||
)
|
||||
|
||||
resp = authenticated_http_client.post(
|
||||
"/api/repos/create", json=payload.model_dump()
|
||||
)
|
||||
assert resp.status_code == 200, f"Create org repo failed: {resp.text}"
|
||||
|
||||
# Verify it exists under organization
|
||||
repo_id = f"{test_org}/{repo_name}"
|
||||
info = hf_client.repo_info(repo_id=repo_id, repo_type="model")
|
||||
assert info is not None
|
||||
|
||||
# Cleanup
|
||||
hf_client.delete_repo(repo_id=repo_id, repo_type="model")
|
||||
@@ -1,261 +0,0 @@
|
||||
"""Repository information and listing tests.
|
||||
|
||||
Tests repository metadata, listing, filtering, and privacy.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.base import HTTPClient
|
||||
from tests.config import config
|
||||
|
||||
|
||||
class TestRepositoryInfo:
|
||||
"""Test repository information and listing endpoints."""
|
||||
|
||||
def test_get_repo_info_hf_client(self, temp_repo):
|
||||
"""Test getting repository info using HF client."""
|
||||
repo_id, repo_type, hf_client = temp_repo
|
||||
|
||||
# Get repository info
|
||||
info = hf_client.repo_info(repo_id=repo_id, repo_type=repo_type)
|
||||
assert info is not None
|
||||
|
||||
# Check basic fields
|
||||
repo_field = getattr(info, "id", getattr(info, "repo_id", None))
|
||||
assert repo_id in str(repo_field)
|
||||
|
||||
def test_get_repo_info_http_client(self, random_user, temp_repo):
|
||||
"""Test getting repository info using HTTP client."""
|
||||
username, token, _ = random_user
|
||||
repo_id, repo_type, hf_client = temp_repo
|
||||
namespace, repo_name = repo_id.split("/")
|
||||
|
||||
# Get repository info using repo owner's token
|
||||
user_http_client = HTTPClient(token=token)
|
||||
resp = user_http_client.get(f"/api/{repo_type}s/{namespace}/{repo_name}")
|
||||
assert resp.status_code == 200, f"Get repo info failed: {resp.text}"
|
||||
|
||||
data = resp.json()
|
||||
assert isinstance(data, dict)
|
||||
# Should contain repository metadata
|
||||
|
||||
def test_list_repos_by_author(self, random_user):
|
||||
"""Test listing repositories by author."""
|
||||
username, token, hf_client = random_user
|
||||
|
||||
unique_id = uuid.uuid4().hex[:6]
|
||||
# Create test repository
|
||||
repo_id = f"{username}/lst-{unique_id}"
|
||||
hf_client.create_repo(repo_id=repo_id, repo_type="model", private=False)
|
||||
|
||||
# List repos by author
|
||||
http_client = HTTPClient(token=token)
|
||||
resp = http_client.get("/api/models", params={"author": username, "limit": 100})
|
||||
assert resp.status_code == 200
|
||||
repos = resp.json()
|
||||
assert isinstance(repos, list)
|
||||
|
||||
# Our repo should be in the list
|
||||
repo_ids = [r.get("id") or r.get("repo_id") for r in repos]
|
||||
assert repo_id in repo_ids
|
||||
|
||||
# Cleanup
|
||||
hf_client.delete_repo(repo_id=repo_id, repo_type="model")
|
||||
|
||||
def test_list_repos_with_limit(self, authenticated_http_client):
|
||||
"""Test listing repositories with limit parameter."""
|
||||
# List repos with small limit
|
||||
resp = authenticated_http_client.get("/api/models", params={"limit": 5})
|
||||
assert resp.status_code == 200
|
||||
repos = resp.json()
|
||||
assert isinstance(repos, list)
|
||||
assert len(repos) <= 5
|
||||
|
||||
def test_list_namespace_repos(self, random_user):
|
||||
username, token, hf_client = random_user
|
||||
|
||||
unique_id = uuid.uuid4().hex[:6]
|
||||
# Create test repos of different types
|
||||
model_id = f"{username}/nsm-{unique_id}" # namespace-model
|
||||
dataset_id = f"{username}/nsd-{unique_id}" # namespace-dataset
|
||||
|
||||
hf_client.create_repo(repo_id=model_id, repo_type="model", private=False)
|
||||
hf_client.create_repo(repo_id=dataset_id, repo_type="dataset", private=False)
|
||||
|
||||
# List all repos for namespace
|
||||
http_client = HTTPClient(token=token)
|
||||
resp = http_client.get(f"/api/users/{username}/repos")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert isinstance(data, dict)
|
||||
|
||||
# Should have models and datasets grouped
|
||||
if "models" in data:
|
||||
model_ids = [r.get("id") or r.get("repo_id") for r in data["models"]]
|
||||
assert model_id in model_ids
|
||||
|
||||
if "datasets" in data:
|
||||
dataset_ids = [r.get("id") or r.get("repo_id") for r in data["datasets"]]
|
||||
assert dataset_id in dataset_ids
|
||||
|
||||
# Cleanup
|
||||
hf_client.delete_repo(repo_id=model_id, repo_type="model")
|
||||
hf_client.delete_repo(repo_id=dataset_id, repo_type="dataset")
|
||||
|
||||
def test_private_repo_visibility(self, random_user):
|
||||
"""Test that private repositories are only visible to owner."""
|
||||
username, token, hf_client = random_user
|
||||
|
||||
unique_id = uuid.uuid4().hex[:6]
|
||||
# Create private repository
|
||||
repo_id = f"{username}/prv-{unique_id}"
|
||||
hf_client.create_repo(repo_id=repo_id, repo_type="model", private=True)
|
||||
|
||||
# Owner should see it
|
||||
owner_client = HTTPClient(token=token)
|
||||
resp = owner_client.get("/api/models", params={"author": username})
|
||||
assert resp.status_code == 200
|
||||
repos = resp.json()
|
||||
repo_ids = [r.get("id") or r.get("repo_id") for r in repos]
|
||||
assert repo_id in repo_ids, "Owner should see private repo"
|
||||
|
||||
# Unauthenticated user should NOT see it
|
||||
unauth_client = HTTPClient()
|
||||
resp = unauth_client.get("/api/models", params={"author": username})
|
||||
assert resp.status_code == 200
|
||||
repos = resp.json()
|
||||
repo_ids = [r.get("id") or r.get("repo_id") for r in repos]
|
||||
# Private repo should NOT be in list for unauthenticated user
|
||||
assert (
|
||||
repo_id not in repo_ids
|
||||
), "Unauthenticated user should NOT see private repo"
|
||||
|
||||
# Cleanup
|
||||
hf_client.delete_repo(repo_id=repo_id, repo_type="model")
|
||||
|
||||
def test_repo_revision_info(self, random_user, temp_repo):
|
||||
"""Test getting repository info for specific revision."""
|
||||
username, token, _ = random_user
|
||||
repo_id, repo_type, hf_client = temp_repo
|
||||
|
||||
# Upload file to create commit
|
||||
temp_file = Path(tempfile.mktemp())
|
||||
temp_file.write_bytes(b"Test content")
|
||||
|
||||
hf_client.upload_file(
|
||||
path_or_fileobj=str(temp_file),
|
||||
path_in_repo="test.txt",
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
)
|
||||
|
||||
# Get info for specific revision (main) using repo owner's token
|
||||
user_http_client = HTTPClient(token=token)
|
||||
namespace, repo_name = repo_id.split("/")
|
||||
resp = user_http_client.get(
|
||||
f"/api/{repo_type}s/{namespace}/{repo_name}/revision/main"
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert isinstance(data, dict)
|
||||
|
||||
# Cleanup
|
||||
temp_file.unlink(missing_ok=True)
|
||||
|
||||
@pytest.mark.skip(reason="Form(list[str]) encoding not yet working - returns 422")
|
||||
def test_repo_paths_info(self, random_user, temp_repo):
|
||||
"""Test getting info for specific paths in repository.
|
||||
|
||||
SKIPPED: FastAPI Form(list[str]) requires special encoding that current
|
||||
requests library usage doesn't handle correctly.
|
||||
|
||||
Error: {"detail":[{"type":"missing","loc":["body","paths"],"msg":"Field required"}]}
|
||||
|
||||
TODO: Need to determine correct multipart/form-data encoding for list[str].
|
||||
Endpoint signature: paths: list[str] = Form(...), expand: bool = Form(False)
|
||||
"""
|
||||
pass
|
||||
|
||||
def test_nonexistent_repo_info(self, authenticated_http_client):
|
||||
"""Test getting info for non-existent repository."""
|
||||
namespace = config.username
|
||||
repo_name = "nonexistent-repo-xyz"
|
||||
|
||||
resp = authenticated_http_client.get(f"/api/models/{namespace}/{repo_name}")
|
||||
assert resp.status_code == 404
|
||||
|
||||
# Check for HF error headers
|
||||
error_code = resp.headers.get("X-Error-Code")
|
||||
if error_code:
|
||||
assert error_code == "RepoNotFound"
|
||||
|
||||
def test_list_repo_files(self, temp_repo):
|
||||
"""Test listing files in repository."""
|
||||
repo_id, repo_type, hf_client = temp_repo
|
||||
|
||||
# Upload some files
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
(temp_dir / "README.md").write_bytes(b"# Test Repo")
|
||||
(temp_dir / "config.json").write_bytes(b'{"key": "value"}')
|
||||
(temp_dir / "data").mkdir()
|
||||
(temp_dir / "data" / "file.txt").write_bytes(b"Data file")
|
||||
|
||||
hf_client.upload_folder(
|
||||
folder_path=str(temp_dir),
|
||||
path_in_repo="",
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
)
|
||||
|
||||
# List files
|
||||
files = hf_client.list_repo_files(repo_id=repo_id, repo_type=repo_type)
|
||||
assert isinstance(files, list)
|
||||
assert "README.md" in files
|
||||
assert "config.json" in files
|
||||
assert "data/file.txt" in files
|
||||
|
||||
# Cleanup
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
def test_tree_recursive_listing(self, random_user, temp_repo):
|
||||
"""Test recursive tree listing."""
|
||||
username, token, _ = random_user
|
||||
repo_id, repo_type, hf_client = temp_repo
|
||||
|
||||
# Upload nested structure
|
||||
temp_dir = Path(tempfile.mkdtemp())
|
||||
(temp_dir / "level1").mkdir()
|
||||
(temp_dir / "level1" / "file1.txt").write_bytes(b"File 1")
|
||||
(temp_dir / "level1" / "level2").mkdir()
|
||||
(temp_dir / "level1" / "level2" / "file2.txt").write_bytes(b"File 2")
|
||||
|
||||
hf_client.upload_folder(
|
||||
folder_path=str(temp_dir),
|
||||
path_in_repo="",
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
)
|
||||
|
||||
# Query tree with recursive=true using repo owner's token
|
||||
user_http_client = HTTPClient(token=token)
|
||||
namespace, repo_name = repo_id.split("/")
|
||||
resp = user_http_client.get(
|
||||
f"/api/{repo_type}s/{namespace}/{repo_name}/tree/main/",
|
||||
params={"recursive": "true"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
tree_data = resp.json()
|
||||
assert isinstance(tree_data, list)
|
||||
|
||||
# Should include nested files
|
||||
paths = [item["path"] for item in tree_data]
|
||||
assert any("level1/file1.txt" in p for p in paths)
|
||||
assert any("level1/level2/file2.txt" in p for p in paths)
|
||||
|
||||
# Cleanup
|
||||
shutil.rmtree(temp_dir)
|
||||
Reference in New Issue
Block a user