add local hosting for kohakuboard

This commit is contained in:
Kohaku-Blueleaf
2025-10-27 00:33:41 +08:00
parent 13ee449b23
commit 52c20668bb
2 changed files with 101 additions and 13 deletions

View File

@@ -107,7 +107,7 @@ async def get_experiment_summary(experiment_id: str):
"scalars": summary["available_metrics"],
"media": summary["available_media"],
"tables": summary["available_tables"],
"histograms": [], # Not yet implemented in board storage
"histograms": summary["available_histograms"],
},
}
except FileNotFoundError:
@@ -117,9 +117,13 @@ async def get_experiment_summary(experiment_id: str):
raise HTTPException(status_code=500, detail=str(e))
@router.get("/experiments/{experiment_id}/scalars/{metric_name}")
@router.get("/experiments/{experiment_id}/scalars/{metric_name:path}")
async def get_scalar_metric(experiment_id: str, metric_name: str):
"""Get scalar metric as step-value pairs"""
"""Get scalar metric as step-value pairs
Note: metric_name can contain slashes (e.g., "train/loss")
FastAPI path parameter automatically URL-decodes it
"""
logger_api.info(f"Fetching scalar '{metric_name}' for experiment: {experiment_id}")
try:
@@ -143,7 +147,7 @@ async def get_scalar_metric(experiment_id: str, metric_name: str):
raise HTTPException(status_code=500, detail=str(e))
@router.get("/experiments/{experiment_id}/media/{media_name}")
@router.get("/experiments/{experiment_id}/media/{media_name:path}")
async def get_media_log(experiment_id: str, media_name: str):
"""Get media log entries"""
logger_api.info(f"Fetching media '{media_name}' for experiment: {experiment_id}")
@@ -180,7 +184,7 @@ async def get_media_log(experiment_id: str, media_name: str):
raise HTTPException(status_code=500, detail=str(e))
@router.get("/experiments/{experiment_id}/tables/{table_name}")
@router.get("/experiments/{experiment_id}/tables/{table_name:path}")
async def get_table_log(experiment_id: str, table_name: str):
"""Get table log entries"""
logger_api.info(f"Fetching table '{table_name}' for experiment: {experiment_id}")
@@ -203,12 +207,27 @@ async def get_table_log(experiment_id: str, table_name: str):
raise HTTPException(status_code=500, detail=str(e))
@router.get("/experiments/{experiment_id}/histograms/{histogram_name}")
@router.get("/experiments/{experiment_id}/histograms/{histogram_name:path}")
async def get_histogram_log(experiment_id: str, histogram_name: str):
"""Get histogram log entries"""
logger_api.info(
f"Fetching histogram '{histogram_name}' for experiment: {experiment_id}"
)
# Histograms not yet implemented in board storage
raise HTTPException(status_code=501, detail="Histograms not yet implemented")
try:
board_dir = Path(cfg.app.board_data_dir) / experiment_id
reader = BoardReader(board_dir)
data = reader.get_histogram_data(histogram_name)
return {
"experiment_id": experiment_id,
"histogram_name": histogram_name,
"data": data,
}
except FileNotFoundError:
raise HTTPException(status_code=404, detail="Experiment not found")
except Exception as e:
logger_api.error(
f"Failed to get histogram {histogram_name} for {experiment_id}: {e}"
)
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -63,11 +63,12 @@ class BoardReader:
# Return all columns (including step, global_step, timestamp for x-axis)
# step, global_step, timestamp should be first for better UX
# Also convert "__" back to "/" for namespace support
axis_cols = [
col for col in columns if col in ("step", "global_step", "timestamp")
]
other_cols = [
col
col.replace("__", "/") # Convert back to namespace format
for col in columns
if col not in ("step", "global_step", "timestamp")
]
@@ -91,15 +92,16 @@ class BoardReader:
conn = self._get_connection()
try:
# Special handling for step/global_step/timestamp - they're already included
# Escape metric name (convert "/" to "__" for DuckDB column name)
escaped_metric = metric.replace("/", "__")
# Build query - always select step, global_step, timestamp, and the requested metric
query = (
f"SELECT step, global_step, timestamp, {metric} as value FROM metrics"
)
query = f'SELECT step, global_step, timestamp, "{escaped_metric}" as value FROM metrics'
# For regular metrics, filter out NULLs
# For step/global_step/timestamp, include all rows (they're never NULL)
if metric not in ("step", "global_step", "timestamp"):
query += f" WHERE {metric} IS NOT NULL"
query += f' WHERE "{escaped_metric}" IS NOT NULL'
if limit:
query += f" LIMIT {limit}"
@@ -225,6 +227,64 @@ class BoardReader:
finally:
conn.close()
def get_available_histogram_names(self) -> List[str]:
"""Get list of available histogram log names
Returns:
List of unique histogram log names
"""
conn = self._get_connection()
try:
result = conn.execute(
"SELECT DISTINCT name FROM histograms ORDER BY name"
).fetchall()
return [row[0] for row in result]
except Exception as e:
logger.warning(f"Failed to query histograms table: {e}")
return []
finally:
conn.close()
def get_histogram_data(
self, name: str, limit: Optional[int] = None
) -> List[Dict[str, Any]]:
"""Get histogram data for a specific log name
Args:
name: Histogram log name
limit: Optional row limit
Returns:
List of dicts with step, global_step, bins, values
"""
conn = self._get_connection()
try:
query = f"SELECT * FROM histograms WHERE name = ?"
if limit:
query += f" LIMIT {limit}"
result = conn.execute(query, [name]).fetchall()
# Get column names
columns = [desc[0] for desc in conn.description]
# Convert to list of dicts, parsing JSON fields
data = []
for row in result:
row_dict = dict(zip(columns, row))
# Parse JSON bins and counts fields
if row_dict.get("bins"):
row_dict["bins"] = json.loads(row_dict["bins"])
if row_dict.get("counts"):
row_dict["counts"] = json.loads(row_dict["counts"])
data.append(row_dict)
return data
finally:
conn.close()
def get_media_file_path(self, filename: str) -> Optional[Path]:
"""Get full path to media file
@@ -262,14 +322,23 @@ class BoardReader:
except Exception:
tables_count = 0
try:
histograms_count = conn.execute(
"SELECT COUNT(*) FROM histograms"
).fetchone()[0]
except Exception:
histograms_count = 0
return {
"metadata": metadata,
"metrics_count": metrics_count,
"media_count": media_count,
"tables_count": tables_count,
"histograms_count": histograms_count,
"available_metrics": self.get_available_metrics(),
"available_media": self.get_available_media_names(),
"available_tables": self.get_available_table_names(),
"available_histograms": self.get_available_histogram_names(),
}
finally:
conn.close()