diff --git a/pyproject.toml b/pyproject.toml index 66c0d03..652b3eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "pyarrow", "pydantic[email]", "python-multipart", + "pytz", "pyyaml", "questionary", "requests", diff --git a/scripts/generate_test_dataset.py b/scripts/generate_test_dataset.py index f14ea77..052bd77 100644 --- a/scripts/generate_test_dataset.py +++ b/scripts/generate_test_dataset.py @@ -84,27 +84,66 @@ def generate_value(col_type: str, row_id: int): def generate_column_schema(num_cols: int) -> list[dict]: - """Generate random column schema.""" - columns = [{"name": "id", "type": "id"}] + """Generate column schema with meaningful names.""" + # Predefined meaningful column names with types + predefined_columns = [ + {"name": "id", "type": "id"}, + {"name": "user_id", "type": "int"}, + {"name": "age", "type": "int"}, + {"name": "score", "type": "float"}, + {"name": "rating", "type": "float"}, + {"name": "is_active", "type": "bool"}, + {"name": "is_verified", "type": "bool"}, + {"name": "created_at", "type": "datetime"}, + {"name": "updated_at", "type": "datetime"}, + {"name": "birth_date", "type": "date"}, + {"name": "username", "type": "short_text"}, + {"name": "email", "type": "short_text"}, + {"name": "name", "type": "short_text"}, + {"name": "title", "type": "text"}, + {"name": "description", "type": "text"}, + {"name": "category", "type": "short_text"}, + {"name": "status", "type": "short_text"}, + {"name": "comment", "type": "long_text"}, + {"name": "review", "type": "long_text"}, + {"name": "content", "type": "very_long_text"}, + {"name": "price", "type": "float"}, + {"name": "quantity", "type": "int"}, + {"name": "views", "type": "int"}, + {"name": "likes", "type": "int"}, + {"name": "tags", "type": "text"}, + {"name": "metadata", "type": "text"}, + {"name": "notes", "type": "long_text"}, + {"name": "address", "type": "text"}, + {"name": "city", "type": "short_text"}, + {"name": "country", "type": "short_text"}, + ] - # Mix of different column types - type_distribution = { - "int": 0.2, - "float": 0.15, - "bool": 0.1, - "date": 0.1, - "short_text": 0.15, - "text": 0.15, - "long_text": 0.1, - "very_long_text": 0.05, - } + columns = [] - types = list(type_distribution.keys()) - weights = list(type_distribution.values()) + # Use predefined columns first + for i in range(min(num_cols, len(predefined_columns))): + columns.append(predefined_columns[i]) - for i in range(1, num_cols): - col_type = random.choices(types, weights=weights)[0] - columns.append({"name": f"col_{i}_{col_type}", "type": col_type}) + # If we need more columns, generate with numbered suffix + if num_cols > len(predefined_columns): + type_distribution = { + "int": 0.2, + "float": 0.15, + "bool": 0.1, + "date": 0.1, + "short_text": 0.15, + "text": 0.15, + "long_text": 0.1, + "very_long_text": 0.05, + } + + types = list(type_distribution.keys()) + weights = list(type_distribution.values()) + + for i in range(len(predefined_columns), num_cols): + col_type = random.choices(types, weights=weights)[0] + columns.append({"name": f"field_{i}_{col_type}", "type": col_type}) return columns @@ -123,7 +162,8 @@ def generate_csv(output_path: Path, num_rows: int, num_cols: int): batch_size = 10000 for batch_start in range(0, num_rows, batch_size): batch_end = min(batch_start + batch_size, num_rows) - if batch_start % 50000 == 0: + # Print progress every 5 batches (every 50k rows) + if batch_start % (batch_size * 5) == 0: print(f" Writing rows {batch_start:,} to {batch_end:,}...") for row_id in range(batch_start, batch_end): @@ -146,7 +186,8 @@ def generate_jsonl(output_path: Path, num_rows: int, num_cols: int): batch_size = 10000 for batch_start in range(0, num_rows, batch_size): batch_end = min(batch_start + batch_size, num_rows) - if batch_start % 50000 == 0: + # Print progress every 5 batches (every 50k rows) + if batch_start % (batch_size * 5) == 0: print(f" Writing rows {batch_start:,} to {batch_end:,}...") for row_id in range(batch_start, batch_end): @@ -172,8 +213,10 @@ def generate_parquet(output_path: Path, num_rows: int, num_cols: int): for batch_start in range(0, num_rows, batch_size): batch_end = min(batch_start + batch_size, num_rows) + # Print progress every 100k rows (every 2 batches of 50k) if batch_start % 100000 == 0: - print(f" Processing rows {batch_start:,} to {batch_end:,}...") + progress_end = min(batch_start + 100000, num_rows) + print(f" Processing rows {batch_start:,} to {progress_end:,}...") # Generate batch data data = [] diff --git a/src/kohakuhub/datasetviewer/parsers.py b/src/kohakuhub/datasetviewer/parsers.py index 7d8b107..08cd105 100644 --- a/src/kohakuhub/datasetviewer/parsers.py +++ b/src/kohakuhub/datasetviewer/parsers.py @@ -16,6 +16,49 @@ import pyarrow.parquet as pq from fsspec.implementations.http import HTTPFileSystem +async def resolve_url_redirects(url: str) -> str: + """ + Resolve URL redirects by following 302 responses. + + For resolve URLs that return 302, we need to follow the redirect + and use the final S3 presigned URL for fsspec/DuckDB. + Otherwise, they keep hitting our backend for every range request. + + Uses GET with streaming (no redirect following) to detect 302 responses + without actually downloading content. + + Args: + url: Original URL (may return 302) + + Returns: + Final URL after following redirects (or original if no redirect) + """ + try: + async with httpx.AsyncClient(timeout=10.0, follow_redirects=False) as client: + # Send GET request with streaming, but don't follow redirects + # This gives us the correct 302 response that GET would return + async with client.stream("GET", url) as response: + # Check for redirect status codes + if response.status_code in [301, 302, 303, 307, 308]: + location = response.headers.get("Location") + if location: + print(f"Resolved redirect: {url[:50]}... -> {location[:50]}...") + # Close stream immediately without reading content + await response.aclose() + return location + + # For other status codes (200, 4xx, 5xx), use original URL + # Close stream without reading content + await response.aclose() + return url + + except Exception as e: + # If request fails, fall back to original URL + # fsspec will handle the actual request + print(f"Warning: Could not resolve redirects for {url[:50]}...: {e}") + return url + + class ParserError(Exception): """Base exception for parser errors.""" @@ -23,14 +66,17 @@ class ParserError(Exception): class CSVParser: - """Stream CSV files from URL.""" + """Parse CSV files using DuckDB (non-blocking, efficient).""" @staticmethod async def parse( url: str, max_rows: int = 1000, delimiter: str = "," ) -> dict[str, Any]: """ - Parse CSV file from URL. + Parse CSV file from URL using DuckDB. + + DuckDB supports CSV with automatic type detection and HTTP range requests. + Much faster and more robust than manual parsing. Args: url: File URL (presigned S3 URL) @@ -42,78 +88,72 @@ class CSVParser: "columns": ["col1", "col2", ...], "rows": [[val1, val2, ...], ...], "total_rows": N, - "truncated": bool + "truncated": bool, + "file_size": N } """ - async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: - async with client.stream("GET", url) as response: - response.raise_for_status() + # Resolve redirects first + resolved_url = await resolve_url_redirects(url) - # Get file size if available - content_length = response.headers.get("content-length") - file_size = int(content_length) if content_length else None + def _parse_csv_sync(url: str, max_rows: int, delimiter: str) -> dict[str, Any]: + """Synchronous CSV parsing with DuckDB (runs in thread pool).""" + import duckdb - # Read in chunks - buffer = "" - rows = [] - columns = None - line_count = 0 + conn = duckdb.connect(":memory:") + conn.execute("INSTALL httpfs") + conn.execute("LOAD httpfs") - async for chunk in response.aiter_text(): - buffer += chunk + try: + # Read CSV with DuckDB + query = f""" + SELECT * FROM read_csv( + '{url}', + delim='{delimiter}', + header=true, + auto_detect=true + ) + LIMIT {max_rows} + """ - # Process complete lines - while "\n" in buffer: - line, buffer = buffer.split("\n", 1) + result = conn.execute(query).fetchall() + columns = [desc[0] for desc in conn.description] - if not line.strip(): - continue + # Get total row count + try: + count_query = f"SELECT COUNT(*) FROM read_csv('{url}', delim='{delimiter}', header=true)" + total_rows = conn.execute(count_query).fetchone()[0] + except Exception: + total_rows = len(result) - # Parse CSV line - try: - parsed = list(csv.reader([line], delimiter=delimiter))[0] - except Exception: - continue - - # First line is header - if columns is None: - columns = parsed - continue - - rows.append(parsed) - line_count += 1 - - if line_count >= max_rows: - break - - if line_count >= max_rows: - break - - # Process remaining buffer - if buffer.strip() and line_count < max_rows and columns is not None: - try: - parsed = list(csv.reader([buffer], delimiter=delimiter))[0] - rows.append(parsed) - line_count += 1 - except Exception: - pass + conn.close() return { - "columns": columns or [], - "rows": rows, - "total_rows": line_count, - "truncated": line_count >= max_rows, - "file_size": file_size, + "columns": columns, + "rows": [list(row) for row in result], + "total_rows": total_rows, + "truncated": len(result) >= max_rows, + "file_size": None, } + except Exception as e: + conn.close() + raise ParserError(f"Failed to parse CSV with DuckDB: {e}") + + # Run in thread pool to avoid blocking event loop + return await asyncio.to_thread( + _parse_csv_sync, resolved_url, max_rows, delimiter + ) + class JSONLParser: - """Stream JSONL (newline-delimited JSON) files from URL.""" + """Parse JSONL files using DuckDB (non-blocking, efficient).""" @staticmethod async def parse(url: str, max_rows: int = 1000) -> dict[str, Any]: """ - Parse JSONL file from URL. + Parse JSONL file from URL using DuckDB. + + DuckDB's read_ndjson supports newline-delimited JSON with HTTP URLs. Args: url: File URL @@ -124,98 +164,67 @@ class JSONLParser: "columns": ["col1", "col2", ...], "rows": [[val1, val2, ...], ...], "total_rows": N, - "truncated": bool + "truncated": bool, + "file_size": N } """ - async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: - async with client.stream("GET", url) as response: - response.raise_for_status() + # Resolve redirects first + resolved_url = await resolve_url_redirects(url) - content_length = response.headers.get("content-length") - file_size = int(content_length) if content_length else None + def _parse_jsonl_sync(url: str, max_rows: int) -> dict[str, Any]: + """Synchronous JSONL parsing with DuckDB (runs in thread pool).""" + import duckdb - buffer = "" - rows = [] - columns = set() - line_count = 0 + conn = duckdb.connect(":memory:") + conn.execute("INSTALL httpfs") + conn.execute("LOAD httpfs") + conn.execute("INSTALL json") + conn.execute("LOAD json") - async for chunk in response.aiter_text(): - buffer += chunk + try: + # Read JSONL with DuckDB (newline-delimited JSON) + query = f""" + SELECT * FROM read_ndjson('{url}') + LIMIT {max_rows} + """ - while "\n" in buffer: - line, buffer = buffer.split("\n", 1) + result = conn.execute(query).fetchall() + columns = [desc[0] for desc in conn.description] - if not line.strip(): - continue + # Get total row count + try: + count_query = f"SELECT COUNT(*) FROM read_ndjson('{url}')" + total_rows = conn.execute(count_query).fetchone()[0] + except Exception: + total_rows = len(result) - try: - obj = json.loads(line) - if isinstance(obj, dict): - # Collect all keys - columns.update(obj.keys()) - rows.append(obj) - line_count += 1 - except json.JSONDecodeError: - continue - - if line_count >= max_rows: - break - - if line_count >= max_rows: - break - - # Process remaining buffer - if buffer.strip() and line_count < max_rows: - try: - obj = json.loads(buffer) - if isinstance(obj, dict): - columns.update(obj.keys()) - rows.append(obj) - line_count += 1 - except json.JSONDecodeError: - pass - - # Convert to columnar format - # Sort columns by completeness (most complete first), then alphabetically - columns_list = sorted(columns) - - # Calculate completeness for each column - column_completeness = {} - for col in columns_list: - non_null_count = sum(1 for row in rows if row.get(col) is not None) - column_completeness[col] = non_null_count - - # Sort by completeness (descending), then alphabetically - columns_list = sorted( - columns_list, key=lambda col: (-column_completeness[col], col) - ) - - rows_list = [[row.get(col) for col in columns_list] for row in rows] + conn.close() return { - "columns": columns_list, - "rows": rows_list, - "total_rows": line_count, - "truncated": line_count >= max_rows, - "file_size": file_size, + "columns": columns, + "rows": [list(row) for row in result], + "total_rows": total_rows, + "truncated": len(result) >= max_rows, + "file_size": None, } + except Exception as e: + conn.close() + raise ParserError(f"Failed to parse JSONL with DuckDB: {e}") + + # Run in thread pool to avoid blocking event loop + return await asyncio.to_thread(_parse_jsonl_sync, resolved_url, max_rows) + class JSONParser: - """ - Parse JSON array files from URL. - - NOTE: This parser is DEPRECATED and should not be used for large files! - It requires loading the entire file into memory. - Use JSONL format instead for streaming support. - """ + """Parse JSON array files using DuckDB (non-blocking).""" @staticmethod async def parse(url: str, max_rows: int = 1000) -> dict[str, Any]: """ - Parse JSON file from URL. + Parse JSON array file from URL using DuckDB. - Expects format: [{"col1": val1, ...}, ...] + DuckDB's read_json supports JSON arrays with automatic schema detection. Args: url: File URL @@ -224,75 +233,66 @@ class JSONParser: Returns: Same format as JSONLParser """ - async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: - response = await client.get(url) - response.raise_for_status() + # Resolve redirects first + resolved_url = await resolve_url_redirects(url) - file_size = len(response.content) + def _parse_json_sync(url: str, max_rows: int) -> dict[str, Any]: + """Synchronous JSON parsing with DuckDB (runs in thread pool).""" + import duckdb + + conn = duckdb.connect(":memory:") + conn.execute("INSTALL httpfs") + conn.execute("LOAD httpfs") + conn.execute("INSTALL json") + conn.execute("LOAD json") try: - data = response.json() - except json.JSONDecodeError as e: - raise ParserError(f"Invalid JSON: {e}") + # Read JSON array with DuckDB + query = f""" + SELECT * FROM read_json('{url}', format='array') + LIMIT {max_rows} + """ - if not isinstance(data, list): - raise ParserError("JSON must be an array of objects") + result = conn.execute(query).fetchall() + columns = [desc[0] for desc in conn.description] - # Limit rows - rows_data = data[:max_rows] + # Get total row count + try: + count_query = ( + f"SELECT COUNT(*) FROM read_json('{url}', format='array')" + ) + total_rows = conn.execute(count_query).fetchone()[0] + except Exception: + total_rows = len(result) - # Extract columns - columns = set() - for row in rows_data: - if isinstance(row, dict): - columns.update(row.keys()) + conn.close() - columns_list = sorted(columns) + return { + "columns": columns, + "rows": [list(row) for row in result], + "total_rows": total_rows, + "truncated": len(result) >= max_rows, + "file_size": None, + } - # Calculate completeness for each column - column_completeness = {} - for col in columns_list: - non_null_count = sum( - 1 - for row in rows_data - if isinstance(row, dict) and row.get(col) is not None - ) - column_completeness[col] = non_null_count + except Exception as e: + conn.close() + raise ParserError(f"Failed to parse JSON with DuckDB: {e}") - # Sort by completeness (descending), then alphabetically - columns_list = sorted( - columns_list, key=lambda col: (-column_completeness[col], col) - ) - - rows_list = [ - [ - row.get(col) if isinstance(row, dict) else None - for col in columns_list - ] - for row in rows_data - ] - - return { - "columns": columns_list, - "rows": rows_list, - "total_rows": len(rows_data), - "truncated": len(data) > max_rows, - "file_size": file_size, - } + # Run in thread pool to avoid blocking event loop + return await asyncio.to_thread(_parse_json_sync, resolved_url, max_rows) class ParquetParser: - """Parse Parquet files using PyArrow with fsspec for range requests.""" + """Parse Parquet files using DuckDB (non-blocking, most efficient).""" @staticmethod async def parse(url: str, max_rows: int = 1000) -> dict[str, Any]: """ - Parse Parquet file from URL using fsspec + PyArrow. + Parse Parquet file from URL using DuckDB. - fsspec provides HTTP file-like object with automatic range requests. - PyArrow reads only footer + first row group, not entire file! - - Uses asyncio.to_thread to run blocking fsspec operations without blocking event loop. + DuckDB is the most efficient way to read Parquet with HTTP range requests. + It only reads the necessary row groups and columns, not the entire file. Args: url: File URL (HTTP/HTTPS, including S3 presigned URLs) @@ -301,49 +301,51 @@ class ParquetParser: Returns: Same format as other parsers """ + # Resolve redirects first (302 from resolve endpoint -> S3 presigned URL) + # This prevents DuckDB from repeatedly hitting our backend + resolved_url = await resolve_url_redirects(url) def _parse_parquet_sync(url: str, max_rows: int) -> dict[str, Any]: - """Synchronous Parquet parsing (runs in thread pool).""" - # Create sync HTTP filesystem - fs = HTTPFileSystem() + """Synchronous Parquet parsing with DuckDB (runs in thread pool).""" + import duckdb - # Open file with fsspec (handles range requests automatically!) - with fs.open(url, mode="rb") as f: - # PyArrow reads with range requests via fsspec - parquet_file = pq.ParquetFile(f) + conn = duckdb.connect(":memory:") + conn.execute("INSTALL httpfs") + conn.execute("LOAD httpfs") - # Read only first row group (not entire file!) - table = parquet_file.read_row_group(0) + try: + # Read Parquet with DuckDB (uses HTTP range requests automatically!) + query = f"SELECT * FROM read_parquet('{url}') LIMIT {max_rows}" - # Convert to list of dicts - data = table.to_pylist() + result = conn.execute(query).fetchall() + columns = [desc[0] for desc in conn.description] - # Limit rows - limited_data = data[:max_rows] + # Get total row count (efficient - reads only metadata) + try: + count_query = f"SELECT COUNT(*) FROM read_parquet('{url}')" + total_rows = conn.execute(count_query).fetchone()[0] + except Exception: + total_rows = len(result) - # Get columns - columns = table.schema.names - - # Convert to row format - rows = [[row[col] for col in columns] for row in limited_data] - - # Get total rows from metadata - total_rows = parquet_file.metadata.num_rows - - # Get file size - file_size = parquet_file.metadata.serialized_size + conn.close() return { "columns": columns, - "rows": rows, + "rows": [list(row) for row in result], "total_rows": total_rows, - "truncated": total_rows > max_rows, - "file_size": file_size, + "truncated": len(result) >= max_rows or total_rows > max_rows, + "file_size": None, } + except Exception as e: + conn.close() + raise ParserError(f"Failed to parse Parquet with DuckDB: {e}") + try: # Run in thread pool to avoid blocking event loop - result = await asyncio.to_thread(_parse_parquet_sync, url, max_rows) + result = await asyncio.to_thread( + _parse_parquet_sync, resolved_url, max_rows + ) return result except Exception as e: diff --git a/src/kohakuhub/datasetviewer/sql_query.py b/src/kohakuhub/datasetviewer/sql_query.py index 6c78609..e4911d9 100644 --- a/src/kohakuhub/datasetviewer/sql_query.py +++ b/src/kohakuhub/datasetviewer/sql_query.py @@ -17,6 +17,8 @@ from typing import Any import duckdb +from kohakuhub.datasetviewer.parsers import resolve_url_redirects + class SQLQueryError(Exception): """SQL query execution error.""" @@ -131,10 +133,15 @@ async def execute_sql_query( conn.close() raise SQLQueryError(f"Query execution failed: {e}") + # Resolve redirects first (302 from resolve endpoint -> S3 presigned URL) + # This prevents DuckDB from repeatedly hitting our backend for range requests + resolved_url = await resolve_url_redirects(url) + try: # Run in thread pool (DuckDB is synchronous) + # Use resolved_url (not original url) to avoid repeated backend hits result = await asyncio.to_thread( - _execute_query_sync, url, query, file_format, max_rows + _execute_query_sync, resolved_url, query, file_format, max_rows ) return result