diff --git a/src/kohakuhub/datasetviewer/parsers.py b/src/kohakuhub/datasetviewer/parsers.py index cdec308..ba97bc1 100644 --- a/src/kohakuhub/datasetviewer/parsers.py +++ b/src/kohakuhub/datasetviewer/parsers.py @@ -12,53 +12,64 @@ import duckdb import httpx from fsspec.implementations.http import HTTPFileSystem +from kohakuhub.config import cfg from kohakuhub.datasetviewer.logger import get_logger logger = get_logger("Parser") -async def resolve_url_redirects(url: str) -> str: +async def resolve_url_redirects(url: str, auth_headers: dict[str, str] = None) -> str: """ - Resolve URL redirects by following 302 responses. + Resolve URL redirects by following 302 responses with authentication. - For resolve URLs that return 302, we need to follow the redirect + For /resolve URLs that return 302, we need to follow the redirect and use the final S3 presigned URL for fsspec/DuckDB. - Otherwise, they keep hitting our backend for every range request. - Uses GET with streaming (no redirect following) to detect 302 responses - without actually downloading content. + Uses GET with manual redirect handling to get Location header + without downloading file content. Args: - url: Original URL (may return 302) + url: Original URL (e.g., /datasets/.../resolve/main/file.csv or S3 URL) + auth_headers: Optional auth headers (Authorization, Cookie) from user request Returns: - Final URL after following redirects (or original if no redirect) + Final S3 presigned URL after following redirects (or original if external URL) """ + # If already an S3 URL (starts with http:// or https:// with s3/amazonaws), use as-is + if url.startswith("http://") or url.startswith("https://"): + # External URL, don't try to resolve + return url + + # Internal /resolve URL - make authenticated request to get S3 URL try: + # Build full URL (relative path to absolute) + full_url = f"{cfg.app.base_url}{url}" + async with httpx.AsyncClient(timeout=10.0, follow_redirects=False) as client: - # Send GET request with streaming, but don't follow redirects - # This gives us the correct 302 response that GET would return - async with client.stream("GET", url) as response: + # Send GET request with auth headers, don't follow redirects + headers = auth_headers or {} + async with client.stream("GET", full_url, headers=headers) as response: # Check for redirect status codes if response.status_code in [301, 302, 303, 307, 308]: location = response.headers.get("Location") if location: logger.debug( - f"Resolved redirect: {url[:50]}... -> {location[:50]}..." + f"Resolved internal URL: {url[:50]}... -> S3 presigned URL" ) # Close stream immediately without reading content await response.aclose() return location - # For other status codes (200, 4xx, 5xx), use original URL - # Close stream without reading content + # For other status codes, close and return original await response.aclose() + logger.warning( + f"Expected redirect for {url}, got {response.status_code}" + ) return url except Exception as e: # If request fails, fall back to original URL - # fsspec will handle the actual request - logger.warning(f"Could not resolve redirects for {url[:50]}...: {e}") + logger.error(f"Could not resolve internal URL {url[:50]}...: {e}") return url @@ -111,21 +122,25 @@ class CSVParser: @staticmethod async def parse( - url: str, max_rows: int = 1000, delimiter: str = "," + url: str, + max_rows: int = 1000, + delimiter: str = ",", + auth_headers: dict[str, str] = None, ) -> dict[str, Any]: """ Parse CSV file from URL using DuckDB. Args: - url: File URL (presigned S3 URL) + url: File URL (internal /resolve path or S3 presigned URL) max_rows: Maximum rows to return delimiter: CSV delimiter + auth_headers: Optional auth headers for internal /resolve URLs Returns: Dict with columns, rows, total_rows, truncated, file_size """ - # Resolve redirects first - resolved_url = await resolve_url_redirects(url) + # Resolve redirects first (handles internal /resolve URLs) + resolved_url = await resolve_url_redirects(url, auth_headers) # Run in thread pool to avoid blocking event loop return await asyncio.to_thread( @@ -169,19 +184,22 @@ class JSONLParser: } @staticmethod - async def parse(url: str, max_rows: int = 1000) -> dict[str, Any]: + async def parse( + url: str, max_rows: int = 1000, auth_headers: dict[str, str] = None + ) -> dict[str, Any]: """ Parse JSONL file from URL using DuckDB. Args: - url: File URL + url: File URL (internal /resolve path or S3 presigned URL) max_rows: Maximum rows to return + auth_headers: Optional auth headers for internal /resolve URLs Returns: Dict with columns, rows, total_rows, truncated, file_size """ - # Resolve redirects first - resolved_url = await resolve_url_redirects(url) + # Resolve redirects first (handles internal /resolve URLs) + resolved_url = await resolve_url_redirects(url, auth_headers) # Run in thread pool to avoid blocking event loop return await asyncio.to_thread(JSONLParser._parse_sync, resolved_url, max_rows) @@ -223,19 +241,22 @@ class JSONParser: } @staticmethod - async def parse(url: str, max_rows: int = 1000) -> dict[str, Any]: + async def parse( + url: str, max_rows: int = 1000, auth_headers: dict[str, str] = None + ) -> dict[str, Any]: """ Parse JSON array file from URL using DuckDB. Args: - url: File URL + url: File URL (internal /resolve path or S3 presigned URL) max_rows: Maximum rows to return + auth_headers: Optional auth headers for internal /resolve URLs Returns: Dict with columns, rows, total_rows, truncated, file_size """ - # Resolve redirects first - resolved_url = await resolve_url_redirects(url) + # Resolve redirects first (handles internal /resolve URLs) + resolved_url = await resolve_url_redirects(url, auth_headers) # Run in thread pool to avoid blocking event loop return await asyncio.to_thread(JSONParser._parse_sync, resolved_url, max_rows) @@ -275,20 +296,23 @@ class ParquetParser: } @staticmethod - async def parse(url: str, max_rows: int = 1000) -> dict[str, Any]: + async def parse( + url: str, max_rows: int = 1000, auth_headers: dict[str, str] = None + ) -> dict[str, Any]: """ Parse Parquet file from URL using DuckDB. Args: - url: File URL (HTTP/HTTPS, including S3 presigned URLs) + url: File URL (internal /resolve path or S3 presigned URL) max_rows: Maximum rows to return + auth_headers: Optional auth headers for internal /resolve URLs Returns: Dict with columns, rows, total_rows, truncated, file_size """ - # Resolve redirects first (302 from resolve endpoint -> S3 presigned URL) + # Resolve redirects first (handles internal /resolve URLs) # This prevents DuckDB from repeatedly hitting our backend - resolved_url = await resolve_url_redirects(url) + resolved_url = await resolve_url_redirects(url, auth_headers) # Run in thread pool to avoid blocking event loop return await asyncio.to_thread( @@ -375,19 +399,22 @@ class TARParser: raise ParserError(f"File not found in archive: {file_name}") @staticmethod - async def extract_file(url: str, file_name: str) -> bytes: + async def extract_file( + url: str, file_name: str, auth_headers: dict[str, str] = None + ) -> bytes: """ Extract single file from TAR archive using streaming. Args: - url: TAR file URL + url: TAR file URL (internal /resolve path or S3 presigned URL) file_name: Name of file to extract + auth_headers: Optional auth headers for internal /resolve URLs Returns: File content as bytes """ - # Resolve redirects first - resolved_url = await resolve_url_redirects(url) + # Resolve redirects first (handles internal /resolve URLs) + resolved_url = await resolve_url_redirects(url, auth_headers) # Run in thread pool to avoid blocking return await asyncio.to_thread( diff --git a/src/kohakuhub/datasetviewer/router.py b/src/kohakuhub/datasetviewer/router.py index e9c7a28..a0878d8 100644 --- a/src/kohakuhub/datasetviewer/router.py +++ b/src/kohakuhub/datasetviewer/router.py @@ -6,7 +6,7 @@ Minimal, auth-free endpoints for previewing dataset files. from typing import Optional -from fastapi import APIRouter, Depends, HTTPException, Query, Request +from fastapi import APIRouter, Cookie, Depends, Header, HTTPException, Query, Request from pydantic import BaseModel, HttpUrl from kohakuhub.datasetviewer.logger import get_logger @@ -32,6 +32,24 @@ logger = get_logger("Router") router = APIRouter(prefix="/dataset-viewer", tags=["Dataset Viewer"]) +def get_auth_headers( + session_id: Optional[str] = Cookie(None), + authorization: Optional[str] = Header(None), +) -> dict[str, str]: + """Extract authentication headers from request for internal /resolve requests.""" + headers = {} + + # Pass Authorization header if present + if authorization: + headers["Authorization"] = authorization + + # Pass session cookie if present + if session_id: + headers["Cookie"] = f"session_id={session_id}" + + return headers + + class PreviewRequest(BaseModel): """Request to preview a file.""" @@ -47,7 +65,7 @@ class SQLQueryRequest(BaseModel): url: HttpUrl # S3 presigned URL or any HTTP(S) URL query: str # SQL query to execute format: Optional[str] = None # Auto-detect if not provided - max_rows: int = 10000 # Safety limit + max_rows: int = 10 # Default: 10 rows (datasets can have many columns) class PreviewResponse(BaseModel): @@ -97,6 +115,7 @@ async def preview_file( request: Request, req: PreviewRequest, identifier: str = Depends(check_rate_limit_dependency), + auth_headers: dict[str, str] = Depends(get_auth_headers), ): """ Preview a dataset file from URL. @@ -138,15 +157,17 @@ async def preview_file( f"Previewing {file_format} file: {url_str[:80]}... (max_rows={req.max_rows})" ) - # Parse file + # Parse file (pass auth headers for internal /resolve URLs) try: if file_format == "csv" or file_format == "tsv": delimiter = "\t" if file_format == "tsv" else req.delimiter - result = await CSVParser.parse(url_str, req.max_rows, delimiter) + result = await CSVParser.parse( + url_str, req.max_rows, delimiter, auth_headers + ) elif file_format == "jsonl": - result = await JSONLParser.parse(url_str, req.max_rows) + result = await JSONLParser.parse(url_str, req.max_rows, auth_headers) elif file_format == "parquet": - result = await ParquetParser.parse(url_str, req.max_rows) + result = await ParquetParser.parse(url_str, req.max_rows, auth_headers) else: raise HTTPException( 400, @@ -174,6 +195,7 @@ async def list_tar_files( request: Request, req: TARListRequest, identifier: str = Depends(check_rate_limit_dependency), + auth_headers: dict[str, str] = Depends(get_auth_headers), ): """ List files in a TAR archive. @@ -188,7 +210,7 @@ async def list_tar_files( try: # Use streaming parser (doesn't load full TAR into memory!) - result = await TARStreamParser.list_files_streaming(str(req.url)) + result = await TARStreamParser.list_files_streaming(str(req.url), auth_headers) limiter.finish_request(identifier, 0) # Only reads headers, minimal data return result except Exception as e: @@ -201,6 +223,7 @@ async def extract_tar_file( request: Request, req: TARExtractRequest, identifier: str = Depends(check_rate_limit_dependency), + auth_headers: dict[str, str] = Depends(get_auth_headers), ): """ Extract a single file from TAR archive. @@ -216,7 +239,9 @@ async def extract_tar_file( limiter = get_rate_limiter() try: - content = await TARParser.extract_file(str(req.url), req.file_name) + content = await TARParser.extract_file( + str(req.url), req.file_name, auth_headers + ) limiter.finish_request(identifier, len(content)) # Return raw bytes @@ -235,6 +260,7 @@ async def preview_webdataset_tar( req: TARListRequest, max_samples: int = Query(100, description="Max samples to preview"), identifier: str = Depends(check_rate_limit_dependency), + auth_headers: dict[str, str] = Depends(get_auth_headers), ): """ Preview TAR file in webdataset format. @@ -252,7 +278,9 @@ async def preview_webdataset_tar( limiter = get_rate_limiter() try: - result = await WebDatasetTARParser.parse_streaming(str(req.url), max_samples) + result = await WebDatasetTARParser.parse_streaming( + str(req.url), max_samples, auth_headers + ) limiter.finish_request(identifier, 0) return result except Exception as e: @@ -278,6 +306,7 @@ async def execute_sql( request: Request, req: SQLQueryRequest, identifier: str = Depends(check_rate_limit_dependency), + auth_headers: dict[str, str] = Depends(get_auth_headers), ): """ Execute SQL query on dataset file using DuckDB. @@ -326,7 +355,9 @@ async def execute_sql( ) try: - result = await execute_sql_query(url_str, req.query, file_format, req.max_rows) + result = await execute_sql_query( + url_str, req.query, file_format, req.max_rows, auth_headers + ) limiter.finish_request(identifier, 0) diff --git a/src/kohakuhub/datasetviewer/sql_query.py b/src/kohakuhub/datasetviewer/sql_query.py index 1c3d97d..58a3ca9 100644 --- a/src/kohakuhub/datasetviewer/sql_query.py +++ b/src/kohakuhub/datasetviewer/sql_query.py @@ -112,7 +112,11 @@ def _execute_query_sync(url: str, query: str, file_format: str, max_rows: int): async def execute_sql_query( - url: str, query: str, file_format: str = "parquet", max_rows: int = 1000 + url: str, + query: str, + file_format: str = "parquet", + max_rows: int = 10, + auth_headers: dict[str, str] = None, ) -> dict[str, Any]: """ Execute SQL query on remote dataset using DuckDB. @@ -121,10 +125,11 @@ async def execute_sql_query( so it doesn't download the entire file! Args: - url: File URL (HTTP/HTTPS, including S3 presigned URLs) + url: File URL (internal /resolve path or S3 presigned URL) query: SQL query to execute file_format: File format (csv, parquet, jsonl, json) - max_rows: Maximum rows to return (safety limit) + max_rows: Maximum rows to return (default: 10 for wide tables) + auth_headers: Optional auth headers for internal /resolve URLs Returns: { @@ -136,13 +141,13 @@ async def execute_sql_query( } Example queries: - SELECT * FROM dataset LIMIT 100 + SELECT * FROM dataset LIMIT 10 SELECT age, COUNT(*) as count FROM dataset GROUP BY age SELECT * FROM dataset WHERE salary > 100000 ORDER BY salary DESC """ - # Resolve redirects first (302 from resolve endpoint -> S3 presigned URL) + # Resolve redirects first (handles internal /resolve URLs with auth) # This prevents DuckDB from repeatedly hitting our backend for range requests - resolved_url = await resolve_url_redirects(url) + resolved_url = await resolve_url_redirects(url, auth_headers) # Run in thread pool (DuckDB is synchronous) # Use resolved_url (not original url) to avoid repeated backend hits diff --git a/src/kohakuhub/datasetviewer/streaming_parsers.py b/src/kohakuhub/datasetviewer/streaming_parsers.py index bdf58af..4023d5c 100644 --- a/src/kohakuhub/datasetviewer/streaming_parsers.py +++ b/src/kohakuhub/datasetviewer/streaming_parsers.py @@ -14,6 +14,8 @@ from typing import Any, Optional import httpx +from kohakuhub.datasetviewer.parsers import resolve_url_redirects + class ParquetStreamParser: """ @@ -92,7 +94,9 @@ class WebDatasetTARParser: """ @staticmethod - async def parse_streaming(url: str, max_samples: int = 100) -> dict[str, Any]: + async def parse_streaming( + url: str, max_samples: int = 100, auth_headers: dict[str, str] = None + ) -> dict[str, Any]: """ Parse webdataset TAR file using streaming. @@ -100,8 +104,9 @@ class WebDatasetTARParser: Does NOT load file content - only headers! Args: - url: TAR file URL + url: TAR file URL (internal /resolve path or S3 URL) max_samples: Maximum number of samples (IDs) to collect + auth_headers: Optional auth headers for internal requests Returns: { @@ -114,8 +119,11 @@ class WebDatasetTARParser: "truncated": bool } """ + # Resolve URL first if it's an internal path + resolved_url = await resolve_url_redirects(url, auth_headers) + async with httpx.AsyncClient(timeout=60.0, follow_redirects=True) as client: - async with client.stream("GET", url) as response: + async with client.stream("GET", resolved_url) as response: response.raise_for_status() # Stream TAR file @@ -228,13 +236,19 @@ class TARStreamParser: """ @staticmethod - async def list_files_streaming(url: str) -> dict[str, Any]: + async def list_files_streaming( + url: str, auth_headers: dict[str, str] = None + ) -> dict[str, Any]: """ List files in TAR by streaming headers only. This is MUCH more efficient than current implementation which loads entire TAR into memory! + Args: + url: TAR file URL (internal /resolve path or S3 URL) + auth_headers: Optional auth headers for internal requests + Returns: { "files": [ @@ -244,11 +258,14 @@ class TARStreamParser: "total_size": N # Total TAR size } """ + # Resolve URL first if it's an internal path + resolved_url = await resolve_url_redirects(url, auth_headers) + files = [] offset = 0 async with httpx.AsyncClient(timeout=60.0, follow_redirects=True) as client: - async with client.stream("GET", url) as response: + async with client.stream("GET", resolved_url) as response: response.raise_for_status() total_size = int(response.headers.get("content-length", 0))