lance impl for hist

This commit is contained in:
Kohaku-Blueleaf
2025-10-28 04:15:04 +08:00
parent 9beca81307
commit 9c6e61480b
10 changed files with 671 additions and 129 deletions

View File

@@ -100,23 +100,25 @@ class Board:
# _step increments on EVERY log/media/table call (auto-increment)
# _global_step is set explicitly via step() method
self._step = -1 # Start at -1, will be 0 on first log
self._global_step: Optional[int] = None
self._global_step: int = 0 # Start at 0, NOT None!
# Shutdown tracking
self._is_finishing = False # Prevent re-entrant finish() calls
self._interrupt_count = 0 # Track Ctrl+C presses for force exit
# Multiprocessing setup
self.queue = mp.Queue(
maxsize=50000
) # Very large queue for heavy logging (e.g., per-step histograms)
# Multiprocessing setup - use Manager.Queue to avoid Windows deadlock
# See: https://bugs.python.org/issue29797
manager = mp.Manager()
self.queue = manager.Queue(maxsize=50000)
self.stop_event = mp.Event()
# Start writer process
# Start single writer process
from kohakuboard.client.writer import writer_process_main
self.writer_process = mp.Process(
target=writer_process_main,
args=(self.board_dir, self.queue, self.stop_event, self.backend),
daemon=False, # Not daemon - we want clean shutdown
daemon=False,
)
self.writer_process.start()
@@ -319,7 +321,11 @@ class Board:
self.queue.put(message)
def log_histogram(
self, name: str, values: Union[List[float], Any], num_bins: int = 64
self,
name: str,
values: Union[List[float], Any],
num_bins: int = 64,
precision: str = "exact",
):
"""Log histogram data (non-blocking)
@@ -327,15 +333,16 @@ class Board:
name: Name for this histogram log (supports namespace: "gradients/layer1")
values: List of values or tensor to create histogram from
num_bins: Number of bins for histogram (default: 64)
precision: "exact" (int32, default) or "compact" (uint8, ~1% loss)
Example:
>>> # Log gradient histogram
>>> # Log gradient histogram (compact)
>>> grads = [p.grad.flatten().cpu().numpy() for p in model.parameters()]
>>> board.log_histogram("gradients/all", np.concatenate(grads))
>>>
>>> # Log parameter histogram
>>> # Log parameter histogram (exact counts)
>>> params = model.fc1.weight.detach().cpu().numpy().flatten()
>>> board.log_histogram("params/fc1_weight", params)
>>> board.log_histogram("params/fc1_weight", params, precision="exact")
"""
# Increment step (auto-increment on every log call)
self._step += 1
@@ -345,7 +352,7 @@ class Board:
queue_size = self.queue.qsize()
if queue_size > 40000:
logger.warning(
f"Queue size is {queue_size}/50000. Consider reducing histogram logging frequency."
f"Queue size is {queue_size}/50000. Consider reducing logging frequency."
)
except NotImplementedError:
pass # qsize() not supported on all platforms
@@ -365,6 +372,7 @@ class Board:
"name": name,
"values": values,
"num_bins": num_bins,
"precision": precision,
}
self.queue.put(message)
@@ -384,10 +392,7 @@ class Board:
... loss = train_step(batch)
... board.log(loss=loss) # All batches share same global_step
"""
if self._global_step is None:
self._global_step = 0
else:
self._global_step += increment
self._global_step += increment
def flush(self):
"""Flush all pending logs to disk (blocking)
@@ -395,8 +400,10 @@ class Board:
Normally logs are flushed automatically. Use this for
critical checkpoints or before long-running operations.
"""
message = {"type": "flush"}
self.queue.put(message)
# Send flush signal
flush_msg = {"type": "flush"}
self.queue.put(flush_msg)
time.sleep(0.5) # Give writer time to flush
def finish(self):
"""Finish logging and clean up
@@ -409,87 +416,74 @@ class Board:
if self._is_finishing:
logger.debug("finish() already in progress, skipping re-entrant call")
return # Prevent re-entrant calls from signal handler
return
self._is_finishing = True
logger.info(f"Finishing board: {self.name}")
# Check queue size
try:
queue_size = self.queue.qsize()
logger.info(f"Queue size: {queue_size} messages")
except:
queue_size = 0
# Stop output capture
if self.output_capture:
self.output_capture.stop()
# Signal writer to stop
# Signal workers to stop
self.stop_event.set()
logger.info("Stop event set, waiting for writer to drain queue...")
logger.info("Stop event set, waiting for workers to drain queues...")
# Give writer a moment to start draining queue
time.sleep(0.1)
# Give workers a moment to start draining
time.sleep(0.5)
# Check queue size and wait for processing
try:
queue_size = self.queue.qsize()
except NotImplementedError:
queue_size = 0 # Some platforms don't support qsize()
if queue_size > 0:
logger.info(
f"Waiting for writer to process {queue_size} remaining messages..."
)
# Poll queue until empty or timeout
max_wait_time = max(
30, queue_size * 0.05
) # At least 30s, plus 50ms per message
# Poll queue to monitor draining (max 30 seconds)
max_wait_time = 30
start_time = time.time()
last_size = queue_size
while time.time() - start_time < max_wait_time:
try:
current_size = self.queue.qsize()
if current_size == 0:
logger.info("Queue is empty, writer should finish soon")
logger.info("Queue empty - waiting 1s for writer to finish...")
time.sleep(1)
break
# Log progress if queue size changed significantly
if (
last_size - current_size >= 100
or (time.time() - start_time) % 5 < 0.5
):
logger.info(f"Queue progress: {current_size} messages remaining...")
# Log progress
if current_size != last_size:
logger.info(f"Queue: {current_size} remaining")
last_size = current_size
time.sleep(0.5) # Check every 500ms
except NotImplementedError:
# qsize() not supported, just wait
time.sleep(0.5)
except KeyboardInterrupt:
logger.warning("Interrupted during drain - forcing shutdown")
break
except:
break
# Wait for writer process to exit (with generous timeout)
final_timeout = 10
logger.info(
f"Waiting for writer process to exit (timeout: {final_timeout}s)..."
)
self.writer_process.join(timeout=final_timeout)
if time.time() - start_time >= max_wait_time:
logger.error(f"Timeout after {max_wait_time}s - KILLING writer")
if self.writer_process.is_alive():
self.writer_process.kill()
logger.error("Writer killed, exiting")
delattr(self, "writer_process")
sys.exit(1)
# Wait for writer to exit
logger.info("Waiting for writer process to exit...")
self.writer_process.join(timeout=2)
if self.writer_process.is_alive():
logger.warning(
"Writer process did not exit gracefully after queue drained. Waiting 5 more seconds..."
)
self.writer_process.join(timeout=5)
if self.writer_process.is_alive():
logger.error("Writer process still alive, terminating forcefully...")
self.writer_process.terminate()
self.writer_process.join(timeout=2)
# Force kill if still alive
if self.writer_process.is_alive():
logger.error("Writer process did not terminate, killing...")
self.writer_process.kill()
self.writer_process.join(timeout=1)
logger.warning("Writer still alive after 2s, killing...")
self.writer_process.kill()
logger.info(f"Board finished: {self.name}")
# Remove finish method to prevent double-call
# Remove to prevent double-call
delattr(self, "writer_process")
def _register_signal_handlers(self):
@@ -508,20 +502,29 @@ class Board:
if self._interrupt_count == 1:
logger.warning(f"Received {sig_name}, shutting down gracefully...")
logger.warning("Press Ctrl+C again to force exit (may lose data)")
logger.warning("Press Ctrl+C again within 3 seconds to FORCE EXIT")
try:
self.finish()
except Exception as e:
logger.error(f"Error during signal handler cleanup: {e}")
logger.error(f"Error during graceful shutdown: {e}")
finally:
logger.info("Shutdown complete")
sys.exit(0)
else:
elif self._interrupt_count == 2:
logger.error(
f"Received {sig_name} again - FORCE EXIT (data may be lost!)"
"Second Ctrl+C - KILLING writer process (data will be lost!)"
)
if hasattr(self, "writer_process") and self.writer_process.is_alive():
self.writer_process.kill()
time.sleep(0.5)
logger.error("Force exit")
sys.exit(1)
else:
# Third+ interrupt - nuclear option
logger.error("THIRD Ctrl+C - IMMEDIATE EXIT")
import os
os._exit(1)
# Register signal handlers (Ctrl+C, kill)
signal.signal(signal.SIGINT, signal_handler)

View File

@@ -0,0 +1,23 @@
"""Storage backends for KohakuBoard
Available backends:
- HybridStorage: Lance (metrics) + SQLite (metadata) + Histograms (recommended)
- DuckDBStorage: Multi-file DuckDB (backward compatible)
- ParquetStorage: Parquet files (backward compatible)
"""
from kohakuboard.client.storage.base import ParquetStorage
from kohakuboard.client.storage.duckdb import DuckDBStorage
from kohakuboard.client.storage.histogram import HistogramStorage
from kohakuboard.client.storage.hybrid import HybridStorage
from kohakuboard.client.storage.lance import LanceMetricsStorage
from kohakuboard.client.storage.sqlite import SQLiteMetadataStorage
__all__ = [
"HybridStorage",
"DuckDBStorage",
"ParquetStorage",
"HistogramStorage",
"LanceMetricsStorage",
"SQLiteMetadataStorage",
]

View File

@@ -448,6 +448,7 @@ class DuckDBStorage:
name: str,
values: List[float],
num_bins: int = 64,
precision: str = "compact",
):
"""Append histogram log entry (pre-computed bins to save space)
@@ -457,6 +458,7 @@ class DuckDBStorage:
name: Histogram log name
values: List of values to create histogram from
num_bins: Number of bins for histogram
precision: Ignored for DuckDB backend
"""
# Compute histogram (bins + counts) instead of storing raw values
values_array = np.array(values, dtype=np.float32)

View File

@@ -0,0 +1,188 @@
"""Histogram storage using Lance (grouped by namespace)
Strategy:
1. One Lance file per namespace:
- params/layer1, params/layer2 → params_i32.lance (if int32)
- gradients/layer1, gradients/layer2 → gradients_i32.lance
- custom → custom_i32.lance
2. Precision is per-file (suffix: _u8 or _i32)
3. Schema includes "name" field (full name with namespace)
Schema:
- step: int64
- global_step: int64
- name: string (full name: "params/layer1")
- counts: list<uint8 or int32>
- min: float32 (p1)
- max: float32 (p99)
"""
from pathlib import Path
from typing import Any, Dict, List, Optional
import numpy as np
import pyarrow as pa
from lance.dataset import write_dataset
from loguru import logger
class HistogramStorage:
"""Histogram storage with namespace-based grouping"""
def __init__(self, base_dir: Path, num_bins: int = 64):
"""Initialize histogram storage
Args:
base_dir: Base directory
num_bins: Number of bins (default: 64)
"""
self.base_dir = base_dir
self.base_dir.mkdir(parents=True, exist_ok=True)
self.histograms_dir = base_dir / "histograms"
self.histograms_dir.mkdir(exist_ok=True)
self.num_bins = num_bins
# Buffers grouped by namespace + precision
# Key: "{namespace}_{u8|i32}"
self.buffers: Dict[str, List[Dict[str, Any]]] = {}
# Schemas
self.schema_uint8 = pa.schema(
[
pa.field("step", pa.int64()),
pa.field("global_step", pa.int64()),
pa.field("name", pa.string()),
pa.field("counts", pa.list_(pa.uint8())),
pa.field("min", pa.float32()),
pa.field("max", pa.float32()),
]
)
self.schema_int32 = pa.schema(
[
pa.field("step", pa.int64()),
pa.field("global_step", pa.int64()),
pa.field("name", pa.string()),
pa.field("counts", pa.list_(pa.int32())),
pa.field("min", pa.float32()),
pa.field("max", pa.float32()),
]
)
def append_histogram(
self,
step: int,
global_step: Optional[int],
name: str,
values: List[float],
num_bins: int = None,
precision: str = "exact",
):
"""Append histogram
Args:
step: Step number
global_step: Global step
name: Histogram name (e.g., "gradients/layer1")
values: Raw values
num_bins: Ignored
precision: "exact" (int32, default) or "compact" (uint8)
"""
if not values:
return
values_array = np.array(values, dtype=np.float32)
values_array = values_array[np.isfinite(values_array)]
if len(values_array) == 0:
return
# Compute p1-p99 range
p1 = float(np.percentile(values_array, 1))
p99 = float(np.percentile(values_array, 99))
if p99 - p1 < 1e-6:
p1 = float(values_array.min())
p99 = float(values_array.max())
if p99 - p1 < 1e-6:
p1 -= 0.5
p99 += 0.5
# Compute histogram
counts, _ = np.histogram(values_array, bins=self.num_bins, range=(p1, p99))
# Convert based on precision
if precision == "compact":
max_count = counts.max()
final_counts = (
(counts / max_count * 255).astype(np.uint8)
if max_count > 0
else counts.astype(np.uint8)
)
schema = self.schema_uint8
suffix = "_u8"
else:
final_counts = counts.astype(np.int32)
schema = self.schema_int32
suffix = "_i32"
# Extract namespace
namespace = name.split("/")[0] if "/" in name else name.replace("/", "__")
buffer_key = f"{namespace}{suffix}"
# Initialize buffer
if buffer_key not in self.buffers:
self.buffers[buffer_key] = []
# Add to buffer
self.buffers[buffer_key].append(
{
"step": step,
"global_step": global_step,
"name": name,
"counts": final_counts.tolist(),
"min": p1,
"max": p99,
}
)
# Store schema
if not hasattr(self, "_schemas"):
self._schemas = {}
self._schemas[buffer_key] = schema
def flush(self):
"""Flush all buffers (writes to ~2-4 Lance files total)"""
if not self.buffers:
return
total_entries = sum(len(buf) for buf in self.buffers.values())
total_files = len(self.buffers)
for buffer_key, buffer in list(self.buffers.items()):
if not buffer:
continue
try:
schema = self._schemas.get(buffer_key, self.schema_int32)
table = pa.Table.from_pylist(buffer, schema=schema)
hist_file = self.histograms_dir / f"{buffer_key}.lance"
if hist_file.exists():
write_dataset(table, str(hist_file), mode="append")
else:
write_dataset(table, str(hist_file))
buffer.clear()
except Exception as e:
logger.error(f"Failed to flush {buffer_key}: {e}")
logger.debug(f"Flushed {total_entries} histograms to {total_files} Lance files")
def close(self):
"""Close storage"""
self.flush()
logger.debug("Histogram storage closed")

View File

@@ -3,6 +3,7 @@
Combines the best of both worlds:
- Lance: Dynamic schema, efficient columnar storage for metrics
- SQLite: Fixed schema, excellent concurrency for media/tables
- Adaptive histograms: Lance with percentile-based range tracking
"""
from pathlib import Path
@@ -10,8 +11,9 @@ from typing import Any, Dict, List, Optional
from loguru import logger
from kohakuboard.client.storage_lance import LanceMetricsStorage
from kohakuboard.client.storage_sqlite import SQLiteMetadataStorage
from kohakuboard.client.storage.histogram import HistogramStorage
from kohakuboard.client.storage.lance import LanceMetricsStorage
from kohakuboard.client.storage.sqlite import SQLiteMetadataStorage
class HybridStorage:
@@ -42,8 +44,9 @@ class HybridStorage:
# Initialize sub-storages
self.metrics_storage = LanceMetricsStorage(base_dir)
self.metadata_storage = SQLiteMetadataStorage(base_dir)
self.histogram_storage = HistogramStorage(base_dir, num_bins=64)
logger.debug("Hybrid storage initialized (Lance + SQLite)")
logger.debug("Hybrid storage initialized (Lance + SQLite + Histograms)")
def append_metrics(
self,
@@ -89,6 +92,12 @@ class HybridStorage:
media_list: List of media metadata dicts
caption: Optional caption
"""
# Record step info (use current timestamp)
from datetime import datetime, timezone
timestamp_ms = int(datetime.now(timezone.utc).timestamp() * 1000)
self.metadata_storage.append_step_info(step, global_step, timestamp_ms)
self.metadata_storage.append_media(step, global_step, name, media_list, caption)
def append_table(
@@ -106,6 +115,12 @@ class HybridStorage:
name: Table log name
table_data: Table dict
"""
# Record step info
from datetime import datetime, timezone
timestamp_ms = int(datetime.now(timezone.utc).timestamp() * 1000)
self.metadata_storage.append_step_info(step, global_step, timestamp_ms)
self.metadata_storage.append_table(step, global_step, name, table_data)
def append_histogram(
@@ -115,28 +130,27 @@ class HybridStorage:
name: str,
values: List[float],
num_bins: int = 64,
precision: str = "compact",
):
"""Append histogram (SKIPPED - not logged locally)
Histograms are not logged locally in hybrid backend (following wandb pattern).
This silently skips - no error, just logs debug message once per histogram name.
"""Append histogram with configurable precision
Args:
step: Step number
global_step: Global step
name: Histogram name
values: Values (ignored)
num_bins: Number of bins (ignored)
name: Histogram name (e.g., "gradients/layer1")
values: Raw values array
num_bins: Number of bins
precision: "compact" (uint8) or "exact" (int32)
"""
# Silent skip - only log once per histogram name to avoid spam
if not hasattr(self, "_logged_histogram_skip"):
self._logged_histogram_skip = set()
# Record step info
from datetime import datetime, timezone
if name not in self._logged_histogram_skip:
logger.debug(
f"Histogram '{name}' skipped (hybrid backend doesn't log histograms locally)"
)
self._logged_histogram_skip.add(name)
timestamp_ms = int(datetime.now(timezone.utc).timestamp() * 1000)
self.metadata_storage.append_step_info(step, global_step, timestamp_ms)
self.histogram_storage.append_histogram(
step, global_step, name, values, num_bins, precision
)
def flush_metrics(self):
"""Flush metrics buffer to Lance"""
@@ -151,18 +165,21 @@ class HybridStorage:
self.metadata_storage._flush_tables()
def flush_histograms(self):
"""Flush histograms (no-op, skipped)"""
pass
"""Flush histogram buffer"""
self.histogram_storage.flush()
def flush_all(self):
"""Flush all buffers"""
self.flush_metrics()
self.metadata_storage._flush_steps() # CRITICAL: Flush step info!
self.flush_media()
self.flush_tables()
self.flush_histograms()
logger.info("Flushed all buffers (hybrid storage)")
def close(self):
"""Close all storage backends"""
self.metrics_storage.close()
self.metadata_storage.close()
self.histogram_storage.close()
logger.debug("Hybrid storage closed")

View File

@@ -58,12 +58,13 @@ class LanceMetricsStorage:
self.flush_interval = 2.0 # Flush every 2 seconds (not too aggressive)
# Fixed schema for all metrics
# Use float32 for values (sufficient precision for ML metrics, saves space)
self.schema = pa.schema(
[
pa.field("step", pa.int64()),
pa.field("global_step", pa.int64()),
pa.field("timestamp", pa.int64()),
pa.field("value", pa.float64()),
pa.field("value", pa.float32()), # float32, not float64
]
)
@@ -116,17 +117,8 @@ class LanceMetricsStorage:
if escaped_name not in self.last_flush_time:
self.last_flush_time[escaped_name] = time.time()
# Flush this metric if threshold reached OR time interval elapsed
current_time = time.time()
time_since_flush = current_time - self.last_flush_time[escaped_name]
should_flush = (
len(self.buffers[escaped_name]) >= self.flush_threshold
or time_since_flush >= self.flush_interval
)
if should_flush:
self._flush_metric(escaped_name)
# Don't auto-flush - writer will call flush() periodically
# This allows batching ALL pending data at once
def _flush_metric(self, metric_name: str):
"""Flush a single metric's buffer to its Lance file

View File

@@ -134,9 +134,7 @@ class SQLiteMetadataStorage:
"""
self.step_buffer.append((step, global_step, timestamp))
# Batch flush when threshold reached
if len(self.step_buffer) >= self.step_flush_threshold:
self._flush_steps()
# Don't auto-flush - writer will call flush() periodically
def _flush_steps(self):
"""Flush steps buffer"""
@@ -186,9 +184,7 @@ class SQLiteMetadataStorage:
)
self.media_buffer.append(row)
# Batch flush
if len(self.media_buffer) >= self.media_flush_threshold:
self._flush_media()
# Don't auto-flush - writer will call flush() periodically
def _flush_media(self):
"""Flush media buffer"""
@@ -233,9 +229,7 @@ class SQLiteMetadataStorage:
)
self.table_buffer.append(row)
# Batch flush
if len(self.table_buffer) >= self.table_flush_threshold:
self._flush_tables()
# Don't auto-flush - writer will call flush() periodically
def _flush_tables(self):
"""Flush tables buffer"""

View File

@@ -0,0 +1,287 @@
"""Individual worker processes for each storage type
Architecture:
- One process per storage type (metrics, media, tables, histograms)
- Each reads from its own queue
- Batching: count-based (threshold) OR time-based (interval)
- No GIL contention between storage types
"""
import multiprocessing as mp
import time
from pathlib import Path
from queue import Empty
from typing import Any
from loguru import logger
from kohakuboard.client.storage.histogram import HistogramStorage
from kohakuboard.client.storage.lance import LanceMetricsStorage
from kohakuboard.client.storage.sqlite import SQLiteMetadataStorage
def metrics_worker_main(board_dir: Path, queue: Any, stop_event: Any):
"""Worker process for scalar metrics (Lance storage)
Args:
board_dir: Board directory
queue: Message queue (mp.Queue)
stop_event: Stop event (mp.Event)
"""
# Configure logger
logger.remove()
logger.add(
board_dir / "logs" / "metrics_worker.log", rotation="10 MB", level="DEBUG"
)
storage = LanceMetricsStorage(board_dir / "data")
# Batching config
batch_threshold = 1000 # Flush after 1000 messages
batch_interval = 2.0 # OR flush after 2 seconds
logger.info("Metrics worker started")
try:
while not stop_event.is_set():
# STEP 1: Drain ALL messages from queue
messages = []
while True:
try:
messages.append(queue.get_nowait())
except Empty:
break
# STEP 2: Process all messages
if messages:
batch_start = time.time()
for message in messages:
storage.append_metrics(
message["step"],
message.get("global_step"),
message["metrics"],
message.get("timestamp"),
)
# STEP 3: Flush once
storage.flush()
batch_time = time.time() - batch_start
logger.debug(
f"Processed and flushed {len(messages)} metrics in {batch_time*1000:.1f}ms"
)
else:
# No messages - sleep
time.sleep(0.01)
# Final flush
logger.info("Metrics worker shutting down, draining queue...")
final_count = 0
while not queue.empty():
try:
message = queue.get_nowait()
storage.append_metrics(
message["step"],
message.get("global_step"),
message["metrics"],
message.get("timestamp"),
)
final_count += 1
except Empty:
break
storage.flush()
storage.close()
logger.info(f"Metrics worker stopped (drained {final_count}, flushed all)")
except Exception as e:
logger.error(f"Metrics worker error: {e}")
raise
def metadata_worker_main(
board_dir: Path, media_queue: Any, tables_queue: Any, stop_event: Any
):
"""Worker process for media and tables (SQLite storage)
Args:
board_dir: Board directory
media_queue: Media queue (mp.Queue)
tables_queue: Tables queue (mp.Queue)
stop_event: Stop event (mp.Event)
"""
logger.remove()
logger.add(
board_dir / "logs" / "metadata_worker.log", rotation="10 MB", level="DEBUG"
)
from kohakuboard.client.media import MediaHandler
storage = SQLiteMetadataStorage(board_dir / "data")
media_handler = MediaHandler(board_dir / "media")
batch_threshold = 100
batch_interval = 2.0
logger.info("Metadata worker started")
try:
while not stop_event.is_set():
# STEP 1: Drain ALL messages from both queues
media_messages = []
table_messages = []
while True:
try:
media_messages.append(media_queue.get_nowait())
except Empty:
break
while True:
try:
table_messages.append(tables_queue.get_nowait())
except Empty:
break
# STEP 2: Process all messages
total = 0
if media_messages or table_messages:
batch_start = time.time()
for message in media_messages:
if "images" in message:
images = message["images"]
media_list = media_handler.process_images(
images, message["name"], message["step"]
)
else:
media_type = message.get("media_type", "image")
media_data = message["media_data"]
media_meta = media_handler.process_media(
media_data, message["name"], message["step"], media_type
)
media_list = [media_meta]
storage.append_media(
message["step"],
message.get("global_step"),
message["name"],
media_list,
message.get("caption"),
)
for message in table_messages:
storage.append_table(
message["step"],
message.get("global_step"),
message["name"],
message["table_data"],
)
# STEP 3: Flush once
storage.flush_all()
total = len(media_messages) + len(table_messages)
batch_time = time.time() - batch_start
logger.debug(
f"Processed and flushed {total} metadata entries ({len(media_messages)} media, {len(table_messages)} tables) in {batch_time*1000:.1f}ms"
)
else:
# No messages - sleep
time.sleep(0.01)
# Final flush
logger.info("Metadata worker shutting down...")
storage.flush_all()
storage.close()
logger.info("Metadata worker stopped")
except Exception as e:
logger.error(f"Metadata worker error: {e}")
raise
def histogram_worker_main(board_dir: Path, queue: Any, stop_event: Any):
"""Worker process for histograms (Lance storage)
Args:
board_dir: Board directory
queue: Histogram queue (mp.Queue)
stop_event: Stop event (mp.Event)
"""
logger.remove()
logger.add(
board_dir / "logs" / "histogram_worker.log", rotation="10 MB", level="DEBUG"
)
storage = HistogramStorage(board_dir / "data", num_bins=64)
batch_threshold = 500 # Total histograms across all names
batch_interval = 2.0
logger.info("Histogram worker started")
try:
while not stop_event.is_set():
# Check queue size before draining
try:
queue_size = queue.qsize()
except:
queue_size = 0
# STEP 1: Drain ALL messages from queue
messages = []
drain_start = time.time()
while True:
try:
messages.append(queue.get_nowait())
except Empty:
break
drain_time = time.time() - drain_start
# STEP 2: Process all messages (add to storage buffers)
if messages:
batch_start = time.time()
for message in messages:
storage.append_histogram(
message["step"],
message.get("global_step"),
message["name"],
message["values"],
message.get("num_bins", 64),
message.get("precision", "exact"),
)
# STEP 3: Flush ALL buffers once
storage.flush()
batch_time = time.time() - batch_start
logger.info(
f"Drained {len(messages)}/{queue_size} from queue in {drain_time*1000:.1f}ms, processed+flushed in {batch_time*1000:.1f}ms"
)
else:
# No messages - sleep
time.sleep(0.01)
# Final flush
logger.info("Histogram worker shutting down, draining queue...")
drained = 0
while not queue.empty():
try:
message = queue.get_nowait()
storage.append_histogram(
message["step"],
message.get("global_step"),
message["name"],
message["values"],
message.get("num_bins", 64),
message.get("precision", "exact"),
)
drained += 1
except Empty:
break
storage.flush()
storage.close()
logger.info(f"Histogram worker stopped (drained {drained}, flushed all)")
except Exception as e:
logger.error(f"Histogram worker error: {e}")
raise

View File

@@ -18,9 +18,7 @@ from typing import Any
from loguru import logger
from kohakuboard.client.media import MediaHandler
from kohakuboard.client.storage import ParquetStorage
from kohakuboard.client.storage_duckdb import DuckDBStorage
from kohakuboard.client.storage_hybrid import HybridStorage
from kohakuboard.client.storage import DuckDBStorage, HybridStorage, ParquetStorage
class LogWriter:
@@ -59,21 +57,57 @@ class LogWriter:
self.auto_flush_interval = 5 # Auto-flush every 5 seconds (aggressive)
def run(self):
"""Main loop - process messages from queue"""
"""Main loop - adaptive batching with exponential backoff"""
logger.info(f"LogWriter started for {self.board_dir}")
# Adaptive sleep parameters
min_period = 0.01 # 10ms minimum sleep
max_period = 1.0 # 1s maximum sleep
consecutive_empty = 0 # Track consecutive empty queue reads
try:
while not self.stop_event.is_set():
try:
# Get message from queue (shorter timeout for faster shutdown)
message = self.queue.get(timeout=0.5)
self._process_message(message)
self.messages_processed += 1
# Process ALL available messages in queue
batch_count = 0
batch_start = time.time()
except Empty:
# No message - check if we need to auto-flush
if time.time() - self.last_flush_time > self.auto_flush_interval:
self._auto_flush()
# Drain queue completely (up to 10k to allow stop_event check)
while batch_count < 10000 and not self.stop_event.is_set():
try:
message = self.queue.get_nowait()
self._process_message(message)
self.messages_processed += 1
batch_count += 1
except Empty:
break
# Flush immediately after processing ANY messages
if batch_count > 0:
self.storage.flush_all()
batch_time = time.time() - batch_start
logger.debug(
f"Processed and flushed {batch_count} messages in {batch_time*1000:.1f}ms"
)
self.last_flush_time = time.time()
# Reset backoff counter - we had work to do
consecutive_empty = 0
else:
# Queue empty - increase backoff
consecutive_empty += 1
# Adaptive sleep: min_period * 2^k, capped at max_period
sleep_time = min(
min_period * (2**consecutive_empty), max_period
)
# Sleep in small chunks to allow responsive shutdown
# Instead of sleep(1.0), sleep(0.1) × 10 times and check stop_event
slept = 0.0
while slept < sleep_time and not self.stop_event.is_set():
time.sleep(0.05) # Sleep in 50ms chunks
slept += 0.05
except KeyboardInterrupt:
# Received interrupt in worker - DRAIN QUEUE FIRST
@@ -81,7 +115,6 @@ class LogWriter:
"Writer received interrupt, draining queue before stopping..."
)
# Continue processing until stop_event is set by main process
# Don't break immediately!
except Exception as e:
logger.error(f"Error processing message: {e}")
@@ -187,8 +220,11 @@ class LogWriter:
name = message["name"]
values = message["values"]
num_bins = message.get("num_bins", 64)
precision = message.get("precision", "compact")
self.storage.append_histogram(step, global_step, name, values, num_bins)
self.storage.append_histogram(
step, global_step, name, values, num_bins, precision
)
def _handle_flush(self):
"""Handle explicit flush request"""