mirror of
https://github.com/KohakuBlueleaf/KohakuHub.git
synced 2026-03-11 17:34:08 -05:00
improve backend implementation
This commit is contained in:
@@ -30,6 +30,7 @@ dependencies = [
|
||||
"pyarrow",
|
||||
"pydantic[email]",
|
||||
"python-multipart",
|
||||
"pytz",
|
||||
"pyyaml",
|
||||
"questionary",
|
||||
"requests",
|
||||
|
||||
@@ -84,27 +84,66 @@ def generate_value(col_type: str, row_id: int):
|
||||
|
||||
|
||||
def generate_column_schema(num_cols: int) -> list[dict]:
|
||||
"""Generate random column schema."""
|
||||
columns = [{"name": "id", "type": "id"}]
|
||||
"""Generate column schema with meaningful names."""
|
||||
# Predefined meaningful column names with types
|
||||
predefined_columns = [
|
||||
{"name": "id", "type": "id"},
|
||||
{"name": "user_id", "type": "int"},
|
||||
{"name": "age", "type": "int"},
|
||||
{"name": "score", "type": "float"},
|
||||
{"name": "rating", "type": "float"},
|
||||
{"name": "is_active", "type": "bool"},
|
||||
{"name": "is_verified", "type": "bool"},
|
||||
{"name": "created_at", "type": "datetime"},
|
||||
{"name": "updated_at", "type": "datetime"},
|
||||
{"name": "birth_date", "type": "date"},
|
||||
{"name": "username", "type": "short_text"},
|
||||
{"name": "email", "type": "short_text"},
|
||||
{"name": "name", "type": "short_text"},
|
||||
{"name": "title", "type": "text"},
|
||||
{"name": "description", "type": "text"},
|
||||
{"name": "category", "type": "short_text"},
|
||||
{"name": "status", "type": "short_text"},
|
||||
{"name": "comment", "type": "long_text"},
|
||||
{"name": "review", "type": "long_text"},
|
||||
{"name": "content", "type": "very_long_text"},
|
||||
{"name": "price", "type": "float"},
|
||||
{"name": "quantity", "type": "int"},
|
||||
{"name": "views", "type": "int"},
|
||||
{"name": "likes", "type": "int"},
|
||||
{"name": "tags", "type": "text"},
|
||||
{"name": "metadata", "type": "text"},
|
||||
{"name": "notes", "type": "long_text"},
|
||||
{"name": "address", "type": "text"},
|
||||
{"name": "city", "type": "short_text"},
|
||||
{"name": "country", "type": "short_text"},
|
||||
]
|
||||
|
||||
# Mix of different column types
|
||||
type_distribution = {
|
||||
"int": 0.2,
|
||||
"float": 0.15,
|
||||
"bool": 0.1,
|
||||
"date": 0.1,
|
||||
"short_text": 0.15,
|
||||
"text": 0.15,
|
||||
"long_text": 0.1,
|
||||
"very_long_text": 0.05,
|
||||
}
|
||||
columns = []
|
||||
|
||||
types = list(type_distribution.keys())
|
||||
weights = list(type_distribution.values())
|
||||
# Use predefined columns first
|
||||
for i in range(min(num_cols, len(predefined_columns))):
|
||||
columns.append(predefined_columns[i])
|
||||
|
||||
for i in range(1, num_cols):
|
||||
col_type = random.choices(types, weights=weights)[0]
|
||||
columns.append({"name": f"col_{i}_{col_type}", "type": col_type})
|
||||
# If we need more columns, generate with numbered suffix
|
||||
if num_cols > len(predefined_columns):
|
||||
type_distribution = {
|
||||
"int": 0.2,
|
||||
"float": 0.15,
|
||||
"bool": 0.1,
|
||||
"date": 0.1,
|
||||
"short_text": 0.15,
|
||||
"text": 0.15,
|
||||
"long_text": 0.1,
|
||||
"very_long_text": 0.05,
|
||||
}
|
||||
|
||||
types = list(type_distribution.keys())
|
||||
weights = list(type_distribution.values())
|
||||
|
||||
for i in range(len(predefined_columns), num_cols):
|
||||
col_type = random.choices(types, weights=weights)[0]
|
||||
columns.append({"name": f"field_{i}_{col_type}", "type": col_type})
|
||||
|
||||
return columns
|
||||
|
||||
@@ -123,7 +162,8 @@ def generate_csv(output_path: Path, num_rows: int, num_cols: int):
|
||||
batch_size = 10000
|
||||
for batch_start in range(0, num_rows, batch_size):
|
||||
batch_end = min(batch_start + batch_size, num_rows)
|
||||
if batch_start % 50000 == 0:
|
||||
# Print progress every 5 batches (every 50k rows)
|
||||
if batch_start % (batch_size * 5) == 0:
|
||||
print(f" Writing rows {batch_start:,} to {batch_end:,}...")
|
||||
|
||||
for row_id in range(batch_start, batch_end):
|
||||
@@ -146,7 +186,8 @@ def generate_jsonl(output_path: Path, num_rows: int, num_cols: int):
|
||||
batch_size = 10000
|
||||
for batch_start in range(0, num_rows, batch_size):
|
||||
batch_end = min(batch_start + batch_size, num_rows)
|
||||
if batch_start % 50000 == 0:
|
||||
# Print progress every 5 batches (every 50k rows)
|
||||
if batch_start % (batch_size * 5) == 0:
|
||||
print(f" Writing rows {batch_start:,} to {batch_end:,}...")
|
||||
|
||||
for row_id in range(batch_start, batch_end):
|
||||
@@ -172,8 +213,10 @@ def generate_parquet(output_path: Path, num_rows: int, num_cols: int):
|
||||
|
||||
for batch_start in range(0, num_rows, batch_size):
|
||||
batch_end = min(batch_start + batch_size, num_rows)
|
||||
# Print progress every 100k rows (every 2 batches of 50k)
|
||||
if batch_start % 100000 == 0:
|
||||
print(f" Processing rows {batch_start:,} to {batch_end:,}...")
|
||||
progress_end = min(batch_start + 100000, num_rows)
|
||||
print(f" Processing rows {batch_start:,} to {progress_end:,}...")
|
||||
|
||||
# Generate batch data
|
||||
data = []
|
||||
|
||||
@@ -16,6 +16,49 @@ import pyarrow.parquet as pq
|
||||
from fsspec.implementations.http import HTTPFileSystem
|
||||
|
||||
|
||||
async def resolve_url_redirects(url: str) -> str:
|
||||
"""
|
||||
Resolve URL redirects by following 302 responses.
|
||||
|
||||
For resolve URLs that return 302, we need to follow the redirect
|
||||
and use the final S3 presigned URL for fsspec/DuckDB.
|
||||
Otherwise, they keep hitting our backend for every range request.
|
||||
|
||||
Uses GET with streaming (no redirect following) to detect 302 responses
|
||||
without actually downloading content.
|
||||
|
||||
Args:
|
||||
url: Original URL (may return 302)
|
||||
|
||||
Returns:
|
||||
Final URL after following redirects (or original if no redirect)
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0, follow_redirects=False) as client:
|
||||
# Send GET request with streaming, but don't follow redirects
|
||||
# This gives us the correct 302 response that GET would return
|
||||
async with client.stream("GET", url) as response:
|
||||
# Check for redirect status codes
|
||||
if response.status_code in [301, 302, 303, 307, 308]:
|
||||
location = response.headers.get("Location")
|
||||
if location:
|
||||
print(f"Resolved redirect: {url[:50]}... -> {location[:50]}...")
|
||||
# Close stream immediately without reading content
|
||||
await response.aclose()
|
||||
return location
|
||||
|
||||
# For other status codes (200, 4xx, 5xx), use original URL
|
||||
# Close stream without reading content
|
||||
await response.aclose()
|
||||
return url
|
||||
|
||||
except Exception as e:
|
||||
# If request fails, fall back to original URL
|
||||
# fsspec will handle the actual request
|
||||
print(f"Warning: Could not resolve redirects for {url[:50]}...: {e}")
|
||||
return url
|
||||
|
||||
|
||||
class ParserError(Exception):
|
||||
"""Base exception for parser errors."""
|
||||
|
||||
@@ -23,14 +66,17 @@ class ParserError(Exception):
|
||||
|
||||
|
||||
class CSVParser:
|
||||
"""Stream CSV files from URL."""
|
||||
"""Parse CSV files using DuckDB (non-blocking, efficient)."""
|
||||
|
||||
@staticmethod
|
||||
async def parse(
|
||||
url: str, max_rows: int = 1000, delimiter: str = ","
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Parse CSV file from URL.
|
||||
Parse CSV file from URL using DuckDB.
|
||||
|
||||
DuckDB supports CSV with automatic type detection and HTTP range requests.
|
||||
Much faster and more robust than manual parsing.
|
||||
|
||||
Args:
|
||||
url: File URL (presigned S3 URL)
|
||||
@@ -42,78 +88,72 @@ class CSVParser:
|
||||
"columns": ["col1", "col2", ...],
|
||||
"rows": [[val1, val2, ...], ...],
|
||||
"total_rows": N,
|
||||
"truncated": bool
|
||||
"truncated": bool,
|
||||
"file_size": N
|
||||
}
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
async with client.stream("GET", url) as response:
|
||||
response.raise_for_status()
|
||||
# Resolve redirects first
|
||||
resolved_url = await resolve_url_redirects(url)
|
||||
|
||||
# Get file size if available
|
||||
content_length = response.headers.get("content-length")
|
||||
file_size = int(content_length) if content_length else None
|
||||
def _parse_csv_sync(url: str, max_rows: int, delimiter: str) -> dict[str, Any]:
|
||||
"""Synchronous CSV parsing with DuckDB (runs in thread pool)."""
|
||||
import duckdb
|
||||
|
||||
# Read in chunks
|
||||
buffer = ""
|
||||
rows = []
|
||||
columns = None
|
||||
line_count = 0
|
||||
conn = duckdb.connect(":memory:")
|
||||
conn.execute("INSTALL httpfs")
|
||||
conn.execute("LOAD httpfs")
|
||||
|
||||
async for chunk in response.aiter_text():
|
||||
buffer += chunk
|
||||
try:
|
||||
# Read CSV with DuckDB
|
||||
query = f"""
|
||||
SELECT * FROM read_csv(
|
||||
'{url}',
|
||||
delim='{delimiter}',
|
||||
header=true,
|
||||
auto_detect=true
|
||||
)
|
||||
LIMIT {max_rows}
|
||||
"""
|
||||
|
||||
# Process complete lines
|
||||
while "\n" in buffer:
|
||||
line, buffer = buffer.split("\n", 1)
|
||||
result = conn.execute(query).fetchall()
|
||||
columns = [desc[0] for desc in conn.description]
|
||||
|
||||
if not line.strip():
|
||||
continue
|
||||
# Get total row count
|
||||
try:
|
||||
count_query = f"SELECT COUNT(*) FROM read_csv('{url}', delim='{delimiter}', header=true)"
|
||||
total_rows = conn.execute(count_query).fetchone()[0]
|
||||
except Exception:
|
||||
total_rows = len(result)
|
||||
|
||||
# Parse CSV line
|
||||
try:
|
||||
parsed = list(csv.reader([line], delimiter=delimiter))[0]
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# First line is header
|
||||
if columns is None:
|
||||
columns = parsed
|
||||
continue
|
||||
|
||||
rows.append(parsed)
|
||||
line_count += 1
|
||||
|
||||
if line_count >= max_rows:
|
||||
break
|
||||
|
||||
if line_count >= max_rows:
|
||||
break
|
||||
|
||||
# Process remaining buffer
|
||||
if buffer.strip() and line_count < max_rows and columns is not None:
|
||||
try:
|
||||
parsed = list(csv.reader([buffer], delimiter=delimiter))[0]
|
||||
rows.append(parsed)
|
||||
line_count += 1
|
||||
except Exception:
|
||||
pass
|
||||
conn.close()
|
||||
|
||||
return {
|
||||
"columns": columns or [],
|
||||
"rows": rows,
|
||||
"total_rows": line_count,
|
||||
"truncated": line_count >= max_rows,
|
||||
"file_size": file_size,
|
||||
"columns": columns,
|
||||
"rows": [list(row) for row in result],
|
||||
"total_rows": total_rows,
|
||||
"truncated": len(result) >= max_rows,
|
||||
"file_size": None,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
conn.close()
|
||||
raise ParserError(f"Failed to parse CSV with DuckDB: {e}")
|
||||
|
||||
# Run in thread pool to avoid blocking event loop
|
||||
return await asyncio.to_thread(
|
||||
_parse_csv_sync, resolved_url, max_rows, delimiter
|
||||
)
|
||||
|
||||
|
||||
class JSONLParser:
|
||||
"""Stream JSONL (newline-delimited JSON) files from URL."""
|
||||
"""Parse JSONL files using DuckDB (non-blocking, efficient)."""
|
||||
|
||||
@staticmethod
|
||||
async def parse(url: str, max_rows: int = 1000) -> dict[str, Any]:
|
||||
"""
|
||||
Parse JSONL file from URL.
|
||||
Parse JSONL file from URL using DuckDB.
|
||||
|
||||
DuckDB's read_ndjson supports newline-delimited JSON with HTTP URLs.
|
||||
|
||||
Args:
|
||||
url: File URL
|
||||
@@ -124,98 +164,67 @@ class JSONLParser:
|
||||
"columns": ["col1", "col2", ...],
|
||||
"rows": [[val1, val2, ...], ...],
|
||||
"total_rows": N,
|
||||
"truncated": bool
|
||||
"truncated": bool,
|
||||
"file_size": N
|
||||
}
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
async with client.stream("GET", url) as response:
|
||||
response.raise_for_status()
|
||||
# Resolve redirects first
|
||||
resolved_url = await resolve_url_redirects(url)
|
||||
|
||||
content_length = response.headers.get("content-length")
|
||||
file_size = int(content_length) if content_length else None
|
||||
def _parse_jsonl_sync(url: str, max_rows: int) -> dict[str, Any]:
|
||||
"""Synchronous JSONL parsing with DuckDB (runs in thread pool)."""
|
||||
import duckdb
|
||||
|
||||
buffer = ""
|
||||
rows = []
|
||||
columns = set()
|
||||
line_count = 0
|
||||
conn = duckdb.connect(":memory:")
|
||||
conn.execute("INSTALL httpfs")
|
||||
conn.execute("LOAD httpfs")
|
||||
conn.execute("INSTALL json")
|
||||
conn.execute("LOAD json")
|
||||
|
||||
async for chunk in response.aiter_text():
|
||||
buffer += chunk
|
||||
try:
|
||||
# Read JSONL with DuckDB (newline-delimited JSON)
|
||||
query = f"""
|
||||
SELECT * FROM read_ndjson('{url}')
|
||||
LIMIT {max_rows}
|
||||
"""
|
||||
|
||||
while "\n" in buffer:
|
||||
line, buffer = buffer.split("\n", 1)
|
||||
result = conn.execute(query).fetchall()
|
||||
columns = [desc[0] for desc in conn.description]
|
||||
|
||||
if not line.strip():
|
||||
continue
|
||||
# Get total row count
|
||||
try:
|
||||
count_query = f"SELECT COUNT(*) FROM read_ndjson('{url}')"
|
||||
total_rows = conn.execute(count_query).fetchone()[0]
|
||||
except Exception:
|
||||
total_rows = len(result)
|
||||
|
||||
try:
|
||||
obj = json.loads(line)
|
||||
if isinstance(obj, dict):
|
||||
# Collect all keys
|
||||
columns.update(obj.keys())
|
||||
rows.append(obj)
|
||||
line_count += 1
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if line_count >= max_rows:
|
||||
break
|
||||
|
||||
if line_count >= max_rows:
|
||||
break
|
||||
|
||||
# Process remaining buffer
|
||||
if buffer.strip() and line_count < max_rows:
|
||||
try:
|
||||
obj = json.loads(buffer)
|
||||
if isinstance(obj, dict):
|
||||
columns.update(obj.keys())
|
||||
rows.append(obj)
|
||||
line_count += 1
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Convert to columnar format
|
||||
# Sort columns by completeness (most complete first), then alphabetically
|
||||
columns_list = sorted(columns)
|
||||
|
||||
# Calculate completeness for each column
|
||||
column_completeness = {}
|
||||
for col in columns_list:
|
||||
non_null_count = sum(1 for row in rows if row.get(col) is not None)
|
||||
column_completeness[col] = non_null_count
|
||||
|
||||
# Sort by completeness (descending), then alphabetically
|
||||
columns_list = sorted(
|
||||
columns_list, key=lambda col: (-column_completeness[col], col)
|
||||
)
|
||||
|
||||
rows_list = [[row.get(col) for col in columns_list] for row in rows]
|
||||
conn.close()
|
||||
|
||||
return {
|
||||
"columns": columns_list,
|
||||
"rows": rows_list,
|
||||
"total_rows": line_count,
|
||||
"truncated": line_count >= max_rows,
|
||||
"file_size": file_size,
|
||||
"columns": columns,
|
||||
"rows": [list(row) for row in result],
|
||||
"total_rows": total_rows,
|
||||
"truncated": len(result) >= max_rows,
|
||||
"file_size": None,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
conn.close()
|
||||
raise ParserError(f"Failed to parse JSONL with DuckDB: {e}")
|
||||
|
||||
# Run in thread pool to avoid blocking event loop
|
||||
return await asyncio.to_thread(_parse_jsonl_sync, resolved_url, max_rows)
|
||||
|
||||
|
||||
class JSONParser:
|
||||
"""
|
||||
Parse JSON array files from URL.
|
||||
|
||||
NOTE: This parser is DEPRECATED and should not be used for large files!
|
||||
It requires loading the entire file into memory.
|
||||
Use JSONL format instead for streaming support.
|
||||
"""
|
||||
"""Parse JSON array files using DuckDB (non-blocking)."""
|
||||
|
||||
@staticmethod
|
||||
async def parse(url: str, max_rows: int = 1000) -> dict[str, Any]:
|
||||
"""
|
||||
Parse JSON file from URL.
|
||||
Parse JSON array file from URL using DuckDB.
|
||||
|
||||
Expects format: [{"col1": val1, ...}, ...]
|
||||
DuckDB's read_json supports JSON arrays with automatic schema detection.
|
||||
|
||||
Args:
|
||||
url: File URL
|
||||
@@ -224,75 +233,66 @@ class JSONParser:
|
||||
Returns:
|
||||
Same format as JSONLParser
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
# Resolve redirects first
|
||||
resolved_url = await resolve_url_redirects(url)
|
||||
|
||||
file_size = len(response.content)
|
||||
def _parse_json_sync(url: str, max_rows: int) -> dict[str, Any]:
|
||||
"""Synchronous JSON parsing with DuckDB (runs in thread pool)."""
|
||||
import duckdb
|
||||
|
||||
conn = duckdb.connect(":memory:")
|
||||
conn.execute("INSTALL httpfs")
|
||||
conn.execute("LOAD httpfs")
|
||||
conn.execute("INSTALL json")
|
||||
conn.execute("LOAD json")
|
||||
|
||||
try:
|
||||
data = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise ParserError(f"Invalid JSON: {e}")
|
||||
# Read JSON array with DuckDB
|
||||
query = f"""
|
||||
SELECT * FROM read_json('{url}', format='array')
|
||||
LIMIT {max_rows}
|
||||
"""
|
||||
|
||||
if not isinstance(data, list):
|
||||
raise ParserError("JSON must be an array of objects")
|
||||
result = conn.execute(query).fetchall()
|
||||
columns = [desc[0] for desc in conn.description]
|
||||
|
||||
# Limit rows
|
||||
rows_data = data[:max_rows]
|
||||
# Get total row count
|
||||
try:
|
||||
count_query = (
|
||||
f"SELECT COUNT(*) FROM read_json('{url}', format='array')"
|
||||
)
|
||||
total_rows = conn.execute(count_query).fetchone()[0]
|
||||
except Exception:
|
||||
total_rows = len(result)
|
||||
|
||||
# Extract columns
|
||||
columns = set()
|
||||
for row in rows_data:
|
||||
if isinstance(row, dict):
|
||||
columns.update(row.keys())
|
||||
conn.close()
|
||||
|
||||
columns_list = sorted(columns)
|
||||
return {
|
||||
"columns": columns,
|
||||
"rows": [list(row) for row in result],
|
||||
"total_rows": total_rows,
|
||||
"truncated": len(result) >= max_rows,
|
||||
"file_size": None,
|
||||
}
|
||||
|
||||
# Calculate completeness for each column
|
||||
column_completeness = {}
|
||||
for col in columns_list:
|
||||
non_null_count = sum(
|
||||
1
|
||||
for row in rows_data
|
||||
if isinstance(row, dict) and row.get(col) is not None
|
||||
)
|
||||
column_completeness[col] = non_null_count
|
||||
except Exception as e:
|
||||
conn.close()
|
||||
raise ParserError(f"Failed to parse JSON with DuckDB: {e}")
|
||||
|
||||
# Sort by completeness (descending), then alphabetically
|
||||
columns_list = sorted(
|
||||
columns_list, key=lambda col: (-column_completeness[col], col)
|
||||
)
|
||||
|
||||
rows_list = [
|
||||
[
|
||||
row.get(col) if isinstance(row, dict) else None
|
||||
for col in columns_list
|
||||
]
|
||||
for row in rows_data
|
||||
]
|
||||
|
||||
return {
|
||||
"columns": columns_list,
|
||||
"rows": rows_list,
|
||||
"total_rows": len(rows_data),
|
||||
"truncated": len(data) > max_rows,
|
||||
"file_size": file_size,
|
||||
}
|
||||
# Run in thread pool to avoid blocking event loop
|
||||
return await asyncio.to_thread(_parse_json_sync, resolved_url, max_rows)
|
||||
|
||||
|
||||
class ParquetParser:
|
||||
"""Parse Parquet files using PyArrow with fsspec for range requests."""
|
||||
"""Parse Parquet files using DuckDB (non-blocking, most efficient)."""
|
||||
|
||||
@staticmethod
|
||||
async def parse(url: str, max_rows: int = 1000) -> dict[str, Any]:
|
||||
"""
|
||||
Parse Parquet file from URL using fsspec + PyArrow.
|
||||
Parse Parquet file from URL using DuckDB.
|
||||
|
||||
fsspec provides HTTP file-like object with automatic range requests.
|
||||
PyArrow reads only footer + first row group, not entire file!
|
||||
|
||||
Uses asyncio.to_thread to run blocking fsspec operations without blocking event loop.
|
||||
DuckDB is the most efficient way to read Parquet with HTTP range requests.
|
||||
It only reads the necessary row groups and columns, not the entire file.
|
||||
|
||||
Args:
|
||||
url: File URL (HTTP/HTTPS, including S3 presigned URLs)
|
||||
@@ -301,49 +301,51 @@ class ParquetParser:
|
||||
Returns:
|
||||
Same format as other parsers
|
||||
"""
|
||||
# Resolve redirects first (302 from resolve endpoint -> S3 presigned URL)
|
||||
# This prevents DuckDB from repeatedly hitting our backend
|
||||
resolved_url = await resolve_url_redirects(url)
|
||||
|
||||
def _parse_parquet_sync(url: str, max_rows: int) -> dict[str, Any]:
|
||||
"""Synchronous Parquet parsing (runs in thread pool)."""
|
||||
# Create sync HTTP filesystem
|
||||
fs = HTTPFileSystem()
|
||||
"""Synchronous Parquet parsing with DuckDB (runs in thread pool)."""
|
||||
import duckdb
|
||||
|
||||
# Open file with fsspec (handles range requests automatically!)
|
||||
with fs.open(url, mode="rb") as f:
|
||||
# PyArrow reads with range requests via fsspec
|
||||
parquet_file = pq.ParquetFile(f)
|
||||
conn = duckdb.connect(":memory:")
|
||||
conn.execute("INSTALL httpfs")
|
||||
conn.execute("LOAD httpfs")
|
||||
|
||||
# Read only first row group (not entire file!)
|
||||
table = parquet_file.read_row_group(0)
|
||||
try:
|
||||
# Read Parquet with DuckDB (uses HTTP range requests automatically!)
|
||||
query = f"SELECT * FROM read_parquet('{url}') LIMIT {max_rows}"
|
||||
|
||||
# Convert to list of dicts
|
||||
data = table.to_pylist()
|
||||
result = conn.execute(query).fetchall()
|
||||
columns = [desc[0] for desc in conn.description]
|
||||
|
||||
# Limit rows
|
||||
limited_data = data[:max_rows]
|
||||
# Get total row count (efficient - reads only metadata)
|
||||
try:
|
||||
count_query = f"SELECT COUNT(*) FROM read_parquet('{url}')"
|
||||
total_rows = conn.execute(count_query).fetchone()[0]
|
||||
except Exception:
|
||||
total_rows = len(result)
|
||||
|
||||
# Get columns
|
||||
columns = table.schema.names
|
||||
|
||||
# Convert to row format
|
||||
rows = [[row[col] for col in columns] for row in limited_data]
|
||||
|
||||
# Get total rows from metadata
|
||||
total_rows = parquet_file.metadata.num_rows
|
||||
|
||||
# Get file size
|
||||
file_size = parquet_file.metadata.serialized_size
|
||||
conn.close()
|
||||
|
||||
return {
|
||||
"columns": columns,
|
||||
"rows": rows,
|
||||
"rows": [list(row) for row in result],
|
||||
"total_rows": total_rows,
|
||||
"truncated": total_rows > max_rows,
|
||||
"file_size": file_size,
|
||||
"truncated": len(result) >= max_rows or total_rows > max_rows,
|
||||
"file_size": None,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
conn.close()
|
||||
raise ParserError(f"Failed to parse Parquet with DuckDB: {e}")
|
||||
|
||||
try:
|
||||
# Run in thread pool to avoid blocking event loop
|
||||
result = await asyncio.to_thread(_parse_parquet_sync, url, max_rows)
|
||||
result = await asyncio.to_thread(
|
||||
_parse_parquet_sync, resolved_url, max_rows
|
||||
)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -17,6 +17,8 @@ from typing import Any
|
||||
|
||||
import duckdb
|
||||
|
||||
from kohakuhub.datasetviewer.parsers import resolve_url_redirects
|
||||
|
||||
|
||||
class SQLQueryError(Exception):
|
||||
"""SQL query execution error."""
|
||||
@@ -131,10 +133,15 @@ async def execute_sql_query(
|
||||
conn.close()
|
||||
raise SQLQueryError(f"Query execution failed: {e}")
|
||||
|
||||
# Resolve redirects first (302 from resolve endpoint -> S3 presigned URL)
|
||||
# This prevents DuckDB from repeatedly hitting our backend for range requests
|
||||
resolved_url = await resolve_url_redirects(url)
|
||||
|
||||
try:
|
||||
# Run in thread pool (DuckDB is synchronous)
|
||||
# Use resolved_url (not original url) to avoid repeated backend hits
|
||||
result = await asyncio.to_thread(
|
||||
_execute_query_sync, url, query, file_format, max_rows
|
||||
_execute_query_sync, resolved_url, query, file_format, max_rows
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
Reference in New Issue
Block a user