allow datasetviewer to pass user auth to backend resolve

This commit is contained in:
Kohaku-Blueleaf
2025-10-24 23:55:42 +08:00
parent 349f939ece
commit 5bfaee0ae1
4 changed files with 137 additions and 57 deletions

View File

@@ -12,53 +12,64 @@ import duckdb
import httpx
from fsspec.implementations.http import HTTPFileSystem
from kohakuhub.config import cfg
from kohakuhub.datasetviewer.logger import get_logger
logger = get_logger("Parser")
async def resolve_url_redirects(url: str) -> str:
async def resolve_url_redirects(url: str, auth_headers: dict[str, str] = None) -> str:
"""
Resolve URL redirects by following 302 responses.
Resolve URL redirects by following 302 responses with authentication.
For resolve URLs that return 302, we need to follow the redirect
For /resolve URLs that return 302, we need to follow the redirect
and use the final S3 presigned URL for fsspec/DuckDB.
Otherwise, they keep hitting our backend for every range request.
Uses GET with streaming (no redirect following) to detect 302 responses
without actually downloading content.
Uses GET with manual redirect handling to get Location header
without downloading file content.
Args:
url: Original URL (may return 302)
url: Original URL (e.g., /datasets/.../resolve/main/file.csv or S3 URL)
auth_headers: Optional auth headers (Authorization, Cookie) from user request
Returns:
Final URL after following redirects (or original if no redirect)
Final S3 presigned URL after following redirects (or original if external URL)
"""
# If already an S3 URL (starts with http:// or https:// with s3/amazonaws), use as-is
if url.startswith("http://") or url.startswith("https://"):
# External URL, don't try to resolve
return url
# Internal /resolve URL - make authenticated request to get S3 URL
try:
# Build full URL (relative path to absolute)
full_url = f"{cfg.app.base_url}{url}"
async with httpx.AsyncClient(timeout=10.0, follow_redirects=False) as client:
# Send GET request with streaming, but don't follow redirects
# This gives us the correct 302 response that GET would return
async with client.stream("GET", url) as response:
# Send GET request with auth headers, don't follow redirects
headers = auth_headers or {}
async with client.stream("GET", full_url, headers=headers) as response:
# Check for redirect status codes
if response.status_code in [301, 302, 303, 307, 308]:
location = response.headers.get("Location")
if location:
logger.debug(
f"Resolved redirect: {url[:50]}... -> {location[:50]}..."
f"Resolved internal URL: {url[:50]}... -> S3 presigned URL"
)
# Close stream immediately without reading content
await response.aclose()
return location
# For other status codes (200, 4xx, 5xx), use original URL
# Close stream without reading content
# For other status codes, close and return original
await response.aclose()
logger.warning(
f"Expected redirect for {url}, got {response.status_code}"
)
return url
except Exception as e:
# If request fails, fall back to original URL
# fsspec will handle the actual request
logger.warning(f"Could not resolve redirects for {url[:50]}...: {e}")
logger.error(f"Could not resolve internal URL {url[:50]}...: {e}")
return url
@@ -111,21 +122,25 @@ class CSVParser:
@staticmethod
async def parse(
url: str, max_rows: int = 1000, delimiter: str = ","
url: str,
max_rows: int = 1000,
delimiter: str = ",",
auth_headers: dict[str, str] = None,
) -> dict[str, Any]:
"""
Parse CSV file from URL using DuckDB.
Args:
url: File URL (presigned S3 URL)
url: File URL (internal /resolve path or S3 presigned URL)
max_rows: Maximum rows to return
delimiter: CSV delimiter
auth_headers: Optional auth headers for internal /resolve URLs
Returns:
Dict with columns, rows, total_rows, truncated, file_size
"""
# Resolve redirects first
resolved_url = await resolve_url_redirects(url)
# Resolve redirects first (handles internal /resolve URLs)
resolved_url = await resolve_url_redirects(url, auth_headers)
# Run in thread pool to avoid blocking event loop
return await asyncio.to_thread(
@@ -169,19 +184,22 @@ class JSONLParser:
}
@staticmethod
async def parse(url: str, max_rows: int = 1000) -> dict[str, Any]:
async def parse(
url: str, max_rows: int = 1000, auth_headers: dict[str, str] = None
) -> dict[str, Any]:
"""
Parse JSONL file from URL using DuckDB.
Args:
url: File URL
url: File URL (internal /resolve path or S3 presigned URL)
max_rows: Maximum rows to return
auth_headers: Optional auth headers for internal /resolve URLs
Returns:
Dict with columns, rows, total_rows, truncated, file_size
"""
# Resolve redirects first
resolved_url = await resolve_url_redirects(url)
# Resolve redirects first (handles internal /resolve URLs)
resolved_url = await resolve_url_redirects(url, auth_headers)
# Run in thread pool to avoid blocking event loop
return await asyncio.to_thread(JSONLParser._parse_sync, resolved_url, max_rows)
@@ -223,19 +241,22 @@ class JSONParser:
}
@staticmethod
async def parse(url: str, max_rows: int = 1000) -> dict[str, Any]:
async def parse(
url: str, max_rows: int = 1000, auth_headers: dict[str, str] = None
) -> dict[str, Any]:
"""
Parse JSON array file from URL using DuckDB.
Args:
url: File URL
url: File URL (internal /resolve path or S3 presigned URL)
max_rows: Maximum rows to return
auth_headers: Optional auth headers for internal /resolve URLs
Returns:
Dict with columns, rows, total_rows, truncated, file_size
"""
# Resolve redirects first
resolved_url = await resolve_url_redirects(url)
# Resolve redirects first (handles internal /resolve URLs)
resolved_url = await resolve_url_redirects(url, auth_headers)
# Run in thread pool to avoid blocking event loop
return await asyncio.to_thread(JSONParser._parse_sync, resolved_url, max_rows)
@@ -275,20 +296,23 @@ class ParquetParser:
}
@staticmethod
async def parse(url: str, max_rows: int = 1000) -> dict[str, Any]:
async def parse(
url: str, max_rows: int = 1000, auth_headers: dict[str, str] = None
) -> dict[str, Any]:
"""
Parse Parquet file from URL using DuckDB.
Args:
url: File URL (HTTP/HTTPS, including S3 presigned URLs)
url: File URL (internal /resolve path or S3 presigned URL)
max_rows: Maximum rows to return
auth_headers: Optional auth headers for internal /resolve URLs
Returns:
Dict with columns, rows, total_rows, truncated, file_size
"""
# Resolve redirects first (302 from resolve endpoint -> S3 presigned URL)
# Resolve redirects first (handles internal /resolve URLs)
# This prevents DuckDB from repeatedly hitting our backend
resolved_url = await resolve_url_redirects(url)
resolved_url = await resolve_url_redirects(url, auth_headers)
# Run in thread pool to avoid blocking event loop
return await asyncio.to_thread(
@@ -375,19 +399,22 @@ class TARParser:
raise ParserError(f"File not found in archive: {file_name}")
@staticmethod
async def extract_file(url: str, file_name: str) -> bytes:
async def extract_file(
url: str, file_name: str, auth_headers: dict[str, str] = None
) -> bytes:
"""
Extract single file from TAR archive using streaming.
Args:
url: TAR file URL
url: TAR file URL (internal /resolve path or S3 presigned URL)
file_name: Name of file to extract
auth_headers: Optional auth headers for internal /resolve URLs
Returns:
File content as bytes
"""
# Resolve redirects first
resolved_url = await resolve_url_redirects(url)
# Resolve redirects first (handles internal /resolve URLs)
resolved_url = await resolve_url_redirects(url, auth_headers)
# Run in thread pool to avoid blocking
return await asyncio.to_thread(

View File

@@ -6,7 +6,7 @@ Minimal, auth-free endpoints for previewing dataset files.
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi import APIRouter, Cookie, Depends, Header, HTTPException, Query, Request
from pydantic import BaseModel, HttpUrl
from kohakuhub.datasetviewer.logger import get_logger
@@ -32,6 +32,24 @@ logger = get_logger("Router")
router = APIRouter(prefix="/dataset-viewer", tags=["Dataset Viewer"])
def get_auth_headers(
session_id: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> dict[str, str]:
"""Extract authentication headers from request for internal /resolve requests."""
headers = {}
# Pass Authorization header if present
if authorization:
headers["Authorization"] = authorization
# Pass session cookie if present
if session_id:
headers["Cookie"] = f"session_id={session_id}"
return headers
class PreviewRequest(BaseModel):
"""Request to preview a file."""
@@ -47,7 +65,7 @@ class SQLQueryRequest(BaseModel):
url: HttpUrl # S3 presigned URL or any HTTP(S) URL
query: str # SQL query to execute
format: Optional[str] = None # Auto-detect if not provided
max_rows: int = 10000 # Safety limit
max_rows: int = 10 # Default: 10 rows (datasets can have many columns)
class PreviewResponse(BaseModel):
@@ -97,6 +115,7 @@ async def preview_file(
request: Request,
req: PreviewRequest,
identifier: str = Depends(check_rate_limit_dependency),
auth_headers: dict[str, str] = Depends(get_auth_headers),
):
"""
Preview a dataset file from URL.
@@ -138,15 +157,17 @@ async def preview_file(
f"Previewing {file_format} file: {url_str[:80]}... (max_rows={req.max_rows})"
)
# Parse file
# Parse file (pass auth headers for internal /resolve URLs)
try:
if file_format == "csv" or file_format == "tsv":
delimiter = "\t" if file_format == "tsv" else req.delimiter
result = await CSVParser.parse(url_str, req.max_rows, delimiter)
result = await CSVParser.parse(
url_str, req.max_rows, delimiter, auth_headers
)
elif file_format == "jsonl":
result = await JSONLParser.parse(url_str, req.max_rows)
result = await JSONLParser.parse(url_str, req.max_rows, auth_headers)
elif file_format == "parquet":
result = await ParquetParser.parse(url_str, req.max_rows)
result = await ParquetParser.parse(url_str, req.max_rows, auth_headers)
else:
raise HTTPException(
400,
@@ -174,6 +195,7 @@ async def list_tar_files(
request: Request,
req: TARListRequest,
identifier: str = Depends(check_rate_limit_dependency),
auth_headers: dict[str, str] = Depends(get_auth_headers),
):
"""
List files in a TAR archive.
@@ -188,7 +210,7 @@ async def list_tar_files(
try:
# Use streaming parser (doesn't load full TAR into memory!)
result = await TARStreamParser.list_files_streaming(str(req.url))
result = await TARStreamParser.list_files_streaming(str(req.url), auth_headers)
limiter.finish_request(identifier, 0) # Only reads headers, minimal data
return result
except Exception as e:
@@ -201,6 +223,7 @@ async def extract_tar_file(
request: Request,
req: TARExtractRequest,
identifier: str = Depends(check_rate_limit_dependency),
auth_headers: dict[str, str] = Depends(get_auth_headers),
):
"""
Extract a single file from TAR archive.
@@ -216,7 +239,9 @@ async def extract_tar_file(
limiter = get_rate_limiter()
try:
content = await TARParser.extract_file(str(req.url), req.file_name)
content = await TARParser.extract_file(
str(req.url), req.file_name, auth_headers
)
limiter.finish_request(identifier, len(content))
# Return raw bytes
@@ -235,6 +260,7 @@ async def preview_webdataset_tar(
req: TARListRequest,
max_samples: int = Query(100, description="Max samples to preview"),
identifier: str = Depends(check_rate_limit_dependency),
auth_headers: dict[str, str] = Depends(get_auth_headers),
):
"""
Preview TAR file in webdataset format.
@@ -252,7 +278,9 @@ async def preview_webdataset_tar(
limiter = get_rate_limiter()
try:
result = await WebDatasetTARParser.parse_streaming(str(req.url), max_samples)
result = await WebDatasetTARParser.parse_streaming(
str(req.url), max_samples, auth_headers
)
limiter.finish_request(identifier, 0)
return result
except Exception as e:
@@ -278,6 +306,7 @@ async def execute_sql(
request: Request,
req: SQLQueryRequest,
identifier: str = Depends(check_rate_limit_dependency),
auth_headers: dict[str, str] = Depends(get_auth_headers),
):
"""
Execute SQL query on dataset file using DuckDB.
@@ -326,7 +355,9 @@ async def execute_sql(
)
try:
result = await execute_sql_query(url_str, req.query, file_format, req.max_rows)
result = await execute_sql_query(
url_str, req.query, file_format, req.max_rows, auth_headers
)
limiter.finish_request(identifier, 0)

View File

@@ -112,7 +112,11 @@ def _execute_query_sync(url: str, query: str, file_format: str, max_rows: int):
async def execute_sql_query(
url: str, query: str, file_format: str = "parquet", max_rows: int = 1000
url: str,
query: str,
file_format: str = "parquet",
max_rows: int = 10,
auth_headers: dict[str, str] = None,
) -> dict[str, Any]:
"""
Execute SQL query on remote dataset using DuckDB.
@@ -121,10 +125,11 @@ async def execute_sql_query(
so it doesn't download the entire file!
Args:
url: File URL (HTTP/HTTPS, including S3 presigned URLs)
url: File URL (internal /resolve path or S3 presigned URL)
query: SQL query to execute
file_format: File format (csv, parquet, jsonl, json)
max_rows: Maximum rows to return (safety limit)
max_rows: Maximum rows to return (default: 10 for wide tables)
auth_headers: Optional auth headers for internal /resolve URLs
Returns:
{
@@ -136,13 +141,13 @@ async def execute_sql_query(
}
Example queries:
SELECT * FROM dataset LIMIT 100
SELECT * FROM dataset LIMIT 10
SELECT age, COUNT(*) as count FROM dataset GROUP BY age
SELECT * FROM dataset WHERE salary > 100000 ORDER BY salary DESC
"""
# Resolve redirects first (302 from resolve endpoint -> S3 presigned URL)
# Resolve redirects first (handles internal /resolve URLs with auth)
# This prevents DuckDB from repeatedly hitting our backend for range requests
resolved_url = await resolve_url_redirects(url)
resolved_url = await resolve_url_redirects(url, auth_headers)
# Run in thread pool (DuckDB is synchronous)
# Use resolved_url (not original url) to avoid repeated backend hits

View File

@@ -14,6 +14,8 @@ from typing import Any, Optional
import httpx
from kohakuhub.datasetviewer.parsers import resolve_url_redirects
class ParquetStreamParser:
"""
@@ -92,7 +94,9 @@ class WebDatasetTARParser:
"""
@staticmethod
async def parse_streaming(url: str, max_samples: int = 100) -> dict[str, Any]:
async def parse_streaming(
url: str, max_samples: int = 100, auth_headers: dict[str, str] = None
) -> dict[str, Any]:
"""
Parse webdataset TAR file using streaming.
@@ -100,8 +104,9 @@ class WebDatasetTARParser:
Does NOT load file content - only headers!
Args:
url: TAR file URL
url: TAR file URL (internal /resolve path or S3 URL)
max_samples: Maximum number of samples (IDs) to collect
auth_headers: Optional auth headers for internal requests
Returns:
{
@@ -114,8 +119,11 @@ class WebDatasetTARParser:
"truncated": bool
}
"""
# Resolve URL first if it's an internal path
resolved_url = await resolve_url_redirects(url, auth_headers)
async with httpx.AsyncClient(timeout=60.0, follow_redirects=True) as client:
async with client.stream("GET", url) as response:
async with client.stream("GET", resolved_url) as response:
response.raise_for_status()
# Stream TAR file
@@ -228,13 +236,19 @@ class TARStreamParser:
"""
@staticmethod
async def list_files_streaming(url: str) -> dict[str, Any]:
async def list_files_streaming(
url: str, auth_headers: dict[str, str] = None
) -> dict[str, Any]:
"""
List files in TAR by streaming headers only.
This is MUCH more efficient than current implementation
which loads entire TAR into memory!
Args:
url: TAR file URL (internal /resolve path or S3 URL)
auth_headers: Optional auth headers for internal requests
Returns:
{
"files": [
@@ -244,11 +258,14 @@ class TARStreamParser:
"total_size": N # Total TAR size
}
"""
# Resolve URL first if it's an internal path
resolved_url = await resolve_url_redirects(url, auth_headers)
files = []
offset = 0
async with httpx.AsyncClient(timeout=60.0, follow_redirects=True) as client:
async with client.stream("GET", url) as response:
async with client.stream("GET", resolved_url) as response:
response.raise_for_status()
total_size = int(response.headers.get("content-length", 0))