diff --git a/src/kohakuboard/api/boards.py b/src/kohakuboard/api/boards.py index 8384b7f..0d08bec 100644 --- a/src/kohakuboard/api/boards.py +++ b/src/kohakuboard/api/boards.py @@ -126,7 +126,8 @@ async def get_scalar_data( board_dir = Path(cfg.app.board_data_dir) / board_id reader = BoardReader(board_dir) data = reader.get_scalar_data(metric, limit=limit) - return {"metric": metric, "data": data} + # data is now columnar format: {steps: [], global_steps: [], timestamps: [], values: []} + return {"metric": metric, **data} except FileNotFoundError as e: logger_api.warning(f"Board not found: {board_id}") raise HTTPException(status_code=404, detail=str(e)) diff --git a/src/kohakuboard/api/runs.py b/src/kohakuboard/api/runs.py index 7064b63..c73469a 100644 --- a/src/kohakuboard/api/runs.py +++ b/src/kohakuboard/api/runs.py @@ -156,7 +156,8 @@ async def get_scalar_data( reader = BoardReader(run_path) data = reader.get_scalar_data(metric, limit=limit) - return {"metric": metric, "data": data} + # data is now columnar format: {steps: [], global_steps: [], timestamps: [], values: []} + return {"metric": metric, **data} @router.get("/projects/{project}/runs/{run_id}/media") diff --git a/src/kohakuboard/api/utils/board_reader.py b/src/kohakuboard/api/utils/board_reader.py index bd9b686..55b04cc 100644 --- a/src/kohakuboard/api/utils/board_reader.py +++ b/src/kohakuboard/api/utils/board_reader.py @@ -1,6 +1,7 @@ """Utility functions for reading board data from DuckDB files""" import json +import math from pathlib import Path from typing import Any, Dict, List, Optional @@ -19,9 +20,21 @@ class BoardReader: """ self.board_dir = Path(board_dir) self.metadata_path = self.board_dir / "metadata.json" - self.db_path = self.board_dir / "data" / "board.duckdb" self.media_dir = self.board_dir / "media" + # Multi-file DuckDB structure (NEW) + # Try new structure first, fall back to legacy single file + self.metrics_db = self.board_dir / "data" / "metrics.duckdb" + self.media_db = self.board_dir / "data" / "media.duckdb" + self.tables_db = self.board_dir / "data" / "tables.duckdb" + self.histograms_db = self.board_dir / "data" / "histograms.duckdb" + + # Legacy single file (for backward compatibility) + self.legacy_db = self.board_dir / "data" / "board.duckdb" + + # Determine which structure to use + self.use_legacy = self.legacy_db.exists() and not self.metrics_db.exists() + # Validate paths if not self.board_dir.exists(): raise FileNotFoundError(f"Board directory not found: {board_dir}") @@ -38,16 +51,53 @@ class BoardReader: with open(self.metadata_path, "r") as f: return json.load(f) - def _get_connection(self) -> duckdb.DuckDBPyConnection: - """Get read-only DuckDB connection + def _get_metrics_connection(self) -> duckdb.DuckDBPyConnection: + """Get read-only connection to metrics database Returns: DuckDB connection """ - if not self.db_path.exists(): - raise FileNotFoundError(f"Database file not found: {self.db_path}") + db_path = self.legacy_db if self.use_legacy else self.metrics_db + if not db_path.exists(): + raise FileNotFoundError(f"Metrics database not found: {db_path}") - return duckdb.connect(str(self.db_path), read_only=True) + return duckdb.connect(str(db_path), read_only=True) + + def _get_media_connection(self) -> duckdb.DuckDBPyConnection: + """Get read-only connection to media database + + Returns: + DuckDB connection + """ + db_path = self.legacy_db if self.use_legacy else self.media_db + if not db_path.exists(): + raise FileNotFoundError(f"Media database not found: {db_path}") + + return duckdb.connect(str(db_path), read_only=True) + + def _get_tables_connection(self) -> duckdb.DuckDBPyConnection: + """Get read-only connection to tables database + + Returns: + DuckDB connection + """ + db_path = self.legacy_db if self.use_legacy else self.tables_db + if not db_path.exists(): + raise FileNotFoundError(f"Tables database not found: {db_path}") + + return duckdb.connect(str(db_path), read_only=True) + + def _get_histograms_connection(self) -> duckdb.DuckDBPyConnection: + """Get read-only connection to histograms database + + Returns: + DuckDB connection + """ + db_path = self.legacy_db if self.use_legacy else self.histograms_db + if not db_path.exists(): + raise FileNotFoundError(f"Histograms database not found: {db_path}") + + return duckdb.connect(str(db_path), read_only=True) def get_available_metrics(self) -> List[str]: """Get list of available scalar metrics @@ -55,7 +105,7 @@ class BoardReader: Returns: List of metric names (INCLUDING step/global_step for x-axis selection) """ - conn = self._get_connection() + conn = self._get_metrics_connection() try: # Get all columns from metrics table result = conn.execute("PRAGMA table_info(metrics)").fetchall() @@ -79,7 +129,7 @@ class BoardReader: def get_scalar_data( self, metric: str, limit: Optional[int] = None - ) -> List[Dict[str, Any]]: + ) -> Dict[str, List]: """Get scalar data for a specific metric Args: @@ -87,37 +137,54 @@ class BoardReader: limit: Optional row limit Returns: - List of dicts with step, global_step, value + Dict with columnar format: {steps: [], global_steps: [], timestamps: [], values: []} """ - conn = self._get_connection() + conn = self._get_metrics_connection() try: - # Special handling for step/global_step/timestamp - they're already included # Escape metric name (convert "/" to "__" for DuckDB column name) escaped_metric = metric.replace("/", "__") - # Build query - always select step, global_step, timestamp, and the requested metric + # Build query - select ALL rows (don't filter out NaN!) + # This is critical: NaN values are data, not missing data query = f'SELECT step, global_step, timestamp, "{escaped_metric}" as value FROM metrics' - # For regular metrics, filter out NULLs - # For step/global_step/timestamp, include all rows (they're never NULL) - if metric not in ("step", "global_step", "timestamp"): - query += f' WHERE "{escaped_metric}" IS NOT NULL' - if limit: query += f" LIMIT {limit}" result = conn.execute(query).fetchall() - # Convert to list of dicts - return [ - { - "step": row[0], - "global_step": row[1], - "timestamp": row[2].isoformat() if row[2] else None, - "value": row[3], - } - for row in result - ] + # Convert to columnar format (more efficient than row-based) + # Format: {steps: [], global_steps: [], timestamps: [], values: []} + steps = [] + global_steps = [] + timestamps = [] + values = [] + + for row in result: + steps.append(row[0]) + global_steps.append(row[1]) + # Convert timestamp to Unix seconds (integer) for efficiency + timestamps.append(int(row[2].timestamp()) if row[2] else None) + + value = row[3] + # Convert special values to string markers for JSON compatibility + # null = sparse/missing data (not logged at this step) + # "NaN" = explicitly logged NaN value + # "Infinity"/"-Infinity" = explicitly logged inf values + if value is not None: + if math.isnan(value): + value = "NaN" + elif math.isinf(value): + value = "Infinity" if value > 0 else "-Infinity" + + values.append(value) + + return { + "steps": steps, + "global_steps": global_steps, + "timestamps": timestamps, + "values": values, + } finally: conn.close() @@ -127,7 +194,7 @@ class BoardReader: Returns: List of unique media log names """ - conn = self._get_connection() + conn = self._get_media_connection() try: result = conn.execute( "SELECT DISTINCT name FROM media ORDER BY name" @@ -151,7 +218,7 @@ class BoardReader: Returns: List of dicts with step, global_step, caption, media metadata """ - conn = self._get_connection() + conn = self._get_media_connection() try: query = f"SELECT * FROM media WHERE name = ?" if limit: @@ -173,7 +240,7 @@ class BoardReader: Returns: List of unique table log names """ - conn = self._get_connection() + conn = self._get_tables_connection() try: result = conn.execute( "SELECT DISTINCT name FROM tables ORDER BY name" @@ -197,7 +264,7 @@ class BoardReader: Returns: List of dicts with step, global_step, columns, column_types, rows """ - conn = self._get_connection() + conn = self._get_tables_connection() try: query = f"SELECT * FROM tables WHERE name = ?" if limit: @@ -233,7 +300,7 @@ class BoardReader: Returns: List of unique histogram log names """ - conn = self._get_connection() + conn = self._get_histograms_connection() try: result = conn.execute( "SELECT DISTINCT name FROM histograms ORDER BY name" @@ -257,7 +324,7 @@ class BoardReader: Returns: List of dicts with step, global_step, bins, values """ - conn = self._get_connection() + conn = self._get_histograms_connection() try: query = f"SELECT * FROM histograms WHERE name = ?" if limit: @@ -305,43 +372,55 @@ class BoardReader: Returns: Dict with metadata, metrics, media, tables counts """ - conn = self._get_connection() + metadata = self.get_metadata() + + # Count rows from each database (use separate connections) + metrics_count = 0 + media_count = 0 + tables_count = 0 + histograms_count = 0 + try: - metadata = self.get_metadata() - - # Count rows + conn = self._get_metrics_connection() metrics_count = conn.execute("SELECT COUNT(*) FROM metrics").fetchone()[0] - - try: - media_count = conn.execute("SELECT COUNT(*) FROM media").fetchone()[0] - except Exception: - media_count = 0 - - try: - tables_count = conn.execute("SELECT COUNT(*) FROM tables").fetchone()[0] - except Exception: - tables_count = 0 - - try: - histograms_count = conn.execute( - "SELECT COUNT(*) FROM histograms" - ).fetchone()[0] - except Exception: - histograms_count = 0 - - return { - "metadata": metadata, - "metrics_count": metrics_count, - "media_count": media_count, - "tables_count": tables_count, - "histograms_count": histograms_count, - "available_metrics": self.get_available_metrics(), - "available_media": self.get_available_media_names(), - "available_tables": self.get_available_table_names(), - "available_histograms": self.get_available_histogram_names(), - } - finally: conn.close() + except Exception as e: + logger.warning(f"Failed to count metrics: {e}") + + try: + conn = self._get_media_connection() + media_count = conn.execute("SELECT COUNT(*) FROM media").fetchone()[0] + conn.close() + except Exception as e: + logger.warning(f"Failed to count media: {e}") + + try: + conn = self._get_tables_connection() + tables_count = conn.execute("SELECT COUNT(*) FROM tables").fetchone()[0] + conn.close() + except Exception as e: + logger.warning(f"Failed to count tables: {e}") + + try: + conn = self._get_histograms_connection() + histograms_count = conn.execute( + "SELECT COUNT(*) FROM histograms" + ).fetchone()[0] + conn.close() + except Exception as e: + logger.warning(f"Failed to count histograms: {e}") + + return { + "metadata": metadata, + "metrics_count": metrics_count, + "media_count": media_count, + "tables_count": tables_count, + "histograms_count": histograms_count, + "available_metrics": self.get_available_metrics(), + "available_media": self.get_available_media_names(), + "available_tables": self.get_available_table_names(), + "available_histograms": self.get_available_histogram_names(), + } def list_boards(base_dir: Path) -> List[Dict[str, Any]]: diff --git a/src/kohakuboard/client/media.py b/src/kohakuboard/client/media.py index b6185c4..f4689e1 100644 --- a/src/kohakuboard/client/media.py +++ b/src/kohakuboard/client/media.py @@ -1,10 +1,12 @@ """Media handling utilities for images, videos, and audio""" import hashlib +import io import shutil from pathlib import Path from typing import Any, List, Union +import numpy as np from loguru import logger @@ -209,7 +211,6 @@ class MediaHandler: # Numpy array if hasattr(image, "__array__"): - import numpy as np arr = np.array(image) @@ -242,8 +243,6 @@ class MediaHandler: def _hash_media(self, pil_image) -> str: """Generate hash for image deduplication (also used as media ID)""" - import io - # Convert to bytes buf = io.BytesIO() pil_image.save(buf, format="PNG") diff --git a/src/kohakuboard/client/storage_duckdb.py b/src/kohakuboard/client/storage_duckdb.py index 418970a..792b0c5 100644 --- a/src/kohakuboard/client/storage_duckdb.py +++ b/src/kohakuboard/client/storage_duckdb.py @@ -1,45 +1,64 @@ -"""Storage backend using DuckDB for true incremental appends""" +"""Storage backend using DuckDB for true incremental appends + +NEW ARCHITECTURE (Multi-File): +- 4 separate DuckDB files (metrics, media, tables, histograms) +- Enables concurrent read while write (different files) +- Heavy logging isolated (histogram writes don't block scalar reads) +- Compatible with ThreadPoolExecutor for parallel writes +""" import json +import math +from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Any, Dict, List, Optional import duckdb -import pandas as pd +import numpy as np from loguru import logger class DuckDBStorage: - """DuckDB-based storage for metrics, media, and table logs + """DuckDB-based storage with multi-file architecture - Benefits over Parquet: - - True incremental append (connection.append()) - - No read overhead - - SQL queries natively - - ACID transactions - - Schema evolution via ALTER TABLE - - Single .duckdb file per board + Architecture: + - metrics.duckdb: Scalar metrics (step, global_step, timestamp, dynamic columns) + - media.duckdb: Media metadata (images, videos, audio) + - tables.duckdb: Table logs + - histograms.duckdb: Histogram data + + Benefits: + - Concurrent read/write (different files) + - Heavy logging isolation (separate queues/files) + - True incremental append + - NaN/inf preservation (direct INSERT) """ def __init__(self, base_dir: Path): - """Initialize DuckDB storage + """Initialize DuckDB storage with multi-file architecture Args: - base_dir: Base directory for database file + base_dir: Base directory for database files """ self.base_dir = base_dir self.base_dir.mkdir(parents=True, exist_ok=True) - # Single database file - self.db_file = base_dir / "board.duckdb" + # Separate database files (one per data type) + self.metrics_file = base_dir / "metrics.duckdb" + self.media_file = base_dir / "media.duckdb" + self.tables_file = base_dir / "tables.duckdb" + self.histograms_file = base_dir / "histograms.duckdb" - # Connect to database - self.conn = duckdb.connect(str(self.db_file)) + # Separate connections (one per file) + self.metrics_conn = duckdb.connect(str(self.metrics_file)) + self.media_conn = duckdb.connect(str(self.media_file)) + self.tables_conn = duckdb.connect(str(self.tables_file)) + self.histograms_conn = duckdb.connect(str(self.histograms_file)) - # Create tables + # Create tables in each database self._init_tables() - # In-memory buffers + # In-memory buffers (one per data type) self.metrics_buffer: List[Dict[str, Any]] = [] self.media_buffer: List[Dict[str, Any]] = [] self.tables_buffer: List[Dict[str, Any]] = [] @@ -50,14 +69,12 @@ class DuckDBStorage: # Flush thresholds self.flush_threshold = 10 # Metrics - self.histogram_flush_threshold = ( - 100 # Histograms (batch aggressively for performance) - ) + self.histogram_flush_threshold = 100 # Histograms (batch aggressively) def _init_tables(self): - """Initialize database tables""" + """Initialize database tables in separate files""" # Metrics table (dynamic columns added as needed) - self.conn.execute( + self.metrics_conn.execute( """ CREATE TABLE IF NOT EXISTS metrics ( step BIGINT NOT NULL, @@ -66,9 +83,10 @@ class DuckDBStorage: ) """ ) + self.metrics_conn.commit() # Media table (fixed schema) - self.conn.execute( + self.media_conn.execute( """ CREATE TABLE IF NOT EXISTS media ( step BIGINT NOT NULL, @@ -86,9 +104,10 @@ class DuckDBStorage: ) """ ) + self.media_conn.commit() # Tables table (fixed schema, JSON for table data) - self.conn.execute( + self.tables_conn.execute( """ CREATE TABLE IF NOT EXISTS tables ( step BIGINT NOT NULL, @@ -100,9 +119,10 @@ class DuckDBStorage: ) """ ) + self.tables_conn.commit() # Histograms table (pre-computed bins to save space) - self.conn.execute( + self.histograms_conn.execute( """ CREATE TABLE IF NOT EXISTS histograms ( step BIGINT NOT NULL, @@ -114,8 +134,7 @@ class DuckDBStorage: ) """ ) - - self.conn.commit() + self.histograms_conn.commit() def append_metrics( self, @@ -163,7 +182,7 @@ class DuckDBStorage: # Escape column name (replace "/" with "__" for DuckDB compatibility) escaped_col = col.replace("/", "__") # Add column as DOUBLE (works for most ML metrics) - self.conn.execute( + self.metrics_conn.execute( f'ALTER TABLE metrics ADD COLUMN IF NOT EXISTS "{escaped_col}" DOUBLE' ) self.known_metric_cols.add(col) @@ -171,7 +190,7 @@ class DuckDBStorage: except Exception as e: logger.error(f"Failed to add column {col}: {e}") - self.conn.commit() + self.metrics_conn.commit() def append_media( self, @@ -232,20 +251,26 @@ class DuckDBStorage: self.flush_tables() def flush_metrics(self): - """Flush metrics buffer to DuckDB (TRUE INCREMENTAL!)""" + """Flush metrics buffer to DuckDB (preserves NaN/inf)""" if not self.metrics_buffer: return try: - # Convert to DataFrame - df = pd.DataFrame(self.metrics_buffer) + # Direct SQL INSERT to preserve NaN/inf as IEEE 754 values (not NULL) + for row in self.metrics_buffer: + columns = list(row.keys()) + values = list(row.values()) - # TRUE INCREMENTAL APPEND - no read, just append! - self.conn.append("metrics", df, by_name=True) - self.conn.commit() + col_names = ", ".join(f'"{col}"' for col in columns) + placeholders = ", ".join("?" * len(columns)) + query = f"INSERT INTO metrics ({col_names}) VALUES ({placeholders})" + + self.metrics_conn.execute(query, values) + + self.metrics_conn.commit() logger.debug( - f"Appended {len(self.metrics_buffer)} metrics rows (INCREMENTAL)" + f"Appended {len(self.metrics_buffer)} metrics rows (preserving NaN/inf)" ) self.metrics_buffer.clear() @@ -257,16 +282,25 @@ class DuckDBStorage: logger.exception(e) def flush_media(self): - """Flush media buffer to DuckDB (TRUE INCREMENTAL!)""" + """Flush media buffer to DuckDB (direct INSERT)""" if not self.media_buffer: return try: - df = pd.DataFrame(self.media_buffer) - self.conn.append("media", df, by_name=True) - self.conn.commit() + # Direct SQL INSERT (consistent with metrics approach) + for row in self.media_buffer: + columns = list(row.keys()) + values = list(row.values()) - logger.debug(f"Appended {len(self.media_buffer)} media rows (INCREMENTAL)") + col_names = ", ".join(f'"{col}"' for col in columns) + placeholders = ", ".join("?" * len(columns)) + query = f"INSERT INTO media ({col_names}) VALUES ({placeholders})" + + self.media_conn.execute(query, values) + + self.media_conn.commit() + + logger.debug(f"Appended {len(self.media_buffer)} media rows") self.media_buffer.clear() except KeyboardInterrupt: @@ -274,18 +308,28 @@ class DuckDBStorage: self.media_buffer.clear() except Exception as e: logger.error(f"Failed to flush media: {e}") + logger.exception(e) def flush_tables(self): - """Flush tables buffer to DuckDB (TRUE INCREMENTAL!)""" + """Flush tables buffer to DuckDB (direct INSERT)""" if not self.tables_buffer: return try: - df = pd.DataFrame(self.tables_buffer) - self.conn.append("tables", df, by_name=True) - self.conn.commit() + # Direct SQL INSERT + for row in self.tables_buffer: + columns = list(row.keys()) + values = list(row.values()) - logger.debug(f"Appended {len(self.tables_buffer)} table rows (INCREMENTAL)") + col_names = ", ".join(f'"{col}"' for col in columns) + placeholders = ", ".join("?" * len(columns)) + query = f"INSERT INTO tables ({col_names}) VALUES ({placeholders})" + + self.tables_conn.execute(query, values) + + self.tables_conn.commit() + + logger.debug(f"Appended {len(self.tables_buffer)} table rows") self.tables_buffer.clear() except KeyboardInterrupt: @@ -293,6 +337,7 @@ class DuckDBStorage: self.tables_buffer.clear() except Exception as e: logger.error(f"Failed to flush tables: {e}") + logger.exception(e) def append_histogram( self, @@ -312,8 +357,6 @@ class DuckDBStorage: num_bins: Number of bins for histogram """ # Compute histogram (bins + counts) instead of storing raw values - import numpy as np - values_array = np.array(values, dtype=np.float32) counts, bin_edges = np.histogram(values_array, bins=num_bins) @@ -332,18 +375,25 @@ class DuckDBStorage: self.flush_histograms() def flush_histograms(self): - """Flush histograms buffer to DuckDB (TRUE INCREMENTAL!)""" + """Flush histograms buffer to DuckDB (direct INSERT)""" if not self.histograms_buffer: return try: - df = pd.DataFrame(self.histograms_buffer) - self.conn.append("histograms", df, by_name=True) - self.conn.commit() + # Direct SQL INSERT + for row in self.histograms_buffer: + columns = list(row.keys()) + values = list(row.values()) - logger.debug( - f"Appended {len(self.histograms_buffer)} histogram rows (INCREMENTAL)" - ) + col_names = ", ".join(f'"{col}"' for col in columns) + placeholders = ", ".join("?" * len(columns)) + query = f"INSERT INTO histograms ({col_names}) VALUES ({placeholders})" + + self.histograms_conn.execute(query, values) + + self.histograms_conn.commit() + + logger.debug(f"Appended {len(self.histograms_buffer)} histogram rows") self.histograms_buffer.clear() except KeyboardInterrupt: @@ -351,6 +401,7 @@ class DuckDBStorage: self.histograms_buffer.clear() except Exception as e: logger.error(f"Failed to flush histograms: {e}") + logger.exception(e) def flush_all(self): """Flush all buffers""" @@ -361,7 +412,13 @@ class DuckDBStorage: logger.info("Flushed all buffers to DuckDB") def close(self): - """Close database connection""" - if self.conn: - self.conn.close() - logger.debug("Closed DuckDB connection") + """Close all database connections""" + if hasattr(self, "metrics_conn") and self.metrics_conn: + self.metrics_conn.close() + if hasattr(self, "media_conn") and self.media_conn: + self.media_conn.close() + if hasattr(self, "tables_conn") and self.tables_conn: + self.tables_conn.close() + if hasattr(self, "histograms_conn") and self.histograms_conn: + self.histograms_conn.close() + logger.debug("Closed all DuckDB connections")