mirror of
https://github.com/KohakuBlueleaf/KohakuHub.git
synced 2026-03-11 17:34:08 -05:00
correct args and logging
This commit is contained in:
@@ -17,6 +17,7 @@ from kohakuboard.db_operations import (
|
||||
update_user_organization,
|
||||
)
|
||||
from kohakuboard.logger import logger_api
|
||||
from kohakuboard.utils.datetime_utils import safe_isoformat
|
||||
from kohakuboard.utils.names import normalize_name
|
||||
|
||||
router = APIRouter()
|
||||
@@ -99,7 +100,7 @@ async def get_organization_info(org_name: str):
|
||||
return {
|
||||
"name": org.username,
|
||||
"description": org.description,
|
||||
"created_at": org.created_at.isoformat() if org.created_at else None,
|
||||
"created_at": safe_isoformat(org.created_at),
|
||||
}
|
||||
|
||||
|
||||
@@ -263,7 +264,7 @@ async def list_organization_members_endpoint(org_name: str):
|
||||
{
|
||||
"user": m.user.username,
|
||||
"role": m.role,
|
||||
"created_at": m.created_at.isoformat() if m.created_at else None,
|
||||
"created_at": safe_isoformat(m.created_at),
|
||||
}
|
||||
for m in members
|
||||
]
|
||||
@@ -296,7 +297,7 @@ async def list_user_organizations_endpoint(username: str):
|
||||
"name": org.organization.username,
|
||||
"description": org.organization.description,
|
||||
"role": org.role,
|
||||
"created_at": org.created_at.isoformat() if org.created_at else None,
|
||||
"created_at": safe_isoformat(org.created_at),
|
||||
}
|
||||
for org in orgs
|
||||
]
|
||||
|
||||
@@ -11,6 +11,7 @@ from kohakuboard.auth import get_optional_user
|
||||
from kohakuboard.config import cfg
|
||||
from kohakuboard.db import Board, User
|
||||
from kohakuboard.logger import logger_api
|
||||
from kohakuboard.utils.datetime_utils import safe_isoformat
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -66,11 +67,9 @@ def fetchProjectRuns(project_name: str, current_user: User | None):
|
||||
"run_id": run.run_id,
|
||||
"name": run.name,
|
||||
"private": run.private,
|
||||
"created_at": run.created_at.isoformat(),
|
||||
"updated_at": run.updated_at.isoformat(),
|
||||
"last_synced_at": (
|
||||
run.last_synced_at.isoformat() if run.last_synced_at else None
|
||||
),
|
||||
"created_at": safe_isoformat(run.created_at),
|
||||
"updated_at": safe_isoformat(run.updated_at),
|
||||
"last_synced_at": safe_isoformat(run.last_synced_at),
|
||||
"total_size": run.total_size_bytes,
|
||||
"config": json.loads(run.config) if run.config else {},
|
||||
}
|
||||
@@ -141,12 +140,8 @@ async def list_projects(current_user: User | None = Depends(get_optional_user)):
|
||||
"name": project.project_name,
|
||||
"display_name": project.project_name.replace("-", " ").title(),
|
||||
"run_count": project.run_count,
|
||||
"created_at": (
|
||||
project.created_at.isoformat() if project.created_at else None
|
||||
),
|
||||
"updated_at": (
|
||||
project.updated_at.isoformat() if project.updated_at else None
|
||||
),
|
||||
"created_at": safe_isoformat(project.created_at),
|
||||
"updated_at": safe_isoformat(project.updated_at),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ from kohakuboard.db_operations import (
|
||||
update_user,
|
||||
)
|
||||
from kohakuboard.logger import logger_api
|
||||
from kohakuboard.utils.datetime_utils import safe_isoformat
|
||||
from kohakuboard.utils.names import normalize_name
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
@@ -278,7 +279,7 @@ def get_me(user: User = Depends(get_current_user)):
|
||||
"username": user.username,
|
||||
"email": user.email,
|
||||
"email_verified": user.email_verified,
|
||||
"created_at": user.created_at.isoformat() if user.created_at else None,
|
||||
"created_at": safe_isoformat(user.created_at),
|
||||
}
|
||||
|
||||
|
||||
@@ -293,8 +294,8 @@ async def list_tokens(user: User = Depends(get_current_user)):
|
||||
{
|
||||
"id": t.id,
|
||||
"name": t.name,
|
||||
"last_used": t.last_used.isoformat() if t.last_used else None,
|
||||
"created_at": t.created_at.isoformat() if t.created_at else None,
|
||||
"last_used": safe_isoformat(t.last_used),
|
||||
"created_at": safe_isoformat(t.created_at),
|
||||
}
|
||||
for t in tokens
|
||||
]
|
||||
|
||||
@@ -35,8 +35,8 @@ def cli():
|
||||
@click.option("--port", default=48889, help="Server port (default: 48889)")
|
||||
@click.option("--host", default="0.0.0.0", help="Server host (default: 0.0.0.0)")
|
||||
@click.option("--reload", is_flag=True, help="Enable auto-reload for development")
|
||||
@click.option("--no-browser", is_flag=True, help="Do not open browser automatically")
|
||||
def open(folder, port, host, reload, no_browser):
|
||||
@click.option("--browser", is_flag=True, help="Open browser automatically")
|
||||
def open(folder, port, host, reload, browser):
|
||||
"""Open local board folder in browser
|
||||
|
||||
Starts a local web server to browse boards in the specified folder.
|
||||
@@ -45,7 +45,7 @@ def open(folder, port, host, reload, no_browser):
|
||||
Examples:
|
||||
kobo open ./kohakuboard
|
||||
kobo open /path/to/experiments --port 8080
|
||||
kobo open ./boards --reload --no-browser
|
||||
kobo open ./boards --reload --browser
|
||||
"""
|
||||
folder_path = Path(folder).resolve()
|
||||
|
||||
@@ -63,7 +63,7 @@ def open(folder, port, host, reload, no_browser):
|
||||
click.echo()
|
||||
|
||||
# Open browser after delay
|
||||
if not no_browser:
|
||||
if browser:
|
||||
|
||||
def open_browser():
|
||||
time.sleep(1.5) # Wait for server to start
|
||||
@@ -124,12 +124,12 @@ def open(folder, port, host, reload, no_browser):
|
||||
help="Session secret for authentication (required in production)",
|
||||
)
|
||||
@click.option(
|
||||
"--no-browser",
|
||||
"--browser",
|
||||
is_flag=True,
|
||||
help="Do not open browser automatically",
|
||||
help="Open browser automatically",
|
||||
)
|
||||
def serve(
|
||||
host, port, data_dir, db, db_backend, reload, workers, session_secret, no_browser
|
||||
host, port, data_dir, db, db_backend, reload, workers, session_secret, browser
|
||||
):
|
||||
"""Start KohakuBoard server in remote mode with authentication
|
||||
|
||||
@@ -190,7 +190,7 @@ def serve(
|
||||
click.echo()
|
||||
|
||||
# Open browser after delay
|
||||
if not no_browser:
|
||||
if browser:
|
||||
|
||||
def open_browser():
|
||||
time.sleep(2) # Wait for server to start
|
||||
|
||||
@@ -13,6 +13,7 @@ from peewee import (
|
||||
BlobField,
|
||||
BooleanField,
|
||||
CharField,
|
||||
DatabaseProxy,
|
||||
DateTimeField,
|
||||
ForeignKeyField,
|
||||
Model,
|
||||
@@ -21,14 +22,12 @@ from peewee import (
|
||||
TextField,
|
||||
)
|
||||
|
||||
# Database connection will be initialized based on config
|
||||
db = None
|
||||
# Database proxy - will be initialized later
|
||||
db = DatabaseProxy()
|
||||
|
||||
|
||||
def init_db(backend: str, database_url: str):
|
||||
"""Initialize database connection"""
|
||||
global db
|
||||
|
||||
if backend == "postgres":
|
||||
# Parse PostgreSQL URL
|
||||
url = database_url.replace("postgresql://", "")
|
||||
@@ -40,7 +39,7 @@ def init_db(backend: str, database_url: str):
|
||||
else:
|
||||
host, port = host_port, 5432
|
||||
|
||||
db = PostgresqlDatabase(
|
||||
real_db = PostgresqlDatabase(
|
||||
dbname,
|
||||
user=user,
|
||||
password=password,
|
||||
@@ -50,10 +49,10 @@ def init_db(backend: str, database_url: str):
|
||||
else:
|
||||
# SQLite
|
||||
db_path = database_url.replace("sqlite:///", "")
|
||||
db = SqliteDatabase(db_path, pragmas={"foreign_keys": 1})
|
||||
real_db = SqliteDatabase(db_path, pragmas={"foreign_keys": 1})
|
||||
|
||||
# Bind models to db
|
||||
BaseModel._meta.database = db
|
||||
# Initialize the proxy with the real database
|
||||
db.initialize(real_db)
|
||||
|
||||
# Create tables
|
||||
with db:
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import traceback as tb
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ColoredFormatter(logging.Formatter):
|
||||
@@ -10,6 +12,7 @@ class ColoredFormatter(logging.Formatter):
|
||||
COLORS = {
|
||||
"DEBUG": "\033[0;36m", # Cyan
|
||||
"INFO": "\033[0;32m", # Green
|
||||
"SUCCESS": "\033[0;92m", # Bright Green
|
||||
"WARNING": "\033[0;33m", # Yellow
|
||||
"ERROR": "\033[0;31m", # Red
|
||||
"CRITICAL": "\033[1;31m", # Bold Red
|
||||
@@ -23,18 +26,124 @@ class ColoredFormatter(logging.Formatter):
|
||||
return super().format(record)
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""Get a colored logger instance"""
|
||||
logger = logging.getLogger(name)
|
||||
class Logger:
|
||||
"""Custom logger with success() and exception() methods"""
|
||||
|
||||
if not logger.handlers:
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
formatter = ColoredFormatter("%(name)s %(levelname)s: %(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(logging.INFO)
|
||||
def __init__(self, name: str):
|
||||
"""Initialize logger with name.
|
||||
|
||||
return logger
|
||||
Args:
|
||||
name: Name of the logger (e.g., "API", "MOCK")
|
||||
"""
|
||||
self.name = name
|
||||
self._logger = logging.getLogger(name)
|
||||
|
||||
if not self._logger.handlers:
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
formatter = ColoredFormatter("%(name)s %(levelname)s: %(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
self._logger.addHandler(handler)
|
||||
self._logger.setLevel(logging.INFO)
|
||||
|
||||
def debug(self, message: str):
|
||||
self._logger.debug(message)
|
||||
|
||||
def info(self, message: str):
|
||||
self._logger.info(message)
|
||||
|
||||
def success(self, message: str):
|
||||
"""Log success message (custom level, logs as INFO with SUCCESS prefix)"""
|
||||
# Create a custom log record with SUCCESS level
|
||||
record = self._logger.makeRecord(
|
||||
self._logger.name,
|
||||
logging.INFO,
|
||||
"(unknown file)",
|
||||
0,
|
||||
message,
|
||||
(),
|
||||
None,
|
||||
)
|
||||
record.levelname = "SUCCESS"
|
||||
self._logger.handle(record)
|
||||
|
||||
def warning(self, message: str):
|
||||
self._logger.warning(message)
|
||||
|
||||
def error(self, message: str):
|
||||
self._logger.error(message)
|
||||
|
||||
def critical(self, message: str):
|
||||
self._logger.critical(message)
|
||||
|
||||
def exception(self, message: str, exc: Optional[Exception] = None):
|
||||
"""Log exception with formatted traceback.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
exc: Exception object (if None, uses sys.exc_info())
|
||||
"""
|
||||
self.error(message)
|
||||
self._print_formatted_traceback(exc)
|
||||
|
||||
def _print_formatted_traceback(self, exc: Optional[Exception] = None):
|
||||
"""Print formatted traceback.
|
||||
|
||||
Args:
|
||||
exc: Exception object (if None, uses sys.exc_info())
|
||||
"""
|
||||
if exc is None:
|
||||
exc_type, exc_value, exc_tb = sys.exc_info()
|
||||
else:
|
||||
exc_type = type(exc)
|
||||
exc_value = exc
|
||||
exc_tb = exc.__traceback__
|
||||
|
||||
if exc_tb is None:
|
||||
return
|
||||
|
||||
# Extract traceback frames
|
||||
frames = tb.extract_tb(exc_tb)
|
||||
|
||||
# Print header
|
||||
self.debug(f"{'=' * 50}")
|
||||
self.debug("TRACEBACK")
|
||||
self.debug(f"{'=' * 50}")
|
||||
|
||||
# Print stack frames
|
||||
for i, frame in enumerate(frames, 1):
|
||||
is_last = i == len(frames)
|
||||
self.debug(f"┌─ Frame #{i} {' (ERROR HERE)' if is_last else ''}")
|
||||
self.debug(f"│ File: {frame.filename}")
|
||||
self.debug(f"│ Line: {frame.lineno}")
|
||||
if frame.name:
|
||||
self.debug(f"│ In: {frame.name}()")
|
||||
if frame.line:
|
||||
self.debug(f"│ Code: {frame.line.strip()}")
|
||||
self.debug(f"└{'─' * 49}")
|
||||
|
||||
# Print error details
|
||||
self.debug(" EXCEPTION DETAILS ")
|
||||
self.debug(f"┌{'─' * 49}")
|
||||
self.debug(f"│ Type: {exc_type.__name__}")
|
||||
self.debug(f"│ Message: {str(exc_value)}")
|
||||
if frames:
|
||||
last_frame = frames[-1]
|
||||
self.debug(f"│ Location: {last_frame.filename}:{last_frame.lineno}")
|
||||
if last_frame.line:
|
||||
self.debug(f"│ Code: {last_frame.line.strip()}")
|
||||
self.debug(f"└{'─' * 49}")
|
||||
|
||||
|
||||
def get_logger(name: str) -> Logger:
|
||||
"""Get a custom logger instance
|
||||
|
||||
Args:
|
||||
name: Name of the logger
|
||||
|
||||
Returns:
|
||||
Logger: Custom logger instance
|
||||
"""
|
||||
return Logger(name)
|
||||
|
||||
|
||||
# Pre-created loggers
|
||||
|
||||
@@ -14,14 +14,14 @@ if cfg.app.mode == "remote":
|
||||
logger_api.info(f" Backend: {cfg.app.db_backend}")
|
||||
logger_api.info(f" URL: {cfg.app.database_url}")
|
||||
init_db(cfg.app.db_backend, cfg.app.database_url)
|
||||
logger_api.success("Database initialized")
|
||||
logger_api.info("✓ Database initialized")
|
||||
else:
|
||||
logger_api.info(f"Running in {cfg.app.mode} mode (no database needed)")
|
||||
# Create a dummy db object for local mode so imports don't fail
|
||||
from kohakuboard import db as db_module
|
||||
# Initialize dummy database for local mode so imports don't fail
|
||||
from kohakuboard.db import db
|
||||
from peewee import SqliteDatabase
|
||||
|
||||
db_module.db = SqliteDatabase(":memory:")
|
||||
db.initialize(SqliteDatabase(":memory:"))
|
||||
|
||||
# Now import routers (after db is initialized)
|
||||
from kohakuboard.api import boards, org, projects, runs, sync, system
|
||||
|
||||
82
src/kohakuboard/utils/datetime_utils.py
Normal file
82
src/kohakuboard/utils/datetime_utils.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Datetime utility functions for safe handling of database datetime fields."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def safe_isoformat(dt) -> str | None:
|
||||
"""Safely convert datetime field to ISO format string.
|
||||
|
||||
Handles both datetime objects and string timestamps from database.
|
||||
Peewee sometimes returns datetime fields as strings depending on the query.
|
||||
|
||||
Args:
|
||||
dt: Either a datetime object, string timestamp, or None
|
||||
|
||||
Returns:
|
||||
ISO format string or None
|
||||
"""
|
||||
if dt is None:
|
||||
return None
|
||||
elif isinstance(dt, str):
|
||||
# Already a string, try to parse and re-format for consistency
|
||||
try:
|
||||
return datetime.fromisoformat(dt.replace("Z", "+00:00")).isoformat()
|
||||
except (ValueError, AttributeError):
|
||||
# If parsing fails, return as-is
|
||||
return dt
|
||||
elif isinstance(dt, datetime):
|
||||
return dt.isoformat()
|
||||
else:
|
||||
# Fallback: convert to string
|
||||
return str(dt)
|
||||
|
||||
|
||||
def ensure_datetime(dt) -> Optional[datetime]:
|
||||
"""Convert string or datetime to datetime object.
|
||||
|
||||
Handles both datetime objects and string timestamps from database.
|
||||
Peewee sometimes returns datetime fields as strings depending on the query.
|
||||
|
||||
Args:
|
||||
dt: Either a datetime object, string timestamp, or None
|
||||
|
||||
Returns:
|
||||
datetime object or None
|
||||
|
||||
Raises:
|
||||
ValueError: If string cannot be parsed as datetime
|
||||
"""
|
||||
if dt is None:
|
||||
return None
|
||||
elif isinstance(dt, datetime):
|
||||
return dt
|
||||
elif isinstance(dt, str):
|
||||
# Try to parse string to datetime
|
||||
try:
|
||||
# Handle ISO format with 'Z' suffix
|
||||
return datetime.fromisoformat(dt.replace("Z", "+00:00"))
|
||||
except (ValueError, AttributeError) as e:
|
||||
raise ValueError(f"Cannot parse datetime string: {dt}") from e
|
||||
else:
|
||||
raise TypeError(f"Expected datetime or str, got {type(dt)}")
|
||||
|
||||
|
||||
def safe_strftime(dt, fmt: str) -> Optional[str]:
|
||||
"""Safely format datetime field using strftime.
|
||||
|
||||
Handles both datetime objects and string timestamps from database.
|
||||
|
||||
Args:
|
||||
dt: Either a datetime object, string timestamp, or None
|
||||
fmt: strftime format string
|
||||
|
||||
Returns:
|
||||
Formatted datetime string or None
|
||||
"""
|
||||
if dt is None:
|
||||
return None
|
||||
|
||||
# Convert to datetime object first
|
||||
dt_obj = ensure_datetime(dt)
|
||||
return dt_obj.strftime(fmt)
|
||||
Reference in New Issue
Block a user