mirror of
https://github.com/KohakuBlueleaf/KohakuHub.git
synced 2026-03-11 17:34:08 -05:00
add histogram logging and better flush strategy
This commit is contained in:
@@ -103,7 +103,9 @@ class Board:
|
||||
self._global_step: Optional[int] = None
|
||||
|
||||
# Multiprocessing setup
|
||||
self.queue = mp.Queue(maxsize=10000) # Large queue for buffering
|
||||
self.queue = mp.Queue(
|
||||
maxsize=50000
|
||||
) # Very large queue for heavy logging (e.g., per-step histograms)
|
||||
self.stop_event = mp.Event()
|
||||
|
||||
# Start writer process
|
||||
@@ -312,6 +314,56 @@ class Board:
|
||||
}
|
||||
self.queue.put(message)
|
||||
|
||||
def log_histogram(
|
||||
self, name: str, values: Union[List[float], Any], num_bins: int = 64
|
||||
):
|
||||
"""Log histogram data (non-blocking)
|
||||
|
||||
Args:
|
||||
name: Name for this histogram log (supports namespace: "gradients/layer1")
|
||||
values: List of values or tensor to create histogram from
|
||||
num_bins: Number of bins for histogram (default: 64)
|
||||
|
||||
Example:
|
||||
>>> # Log gradient histogram
|
||||
>>> grads = [p.grad.flatten().cpu().numpy() for p in model.parameters()]
|
||||
>>> board.log_histogram("gradients/all", np.concatenate(grads))
|
||||
>>>
|
||||
>>> # Log parameter histogram
|
||||
>>> params = model.fc1.weight.detach().cpu().numpy().flatten()
|
||||
>>> board.log_histogram("params/fc1_weight", params)
|
||||
"""
|
||||
# Increment step (auto-increment on every log call)
|
||||
self._step += 1
|
||||
|
||||
# Check queue size and warn if getting full
|
||||
try:
|
||||
queue_size = self.queue.qsize()
|
||||
if queue_size > 40000:
|
||||
logger.warning(
|
||||
f"Queue size is {queue_size}/50000. Consider reducing histogram logging frequency."
|
||||
)
|
||||
except NotImplementedError:
|
||||
pass # qsize() not supported on all platforms
|
||||
|
||||
# Convert tensor to list if needed
|
||||
if hasattr(values, "cpu"): # PyTorch tensor
|
||||
values = values.detach().cpu().numpy().flatten().tolist()
|
||||
elif hasattr(values, "numpy"): # NumPy array
|
||||
values = values.flatten().tolist()
|
||||
elif not isinstance(values, list):
|
||||
values = list(values)
|
||||
|
||||
message = {
|
||||
"type": "histogram",
|
||||
"step": self._step,
|
||||
"global_step": self._global_step,
|
||||
"name": name,
|
||||
"values": values,
|
||||
"num_bins": num_bins,
|
||||
}
|
||||
self.queue.put(message)
|
||||
|
||||
def step(self, increment: int = 1):
|
||||
"""Explicit step increment
|
||||
|
||||
|
||||
@@ -78,17 +78,31 @@ class TeeStream:
|
||||
self.stream1 = stream1
|
||||
self.stream2 = stream2
|
||||
self.prefix = prefix
|
||||
self.current_line = "" # Track current line for \r handling
|
||||
|
||||
def write(self, data):
|
||||
"""Write data to both streams"""
|
||||
# Write to terminal
|
||||
# Write to terminal (with \r for tqdm)
|
||||
self.stream1.write(data)
|
||||
|
||||
# Write to file with optional prefix
|
||||
if self.prefix and data.strip():
|
||||
self.stream2.write(self.prefix + data)
|
||||
else:
|
||||
self.stream2.write(data)
|
||||
# For file: handle \r properly (tqdm progress bars)
|
||||
for char in data:
|
||||
if char == "\r":
|
||||
# Carriage return - discard current line, start fresh
|
||||
self.current_line = ""
|
||||
elif char == "\n":
|
||||
# Newline - write current line to file
|
||||
if self.current_line:
|
||||
if self.prefix:
|
||||
self.stream2.write(self.prefix + self.current_line + "\n")
|
||||
else:
|
||||
self.stream2.write(self.current_line + "\n")
|
||||
self.current_line = ""
|
||||
else:
|
||||
self.stream2.write("\n")
|
||||
else:
|
||||
# Regular character - add to current line
|
||||
self.current_line += char
|
||||
|
||||
def flush(self):
|
||||
"""Flush both streams"""
|
||||
|
||||
@@ -95,9 +95,11 @@ class MediaHandler:
|
||||
pil_image = self._to_pil(image)
|
||||
|
||||
# Generate filename and hash
|
||||
# Replace "/" with "__" in name to avoid subdirectory issues
|
||||
safe_name = name.replace("/", "__")
|
||||
image_hash = self._hash_media(pil_image)
|
||||
ext = "png"
|
||||
filename = f"{name}_{step:08d}_{image_hash[:8]}.{ext}"
|
||||
filename = f"{safe_name}_{step:08d}_{image_hash[:8]}.{ext}"
|
||||
filepath = self.media_dir / filename
|
||||
|
||||
# Save image
|
||||
@@ -155,8 +157,10 @@ class MediaHandler:
|
||||
media_hash = self._hash_file(source_path)
|
||||
|
||||
# Preserve original extension
|
||||
# Replace "/" with "__" in name to avoid subdirectory issues
|
||||
safe_name = name.replace("/", "__")
|
||||
ext = source_path.suffix.lstrip(".")
|
||||
filename = f"{name}_{step:08d}_{media_hash[:8]}.{ext}"
|
||||
filename = f"{safe_name}_{step:08d}_{media_hash[:8]}.{ext}"
|
||||
dest_path = self.media_dir / filename
|
||||
|
||||
# Copy file if it doesn't exist (deduplication)
|
||||
|
||||
@@ -43,12 +43,16 @@ class DuckDBStorage:
|
||||
self.metrics_buffer: List[Dict[str, Any]] = []
|
||||
self.media_buffer: List[Dict[str, Any]] = []
|
||||
self.tables_buffer: List[Dict[str, Any]] = []
|
||||
self.histograms_buffer: List[Dict[str, Any]] = []
|
||||
|
||||
# Track known columns for schema evolution
|
||||
self.known_metric_cols = {"step", "global_step", "timestamp"}
|
||||
|
||||
# Flush threshold
|
||||
self.flush_threshold = 10
|
||||
# Flush thresholds
|
||||
self.flush_threshold = 10 # Metrics
|
||||
self.histogram_flush_threshold = (
|
||||
100 # Histograms (batch aggressively for performance)
|
||||
)
|
||||
|
||||
def _init_tables(self):
|
||||
"""Initialize database tables"""
|
||||
@@ -97,6 +101,20 @@ class DuckDBStorage:
|
||||
"""
|
||||
)
|
||||
|
||||
# Histograms table (pre-computed bins to save space)
|
||||
self.conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS histograms (
|
||||
step BIGINT NOT NULL,
|
||||
global_step BIGINT,
|
||||
name VARCHAR NOT NULL,
|
||||
num_bins INTEGER,
|
||||
bins VARCHAR,
|
||||
counts VARCHAR
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def append_metrics(
|
||||
@@ -111,14 +129,17 @@ class DuckDBStorage:
|
||||
Args:
|
||||
step: Auto-increment step
|
||||
global_step: Explicit global step (optional)
|
||||
metrics: Dict of metric name -> value
|
||||
metrics: Dict of metric name -> value (can contain "/" for namespaces)
|
||||
timestamp: Timestamp of log event (datetime object)
|
||||
"""
|
||||
# Escape metric names (replace "/" with "__" for DuckDB)
|
||||
escaped_metrics = {k.replace("/", "__"): v for k, v in metrics.items()}
|
||||
|
||||
row = {
|
||||
"step": step,
|
||||
"global_step": global_step,
|
||||
"timestamp": timestamp,
|
||||
**metrics,
|
||||
**escaped_metrics,
|
||||
}
|
||||
self.metrics_buffer.append(row)
|
||||
|
||||
@@ -139,12 +160,14 @@ class DuckDBStorage:
|
||||
"""
|
||||
for col in new_cols:
|
||||
try:
|
||||
# Escape column name (replace "/" with "__" for DuckDB compatibility)
|
||||
escaped_col = col.replace("/", "__")
|
||||
# Add column as DOUBLE (works for most ML metrics)
|
||||
self.conn.execute(
|
||||
f"ALTER TABLE metrics ADD COLUMN IF NOT EXISTS {col} DOUBLE"
|
||||
f'ALTER TABLE metrics ADD COLUMN IF NOT EXISTS "{escaped_col}" DOUBLE'
|
||||
)
|
||||
self.known_metric_cols.add(col)
|
||||
logger.debug(f"Added new metric column: {col}")
|
||||
logger.debug(f"Added new metric column: {col} (as {escaped_col})")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add column {col}: {e}")
|
||||
|
||||
@@ -271,11 +294,70 @@ class DuckDBStorage:
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to flush tables: {e}")
|
||||
|
||||
def append_histogram(
|
||||
self,
|
||||
step: int,
|
||||
global_step: Optional[int],
|
||||
name: str,
|
||||
values: List[float],
|
||||
num_bins: int = 64,
|
||||
):
|
||||
"""Append histogram log entry (pre-computed bins to save space)
|
||||
|
||||
Args:
|
||||
step: Auto-increment step
|
||||
global_step: Explicit global step
|
||||
name: Histogram log name
|
||||
values: List of values to create histogram from
|
||||
num_bins: Number of bins for histogram
|
||||
"""
|
||||
# Compute histogram (bins + counts) instead of storing raw values
|
||||
import numpy as np
|
||||
|
||||
values_array = np.array(values, dtype=np.float32)
|
||||
counts, bin_edges = np.histogram(values_array, bins=num_bins)
|
||||
|
||||
row = {
|
||||
"step": step,
|
||||
"global_step": global_step,
|
||||
"name": name,
|
||||
"num_bins": num_bins,
|
||||
"bins": json.dumps(bin_edges.tolist()), # Bin edges
|
||||
"counts": json.dumps(counts.tolist()), # Counts per bin
|
||||
}
|
||||
self.histograms_buffer.append(row)
|
||||
|
||||
# Batch flush when threshold reached (not immediate!)
|
||||
if len(self.histograms_buffer) >= self.histogram_flush_threshold:
|
||||
self.flush_histograms()
|
||||
|
||||
def flush_histograms(self):
|
||||
"""Flush histograms buffer to DuckDB (TRUE INCREMENTAL!)"""
|
||||
if not self.histograms_buffer:
|
||||
return
|
||||
|
||||
try:
|
||||
df = pd.DataFrame(self.histograms_buffer)
|
||||
self.conn.append("histograms", df, by_name=True)
|
||||
self.conn.commit()
|
||||
|
||||
logger.debug(
|
||||
f"Appended {len(self.histograms_buffer)} histogram rows (INCREMENTAL)"
|
||||
)
|
||||
self.histograms_buffer.clear()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.warning("Histograms flush interrupted")
|
||||
self.histograms_buffer.clear()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to flush histograms: {e}")
|
||||
|
||||
def flush_all(self):
|
||||
"""Flush all buffers"""
|
||||
self.flush_metrics()
|
||||
self.flush_media()
|
||||
self.flush_tables()
|
||||
self.flush_histograms()
|
||||
logger.info("Flushed all buffers to DuckDB")
|
||||
|
||||
def close(self):
|
||||
|
||||
@@ -95,6 +95,8 @@ class LogWriter:
|
||||
self._handle_media(message)
|
||||
elif msg_type == "table":
|
||||
self._handle_table(message)
|
||||
elif msg_type == "histogram":
|
||||
self._handle_histogram(message)
|
||||
elif msg_type == "flush":
|
||||
self._handle_flush()
|
||||
else:
|
||||
@@ -163,6 +165,16 @@ class LogWriter:
|
||||
|
||||
self.storage.append_table(step, global_step, name, table_data)
|
||||
|
||||
def _handle_histogram(self, message: dict):
|
||||
"""Handle histogram logging"""
|
||||
step = message["step"]
|
||||
global_step = message.get("global_step")
|
||||
name = message["name"]
|
||||
values = message["values"]
|
||||
num_bins = message.get("num_bins", 64)
|
||||
|
||||
self.storage.append_histogram(step, global_step, name, values, num_bins)
|
||||
|
||||
def _handle_flush(self):
|
||||
"""Handle explicit flush request"""
|
||||
self.storage.flush_all()
|
||||
|
||||
Reference in New Issue
Block a user