diff --git a/src/kohakuboard/api/routers/experiments.py b/src/kohakuboard/api/routers/experiments.py index ef21691..45e740b 100644 --- a/src/kohakuboard/api/routers/experiments.py +++ b/src/kohakuboard/api/routers/experiments.py @@ -107,7 +107,7 @@ async def get_experiment_summary(experiment_id: str): "scalars": summary["available_metrics"], "media": summary["available_media"], "tables": summary["available_tables"], - "histograms": [], # Not yet implemented in board storage + "histograms": summary["available_histograms"], }, } except FileNotFoundError: @@ -117,9 +117,13 @@ async def get_experiment_summary(experiment_id: str): raise HTTPException(status_code=500, detail=str(e)) -@router.get("/experiments/{experiment_id}/scalars/{metric_name}") +@router.get("/experiments/{experiment_id}/scalars/{metric_name:path}") async def get_scalar_metric(experiment_id: str, metric_name: str): - """Get scalar metric as step-value pairs""" + """Get scalar metric as step-value pairs + + Note: metric_name can contain slashes (e.g., "train/loss") + FastAPI path parameter automatically URL-decodes it + """ logger_api.info(f"Fetching scalar '{metric_name}' for experiment: {experiment_id}") try: @@ -143,7 +147,7 @@ async def get_scalar_metric(experiment_id: str, metric_name: str): raise HTTPException(status_code=500, detail=str(e)) -@router.get("/experiments/{experiment_id}/media/{media_name}") +@router.get("/experiments/{experiment_id}/media/{media_name:path}") async def get_media_log(experiment_id: str, media_name: str): """Get media log entries""" logger_api.info(f"Fetching media '{media_name}' for experiment: {experiment_id}") @@ -180,7 +184,7 @@ async def get_media_log(experiment_id: str, media_name: str): raise HTTPException(status_code=500, detail=str(e)) -@router.get("/experiments/{experiment_id}/tables/{table_name}") +@router.get("/experiments/{experiment_id}/tables/{table_name:path}") async def get_table_log(experiment_id: str, table_name: str): """Get table log entries""" logger_api.info(f"Fetching table '{table_name}' for experiment: {experiment_id}") @@ -203,12 +207,27 @@ async def get_table_log(experiment_id: str, table_name: str): raise HTTPException(status_code=500, detail=str(e)) -@router.get("/experiments/{experiment_id}/histograms/{histogram_name}") +@router.get("/experiments/{experiment_id}/histograms/{histogram_name:path}") async def get_histogram_log(experiment_id: str, histogram_name: str): """Get histogram log entries""" logger_api.info( f"Fetching histogram '{histogram_name}' for experiment: {experiment_id}" ) - # Histograms not yet implemented in board storage - raise HTTPException(status_code=501, detail="Histograms not yet implemented") + try: + board_dir = Path(cfg.app.board_data_dir) / experiment_id + reader = BoardReader(board_dir) + data = reader.get_histogram_data(histogram_name) + + return { + "experiment_id": experiment_id, + "histogram_name": histogram_name, + "data": data, + } + except FileNotFoundError: + raise HTTPException(status_code=404, detail="Experiment not found") + except Exception as e: + logger_api.error( + f"Failed to get histogram {histogram_name} for {experiment_id}: {e}" + ) + raise HTTPException(status_code=500, detail=str(e)) diff --git a/src/kohakuboard/api/utils/board_reader.py b/src/kohakuboard/api/utils/board_reader.py index 1a914f6..bd9b686 100644 --- a/src/kohakuboard/api/utils/board_reader.py +++ b/src/kohakuboard/api/utils/board_reader.py @@ -63,11 +63,12 @@ class BoardReader: # Return all columns (including step, global_step, timestamp for x-axis) # step, global_step, timestamp should be first for better UX + # Also convert "__" back to "/" for namespace support axis_cols = [ col for col in columns if col in ("step", "global_step", "timestamp") ] other_cols = [ - col + col.replace("__", "/") # Convert back to namespace format for col in columns if col not in ("step", "global_step", "timestamp") ] @@ -91,15 +92,16 @@ class BoardReader: conn = self._get_connection() try: # Special handling for step/global_step/timestamp - they're already included + # Escape metric name (convert "/" to "__" for DuckDB column name) + escaped_metric = metric.replace("/", "__") + # Build query - always select step, global_step, timestamp, and the requested metric - query = ( - f"SELECT step, global_step, timestamp, {metric} as value FROM metrics" - ) + query = f'SELECT step, global_step, timestamp, "{escaped_metric}" as value FROM metrics' # For regular metrics, filter out NULLs # For step/global_step/timestamp, include all rows (they're never NULL) if metric not in ("step", "global_step", "timestamp"): - query += f" WHERE {metric} IS NOT NULL" + query += f' WHERE "{escaped_metric}" IS NOT NULL' if limit: query += f" LIMIT {limit}" @@ -225,6 +227,64 @@ class BoardReader: finally: conn.close() + def get_available_histogram_names(self) -> List[str]: + """Get list of available histogram log names + + Returns: + List of unique histogram log names + """ + conn = self._get_connection() + try: + result = conn.execute( + "SELECT DISTINCT name FROM histograms ORDER BY name" + ).fetchall() + return [row[0] for row in result] + except Exception as e: + logger.warning(f"Failed to query histograms table: {e}") + return [] + finally: + conn.close() + + def get_histogram_data( + self, name: str, limit: Optional[int] = None + ) -> List[Dict[str, Any]]: + """Get histogram data for a specific log name + + Args: + name: Histogram log name + limit: Optional row limit + + Returns: + List of dicts with step, global_step, bins, values + """ + conn = self._get_connection() + try: + query = f"SELECT * FROM histograms WHERE name = ?" + if limit: + query += f" LIMIT {limit}" + + result = conn.execute(query, [name]).fetchall() + + # Get column names + columns = [desc[0] for desc in conn.description] + + # Convert to list of dicts, parsing JSON fields + data = [] + for row in result: + row_dict = dict(zip(columns, row)) + + # Parse JSON bins and counts fields + if row_dict.get("bins"): + row_dict["bins"] = json.loads(row_dict["bins"]) + if row_dict.get("counts"): + row_dict["counts"] = json.loads(row_dict["counts"]) + + data.append(row_dict) + + return data + finally: + conn.close() + def get_media_file_path(self, filename: str) -> Optional[Path]: """Get full path to media file @@ -262,14 +322,23 @@ class BoardReader: except Exception: tables_count = 0 + try: + histograms_count = conn.execute( + "SELECT COUNT(*) FROM histograms" + ).fetchone()[0] + except Exception: + histograms_count = 0 + return { "metadata": metadata, "metrics_count": metrics_count, "media_count": media_count, "tables_count": tables_count, + "histograms_count": histograms_count, "available_metrics": self.get_available_metrics(), "available_media": self.get_available_media_names(), "available_tables": self.get_available_table_names(), + "available_histograms": self.get_available_histogram_names(), } finally: conn.close()