add histogram logging and better flush strategy

This commit is contained in:
Kohaku-Blueleaf
2025-10-27 00:33:21 +08:00
parent 8af2f3a6c2
commit 13ee449b23
5 changed files with 179 additions and 15 deletions

View File

@@ -103,7 +103,9 @@ class Board:
self._global_step: Optional[int] = None
# Multiprocessing setup
self.queue = mp.Queue(maxsize=10000) # Large queue for buffering
self.queue = mp.Queue(
maxsize=50000
) # Very large queue for heavy logging (e.g., per-step histograms)
self.stop_event = mp.Event()
# Start writer process
@@ -312,6 +314,56 @@ class Board:
}
self.queue.put(message)
def log_histogram(
self, name: str, values: Union[List[float], Any], num_bins: int = 64
):
"""Log histogram data (non-blocking)
Args:
name: Name for this histogram log (supports namespace: "gradients/layer1")
values: List of values or tensor to create histogram from
num_bins: Number of bins for histogram (default: 64)
Example:
>>> # Log gradient histogram
>>> grads = [p.grad.flatten().cpu().numpy() for p in model.parameters()]
>>> board.log_histogram("gradients/all", np.concatenate(grads))
>>>
>>> # Log parameter histogram
>>> params = model.fc1.weight.detach().cpu().numpy().flatten()
>>> board.log_histogram("params/fc1_weight", params)
"""
# Increment step (auto-increment on every log call)
self._step += 1
# Check queue size and warn if getting full
try:
queue_size = self.queue.qsize()
if queue_size > 40000:
logger.warning(
f"Queue size is {queue_size}/50000. Consider reducing histogram logging frequency."
)
except NotImplementedError:
pass # qsize() not supported on all platforms
# Convert tensor to list if needed
if hasattr(values, "cpu"): # PyTorch tensor
values = values.detach().cpu().numpy().flatten().tolist()
elif hasattr(values, "numpy"): # NumPy array
values = values.flatten().tolist()
elif not isinstance(values, list):
values = list(values)
message = {
"type": "histogram",
"step": self._step,
"global_step": self._global_step,
"name": name,
"values": values,
"num_bins": num_bins,
}
self.queue.put(message)
def step(self, increment: int = 1):
"""Explicit step increment

View File

@@ -78,17 +78,31 @@ class TeeStream:
self.stream1 = stream1
self.stream2 = stream2
self.prefix = prefix
self.current_line = "" # Track current line for \r handling
def write(self, data):
"""Write data to both streams"""
# Write to terminal
# Write to terminal (with \r for tqdm)
self.stream1.write(data)
# Write to file with optional prefix
if self.prefix and data.strip():
self.stream2.write(self.prefix + data)
else:
self.stream2.write(data)
# For file: handle \r properly (tqdm progress bars)
for char in data:
if char == "\r":
# Carriage return - discard current line, start fresh
self.current_line = ""
elif char == "\n":
# Newline - write current line to file
if self.current_line:
if self.prefix:
self.stream2.write(self.prefix + self.current_line + "\n")
else:
self.stream2.write(self.current_line + "\n")
self.current_line = ""
else:
self.stream2.write("\n")
else:
# Regular character - add to current line
self.current_line += char
def flush(self):
"""Flush both streams"""

View File

@@ -95,9 +95,11 @@ class MediaHandler:
pil_image = self._to_pil(image)
# Generate filename and hash
# Replace "/" with "__" in name to avoid subdirectory issues
safe_name = name.replace("/", "__")
image_hash = self._hash_media(pil_image)
ext = "png"
filename = f"{name}_{step:08d}_{image_hash[:8]}.{ext}"
filename = f"{safe_name}_{step:08d}_{image_hash[:8]}.{ext}"
filepath = self.media_dir / filename
# Save image
@@ -155,8 +157,10 @@ class MediaHandler:
media_hash = self._hash_file(source_path)
# Preserve original extension
# Replace "/" with "__" in name to avoid subdirectory issues
safe_name = name.replace("/", "__")
ext = source_path.suffix.lstrip(".")
filename = f"{name}_{step:08d}_{media_hash[:8]}.{ext}"
filename = f"{safe_name}_{step:08d}_{media_hash[:8]}.{ext}"
dest_path = self.media_dir / filename
# Copy file if it doesn't exist (deduplication)

View File

@@ -43,12 +43,16 @@ class DuckDBStorage:
self.metrics_buffer: List[Dict[str, Any]] = []
self.media_buffer: List[Dict[str, Any]] = []
self.tables_buffer: List[Dict[str, Any]] = []
self.histograms_buffer: List[Dict[str, Any]] = []
# Track known columns for schema evolution
self.known_metric_cols = {"step", "global_step", "timestamp"}
# Flush threshold
self.flush_threshold = 10
# Flush thresholds
self.flush_threshold = 10 # Metrics
self.histogram_flush_threshold = (
100 # Histograms (batch aggressively for performance)
)
def _init_tables(self):
"""Initialize database tables"""
@@ -97,6 +101,20 @@ class DuckDBStorage:
"""
)
# Histograms table (pre-computed bins to save space)
self.conn.execute(
"""
CREATE TABLE IF NOT EXISTS histograms (
step BIGINT NOT NULL,
global_step BIGINT,
name VARCHAR NOT NULL,
num_bins INTEGER,
bins VARCHAR,
counts VARCHAR
)
"""
)
self.conn.commit()
def append_metrics(
@@ -111,14 +129,17 @@ class DuckDBStorage:
Args:
step: Auto-increment step
global_step: Explicit global step (optional)
metrics: Dict of metric name -> value
metrics: Dict of metric name -> value (can contain "/" for namespaces)
timestamp: Timestamp of log event (datetime object)
"""
# Escape metric names (replace "/" with "__" for DuckDB)
escaped_metrics = {k.replace("/", "__"): v for k, v in metrics.items()}
row = {
"step": step,
"global_step": global_step,
"timestamp": timestamp,
**metrics,
**escaped_metrics,
}
self.metrics_buffer.append(row)
@@ -139,12 +160,14 @@ class DuckDBStorage:
"""
for col in new_cols:
try:
# Escape column name (replace "/" with "__" for DuckDB compatibility)
escaped_col = col.replace("/", "__")
# Add column as DOUBLE (works for most ML metrics)
self.conn.execute(
f"ALTER TABLE metrics ADD COLUMN IF NOT EXISTS {col} DOUBLE"
f'ALTER TABLE metrics ADD COLUMN IF NOT EXISTS "{escaped_col}" DOUBLE'
)
self.known_metric_cols.add(col)
logger.debug(f"Added new metric column: {col}")
logger.debug(f"Added new metric column: {col} (as {escaped_col})")
except Exception as e:
logger.error(f"Failed to add column {col}: {e}")
@@ -271,11 +294,70 @@ class DuckDBStorage:
except Exception as e:
logger.error(f"Failed to flush tables: {e}")
def append_histogram(
self,
step: int,
global_step: Optional[int],
name: str,
values: List[float],
num_bins: int = 64,
):
"""Append histogram log entry (pre-computed bins to save space)
Args:
step: Auto-increment step
global_step: Explicit global step
name: Histogram log name
values: List of values to create histogram from
num_bins: Number of bins for histogram
"""
# Compute histogram (bins + counts) instead of storing raw values
import numpy as np
values_array = np.array(values, dtype=np.float32)
counts, bin_edges = np.histogram(values_array, bins=num_bins)
row = {
"step": step,
"global_step": global_step,
"name": name,
"num_bins": num_bins,
"bins": json.dumps(bin_edges.tolist()), # Bin edges
"counts": json.dumps(counts.tolist()), # Counts per bin
}
self.histograms_buffer.append(row)
# Batch flush when threshold reached (not immediate!)
if len(self.histograms_buffer) >= self.histogram_flush_threshold:
self.flush_histograms()
def flush_histograms(self):
"""Flush histograms buffer to DuckDB (TRUE INCREMENTAL!)"""
if not self.histograms_buffer:
return
try:
df = pd.DataFrame(self.histograms_buffer)
self.conn.append("histograms", df, by_name=True)
self.conn.commit()
logger.debug(
f"Appended {len(self.histograms_buffer)} histogram rows (INCREMENTAL)"
)
self.histograms_buffer.clear()
except KeyboardInterrupt:
logger.warning("Histograms flush interrupted")
self.histograms_buffer.clear()
except Exception as e:
logger.error(f"Failed to flush histograms: {e}")
def flush_all(self):
"""Flush all buffers"""
self.flush_metrics()
self.flush_media()
self.flush_tables()
self.flush_histograms()
logger.info("Flushed all buffers to DuckDB")
def close(self):

View File

@@ -95,6 +95,8 @@ class LogWriter:
self._handle_media(message)
elif msg_type == "table":
self._handle_table(message)
elif msg_type == "histogram":
self._handle_histogram(message)
elif msg_type == "flush":
self._handle_flush()
else:
@@ -163,6 +165,16 @@ class LogWriter:
self.storage.append_table(step, global_step, name, table_data)
def _handle_histogram(self, message: dict):
"""Handle histogram logging"""
step = message["step"]
global_step = message.get("global_step")
name = message["name"]
values = message["values"]
num_bins = message.get("num_bins", 64)
self.storage.append_histogram(step, global_step, name, values, num_bins)
def _handle_flush(self):
"""Handle explicit flush request"""
self.storage.flush_all()