mirror of
https://github.com/KohakuBlueleaf/KohakuHub.git
synced 2026-03-11 17:34:08 -05:00
allow datasetviewer to pass user auth to backend resolve
This commit is contained in:
@@ -12,53 +12,64 @@ import duckdb
|
||||
import httpx
|
||||
from fsspec.implementations.http import HTTPFileSystem
|
||||
|
||||
from kohakuhub.config import cfg
|
||||
from kohakuhub.datasetviewer.logger import get_logger
|
||||
|
||||
logger = get_logger("Parser")
|
||||
|
||||
|
||||
async def resolve_url_redirects(url: str) -> str:
|
||||
async def resolve_url_redirects(url: str, auth_headers: dict[str, str] = None) -> str:
|
||||
"""
|
||||
Resolve URL redirects by following 302 responses.
|
||||
Resolve URL redirects by following 302 responses with authentication.
|
||||
|
||||
For resolve URLs that return 302, we need to follow the redirect
|
||||
For /resolve URLs that return 302, we need to follow the redirect
|
||||
and use the final S3 presigned URL for fsspec/DuckDB.
|
||||
Otherwise, they keep hitting our backend for every range request.
|
||||
|
||||
Uses GET with streaming (no redirect following) to detect 302 responses
|
||||
without actually downloading content.
|
||||
Uses GET with manual redirect handling to get Location header
|
||||
without downloading file content.
|
||||
|
||||
Args:
|
||||
url: Original URL (may return 302)
|
||||
url: Original URL (e.g., /datasets/.../resolve/main/file.csv or S3 URL)
|
||||
auth_headers: Optional auth headers (Authorization, Cookie) from user request
|
||||
|
||||
Returns:
|
||||
Final URL after following redirects (or original if no redirect)
|
||||
Final S3 presigned URL after following redirects (or original if external URL)
|
||||
"""
|
||||
# If already an S3 URL (starts with http:// or https:// with s3/amazonaws), use as-is
|
||||
if url.startswith("http://") or url.startswith("https://"):
|
||||
# External URL, don't try to resolve
|
||||
return url
|
||||
|
||||
# Internal /resolve URL - make authenticated request to get S3 URL
|
||||
try:
|
||||
# Build full URL (relative path to absolute)
|
||||
full_url = f"{cfg.app.base_url}{url}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=10.0, follow_redirects=False) as client:
|
||||
# Send GET request with streaming, but don't follow redirects
|
||||
# This gives us the correct 302 response that GET would return
|
||||
async with client.stream("GET", url) as response:
|
||||
# Send GET request with auth headers, don't follow redirects
|
||||
headers = auth_headers or {}
|
||||
async with client.stream("GET", full_url, headers=headers) as response:
|
||||
# Check for redirect status codes
|
||||
if response.status_code in [301, 302, 303, 307, 308]:
|
||||
location = response.headers.get("Location")
|
||||
if location:
|
||||
logger.debug(
|
||||
f"Resolved redirect: {url[:50]}... -> {location[:50]}..."
|
||||
f"Resolved internal URL: {url[:50]}... -> S3 presigned URL"
|
||||
)
|
||||
# Close stream immediately without reading content
|
||||
await response.aclose()
|
||||
return location
|
||||
|
||||
# For other status codes (200, 4xx, 5xx), use original URL
|
||||
# Close stream without reading content
|
||||
# For other status codes, close and return original
|
||||
await response.aclose()
|
||||
logger.warning(
|
||||
f"Expected redirect for {url}, got {response.status_code}"
|
||||
)
|
||||
return url
|
||||
|
||||
except Exception as e:
|
||||
# If request fails, fall back to original URL
|
||||
# fsspec will handle the actual request
|
||||
logger.warning(f"Could not resolve redirects for {url[:50]}...: {e}")
|
||||
logger.error(f"Could not resolve internal URL {url[:50]}...: {e}")
|
||||
return url
|
||||
|
||||
|
||||
@@ -111,21 +122,25 @@ class CSVParser:
|
||||
|
||||
@staticmethod
|
||||
async def parse(
|
||||
url: str, max_rows: int = 1000, delimiter: str = ","
|
||||
url: str,
|
||||
max_rows: int = 1000,
|
||||
delimiter: str = ",",
|
||||
auth_headers: dict[str, str] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Parse CSV file from URL using DuckDB.
|
||||
|
||||
Args:
|
||||
url: File URL (presigned S3 URL)
|
||||
url: File URL (internal /resolve path or S3 presigned URL)
|
||||
max_rows: Maximum rows to return
|
||||
delimiter: CSV delimiter
|
||||
auth_headers: Optional auth headers for internal /resolve URLs
|
||||
|
||||
Returns:
|
||||
Dict with columns, rows, total_rows, truncated, file_size
|
||||
"""
|
||||
# Resolve redirects first
|
||||
resolved_url = await resolve_url_redirects(url)
|
||||
# Resolve redirects first (handles internal /resolve URLs)
|
||||
resolved_url = await resolve_url_redirects(url, auth_headers)
|
||||
|
||||
# Run in thread pool to avoid blocking event loop
|
||||
return await asyncio.to_thread(
|
||||
@@ -169,19 +184,22 @@ class JSONLParser:
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def parse(url: str, max_rows: int = 1000) -> dict[str, Any]:
|
||||
async def parse(
|
||||
url: str, max_rows: int = 1000, auth_headers: dict[str, str] = None
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Parse JSONL file from URL using DuckDB.
|
||||
|
||||
Args:
|
||||
url: File URL
|
||||
url: File URL (internal /resolve path or S3 presigned URL)
|
||||
max_rows: Maximum rows to return
|
||||
auth_headers: Optional auth headers for internal /resolve URLs
|
||||
|
||||
Returns:
|
||||
Dict with columns, rows, total_rows, truncated, file_size
|
||||
"""
|
||||
# Resolve redirects first
|
||||
resolved_url = await resolve_url_redirects(url)
|
||||
# Resolve redirects first (handles internal /resolve URLs)
|
||||
resolved_url = await resolve_url_redirects(url, auth_headers)
|
||||
|
||||
# Run in thread pool to avoid blocking event loop
|
||||
return await asyncio.to_thread(JSONLParser._parse_sync, resolved_url, max_rows)
|
||||
@@ -223,19 +241,22 @@ class JSONParser:
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def parse(url: str, max_rows: int = 1000) -> dict[str, Any]:
|
||||
async def parse(
|
||||
url: str, max_rows: int = 1000, auth_headers: dict[str, str] = None
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Parse JSON array file from URL using DuckDB.
|
||||
|
||||
Args:
|
||||
url: File URL
|
||||
url: File URL (internal /resolve path or S3 presigned URL)
|
||||
max_rows: Maximum rows to return
|
||||
auth_headers: Optional auth headers for internal /resolve URLs
|
||||
|
||||
Returns:
|
||||
Dict with columns, rows, total_rows, truncated, file_size
|
||||
"""
|
||||
# Resolve redirects first
|
||||
resolved_url = await resolve_url_redirects(url)
|
||||
# Resolve redirects first (handles internal /resolve URLs)
|
||||
resolved_url = await resolve_url_redirects(url, auth_headers)
|
||||
|
||||
# Run in thread pool to avoid blocking event loop
|
||||
return await asyncio.to_thread(JSONParser._parse_sync, resolved_url, max_rows)
|
||||
@@ -275,20 +296,23 @@ class ParquetParser:
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def parse(url: str, max_rows: int = 1000) -> dict[str, Any]:
|
||||
async def parse(
|
||||
url: str, max_rows: int = 1000, auth_headers: dict[str, str] = None
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Parse Parquet file from URL using DuckDB.
|
||||
|
||||
Args:
|
||||
url: File URL (HTTP/HTTPS, including S3 presigned URLs)
|
||||
url: File URL (internal /resolve path or S3 presigned URL)
|
||||
max_rows: Maximum rows to return
|
||||
auth_headers: Optional auth headers for internal /resolve URLs
|
||||
|
||||
Returns:
|
||||
Dict with columns, rows, total_rows, truncated, file_size
|
||||
"""
|
||||
# Resolve redirects first (302 from resolve endpoint -> S3 presigned URL)
|
||||
# Resolve redirects first (handles internal /resolve URLs)
|
||||
# This prevents DuckDB from repeatedly hitting our backend
|
||||
resolved_url = await resolve_url_redirects(url)
|
||||
resolved_url = await resolve_url_redirects(url, auth_headers)
|
||||
|
||||
# Run in thread pool to avoid blocking event loop
|
||||
return await asyncio.to_thread(
|
||||
@@ -375,19 +399,22 @@ class TARParser:
|
||||
raise ParserError(f"File not found in archive: {file_name}")
|
||||
|
||||
@staticmethod
|
||||
async def extract_file(url: str, file_name: str) -> bytes:
|
||||
async def extract_file(
|
||||
url: str, file_name: str, auth_headers: dict[str, str] = None
|
||||
) -> bytes:
|
||||
"""
|
||||
Extract single file from TAR archive using streaming.
|
||||
|
||||
Args:
|
||||
url: TAR file URL
|
||||
url: TAR file URL (internal /resolve path or S3 presigned URL)
|
||||
file_name: Name of file to extract
|
||||
auth_headers: Optional auth headers for internal /resolve URLs
|
||||
|
||||
Returns:
|
||||
File content as bytes
|
||||
"""
|
||||
# Resolve redirects first
|
||||
resolved_url = await resolve_url_redirects(url)
|
||||
# Resolve redirects first (handles internal /resolve URLs)
|
||||
resolved_url = await resolve_url_redirects(url, auth_headers)
|
||||
|
||||
# Run in thread pool to avoid blocking
|
||||
return await asyncio.to_thread(
|
||||
|
||||
@@ -6,7 +6,7 @@ Minimal, auth-free endpoints for previewing dataset files.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi import APIRouter, Cookie, Depends, Header, HTTPException, Query, Request
|
||||
from pydantic import BaseModel, HttpUrl
|
||||
|
||||
from kohakuhub.datasetviewer.logger import get_logger
|
||||
@@ -32,6 +32,24 @@ logger = get_logger("Router")
|
||||
router = APIRouter(prefix="/dataset-viewer", tags=["Dataset Viewer"])
|
||||
|
||||
|
||||
def get_auth_headers(
|
||||
session_id: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> dict[str, str]:
|
||||
"""Extract authentication headers from request for internal /resolve requests."""
|
||||
headers = {}
|
||||
|
||||
# Pass Authorization header if present
|
||||
if authorization:
|
||||
headers["Authorization"] = authorization
|
||||
|
||||
# Pass session cookie if present
|
||||
if session_id:
|
||||
headers["Cookie"] = f"session_id={session_id}"
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
class PreviewRequest(BaseModel):
|
||||
"""Request to preview a file."""
|
||||
|
||||
@@ -47,7 +65,7 @@ class SQLQueryRequest(BaseModel):
|
||||
url: HttpUrl # S3 presigned URL or any HTTP(S) URL
|
||||
query: str # SQL query to execute
|
||||
format: Optional[str] = None # Auto-detect if not provided
|
||||
max_rows: int = 10000 # Safety limit
|
||||
max_rows: int = 10 # Default: 10 rows (datasets can have many columns)
|
||||
|
||||
|
||||
class PreviewResponse(BaseModel):
|
||||
@@ -97,6 +115,7 @@ async def preview_file(
|
||||
request: Request,
|
||||
req: PreviewRequest,
|
||||
identifier: str = Depends(check_rate_limit_dependency),
|
||||
auth_headers: dict[str, str] = Depends(get_auth_headers),
|
||||
):
|
||||
"""
|
||||
Preview a dataset file from URL.
|
||||
@@ -138,15 +157,17 @@ async def preview_file(
|
||||
f"Previewing {file_format} file: {url_str[:80]}... (max_rows={req.max_rows})"
|
||||
)
|
||||
|
||||
# Parse file
|
||||
# Parse file (pass auth headers for internal /resolve URLs)
|
||||
try:
|
||||
if file_format == "csv" or file_format == "tsv":
|
||||
delimiter = "\t" if file_format == "tsv" else req.delimiter
|
||||
result = await CSVParser.parse(url_str, req.max_rows, delimiter)
|
||||
result = await CSVParser.parse(
|
||||
url_str, req.max_rows, delimiter, auth_headers
|
||||
)
|
||||
elif file_format == "jsonl":
|
||||
result = await JSONLParser.parse(url_str, req.max_rows)
|
||||
result = await JSONLParser.parse(url_str, req.max_rows, auth_headers)
|
||||
elif file_format == "parquet":
|
||||
result = await ParquetParser.parse(url_str, req.max_rows)
|
||||
result = await ParquetParser.parse(url_str, req.max_rows, auth_headers)
|
||||
else:
|
||||
raise HTTPException(
|
||||
400,
|
||||
@@ -174,6 +195,7 @@ async def list_tar_files(
|
||||
request: Request,
|
||||
req: TARListRequest,
|
||||
identifier: str = Depends(check_rate_limit_dependency),
|
||||
auth_headers: dict[str, str] = Depends(get_auth_headers),
|
||||
):
|
||||
"""
|
||||
List files in a TAR archive.
|
||||
@@ -188,7 +210,7 @@ async def list_tar_files(
|
||||
|
||||
try:
|
||||
# Use streaming parser (doesn't load full TAR into memory!)
|
||||
result = await TARStreamParser.list_files_streaming(str(req.url))
|
||||
result = await TARStreamParser.list_files_streaming(str(req.url), auth_headers)
|
||||
limiter.finish_request(identifier, 0) # Only reads headers, minimal data
|
||||
return result
|
||||
except Exception as e:
|
||||
@@ -201,6 +223,7 @@ async def extract_tar_file(
|
||||
request: Request,
|
||||
req: TARExtractRequest,
|
||||
identifier: str = Depends(check_rate_limit_dependency),
|
||||
auth_headers: dict[str, str] = Depends(get_auth_headers),
|
||||
):
|
||||
"""
|
||||
Extract a single file from TAR archive.
|
||||
@@ -216,7 +239,9 @@ async def extract_tar_file(
|
||||
limiter = get_rate_limiter()
|
||||
|
||||
try:
|
||||
content = await TARParser.extract_file(str(req.url), req.file_name)
|
||||
content = await TARParser.extract_file(
|
||||
str(req.url), req.file_name, auth_headers
|
||||
)
|
||||
limiter.finish_request(identifier, len(content))
|
||||
|
||||
# Return raw bytes
|
||||
@@ -235,6 +260,7 @@ async def preview_webdataset_tar(
|
||||
req: TARListRequest,
|
||||
max_samples: int = Query(100, description="Max samples to preview"),
|
||||
identifier: str = Depends(check_rate_limit_dependency),
|
||||
auth_headers: dict[str, str] = Depends(get_auth_headers),
|
||||
):
|
||||
"""
|
||||
Preview TAR file in webdataset format.
|
||||
@@ -252,7 +278,9 @@ async def preview_webdataset_tar(
|
||||
limiter = get_rate_limiter()
|
||||
|
||||
try:
|
||||
result = await WebDatasetTARParser.parse_streaming(str(req.url), max_samples)
|
||||
result = await WebDatasetTARParser.parse_streaming(
|
||||
str(req.url), max_samples, auth_headers
|
||||
)
|
||||
limiter.finish_request(identifier, 0)
|
||||
return result
|
||||
except Exception as e:
|
||||
@@ -278,6 +306,7 @@ async def execute_sql(
|
||||
request: Request,
|
||||
req: SQLQueryRequest,
|
||||
identifier: str = Depends(check_rate_limit_dependency),
|
||||
auth_headers: dict[str, str] = Depends(get_auth_headers),
|
||||
):
|
||||
"""
|
||||
Execute SQL query on dataset file using DuckDB.
|
||||
@@ -326,7 +355,9 @@ async def execute_sql(
|
||||
)
|
||||
|
||||
try:
|
||||
result = await execute_sql_query(url_str, req.query, file_format, req.max_rows)
|
||||
result = await execute_sql_query(
|
||||
url_str, req.query, file_format, req.max_rows, auth_headers
|
||||
)
|
||||
|
||||
limiter.finish_request(identifier, 0)
|
||||
|
||||
|
||||
@@ -112,7 +112,11 @@ def _execute_query_sync(url: str, query: str, file_format: str, max_rows: int):
|
||||
|
||||
|
||||
async def execute_sql_query(
|
||||
url: str, query: str, file_format: str = "parquet", max_rows: int = 1000
|
||||
url: str,
|
||||
query: str,
|
||||
file_format: str = "parquet",
|
||||
max_rows: int = 10,
|
||||
auth_headers: dict[str, str] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Execute SQL query on remote dataset using DuckDB.
|
||||
@@ -121,10 +125,11 @@ async def execute_sql_query(
|
||||
so it doesn't download the entire file!
|
||||
|
||||
Args:
|
||||
url: File URL (HTTP/HTTPS, including S3 presigned URLs)
|
||||
url: File URL (internal /resolve path or S3 presigned URL)
|
||||
query: SQL query to execute
|
||||
file_format: File format (csv, parquet, jsonl, json)
|
||||
max_rows: Maximum rows to return (safety limit)
|
||||
max_rows: Maximum rows to return (default: 10 for wide tables)
|
||||
auth_headers: Optional auth headers for internal /resolve URLs
|
||||
|
||||
Returns:
|
||||
{
|
||||
@@ -136,13 +141,13 @@ async def execute_sql_query(
|
||||
}
|
||||
|
||||
Example queries:
|
||||
SELECT * FROM dataset LIMIT 100
|
||||
SELECT * FROM dataset LIMIT 10
|
||||
SELECT age, COUNT(*) as count FROM dataset GROUP BY age
|
||||
SELECT * FROM dataset WHERE salary > 100000 ORDER BY salary DESC
|
||||
"""
|
||||
# Resolve redirects first (302 from resolve endpoint -> S3 presigned URL)
|
||||
# Resolve redirects first (handles internal /resolve URLs with auth)
|
||||
# This prevents DuckDB from repeatedly hitting our backend for range requests
|
||||
resolved_url = await resolve_url_redirects(url)
|
||||
resolved_url = await resolve_url_redirects(url, auth_headers)
|
||||
|
||||
# Run in thread pool (DuckDB is synchronous)
|
||||
# Use resolved_url (not original url) to avoid repeated backend hits
|
||||
|
||||
@@ -14,6 +14,8 @@ from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from kohakuhub.datasetviewer.parsers import resolve_url_redirects
|
||||
|
||||
|
||||
class ParquetStreamParser:
|
||||
"""
|
||||
@@ -92,7 +94,9 @@ class WebDatasetTARParser:
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def parse_streaming(url: str, max_samples: int = 100) -> dict[str, Any]:
|
||||
async def parse_streaming(
|
||||
url: str, max_samples: int = 100, auth_headers: dict[str, str] = None
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Parse webdataset TAR file using streaming.
|
||||
|
||||
@@ -100,8 +104,9 @@ class WebDatasetTARParser:
|
||||
Does NOT load file content - only headers!
|
||||
|
||||
Args:
|
||||
url: TAR file URL
|
||||
url: TAR file URL (internal /resolve path or S3 URL)
|
||||
max_samples: Maximum number of samples (IDs) to collect
|
||||
auth_headers: Optional auth headers for internal requests
|
||||
|
||||
Returns:
|
||||
{
|
||||
@@ -114,8 +119,11 @@ class WebDatasetTARParser:
|
||||
"truncated": bool
|
||||
}
|
||||
"""
|
||||
# Resolve URL first if it's an internal path
|
||||
resolved_url = await resolve_url_redirects(url, auth_headers)
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0, follow_redirects=True) as client:
|
||||
async with client.stream("GET", url) as response:
|
||||
async with client.stream("GET", resolved_url) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
# Stream TAR file
|
||||
@@ -228,13 +236,19 @@ class TARStreamParser:
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def list_files_streaming(url: str) -> dict[str, Any]:
|
||||
async def list_files_streaming(
|
||||
url: str, auth_headers: dict[str, str] = None
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
List files in TAR by streaming headers only.
|
||||
|
||||
This is MUCH more efficient than current implementation
|
||||
which loads entire TAR into memory!
|
||||
|
||||
Args:
|
||||
url: TAR file URL (internal /resolve path or S3 URL)
|
||||
auth_headers: Optional auth headers for internal requests
|
||||
|
||||
Returns:
|
||||
{
|
||||
"files": [
|
||||
@@ -244,11 +258,14 @@ class TARStreamParser:
|
||||
"total_size": N # Total TAR size
|
||||
}
|
||||
"""
|
||||
# Resolve URL first if it's an internal path
|
||||
resolved_url = await resolve_url_redirects(url, auth_headers)
|
||||
|
||||
files = []
|
||||
offset = 0
|
||||
|
||||
async with httpx.AsyncClient(timeout=60.0, follow_redirects=True) as client:
|
||||
async with client.stream("GET", url) as response:
|
||||
async with client.stream("GET", resolved_url) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
|
||||
Reference in New Issue
Block a user