mirror of
https://github.com/KohakuBlueleaf/KohakuHub.git
synced 2026-03-11 17:34:08 -05:00
clean up duplicated impl
This commit is contained in:
@@ -1,233 +0,0 @@
|
||||
"""Experiments API endpoints - serves real board data"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
from kohakuboard.api.utils.board_reader import BoardReader, list_boards
|
||||
from kohakuboard.config import cfg
|
||||
from kohakuboard.logger import logger_api
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class MetricsQuery(BaseModel):
|
||||
"""Query parameters for metrics"""
|
||||
|
||||
metric_names: Optional[List[str]] = None
|
||||
start_step: Optional[int] = None
|
||||
end_step: Optional[int] = None
|
||||
|
||||
|
||||
@router.get("/experiments")
|
||||
async def list_experiments():
|
||||
"""List all experiments (boards)"""
|
||||
logger_api.info("Fetching experiments list from boards")
|
||||
|
||||
try:
|
||||
boards = list_boards(Path(cfg.app.board_data_dir))
|
||||
|
||||
# Convert board format to experiment format
|
||||
experiments = []
|
||||
for board in boards:
|
||||
experiments.append(
|
||||
{
|
||||
"id": board["board_id"],
|
||||
"name": board["name"],
|
||||
"description": f"Config: {board.get('config', {})}",
|
||||
"status": "completed", # For now, all boards are considered completed
|
||||
"total_steps": 0, # Will be filled from actual data if needed
|
||||
"duration": "N/A",
|
||||
"created_at": board.get("created_at", ""),
|
||||
}
|
||||
)
|
||||
|
||||
logger_api.info(f"Found {len(experiments)} experiments")
|
||||
return experiments
|
||||
except Exception as e:
|
||||
logger_api.error(f"Failed to list experiments: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to list experiments: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/experiments/{experiment_id}")
|
||||
async def get_experiment(experiment_id: str):
|
||||
"""Get experiment details"""
|
||||
logger_api.info(f"Fetching experiment: {experiment_id}")
|
||||
|
||||
try:
|
||||
board_dir = Path(cfg.app.board_data_dir) / experiment_id
|
||||
reader = BoardReader(board_dir)
|
||||
metadata = reader.get_metadata()
|
||||
|
||||
return {
|
||||
"id": experiment_id,
|
||||
"name": metadata.get("name", experiment_id),
|
||||
"description": f"Config: {metadata.get('config', {})}",
|
||||
"status": "completed",
|
||||
"total_steps": 0, # TODO: Calculate from data
|
||||
"duration": "N/A",
|
||||
"created_at": metadata.get("created_at", ""),
|
||||
}
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Experiment not found")
|
||||
except Exception as e:
|
||||
logger_api.error(f"Failed to get experiment {experiment_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/experiments/{experiment_id}/summary")
|
||||
async def get_experiment_summary(experiment_id: str):
|
||||
"""Get experiment summary with available data"""
|
||||
logger_api.info(f"Fetching summary for experiment: {experiment_id}")
|
||||
|
||||
try:
|
||||
board_dir = Path(cfg.app.board_data_dir) / experiment_id
|
||||
reader = BoardReader(board_dir)
|
||||
summary = reader.get_summary()
|
||||
|
||||
metadata = summary["metadata"]
|
||||
|
||||
return {
|
||||
"experiment_id": experiment_id,
|
||||
"experiment_info": {
|
||||
"id": experiment_id,
|
||||
"name": metadata.get("name", experiment_id),
|
||||
"description": f"Config: {metadata.get('config', {})}",
|
||||
"status": "completed",
|
||||
"total_steps": summary["metrics_count"],
|
||||
"duration": "N/A",
|
||||
"created_at": metadata.get("created_at", ""),
|
||||
},
|
||||
"total_steps": summary["metrics_count"],
|
||||
"available_data": {
|
||||
"scalars": summary["available_metrics"],
|
||||
"media": summary["available_media"],
|
||||
"tables": summary["available_tables"],
|
||||
"histograms": summary["available_histograms"],
|
||||
},
|
||||
}
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Experiment not found")
|
||||
except Exception as e:
|
||||
logger_api.error(f"Failed to get summary for {experiment_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/experiments/{experiment_id}/scalars/{metric_name:path}")
|
||||
async def get_scalar_metric(experiment_id: str, metric_name: str):
|
||||
"""Get scalar metric as step-value pairs
|
||||
|
||||
Note: metric_name can contain slashes (e.g., "train/loss")
|
||||
FastAPI path parameter automatically URL-decodes it
|
||||
"""
|
||||
logger_api.info(f"Fetching scalar '{metric_name}' for experiment: {experiment_id}")
|
||||
|
||||
try:
|
||||
board_dir = Path(cfg.app.board_data_dir) / experiment_id
|
||||
reader = BoardReader(board_dir)
|
||||
data = reader.get_scalar_data(metric_name)
|
||||
|
||||
return {
|
||||
"experiment_id": experiment_id,
|
||||
"metric_name": metric_name,
|
||||
"data": data,
|
||||
}
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Experiment not found")
|
||||
except Exception as e:
|
||||
if "not found" in str(e).lower():
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Metric '{metric_name}' not found"
|
||||
)
|
||||
logger_api.error(f"Failed to get scalar {metric_name} for {experiment_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/experiments/{experiment_id}/media/{media_name:path}")
|
||||
async def get_media_log(experiment_id: str, media_name: str):
|
||||
"""Get media log entries"""
|
||||
logger_api.info(f"Fetching media '{media_name}' for experiment: {experiment_id}")
|
||||
|
||||
try:
|
||||
board_dir = Path(cfg.app.board_data_dir) / experiment_id
|
||||
reader = BoardReader(board_dir)
|
||||
data = reader.get_media_data(media_name)
|
||||
|
||||
# Transform to expected format
|
||||
media_entries = []
|
||||
for entry in data:
|
||||
media_entries.append(
|
||||
{
|
||||
"name": entry.get("media_id", ""),
|
||||
"step": entry.get("step", 0),
|
||||
"type": entry.get("type", "image"),
|
||||
"url": f"/api/boards/{experiment_id}/media/files/{entry.get('filename', '')}",
|
||||
"caption": entry.get("caption", ""),
|
||||
"width": entry.get("width"),
|
||||
"height": entry.get("height"),
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"experiment_id": experiment_id,
|
||||
"media_name": media_name,
|
||||
"data": media_entries,
|
||||
}
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Experiment not found")
|
||||
except Exception as e:
|
||||
logger_api.error(f"Failed to get media {media_name} for {experiment_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/experiments/{experiment_id}/tables/{table_name:path}")
|
||||
async def get_table_log(experiment_id: str, table_name: str):
|
||||
"""Get table log entries"""
|
||||
logger_api.info(f"Fetching table '{table_name}' for experiment: {experiment_id}")
|
||||
|
||||
try:
|
||||
board_dir = Path(cfg.app.board_data_dir) / experiment_id
|
||||
reader = BoardReader(board_dir)
|
||||
data = reader.get_table_data(table_name)
|
||||
|
||||
# Transform to expected format (data is already parsed)
|
||||
return {
|
||||
"experiment_id": experiment_id,
|
||||
"table_name": table_name,
|
||||
"data": data,
|
||||
}
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Experiment not found")
|
||||
except Exception as e:
|
||||
logger_api.error(f"Failed to get table {table_name} for {experiment_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/experiments/{experiment_id}/histograms/{histogram_name:path}")
|
||||
async def get_histogram_log(experiment_id: str, histogram_name: str):
|
||||
"""Get histogram log entries"""
|
||||
logger_api.info(
|
||||
f"Fetching histogram '{histogram_name}' for experiment: {experiment_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
board_dir = Path(cfg.app.board_data_dir) / experiment_id
|
||||
reader = BoardReader(board_dir)
|
||||
data = reader.get_histogram_data(histogram_name)
|
||||
|
||||
return {
|
||||
"experiment_id": experiment_id,
|
||||
"histogram_name": histogram_name,
|
||||
"data": data,
|
||||
}
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Experiment not found")
|
||||
except Exception as e:
|
||||
logger_api.error(
|
||||
f"Failed to get histogram {histogram_name} for {experiment_id}: {e}"
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -73,6 +73,7 @@ async def get_run_summary(
|
||||
|
||||
Returns:
|
||||
dict: Run summary with metadata, counts, available metrics/media/tables
|
||||
Same format as experiments API for compatibility
|
||||
"""
|
||||
logger_api.info(f"Fetching summary for {project}/{run_id}")
|
||||
|
||||
@@ -80,11 +81,30 @@ async def get_run_summary(
|
||||
reader = BoardReader(run_path)
|
||||
summary = reader.get_summary()
|
||||
|
||||
# Add project context
|
||||
summary["project"] = project
|
||||
summary["run_id"] = run_id
|
||||
# Return in same format as experiments API for frontend compatibility
|
||||
metadata = summary["metadata"]
|
||||
|
||||
return summary
|
||||
return {
|
||||
"experiment_id": run_id, # For compatibility with ConfigurableChartCard
|
||||
"project": project,
|
||||
"run_id": run_id,
|
||||
"experiment_info": {
|
||||
"id": run_id,
|
||||
"name": metadata.get("name", run_id),
|
||||
"description": f"Config: {metadata.get('config', {})}",
|
||||
"status": "completed",
|
||||
"total_steps": summary["metrics_count"],
|
||||
"duration": "N/A",
|
||||
"created_at": metadata.get("created_at", ""),
|
||||
},
|
||||
"total_steps": summary["metrics_count"],
|
||||
"available_data": {
|
||||
"scalars": summary["available_metrics"],
|
||||
"media": summary["available_media"],
|
||||
"tables": summary["available_tables"],
|
||||
"histograms": summary["available_histograms"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.get("/projects/{project}/runs/{run_id}/metadata")
|
||||
@@ -119,7 +139,7 @@ async def get_available_scalars(
|
||||
return {"metrics": metrics}
|
||||
|
||||
|
||||
@router.get("/projects/{project}/runs/{run_id}/scalars/{metric}")
|
||||
@router.get("/projects/{project}/runs/{run_id}/scalars/{metric:path}")
|
||||
async def get_scalar_data(
|
||||
project: str,
|
||||
run_id: str,
|
||||
@@ -127,7 +147,11 @@ async def get_scalar_data(
|
||||
limit: int | None = Query(None, description="Maximum number of data points"),
|
||||
current_user: User | None = Depends(get_optional_user),
|
||||
):
|
||||
"""Get scalar data for a specific metric"""
|
||||
"""Get scalar data for a specific metric
|
||||
|
||||
Note: metric can contain slashes (e.g., "train/loss")
|
||||
FastAPI path parameter automatically URL-decodes it
|
||||
"""
|
||||
logger_api.info(f"Fetching scalar data for {project}/{run_id}/{metric}")
|
||||
|
||||
run_path = get_run_path(project, run_id, current_user)
|
||||
@@ -153,7 +177,7 @@ async def get_available_media(
|
||||
return {"media": media_names}
|
||||
|
||||
|
||||
@router.get("/projects/{project}/runs/{run_id}/media/{name}")
|
||||
@router.get("/projects/{project}/runs/{run_id}/media/{name:path}")
|
||||
async def get_media_data(
|
||||
project: str,
|
||||
run_id: str,
|
||||
@@ -168,7 +192,22 @@ async def get_media_data(
|
||||
reader = BoardReader(run_path)
|
||||
data = reader.get_media_data(name, limit=limit)
|
||||
|
||||
return {"name": name, "data": data}
|
||||
# Transform to same format as experiments API
|
||||
media_entries = []
|
||||
for entry in data:
|
||||
media_entries.append(
|
||||
{
|
||||
"name": entry.get("media_id", ""),
|
||||
"step": entry.get("step", 0),
|
||||
"type": entry.get("type", "image"),
|
||||
"url": f"/api/projects/{project}/runs/{run_id}/media/files/{entry.get('filename', '')}",
|
||||
"caption": entry.get("caption", ""),
|
||||
"width": entry.get("width"),
|
||||
"height": entry.get("height"),
|
||||
}
|
||||
)
|
||||
|
||||
return {"experiment_id": run_id, "media_name": name, "data": media_entries}
|
||||
|
||||
|
||||
@router.get("/projects/{project}/runs/{run_id}/media/files/{filename}")
|
||||
@@ -228,7 +267,7 @@ async def get_available_tables(
|
||||
return {"tables": table_names}
|
||||
|
||||
|
||||
@router.get("/projects/{project}/runs/{run_id}/tables/{name}")
|
||||
@router.get("/projects/{project}/runs/{run_id}/tables/{name:path}")
|
||||
async def get_table_data(
|
||||
project: str,
|
||||
run_id: str,
|
||||
@@ -243,7 +282,7 @@ async def get_table_data(
|
||||
reader = BoardReader(run_path)
|
||||
data = reader.get_table_data(name, limit=limit)
|
||||
|
||||
return {"name": name, "data": data}
|
||||
return {"experiment_id": run_id, "table_name": name, "data": data}
|
||||
|
||||
|
||||
@router.get("/projects/{project}/runs/{run_id}/histograms")
|
||||
@@ -262,7 +301,7 @@ async def get_available_histograms(
|
||||
return {"histograms": histogram_names}
|
||||
|
||||
|
||||
@router.get("/projects/{project}/runs/{run_id}/histograms/{name}")
|
||||
@router.get("/projects/{project}/runs/{run_id}/histograms/{name:path}")
|
||||
async def get_histogram_data(
|
||||
project: str,
|
||||
run_id: str,
|
||||
@@ -277,4 +316,4 @@ async def get_histogram_data(
|
||||
reader = BoardReader(run_path)
|
||||
data = reader.get_histogram_data(name, limit=limit)
|
||||
|
||||
return {"name": name, "data": data}
|
||||
return {"experiment_id": run_id, "histogram_name": name, "data": data}
|
||||
|
||||
@@ -24,7 +24,7 @@ else:
|
||||
db_module.db = SqliteDatabase(":memory:")
|
||||
|
||||
# Now import routers (after db is initialized)
|
||||
from kohakuboard.api import boards, experiments, projects, runs, sync, system
|
||||
from kohakuboard.api import boards, projects, runs, sync, system
|
||||
|
||||
app = FastAPI(
|
||||
title="KohakuBoard API",
|
||||
@@ -51,11 +51,8 @@ app.include_router(runs.router, prefix=cfg.app.api_base, tags=["runs"])
|
||||
# Register sync router (remote mode only, but always registered for API docs)
|
||||
app.include_router(sync.router, prefix=cfg.app.api_base, tags=["sync"])
|
||||
|
||||
# Keep legacy routers for backward compatibility
|
||||
# Keep legacy boards router for backward compatibility (media file serving)
|
||||
app.include_router(boards.router, prefix=cfg.app.api_base, tags=["boards (legacy)"])
|
||||
app.include_router(
|
||||
experiments.router, prefix=cfg.app.api_base, tags=["experiments (legacy)"]
|
||||
)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
@@ -81,8 +78,6 @@ async def root():
|
||||
"endpoints": {
|
||||
"system": f"{cfg.app.api_base}/system/info",
|
||||
"projects": f"{cfg.app.api_base}/projects",
|
||||
"experiments": f"{cfg.app.api_base}/experiments",
|
||||
"boards": f"{cfg.app.api_base}/boards",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user