mirror of
https://github.com/KohakuBlueleaf/KohakuHub.git
synced 2026-03-11 17:34:08 -05:00
add local hosting for kohakuboard
This commit is contained in:
@@ -107,7 +107,7 @@ async def get_experiment_summary(experiment_id: str):
|
||||
"scalars": summary["available_metrics"],
|
||||
"media": summary["available_media"],
|
||||
"tables": summary["available_tables"],
|
||||
"histograms": [], # Not yet implemented in board storage
|
||||
"histograms": summary["available_histograms"],
|
||||
},
|
||||
}
|
||||
except FileNotFoundError:
|
||||
@@ -117,9 +117,13 @@ async def get_experiment_summary(experiment_id: str):
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/experiments/{experiment_id}/scalars/{metric_name}")
|
||||
@router.get("/experiments/{experiment_id}/scalars/{metric_name:path}")
|
||||
async def get_scalar_metric(experiment_id: str, metric_name: str):
|
||||
"""Get scalar metric as step-value pairs"""
|
||||
"""Get scalar metric as step-value pairs
|
||||
|
||||
Note: metric_name can contain slashes (e.g., "train/loss")
|
||||
FastAPI path parameter automatically URL-decodes it
|
||||
"""
|
||||
logger_api.info(f"Fetching scalar '{metric_name}' for experiment: {experiment_id}")
|
||||
|
||||
try:
|
||||
@@ -143,7 +147,7 @@ async def get_scalar_metric(experiment_id: str, metric_name: str):
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/experiments/{experiment_id}/media/{media_name}")
|
||||
@router.get("/experiments/{experiment_id}/media/{media_name:path}")
|
||||
async def get_media_log(experiment_id: str, media_name: str):
|
||||
"""Get media log entries"""
|
||||
logger_api.info(f"Fetching media '{media_name}' for experiment: {experiment_id}")
|
||||
@@ -180,7 +184,7 @@ async def get_media_log(experiment_id: str, media_name: str):
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/experiments/{experiment_id}/tables/{table_name}")
|
||||
@router.get("/experiments/{experiment_id}/tables/{table_name:path}")
|
||||
async def get_table_log(experiment_id: str, table_name: str):
|
||||
"""Get table log entries"""
|
||||
logger_api.info(f"Fetching table '{table_name}' for experiment: {experiment_id}")
|
||||
@@ -203,12 +207,27 @@ async def get_table_log(experiment_id: str, table_name: str):
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/experiments/{experiment_id}/histograms/{histogram_name}")
|
||||
@router.get("/experiments/{experiment_id}/histograms/{histogram_name:path}")
|
||||
async def get_histogram_log(experiment_id: str, histogram_name: str):
|
||||
"""Get histogram log entries"""
|
||||
logger_api.info(
|
||||
f"Fetching histogram '{histogram_name}' for experiment: {experiment_id}"
|
||||
)
|
||||
|
||||
# Histograms not yet implemented in board storage
|
||||
raise HTTPException(status_code=501, detail="Histograms not yet implemented")
|
||||
try:
|
||||
board_dir = Path(cfg.app.board_data_dir) / experiment_id
|
||||
reader = BoardReader(board_dir)
|
||||
data = reader.get_histogram_data(histogram_name)
|
||||
|
||||
return {
|
||||
"experiment_id": experiment_id,
|
||||
"histogram_name": histogram_name,
|
||||
"data": data,
|
||||
}
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Experiment not found")
|
||||
except Exception as e:
|
||||
logger_api.error(
|
||||
f"Failed to get histogram {histogram_name} for {experiment_id}: {e}"
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -63,11 +63,12 @@ class BoardReader:
|
||||
|
||||
# Return all columns (including step, global_step, timestamp for x-axis)
|
||||
# step, global_step, timestamp should be first for better UX
|
||||
# Also convert "__" back to "/" for namespace support
|
||||
axis_cols = [
|
||||
col for col in columns if col in ("step", "global_step", "timestamp")
|
||||
]
|
||||
other_cols = [
|
||||
col
|
||||
col.replace("__", "/") # Convert back to namespace format
|
||||
for col in columns
|
||||
if col not in ("step", "global_step", "timestamp")
|
||||
]
|
||||
@@ -91,15 +92,16 @@ class BoardReader:
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
# Special handling for step/global_step/timestamp - they're already included
|
||||
# Escape metric name (convert "/" to "__" for DuckDB column name)
|
||||
escaped_metric = metric.replace("/", "__")
|
||||
|
||||
# Build query - always select step, global_step, timestamp, and the requested metric
|
||||
query = (
|
||||
f"SELECT step, global_step, timestamp, {metric} as value FROM metrics"
|
||||
)
|
||||
query = f'SELECT step, global_step, timestamp, "{escaped_metric}" as value FROM metrics'
|
||||
|
||||
# For regular metrics, filter out NULLs
|
||||
# For step/global_step/timestamp, include all rows (they're never NULL)
|
||||
if metric not in ("step", "global_step", "timestamp"):
|
||||
query += f" WHERE {metric} IS NOT NULL"
|
||||
query += f' WHERE "{escaped_metric}" IS NOT NULL'
|
||||
|
||||
if limit:
|
||||
query += f" LIMIT {limit}"
|
||||
@@ -225,6 +227,64 @@ class BoardReader:
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_available_histogram_names(self) -> List[str]:
|
||||
"""Get list of available histogram log names
|
||||
|
||||
Returns:
|
||||
List of unique histogram log names
|
||||
"""
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
result = conn.execute(
|
||||
"SELECT DISTINCT name FROM histograms ORDER BY name"
|
||||
).fetchall()
|
||||
return [row[0] for row in result]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to query histograms table: {e}")
|
||||
return []
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_histogram_data(
|
||||
self, name: str, limit: Optional[int] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get histogram data for a specific log name
|
||||
|
||||
Args:
|
||||
name: Histogram log name
|
||||
limit: Optional row limit
|
||||
|
||||
Returns:
|
||||
List of dicts with step, global_step, bins, values
|
||||
"""
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
query = f"SELECT * FROM histograms WHERE name = ?"
|
||||
if limit:
|
||||
query += f" LIMIT {limit}"
|
||||
|
||||
result = conn.execute(query, [name]).fetchall()
|
||||
|
||||
# Get column names
|
||||
columns = [desc[0] for desc in conn.description]
|
||||
|
||||
# Convert to list of dicts, parsing JSON fields
|
||||
data = []
|
||||
for row in result:
|
||||
row_dict = dict(zip(columns, row))
|
||||
|
||||
# Parse JSON bins and counts fields
|
||||
if row_dict.get("bins"):
|
||||
row_dict["bins"] = json.loads(row_dict["bins"])
|
||||
if row_dict.get("counts"):
|
||||
row_dict["counts"] = json.loads(row_dict["counts"])
|
||||
|
||||
data.append(row_dict)
|
||||
|
||||
return data
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_media_file_path(self, filename: str) -> Optional[Path]:
|
||||
"""Get full path to media file
|
||||
|
||||
@@ -262,14 +322,23 @@ class BoardReader:
|
||||
except Exception:
|
||||
tables_count = 0
|
||||
|
||||
try:
|
||||
histograms_count = conn.execute(
|
||||
"SELECT COUNT(*) FROM histograms"
|
||||
).fetchone()[0]
|
||||
except Exception:
|
||||
histograms_count = 0
|
||||
|
||||
return {
|
||||
"metadata": metadata,
|
||||
"metrics_count": metrics_count,
|
||||
"media_count": media_count,
|
||||
"tables_count": tables_count,
|
||||
"histograms_count": histograms_count,
|
||||
"available_metrics": self.get_available_metrics(),
|
||||
"available_media": self.get_available_media_names(),
|
||||
"available_tables": self.get_available_table_names(),
|
||||
"available_histograms": self.get_available_histogram_names(),
|
||||
}
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
Reference in New Issue
Block a user