mirror of
https://github.com/KohakuBlueleaf/KohakuHub.git
synced 2026-03-11 17:34:08 -05:00
lance impl for hist
This commit is contained in:
@@ -100,23 +100,25 @@ class Board:
|
||||
# _step increments on EVERY log/media/table call (auto-increment)
|
||||
# _global_step is set explicitly via step() method
|
||||
self._step = -1 # Start at -1, will be 0 on first log
|
||||
self._global_step: Optional[int] = None
|
||||
self._global_step: int = 0 # Start at 0, NOT None!
|
||||
|
||||
# Shutdown tracking
|
||||
self._is_finishing = False # Prevent re-entrant finish() calls
|
||||
self._interrupt_count = 0 # Track Ctrl+C presses for force exit
|
||||
|
||||
# Multiprocessing setup
|
||||
self.queue = mp.Queue(
|
||||
maxsize=50000
|
||||
) # Very large queue for heavy logging (e.g., per-step histograms)
|
||||
# Multiprocessing setup - use Manager.Queue to avoid Windows deadlock
|
||||
# See: https://bugs.python.org/issue29797
|
||||
manager = mp.Manager()
|
||||
self.queue = manager.Queue(maxsize=50000)
|
||||
self.stop_event = mp.Event()
|
||||
|
||||
# Start writer process
|
||||
# Start single writer process
|
||||
from kohakuboard.client.writer import writer_process_main
|
||||
|
||||
self.writer_process = mp.Process(
|
||||
target=writer_process_main,
|
||||
args=(self.board_dir, self.queue, self.stop_event, self.backend),
|
||||
daemon=False, # Not daemon - we want clean shutdown
|
||||
daemon=False,
|
||||
)
|
||||
self.writer_process.start()
|
||||
|
||||
@@ -319,7 +321,11 @@ class Board:
|
||||
self.queue.put(message)
|
||||
|
||||
def log_histogram(
|
||||
self, name: str, values: Union[List[float], Any], num_bins: int = 64
|
||||
self,
|
||||
name: str,
|
||||
values: Union[List[float], Any],
|
||||
num_bins: int = 64,
|
||||
precision: str = "exact",
|
||||
):
|
||||
"""Log histogram data (non-blocking)
|
||||
|
||||
@@ -327,15 +333,16 @@ class Board:
|
||||
name: Name for this histogram log (supports namespace: "gradients/layer1")
|
||||
values: List of values or tensor to create histogram from
|
||||
num_bins: Number of bins for histogram (default: 64)
|
||||
precision: "exact" (int32, default) or "compact" (uint8, ~1% loss)
|
||||
|
||||
Example:
|
||||
>>> # Log gradient histogram
|
||||
>>> # Log gradient histogram (compact)
|
||||
>>> grads = [p.grad.flatten().cpu().numpy() for p in model.parameters()]
|
||||
>>> board.log_histogram("gradients/all", np.concatenate(grads))
|
||||
>>>
|
||||
>>> # Log parameter histogram
|
||||
>>> # Log parameter histogram (exact counts)
|
||||
>>> params = model.fc1.weight.detach().cpu().numpy().flatten()
|
||||
>>> board.log_histogram("params/fc1_weight", params)
|
||||
>>> board.log_histogram("params/fc1_weight", params, precision="exact")
|
||||
"""
|
||||
# Increment step (auto-increment on every log call)
|
||||
self._step += 1
|
||||
@@ -345,7 +352,7 @@ class Board:
|
||||
queue_size = self.queue.qsize()
|
||||
if queue_size > 40000:
|
||||
logger.warning(
|
||||
f"Queue size is {queue_size}/50000. Consider reducing histogram logging frequency."
|
||||
f"Queue size is {queue_size}/50000. Consider reducing logging frequency."
|
||||
)
|
||||
except NotImplementedError:
|
||||
pass # qsize() not supported on all platforms
|
||||
@@ -365,6 +372,7 @@ class Board:
|
||||
"name": name,
|
||||
"values": values,
|
||||
"num_bins": num_bins,
|
||||
"precision": precision,
|
||||
}
|
||||
self.queue.put(message)
|
||||
|
||||
@@ -384,10 +392,7 @@ class Board:
|
||||
... loss = train_step(batch)
|
||||
... board.log(loss=loss) # All batches share same global_step
|
||||
"""
|
||||
if self._global_step is None:
|
||||
self._global_step = 0
|
||||
else:
|
||||
self._global_step += increment
|
||||
self._global_step += increment
|
||||
|
||||
def flush(self):
|
||||
"""Flush all pending logs to disk (blocking)
|
||||
@@ -395,8 +400,10 @@ class Board:
|
||||
Normally logs are flushed automatically. Use this for
|
||||
critical checkpoints or before long-running operations.
|
||||
"""
|
||||
message = {"type": "flush"}
|
||||
self.queue.put(message)
|
||||
# Send flush signal
|
||||
flush_msg = {"type": "flush"}
|
||||
self.queue.put(flush_msg)
|
||||
time.sleep(0.5) # Give writer time to flush
|
||||
|
||||
def finish(self):
|
||||
"""Finish logging and clean up
|
||||
@@ -409,87 +416,74 @@ class Board:
|
||||
|
||||
if self._is_finishing:
|
||||
logger.debug("finish() already in progress, skipping re-entrant call")
|
||||
return # Prevent re-entrant calls from signal handler
|
||||
return
|
||||
|
||||
self._is_finishing = True
|
||||
logger.info(f"Finishing board: {self.name}")
|
||||
|
||||
# Check queue size
|
||||
try:
|
||||
queue_size = self.queue.qsize()
|
||||
logger.info(f"Queue size: {queue_size} messages")
|
||||
except:
|
||||
queue_size = 0
|
||||
|
||||
# Stop output capture
|
||||
if self.output_capture:
|
||||
self.output_capture.stop()
|
||||
|
||||
# Signal writer to stop
|
||||
# Signal workers to stop
|
||||
self.stop_event.set()
|
||||
logger.info("Stop event set, waiting for writer to drain queue...")
|
||||
logger.info("Stop event set, waiting for workers to drain queues...")
|
||||
|
||||
# Give writer a moment to start draining queue
|
||||
time.sleep(0.1)
|
||||
# Give workers a moment to start draining
|
||||
time.sleep(0.5)
|
||||
|
||||
# Check queue size and wait for processing
|
||||
try:
|
||||
queue_size = self.queue.qsize()
|
||||
except NotImplementedError:
|
||||
queue_size = 0 # Some platforms don't support qsize()
|
||||
|
||||
if queue_size > 0:
|
||||
logger.info(
|
||||
f"Waiting for writer to process {queue_size} remaining messages..."
|
||||
)
|
||||
|
||||
# Poll queue until empty or timeout
|
||||
max_wait_time = max(
|
||||
30, queue_size * 0.05
|
||||
) # At least 30s, plus 50ms per message
|
||||
# Poll queue to monitor draining (max 30 seconds)
|
||||
max_wait_time = 30
|
||||
start_time = time.time()
|
||||
last_size = queue_size
|
||||
|
||||
while time.time() - start_time < max_wait_time:
|
||||
try:
|
||||
current_size = self.queue.qsize()
|
||||
|
||||
if current_size == 0:
|
||||
logger.info("Queue is empty, writer should finish soon")
|
||||
logger.info("Queue empty - waiting 1s for writer to finish...")
|
||||
time.sleep(1)
|
||||
break
|
||||
|
||||
# Log progress if queue size changed significantly
|
||||
if (
|
||||
last_size - current_size >= 100
|
||||
or (time.time() - start_time) % 5 < 0.5
|
||||
):
|
||||
logger.info(f"Queue progress: {current_size} messages remaining...")
|
||||
# Log progress
|
||||
if current_size != last_size:
|
||||
logger.info(f"Queue: {current_size} remaining")
|
||||
last_size = current_size
|
||||
|
||||
time.sleep(0.5) # Check every 500ms
|
||||
except NotImplementedError:
|
||||
# qsize() not supported, just wait
|
||||
time.sleep(0.5)
|
||||
except KeyboardInterrupt:
|
||||
logger.warning("Interrupted during drain - forcing shutdown")
|
||||
break
|
||||
except:
|
||||
break
|
||||
|
||||
# Wait for writer process to exit (with generous timeout)
|
||||
final_timeout = 10
|
||||
logger.info(
|
||||
f"Waiting for writer process to exit (timeout: {final_timeout}s)..."
|
||||
)
|
||||
self.writer_process.join(timeout=final_timeout)
|
||||
if time.time() - start_time >= max_wait_time:
|
||||
logger.error(f"Timeout after {max_wait_time}s - KILLING writer")
|
||||
if self.writer_process.is_alive():
|
||||
self.writer_process.kill()
|
||||
logger.error("Writer killed, exiting")
|
||||
delattr(self, "writer_process")
|
||||
sys.exit(1)
|
||||
|
||||
# Wait for writer to exit
|
||||
logger.info("Waiting for writer process to exit...")
|
||||
self.writer_process.join(timeout=2)
|
||||
|
||||
if self.writer_process.is_alive():
|
||||
logger.warning(
|
||||
"Writer process did not exit gracefully after queue drained. Waiting 5 more seconds..."
|
||||
)
|
||||
self.writer_process.join(timeout=5)
|
||||
|
||||
if self.writer_process.is_alive():
|
||||
logger.error("Writer process still alive, terminating forcefully...")
|
||||
self.writer_process.terminate()
|
||||
self.writer_process.join(timeout=2)
|
||||
|
||||
# Force kill if still alive
|
||||
if self.writer_process.is_alive():
|
||||
logger.error("Writer process did not terminate, killing...")
|
||||
self.writer_process.kill()
|
||||
self.writer_process.join(timeout=1)
|
||||
logger.warning("Writer still alive after 2s, killing...")
|
||||
self.writer_process.kill()
|
||||
|
||||
logger.info(f"Board finished: {self.name}")
|
||||
|
||||
# Remove finish method to prevent double-call
|
||||
# Remove to prevent double-call
|
||||
delattr(self, "writer_process")
|
||||
|
||||
def _register_signal_handlers(self):
|
||||
@@ -508,20 +502,29 @@ class Board:
|
||||
|
||||
if self._interrupt_count == 1:
|
||||
logger.warning(f"Received {sig_name}, shutting down gracefully...")
|
||||
logger.warning("Press Ctrl+C again to force exit (may lose data)")
|
||||
logger.warning("Press Ctrl+C again within 3 seconds to FORCE EXIT")
|
||||
try:
|
||||
self.finish()
|
||||
except Exception as e:
|
||||
logger.error(f"Error during signal handler cleanup: {e}")
|
||||
logger.error(f"Error during graceful shutdown: {e}")
|
||||
finally:
|
||||
logger.info("Shutdown complete")
|
||||
sys.exit(0)
|
||||
else:
|
||||
elif self._interrupt_count == 2:
|
||||
logger.error(
|
||||
f"Received {sig_name} again - FORCE EXIT (data may be lost!)"
|
||||
"Second Ctrl+C - KILLING writer process (data will be lost!)"
|
||||
)
|
||||
if hasattr(self, "writer_process") and self.writer_process.is_alive():
|
||||
self.writer_process.kill()
|
||||
time.sleep(0.5)
|
||||
logger.error("Force exit")
|
||||
sys.exit(1)
|
||||
else:
|
||||
# Third+ interrupt - nuclear option
|
||||
logger.error("THIRD Ctrl+C - IMMEDIATE EXIT")
|
||||
import os
|
||||
|
||||
os._exit(1)
|
||||
|
||||
# Register signal handlers (Ctrl+C, kill)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
23
src/kohakuboard/client/storage/__init__.py
Normal file
23
src/kohakuboard/client/storage/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Storage backends for KohakuBoard
|
||||
|
||||
Available backends:
|
||||
- HybridStorage: Lance (metrics) + SQLite (metadata) + Histograms (recommended)
|
||||
- DuckDBStorage: Multi-file DuckDB (backward compatible)
|
||||
- ParquetStorage: Parquet files (backward compatible)
|
||||
"""
|
||||
|
||||
from kohakuboard.client.storage.base import ParquetStorage
|
||||
from kohakuboard.client.storage.duckdb import DuckDBStorage
|
||||
from kohakuboard.client.storage.histogram import HistogramStorage
|
||||
from kohakuboard.client.storage.hybrid import HybridStorage
|
||||
from kohakuboard.client.storage.lance import LanceMetricsStorage
|
||||
from kohakuboard.client.storage.sqlite import SQLiteMetadataStorage
|
||||
|
||||
__all__ = [
|
||||
"HybridStorage",
|
||||
"DuckDBStorage",
|
||||
"ParquetStorage",
|
||||
"HistogramStorage",
|
||||
"LanceMetricsStorage",
|
||||
"SQLiteMetadataStorage",
|
||||
]
|
||||
@@ -448,6 +448,7 @@ class DuckDBStorage:
|
||||
name: str,
|
||||
values: List[float],
|
||||
num_bins: int = 64,
|
||||
precision: str = "compact",
|
||||
):
|
||||
"""Append histogram log entry (pre-computed bins to save space)
|
||||
|
||||
@@ -457,6 +458,7 @@ class DuckDBStorage:
|
||||
name: Histogram log name
|
||||
values: List of values to create histogram from
|
||||
num_bins: Number of bins for histogram
|
||||
precision: Ignored for DuckDB backend
|
||||
"""
|
||||
# Compute histogram (bins + counts) instead of storing raw values
|
||||
values_array = np.array(values, dtype=np.float32)
|
||||
188
src/kohakuboard/client/storage/histogram.py
Normal file
188
src/kohakuboard/client/storage/histogram.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""Histogram storage using Lance (grouped by namespace)
|
||||
|
||||
Strategy:
|
||||
1. One Lance file per namespace:
|
||||
- params/layer1, params/layer2 → params_i32.lance (if int32)
|
||||
- gradients/layer1, gradients/layer2 → gradients_i32.lance
|
||||
- custom → custom_i32.lance
|
||||
2. Precision is per-file (suffix: _u8 or _i32)
|
||||
3. Schema includes "name" field (full name with namespace)
|
||||
|
||||
Schema:
|
||||
- step: int64
|
||||
- global_step: int64
|
||||
- name: string (full name: "params/layer1")
|
||||
- counts: list<uint8 or int32>
|
||||
- min: float32 (p1)
|
||||
- max: float32 (p99)
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from lance.dataset import write_dataset
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class HistogramStorage:
|
||||
"""Histogram storage with namespace-based grouping"""
|
||||
|
||||
def __init__(self, base_dir: Path, num_bins: int = 64):
|
||||
"""Initialize histogram storage
|
||||
|
||||
Args:
|
||||
base_dir: Base directory
|
||||
num_bins: Number of bins (default: 64)
|
||||
"""
|
||||
self.base_dir = base_dir
|
||||
self.base_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.histograms_dir = base_dir / "histograms"
|
||||
self.histograms_dir.mkdir(exist_ok=True)
|
||||
|
||||
self.num_bins = num_bins
|
||||
|
||||
# Buffers grouped by namespace + precision
|
||||
# Key: "{namespace}_{u8|i32}"
|
||||
self.buffers: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
# Schemas
|
||||
self.schema_uint8 = pa.schema(
|
||||
[
|
||||
pa.field("step", pa.int64()),
|
||||
pa.field("global_step", pa.int64()),
|
||||
pa.field("name", pa.string()),
|
||||
pa.field("counts", pa.list_(pa.uint8())),
|
||||
pa.field("min", pa.float32()),
|
||||
pa.field("max", pa.float32()),
|
||||
]
|
||||
)
|
||||
|
||||
self.schema_int32 = pa.schema(
|
||||
[
|
||||
pa.field("step", pa.int64()),
|
||||
pa.field("global_step", pa.int64()),
|
||||
pa.field("name", pa.string()),
|
||||
pa.field("counts", pa.list_(pa.int32())),
|
||||
pa.field("min", pa.float32()),
|
||||
pa.field("max", pa.float32()),
|
||||
]
|
||||
)
|
||||
|
||||
def append_histogram(
|
||||
self,
|
||||
step: int,
|
||||
global_step: Optional[int],
|
||||
name: str,
|
||||
values: List[float],
|
||||
num_bins: int = None,
|
||||
precision: str = "exact",
|
||||
):
|
||||
"""Append histogram
|
||||
|
||||
Args:
|
||||
step: Step number
|
||||
global_step: Global step
|
||||
name: Histogram name (e.g., "gradients/layer1")
|
||||
values: Raw values
|
||||
num_bins: Ignored
|
||||
precision: "exact" (int32, default) or "compact" (uint8)
|
||||
"""
|
||||
if not values:
|
||||
return
|
||||
|
||||
values_array = np.array(values, dtype=np.float32)
|
||||
values_array = values_array[np.isfinite(values_array)]
|
||||
|
||||
if len(values_array) == 0:
|
||||
return
|
||||
|
||||
# Compute p1-p99 range
|
||||
p1 = float(np.percentile(values_array, 1))
|
||||
p99 = float(np.percentile(values_array, 99))
|
||||
|
||||
if p99 - p1 < 1e-6:
|
||||
p1 = float(values_array.min())
|
||||
p99 = float(values_array.max())
|
||||
if p99 - p1 < 1e-6:
|
||||
p1 -= 0.5
|
||||
p99 += 0.5
|
||||
|
||||
# Compute histogram
|
||||
counts, _ = np.histogram(values_array, bins=self.num_bins, range=(p1, p99))
|
||||
|
||||
# Convert based on precision
|
||||
if precision == "compact":
|
||||
max_count = counts.max()
|
||||
final_counts = (
|
||||
(counts / max_count * 255).astype(np.uint8)
|
||||
if max_count > 0
|
||||
else counts.astype(np.uint8)
|
||||
)
|
||||
schema = self.schema_uint8
|
||||
suffix = "_u8"
|
||||
else:
|
||||
final_counts = counts.astype(np.int32)
|
||||
schema = self.schema_int32
|
||||
suffix = "_i32"
|
||||
|
||||
# Extract namespace
|
||||
namespace = name.split("/")[0] if "/" in name else name.replace("/", "__")
|
||||
buffer_key = f"{namespace}{suffix}"
|
||||
|
||||
# Initialize buffer
|
||||
if buffer_key not in self.buffers:
|
||||
self.buffers[buffer_key] = []
|
||||
|
||||
# Add to buffer
|
||||
self.buffers[buffer_key].append(
|
||||
{
|
||||
"step": step,
|
||||
"global_step": global_step,
|
||||
"name": name,
|
||||
"counts": final_counts.tolist(),
|
||||
"min": p1,
|
||||
"max": p99,
|
||||
}
|
||||
)
|
||||
|
||||
# Store schema
|
||||
if not hasattr(self, "_schemas"):
|
||||
self._schemas = {}
|
||||
self._schemas[buffer_key] = schema
|
||||
|
||||
def flush(self):
|
||||
"""Flush all buffers (writes to ~2-4 Lance files total)"""
|
||||
if not self.buffers:
|
||||
return
|
||||
|
||||
total_entries = sum(len(buf) for buf in self.buffers.values())
|
||||
total_files = len(self.buffers)
|
||||
|
||||
for buffer_key, buffer in list(self.buffers.items()):
|
||||
if not buffer:
|
||||
continue
|
||||
|
||||
try:
|
||||
schema = self._schemas.get(buffer_key, self.schema_int32)
|
||||
table = pa.Table.from_pylist(buffer, schema=schema)
|
||||
hist_file = self.histograms_dir / f"{buffer_key}.lance"
|
||||
|
||||
if hist_file.exists():
|
||||
write_dataset(table, str(hist_file), mode="append")
|
||||
else:
|
||||
write_dataset(table, str(hist_file))
|
||||
|
||||
buffer.clear()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to flush {buffer_key}: {e}")
|
||||
|
||||
logger.debug(f"Flushed {total_entries} histograms to {total_files} Lance files")
|
||||
|
||||
def close(self):
|
||||
"""Close storage"""
|
||||
self.flush()
|
||||
logger.debug("Histogram storage closed")
|
||||
@@ -3,6 +3,7 @@
|
||||
Combines the best of both worlds:
|
||||
- Lance: Dynamic schema, efficient columnar storage for metrics
|
||||
- SQLite: Fixed schema, excellent concurrency for media/tables
|
||||
- Adaptive histograms: Lance with percentile-based range tracking
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
@@ -10,8 +11,9 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from kohakuboard.client.storage_lance import LanceMetricsStorage
|
||||
from kohakuboard.client.storage_sqlite import SQLiteMetadataStorage
|
||||
from kohakuboard.client.storage.histogram import HistogramStorage
|
||||
from kohakuboard.client.storage.lance import LanceMetricsStorage
|
||||
from kohakuboard.client.storage.sqlite import SQLiteMetadataStorage
|
||||
|
||||
|
||||
class HybridStorage:
|
||||
@@ -42,8 +44,9 @@ class HybridStorage:
|
||||
# Initialize sub-storages
|
||||
self.metrics_storage = LanceMetricsStorage(base_dir)
|
||||
self.metadata_storage = SQLiteMetadataStorage(base_dir)
|
||||
self.histogram_storage = HistogramStorage(base_dir, num_bins=64)
|
||||
|
||||
logger.debug("Hybrid storage initialized (Lance + SQLite)")
|
||||
logger.debug("Hybrid storage initialized (Lance + SQLite + Histograms)")
|
||||
|
||||
def append_metrics(
|
||||
self,
|
||||
@@ -89,6 +92,12 @@ class HybridStorage:
|
||||
media_list: List of media metadata dicts
|
||||
caption: Optional caption
|
||||
"""
|
||||
# Record step info (use current timestamp)
|
||||
from datetime import datetime, timezone
|
||||
|
||||
timestamp_ms = int(datetime.now(timezone.utc).timestamp() * 1000)
|
||||
self.metadata_storage.append_step_info(step, global_step, timestamp_ms)
|
||||
|
||||
self.metadata_storage.append_media(step, global_step, name, media_list, caption)
|
||||
|
||||
def append_table(
|
||||
@@ -106,6 +115,12 @@ class HybridStorage:
|
||||
name: Table log name
|
||||
table_data: Table dict
|
||||
"""
|
||||
# Record step info
|
||||
from datetime import datetime, timezone
|
||||
|
||||
timestamp_ms = int(datetime.now(timezone.utc).timestamp() * 1000)
|
||||
self.metadata_storage.append_step_info(step, global_step, timestamp_ms)
|
||||
|
||||
self.metadata_storage.append_table(step, global_step, name, table_data)
|
||||
|
||||
def append_histogram(
|
||||
@@ -115,28 +130,27 @@ class HybridStorage:
|
||||
name: str,
|
||||
values: List[float],
|
||||
num_bins: int = 64,
|
||||
precision: str = "compact",
|
||||
):
|
||||
"""Append histogram (SKIPPED - not logged locally)
|
||||
|
||||
Histograms are not logged locally in hybrid backend (following wandb pattern).
|
||||
This silently skips - no error, just logs debug message once per histogram name.
|
||||
"""Append histogram with configurable precision
|
||||
|
||||
Args:
|
||||
step: Step number
|
||||
global_step: Global step
|
||||
name: Histogram name
|
||||
values: Values (ignored)
|
||||
num_bins: Number of bins (ignored)
|
||||
name: Histogram name (e.g., "gradients/layer1")
|
||||
values: Raw values array
|
||||
num_bins: Number of bins
|
||||
precision: "compact" (uint8) or "exact" (int32)
|
||||
"""
|
||||
# Silent skip - only log once per histogram name to avoid spam
|
||||
if not hasattr(self, "_logged_histogram_skip"):
|
||||
self._logged_histogram_skip = set()
|
||||
# Record step info
|
||||
from datetime import datetime, timezone
|
||||
|
||||
if name not in self._logged_histogram_skip:
|
||||
logger.debug(
|
||||
f"Histogram '{name}' skipped (hybrid backend doesn't log histograms locally)"
|
||||
)
|
||||
self._logged_histogram_skip.add(name)
|
||||
timestamp_ms = int(datetime.now(timezone.utc).timestamp() * 1000)
|
||||
self.metadata_storage.append_step_info(step, global_step, timestamp_ms)
|
||||
|
||||
self.histogram_storage.append_histogram(
|
||||
step, global_step, name, values, num_bins, precision
|
||||
)
|
||||
|
||||
def flush_metrics(self):
|
||||
"""Flush metrics buffer to Lance"""
|
||||
@@ -151,18 +165,21 @@ class HybridStorage:
|
||||
self.metadata_storage._flush_tables()
|
||||
|
||||
def flush_histograms(self):
|
||||
"""Flush histograms (no-op, skipped)"""
|
||||
pass
|
||||
"""Flush histogram buffer"""
|
||||
self.histogram_storage.flush()
|
||||
|
||||
def flush_all(self):
|
||||
"""Flush all buffers"""
|
||||
self.flush_metrics()
|
||||
self.metadata_storage._flush_steps() # CRITICAL: Flush step info!
|
||||
self.flush_media()
|
||||
self.flush_tables()
|
||||
self.flush_histograms()
|
||||
logger.info("Flushed all buffers (hybrid storage)")
|
||||
|
||||
def close(self):
|
||||
"""Close all storage backends"""
|
||||
self.metrics_storage.close()
|
||||
self.metadata_storage.close()
|
||||
self.histogram_storage.close()
|
||||
logger.debug("Hybrid storage closed")
|
||||
@@ -58,12 +58,13 @@ class LanceMetricsStorage:
|
||||
self.flush_interval = 2.0 # Flush every 2 seconds (not too aggressive)
|
||||
|
||||
# Fixed schema for all metrics
|
||||
# Use float32 for values (sufficient precision for ML metrics, saves space)
|
||||
self.schema = pa.schema(
|
||||
[
|
||||
pa.field("step", pa.int64()),
|
||||
pa.field("global_step", pa.int64()),
|
||||
pa.field("timestamp", pa.int64()),
|
||||
pa.field("value", pa.float64()),
|
||||
pa.field("value", pa.float32()), # float32, not float64
|
||||
]
|
||||
)
|
||||
|
||||
@@ -116,17 +117,8 @@ class LanceMetricsStorage:
|
||||
if escaped_name not in self.last_flush_time:
|
||||
self.last_flush_time[escaped_name] = time.time()
|
||||
|
||||
# Flush this metric if threshold reached OR time interval elapsed
|
||||
current_time = time.time()
|
||||
time_since_flush = current_time - self.last_flush_time[escaped_name]
|
||||
|
||||
should_flush = (
|
||||
len(self.buffers[escaped_name]) >= self.flush_threshold
|
||||
or time_since_flush >= self.flush_interval
|
||||
)
|
||||
|
||||
if should_flush:
|
||||
self._flush_metric(escaped_name)
|
||||
# Don't auto-flush - writer will call flush() periodically
|
||||
# This allows batching ALL pending data at once
|
||||
|
||||
def _flush_metric(self, metric_name: str):
|
||||
"""Flush a single metric's buffer to its Lance file
|
||||
@@ -134,9 +134,7 @@ class SQLiteMetadataStorage:
|
||||
"""
|
||||
self.step_buffer.append((step, global_step, timestamp))
|
||||
|
||||
# Batch flush when threshold reached
|
||||
if len(self.step_buffer) >= self.step_flush_threshold:
|
||||
self._flush_steps()
|
||||
# Don't auto-flush - writer will call flush() periodically
|
||||
|
||||
def _flush_steps(self):
|
||||
"""Flush steps buffer"""
|
||||
@@ -186,9 +184,7 @@ class SQLiteMetadataStorage:
|
||||
)
|
||||
self.media_buffer.append(row)
|
||||
|
||||
# Batch flush
|
||||
if len(self.media_buffer) >= self.media_flush_threshold:
|
||||
self._flush_media()
|
||||
# Don't auto-flush - writer will call flush() periodically
|
||||
|
||||
def _flush_media(self):
|
||||
"""Flush media buffer"""
|
||||
@@ -233,9 +229,7 @@ class SQLiteMetadataStorage:
|
||||
)
|
||||
self.table_buffer.append(row)
|
||||
|
||||
# Batch flush
|
||||
if len(self.table_buffer) >= self.table_flush_threshold:
|
||||
self._flush_tables()
|
||||
# Don't auto-flush - writer will call flush() periodically
|
||||
|
||||
def _flush_tables(self):
|
||||
"""Flush tables buffer"""
|
||||
287
src/kohakuboard/client/workers.py
Normal file
287
src/kohakuboard/client/workers.py
Normal file
@@ -0,0 +1,287 @@
|
||||
"""Individual worker processes for each storage type
|
||||
|
||||
Architecture:
|
||||
- One process per storage type (metrics, media, tables, histograms)
|
||||
- Each reads from its own queue
|
||||
- Batching: count-based (threshold) OR time-based (interval)
|
||||
- No GIL contention between storage types
|
||||
"""
|
||||
|
||||
import multiprocessing as mp
|
||||
import time
|
||||
from pathlib import Path
|
||||
from queue import Empty
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from kohakuboard.client.storage.histogram import HistogramStorage
|
||||
from kohakuboard.client.storage.lance import LanceMetricsStorage
|
||||
from kohakuboard.client.storage.sqlite import SQLiteMetadataStorage
|
||||
|
||||
|
||||
def metrics_worker_main(board_dir: Path, queue: Any, stop_event: Any):
|
||||
"""Worker process for scalar metrics (Lance storage)
|
||||
|
||||
Args:
|
||||
board_dir: Board directory
|
||||
queue: Message queue (mp.Queue)
|
||||
stop_event: Stop event (mp.Event)
|
||||
"""
|
||||
# Configure logger
|
||||
logger.remove()
|
||||
logger.add(
|
||||
board_dir / "logs" / "metrics_worker.log", rotation="10 MB", level="DEBUG"
|
||||
)
|
||||
|
||||
storage = LanceMetricsStorage(board_dir / "data")
|
||||
|
||||
# Batching config
|
||||
batch_threshold = 1000 # Flush after 1000 messages
|
||||
batch_interval = 2.0 # OR flush after 2 seconds
|
||||
|
||||
logger.info("Metrics worker started")
|
||||
|
||||
try:
|
||||
while not stop_event.is_set():
|
||||
# STEP 1: Drain ALL messages from queue
|
||||
messages = []
|
||||
while True:
|
||||
try:
|
||||
messages.append(queue.get_nowait())
|
||||
except Empty:
|
||||
break
|
||||
|
||||
# STEP 2: Process all messages
|
||||
if messages:
|
||||
batch_start = time.time()
|
||||
for message in messages:
|
||||
storage.append_metrics(
|
||||
message["step"],
|
||||
message.get("global_step"),
|
||||
message["metrics"],
|
||||
message.get("timestamp"),
|
||||
)
|
||||
|
||||
# STEP 3: Flush once
|
||||
storage.flush()
|
||||
batch_time = time.time() - batch_start
|
||||
logger.debug(
|
||||
f"Processed and flushed {len(messages)} metrics in {batch_time*1000:.1f}ms"
|
||||
)
|
||||
else:
|
||||
# No messages - sleep
|
||||
time.sleep(0.01)
|
||||
|
||||
# Final flush
|
||||
logger.info("Metrics worker shutting down, draining queue...")
|
||||
final_count = 0
|
||||
while not queue.empty():
|
||||
try:
|
||||
message = queue.get_nowait()
|
||||
storage.append_metrics(
|
||||
message["step"],
|
||||
message.get("global_step"),
|
||||
message["metrics"],
|
||||
message.get("timestamp"),
|
||||
)
|
||||
final_count += 1
|
||||
except Empty:
|
||||
break
|
||||
|
||||
storage.flush()
|
||||
storage.close()
|
||||
logger.info(f"Metrics worker stopped (drained {final_count}, flushed all)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Metrics worker error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def metadata_worker_main(
|
||||
board_dir: Path, media_queue: Any, tables_queue: Any, stop_event: Any
|
||||
):
|
||||
"""Worker process for media and tables (SQLite storage)
|
||||
|
||||
Args:
|
||||
board_dir: Board directory
|
||||
media_queue: Media queue (mp.Queue)
|
||||
tables_queue: Tables queue (mp.Queue)
|
||||
stop_event: Stop event (mp.Event)
|
||||
"""
|
||||
logger.remove()
|
||||
logger.add(
|
||||
board_dir / "logs" / "metadata_worker.log", rotation="10 MB", level="DEBUG"
|
||||
)
|
||||
|
||||
from kohakuboard.client.media import MediaHandler
|
||||
|
||||
storage = SQLiteMetadataStorage(board_dir / "data")
|
||||
media_handler = MediaHandler(board_dir / "media")
|
||||
|
||||
batch_threshold = 100
|
||||
batch_interval = 2.0
|
||||
|
||||
logger.info("Metadata worker started")
|
||||
|
||||
try:
|
||||
while not stop_event.is_set():
|
||||
# STEP 1: Drain ALL messages from both queues
|
||||
media_messages = []
|
||||
table_messages = []
|
||||
|
||||
while True:
|
||||
try:
|
||||
media_messages.append(media_queue.get_nowait())
|
||||
except Empty:
|
||||
break
|
||||
|
||||
while True:
|
||||
try:
|
||||
table_messages.append(tables_queue.get_nowait())
|
||||
except Empty:
|
||||
break
|
||||
|
||||
# STEP 2: Process all messages
|
||||
total = 0
|
||||
if media_messages or table_messages:
|
||||
batch_start = time.time()
|
||||
|
||||
for message in media_messages:
|
||||
if "images" in message:
|
||||
images = message["images"]
|
||||
media_list = media_handler.process_images(
|
||||
images, message["name"], message["step"]
|
||||
)
|
||||
else:
|
||||
media_type = message.get("media_type", "image")
|
||||
media_data = message["media_data"]
|
||||
media_meta = media_handler.process_media(
|
||||
media_data, message["name"], message["step"], media_type
|
||||
)
|
||||
media_list = [media_meta]
|
||||
|
||||
storage.append_media(
|
||||
message["step"],
|
||||
message.get("global_step"),
|
||||
message["name"],
|
||||
media_list,
|
||||
message.get("caption"),
|
||||
)
|
||||
|
||||
for message in table_messages:
|
||||
storage.append_table(
|
||||
message["step"],
|
||||
message.get("global_step"),
|
||||
message["name"],
|
||||
message["table_data"],
|
||||
)
|
||||
|
||||
# STEP 3: Flush once
|
||||
storage.flush_all()
|
||||
total = len(media_messages) + len(table_messages)
|
||||
batch_time = time.time() - batch_start
|
||||
logger.debug(
|
||||
f"Processed and flushed {total} metadata entries ({len(media_messages)} media, {len(table_messages)} tables) in {batch_time*1000:.1f}ms"
|
||||
)
|
||||
else:
|
||||
# No messages - sleep
|
||||
time.sleep(0.01)
|
||||
|
||||
# Final flush
|
||||
logger.info("Metadata worker shutting down...")
|
||||
storage.flush_all()
|
||||
storage.close()
|
||||
logger.info("Metadata worker stopped")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Metadata worker error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def histogram_worker_main(board_dir: Path, queue: Any, stop_event: Any):
|
||||
"""Worker process for histograms (Lance storage)
|
||||
|
||||
Args:
|
||||
board_dir: Board directory
|
||||
queue: Histogram queue (mp.Queue)
|
||||
stop_event: Stop event (mp.Event)
|
||||
"""
|
||||
logger.remove()
|
||||
logger.add(
|
||||
board_dir / "logs" / "histogram_worker.log", rotation="10 MB", level="DEBUG"
|
||||
)
|
||||
|
||||
storage = HistogramStorage(board_dir / "data", num_bins=64)
|
||||
|
||||
batch_threshold = 500 # Total histograms across all names
|
||||
batch_interval = 2.0
|
||||
|
||||
logger.info("Histogram worker started")
|
||||
|
||||
try:
|
||||
while not stop_event.is_set():
|
||||
# Check queue size before draining
|
||||
try:
|
||||
queue_size = queue.qsize()
|
||||
except:
|
||||
queue_size = 0
|
||||
|
||||
# STEP 1: Drain ALL messages from queue
|
||||
messages = []
|
||||
drain_start = time.time()
|
||||
while True:
|
||||
try:
|
||||
messages.append(queue.get_nowait())
|
||||
except Empty:
|
||||
break
|
||||
drain_time = time.time() - drain_start
|
||||
|
||||
# STEP 2: Process all messages (add to storage buffers)
|
||||
if messages:
|
||||
batch_start = time.time()
|
||||
for message in messages:
|
||||
storage.append_histogram(
|
||||
message["step"],
|
||||
message.get("global_step"),
|
||||
message["name"],
|
||||
message["values"],
|
||||
message.get("num_bins", 64),
|
||||
message.get("precision", "exact"),
|
||||
)
|
||||
|
||||
# STEP 3: Flush ALL buffers once
|
||||
storage.flush()
|
||||
batch_time = time.time() - batch_start
|
||||
logger.info(
|
||||
f"Drained {len(messages)}/{queue_size} from queue in {drain_time*1000:.1f}ms, processed+flushed in {batch_time*1000:.1f}ms"
|
||||
)
|
||||
else:
|
||||
# No messages - sleep
|
||||
time.sleep(0.01)
|
||||
|
||||
# Final flush
|
||||
logger.info("Histogram worker shutting down, draining queue...")
|
||||
drained = 0
|
||||
while not queue.empty():
|
||||
try:
|
||||
message = queue.get_nowait()
|
||||
storage.append_histogram(
|
||||
message["step"],
|
||||
message.get("global_step"),
|
||||
message["name"],
|
||||
message["values"],
|
||||
message.get("num_bins", 64),
|
||||
message.get("precision", "exact"),
|
||||
)
|
||||
drained += 1
|
||||
except Empty:
|
||||
break
|
||||
|
||||
storage.flush()
|
||||
storage.close()
|
||||
logger.info(f"Histogram worker stopped (drained {drained}, flushed all)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Histogram worker error: {e}")
|
||||
raise
|
||||
@@ -18,9 +18,7 @@ from typing import Any
|
||||
from loguru import logger
|
||||
|
||||
from kohakuboard.client.media import MediaHandler
|
||||
from kohakuboard.client.storage import ParquetStorage
|
||||
from kohakuboard.client.storage_duckdb import DuckDBStorage
|
||||
from kohakuboard.client.storage_hybrid import HybridStorage
|
||||
from kohakuboard.client.storage import DuckDBStorage, HybridStorage, ParquetStorage
|
||||
|
||||
|
||||
class LogWriter:
|
||||
@@ -59,21 +57,57 @@ class LogWriter:
|
||||
self.auto_flush_interval = 5 # Auto-flush every 5 seconds (aggressive)
|
||||
|
||||
def run(self):
|
||||
"""Main loop - process messages from queue"""
|
||||
"""Main loop - adaptive batching with exponential backoff"""
|
||||
logger.info(f"LogWriter started for {self.board_dir}")
|
||||
|
||||
# Adaptive sleep parameters
|
||||
min_period = 0.01 # 10ms minimum sleep
|
||||
max_period = 1.0 # 1s maximum sleep
|
||||
consecutive_empty = 0 # Track consecutive empty queue reads
|
||||
|
||||
try:
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
# Get message from queue (shorter timeout for faster shutdown)
|
||||
message = self.queue.get(timeout=0.5)
|
||||
self._process_message(message)
|
||||
self.messages_processed += 1
|
||||
# Process ALL available messages in queue
|
||||
batch_count = 0
|
||||
batch_start = time.time()
|
||||
|
||||
except Empty:
|
||||
# No message - check if we need to auto-flush
|
||||
if time.time() - self.last_flush_time > self.auto_flush_interval:
|
||||
self._auto_flush()
|
||||
# Drain queue completely (up to 10k to allow stop_event check)
|
||||
while batch_count < 10000 and not self.stop_event.is_set():
|
||||
try:
|
||||
message = self.queue.get_nowait()
|
||||
self._process_message(message)
|
||||
self.messages_processed += 1
|
||||
batch_count += 1
|
||||
except Empty:
|
||||
break
|
||||
|
||||
# Flush immediately after processing ANY messages
|
||||
if batch_count > 0:
|
||||
self.storage.flush_all()
|
||||
batch_time = time.time() - batch_start
|
||||
logger.debug(
|
||||
f"Processed and flushed {batch_count} messages in {batch_time*1000:.1f}ms"
|
||||
)
|
||||
self.last_flush_time = time.time()
|
||||
|
||||
# Reset backoff counter - we had work to do
|
||||
consecutive_empty = 0
|
||||
else:
|
||||
# Queue empty - increase backoff
|
||||
consecutive_empty += 1
|
||||
|
||||
# Adaptive sleep: min_period * 2^k, capped at max_period
|
||||
sleep_time = min(
|
||||
min_period * (2**consecutive_empty), max_period
|
||||
)
|
||||
|
||||
# Sleep in small chunks to allow responsive shutdown
|
||||
# Instead of sleep(1.0), sleep(0.1) × 10 times and check stop_event
|
||||
slept = 0.0
|
||||
while slept < sleep_time and not self.stop_event.is_set():
|
||||
time.sleep(0.05) # Sleep in 50ms chunks
|
||||
slept += 0.05
|
||||
|
||||
except KeyboardInterrupt:
|
||||
# Received interrupt in worker - DRAIN QUEUE FIRST
|
||||
@@ -81,7 +115,6 @@ class LogWriter:
|
||||
"Writer received interrupt, draining queue before stopping..."
|
||||
)
|
||||
# Continue processing until stop_event is set by main process
|
||||
# Don't break immediately!
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message: {e}")
|
||||
@@ -187,8 +220,11 @@ class LogWriter:
|
||||
name = message["name"]
|
||||
values = message["values"]
|
||||
num_bins = message.get("num_bins", 64)
|
||||
precision = message.get("precision", "compact")
|
||||
|
||||
self.storage.append_histogram(step, global_step, name, values, num_bins)
|
||||
self.storage.append_histogram(
|
||||
step, global_step, name, values, num_bins, precision
|
||||
)
|
||||
|
||||
def _handle_flush(self):
|
||||
"""Handle explicit flush request"""
|
||||
|
||||
Reference in New Issue
Block a user