correct args and logging

This commit is contained in:
Kohaku-Blueleaf
2025-10-27 14:43:12 +08:00
parent 59bab740ee
commit be448f6c03
8 changed files with 234 additions and 47 deletions

View File

@@ -17,6 +17,7 @@ from kohakuboard.db_operations import (
update_user_organization,
)
from kohakuboard.logger import logger_api
from kohakuboard.utils.datetime_utils import safe_isoformat
from kohakuboard.utils.names import normalize_name
router = APIRouter()
@@ -99,7 +100,7 @@ async def get_organization_info(org_name: str):
return {
"name": org.username,
"description": org.description,
"created_at": org.created_at.isoformat() if org.created_at else None,
"created_at": safe_isoformat(org.created_at),
}
@@ -263,7 +264,7 @@ async def list_organization_members_endpoint(org_name: str):
{
"user": m.user.username,
"role": m.role,
"created_at": m.created_at.isoformat() if m.created_at else None,
"created_at": safe_isoformat(m.created_at),
}
for m in members
]
@@ -296,7 +297,7 @@ async def list_user_organizations_endpoint(username: str):
"name": org.organization.username,
"description": org.organization.description,
"role": org.role,
"created_at": org.created_at.isoformat() if org.created_at else None,
"created_at": safe_isoformat(org.created_at),
}
for org in orgs
]

View File

@@ -11,6 +11,7 @@ from kohakuboard.auth import get_optional_user
from kohakuboard.config import cfg
from kohakuboard.db import Board, User
from kohakuboard.logger import logger_api
from kohakuboard.utils.datetime_utils import safe_isoformat
router = APIRouter()
@@ -66,11 +67,9 @@ def fetchProjectRuns(project_name: str, current_user: User | None):
"run_id": run.run_id,
"name": run.name,
"private": run.private,
"created_at": run.created_at.isoformat(),
"updated_at": run.updated_at.isoformat(),
"last_synced_at": (
run.last_synced_at.isoformat() if run.last_synced_at else None
),
"created_at": safe_isoformat(run.created_at),
"updated_at": safe_isoformat(run.updated_at),
"last_synced_at": safe_isoformat(run.last_synced_at),
"total_size": run.total_size_bytes,
"config": json.loads(run.config) if run.config else {},
}
@@ -141,12 +140,8 @@ async def list_projects(current_user: User | None = Depends(get_optional_user)):
"name": project.project_name,
"display_name": project.project_name.replace("-", " ").title(),
"run_count": project.run_count,
"created_at": (
project.created_at.isoformat() if project.created_at else None
),
"updated_at": (
project.updated_at.isoformat() if project.updated_at else None
),
"created_at": safe_isoformat(project.created_at),
"updated_at": safe_isoformat(project.updated_at),
}
)

View File

@@ -33,6 +33,7 @@ from kohakuboard.db_operations import (
update_user,
)
from kohakuboard.logger import logger_api
from kohakuboard.utils.datetime_utils import safe_isoformat
from kohakuboard.utils.names import normalize_name
router = APIRouter(prefix="/auth", tags=["auth"])
@@ -278,7 +279,7 @@ def get_me(user: User = Depends(get_current_user)):
"username": user.username,
"email": user.email,
"email_verified": user.email_verified,
"created_at": user.created_at.isoformat() if user.created_at else None,
"created_at": safe_isoformat(user.created_at),
}
@@ -293,8 +294,8 @@ async def list_tokens(user: User = Depends(get_current_user)):
{
"id": t.id,
"name": t.name,
"last_used": t.last_used.isoformat() if t.last_used else None,
"created_at": t.created_at.isoformat() if t.created_at else None,
"last_used": safe_isoformat(t.last_used),
"created_at": safe_isoformat(t.created_at),
}
for t in tokens
]

View File

@@ -35,8 +35,8 @@ def cli():
@click.option("--port", default=48889, help="Server port (default: 48889)")
@click.option("--host", default="0.0.0.0", help="Server host (default: 0.0.0.0)")
@click.option("--reload", is_flag=True, help="Enable auto-reload for development")
@click.option("--no-browser", is_flag=True, help="Do not open browser automatically")
def open(folder, port, host, reload, no_browser):
@click.option("--browser", is_flag=True, help="Open browser automatically")
def open(folder, port, host, reload, browser):
"""Open local board folder in browser
Starts a local web server to browse boards in the specified folder.
@@ -45,7 +45,7 @@ def open(folder, port, host, reload, no_browser):
Examples:
kobo open ./kohakuboard
kobo open /path/to/experiments --port 8080
kobo open ./boards --reload --no-browser
kobo open ./boards --reload --browser
"""
folder_path = Path(folder).resolve()
@@ -63,7 +63,7 @@ def open(folder, port, host, reload, no_browser):
click.echo()
# Open browser after delay
if not no_browser:
if browser:
def open_browser():
time.sleep(1.5) # Wait for server to start
@@ -124,12 +124,12 @@ def open(folder, port, host, reload, no_browser):
help="Session secret for authentication (required in production)",
)
@click.option(
"--no-browser",
"--browser",
is_flag=True,
help="Do not open browser automatically",
help="Open browser automatically",
)
def serve(
host, port, data_dir, db, db_backend, reload, workers, session_secret, no_browser
host, port, data_dir, db, db_backend, reload, workers, session_secret, browser
):
"""Start KohakuBoard server in remote mode with authentication
@@ -190,7 +190,7 @@ def serve(
click.echo()
# Open browser after delay
if not no_browser:
if browser:
def open_browser():
time.sleep(2) # Wait for server to start

View File

@@ -13,6 +13,7 @@ from peewee import (
BlobField,
BooleanField,
CharField,
DatabaseProxy,
DateTimeField,
ForeignKeyField,
Model,
@@ -21,14 +22,12 @@ from peewee import (
TextField,
)
# Database connection will be initialized based on config
db = None
# Database proxy - will be initialized later
db = DatabaseProxy()
def init_db(backend: str, database_url: str):
"""Initialize database connection"""
global db
if backend == "postgres":
# Parse PostgreSQL URL
url = database_url.replace("postgresql://", "")
@@ -40,7 +39,7 @@ def init_db(backend: str, database_url: str):
else:
host, port = host_port, 5432
db = PostgresqlDatabase(
real_db = PostgresqlDatabase(
dbname,
user=user,
password=password,
@@ -50,10 +49,10 @@ def init_db(backend: str, database_url: str):
else:
# SQLite
db_path = database_url.replace("sqlite:///", "")
db = SqliteDatabase(db_path, pragmas={"foreign_keys": 1})
real_db = SqliteDatabase(db_path, pragmas={"foreign_keys": 1})
# Bind models to db
BaseModel._meta.database = db
# Initialize the proxy with the real database
db.initialize(real_db)
# Create tables
with db:

View File

@@ -2,6 +2,8 @@
import logging
import sys
import traceback as tb
from typing import Optional
class ColoredFormatter(logging.Formatter):
@@ -10,6 +12,7 @@ class ColoredFormatter(logging.Formatter):
COLORS = {
"DEBUG": "\033[0;36m", # Cyan
"INFO": "\033[0;32m", # Green
"SUCCESS": "\033[0;92m", # Bright Green
"WARNING": "\033[0;33m", # Yellow
"ERROR": "\033[0;31m", # Red
"CRITICAL": "\033[1;31m", # Bold Red
@@ -23,18 +26,124 @@ class ColoredFormatter(logging.Formatter):
return super().format(record)
def get_logger(name: str) -> logging.Logger:
"""Get a colored logger instance"""
logger = logging.getLogger(name)
class Logger:
"""Custom logger with success() and exception() methods"""
if not logger.handlers:
handler = logging.StreamHandler(sys.stdout)
formatter = ColoredFormatter("%(name)s %(levelname)s: %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
def __init__(self, name: str):
"""Initialize logger with name.
return logger
Args:
name: Name of the logger (e.g., "API", "MOCK")
"""
self.name = name
self._logger = logging.getLogger(name)
if not self._logger.handlers:
handler = logging.StreamHandler(sys.stdout)
formatter = ColoredFormatter("%(name)s %(levelname)s: %(message)s")
handler.setFormatter(formatter)
self._logger.addHandler(handler)
self._logger.setLevel(logging.INFO)
def debug(self, message: str):
self._logger.debug(message)
def info(self, message: str):
self._logger.info(message)
def success(self, message: str):
"""Log success message (custom level, logs as INFO with SUCCESS prefix)"""
# Create a custom log record with SUCCESS level
record = self._logger.makeRecord(
self._logger.name,
logging.INFO,
"(unknown file)",
0,
message,
(),
None,
)
record.levelname = "SUCCESS"
self._logger.handle(record)
def warning(self, message: str):
self._logger.warning(message)
def error(self, message: str):
self._logger.error(message)
def critical(self, message: str):
self._logger.critical(message)
def exception(self, message: str, exc: Optional[Exception] = None):
"""Log exception with formatted traceback.
Args:
message: Error message
exc: Exception object (if None, uses sys.exc_info())
"""
self.error(message)
self._print_formatted_traceback(exc)
def _print_formatted_traceback(self, exc: Optional[Exception] = None):
"""Print formatted traceback.
Args:
exc: Exception object (if None, uses sys.exc_info())
"""
if exc is None:
exc_type, exc_value, exc_tb = sys.exc_info()
else:
exc_type = type(exc)
exc_value = exc
exc_tb = exc.__traceback__
if exc_tb is None:
return
# Extract traceback frames
frames = tb.extract_tb(exc_tb)
# Print header
self.debug(f"{'=' * 50}")
self.debug("TRACEBACK")
self.debug(f"{'=' * 50}")
# Print stack frames
for i, frame in enumerate(frames, 1):
is_last = i == len(frames)
self.debug(f"┌─ Frame #{i} {' (ERROR HERE)' if is_last else ''}")
self.debug(f"│ File: {frame.filename}")
self.debug(f"│ Line: {frame.lineno}")
if frame.name:
self.debug(f"│ In: {frame.name}()")
if frame.line:
self.debug(f"│ Code: {frame.line.strip()}")
self.debug(f"{'' * 49}")
# Print error details
self.debug(" EXCEPTION DETAILS ")
self.debug(f"{'' * 49}")
self.debug(f"│ Type: {exc_type.__name__}")
self.debug(f"│ Message: {str(exc_value)}")
if frames:
last_frame = frames[-1]
self.debug(f"│ Location: {last_frame.filename}:{last_frame.lineno}")
if last_frame.line:
self.debug(f"│ Code: {last_frame.line.strip()}")
self.debug(f"{'' * 49}")
def get_logger(name: str) -> Logger:
"""Get a custom logger instance
Args:
name: Name of the logger
Returns:
Logger: Custom logger instance
"""
return Logger(name)
# Pre-created loggers

View File

@@ -14,14 +14,14 @@ if cfg.app.mode == "remote":
logger_api.info(f" Backend: {cfg.app.db_backend}")
logger_api.info(f" URL: {cfg.app.database_url}")
init_db(cfg.app.db_backend, cfg.app.database_url)
logger_api.success("Database initialized")
logger_api.info("Database initialized")
else:
logger_api.info(f"Running in {cfg.app.mode} mode (no database needed)")
# Create a dummy db object for local mode so imports don't fail
from kohakuboard import db as db_module
# Initialize dummy database for local mode so imports don't fail
from kohakuboard.db import db
from peewee import SqliteDatabase
db_module.db = SqliteDatabase(":memory:")
db.initialize(SqliteDatabase(":memory:"))
# Now import routers (after db is initialized)
from kohakuboard.api import boards, org, projects, runs, sync, system

View File

@@ -0,0 +1,82 @@
"""Datetime utility functions for safe handling of database datetime fields."""
from datetime import datetime
from typing import Optional
def safe_isoformat(dt) -> str | None:
"""Safely convert datetime field to ISO format string.
Handles both datetime objects and string timestamps from database.
Peewee sometimes returns datetime fields as strings depending on the query.
Args:
dt: Either a datetime object, string timestamp, or None
Returns:
ISO format string or None
"""
if dt is None:
return None
elif isinstance(dt, str):
# Already a string, try to parse and re-format for consistency
try:
return datetime.fromisoformat(dt.replace("Z", "+00:00")).isoformat()
except (ValueError, AttributeError):
# If parsing fails, return as-is
return dt
elif isinstance(dt, datetime):
return dt.isoformat()
else:
# Fallback: convert to string
return str(dt)
def ensure_datetime(dt) -> Optional[datetime]:
"""Convert string or datetime to datetime object.
Handles both datetime objects and string timestamps from database.
Peewee sometimes returns datetime fields as strings depending on the query.
Args:
dt: Either a datetime object, string timestamp, or None
Returns:
datetime object or None
Raises:
ValueError: If string cannot be parsed as datetime
"""
if dt is None:
return None
elif isinstance(dt, datetime):
return dt
elif isinstance(dt, str):
# Try to parse string to datetime
try:
# Handle ISO format with 'Z' suffix
return datetime.fromisoformat(dt.replace("Z", "+00:00"))
except (ValueError, AttributeError) as e:
raise ValueError(f"Cannot parse datetime string: {dt}") from e
else:
raise TypeError(f"Expected datetime or str, got {type(dt)}")
def safe_strftime(dt, fmt: str) -> Optional[str]:
"""Safely format datetime field using strftime.
Handles both datetime objects and string timestamps from database.
Args:
dt: Either a datetime object, string timestamp, or None
fmt: strftime format string
Returns:
Formatted datetime string or None
"""
if dt is None:
return None
# Convert to datetime object first
dt_obj = ensure_datetime(dt)
return dt_obj.strftime(fmt)