fix: migration streaming/batching (#21542)

* fix: normalize usage tokens + migration streaming/batching

- Migration: replace .fetchall() with yield_per streaming, replace per-message INSERT+SAVEPOINT with batched inserts (5k/batch) with fallback to row-by-row on error, add progress logging

- Write path: call normalize_usage() in upsert_message() before saving to ensure input_tokens/output_tokens always present

- Read path: analytics queries now COALESCE across input_tokens/prompt_tokens and output_tokens/completion_tokens so historical data with OpenAI-format keys is visible

* fix: restore defensive timestamp conversion in migration

Re-add try/except around int(float(timestamp)) that was accidentally dropped. Without this, a non-numeric timestamp string would cause a TypeError on the subsequent comparison, breaking the entire upgrade.

* revert: remove changes to chat_messages.py
This commit is contained in:
Classic298
2026-03-08 01:08:11 +01:00
committed by GitHub
parent 7b2f597b30
commit b4f340806a

View File

@@ -21,6 +21,39 @@ down_revision: Union[str, None] = "374d2f66af06"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
BATCH_SIZE = 5000
def _flush_batch(conn, table, batch):
"""
Insert a batch of messages, falling back to row-by-row on error.
Tries a single bulk insert first (fast path). If that fails (e.g. due to
a duplicate key), falls back to individual inserts wrapped in savepoints
so the rest of the batch can still succeed.
"""
savepoint = conn.begin_nested()
try:
conn.execute(sa.insert(table), batch)
savepoint.commit()
return len(batch), 0
except Exception:
savepoint.rollback()
# Batch failed - insert one-by-one to isolate the bad row(s)
inserted = 0
failed = 0
for msg in batch:
sp = conn.begin_nested()
try:
conn.execute(sa.insert(table).values(**msg))
sp.commit()
inserted += 1
except Exception as e:
sp.rollback()
failed += 1
log.warning(f"Failed to insert message {msg['id']}: {e}")
return inserted, failed
def upgrade() -> None: def upgrade() -> None:
# Step 1: Create table # Step 1: Create table
@@ -88,18 +121,21 @@ def upgrade() -> None:
sa.column("updated_at", sa.BigInteger()), sa.column("updated_at", sa.BigInteger()),
) )
# Fetch all chats (excluding shared chats which have user_id starting with 'shared-') # Stream rows instead of loading all into memory:
chats = conn.execute( # - yield_per: fetches rows in chunks via cursor.fetchmany() (all backends)
# - stream_results: enables server-side cursors on PostgreSQL (no-op on SQLite)
result = conn.execute(
sa.select(chat_table.c.id, chat_table.c.user_id, chat_table.c.chat).where( sa.select(chat_table.c.id, chat_table.c.user_id, chat_table.c.chat).where(
~chat_table.c.user_id.like("shared-%") ~chat_table.c.user_id.like("shared-%")
) ).execution_options(yield_per=1000, stream_results=True)
).fetchall() )
now = int(time.time()) now = int(time.time())
messages_inserted = 0 messages_batch = []
messages_failed = 0 total_inserted = 0
total_failed = 0
for chat_row in chats: for chat_row in result:
chat_id = chat_row[0] chat_id = chat_row[0]
user_id = chat_row[1] user_id = chat_row[1]
chat_data = chat_row[2] chat_data = chat_row[2]
@@ -139,39 +175,43 @@ def upgrade() -> None:
if timestamp < 1577836800 or timestamp > now + 86400: if timestamp < 1577836800 or timestamp > now + 86400:
timestamp = now timestamp = now
# Use savepoint to allow individual insert failures without aborting transaction messages_batch.append({
savepoint = conn.begin_nested() "id": f"{chat_id}-{message_id}",
try: "chat_id": chat_id,
conn.execute( "user_id": user_id,
sa.insert(chat_message_table).values( "role": role,
id=f"{chat_id}-{message_id}", "parent_id": message.get("parentId"),
chat_id=chat_id, "content": message.get("content"),
user_id=user_id, "output": message.get("output"),
role=role, "model_id": message.get("model"),
parent_id=message.get("parentId"), "files": message.get("files"),
content=message.get("content"), "sources": message.get("sources"),
output=message.get("output"), "embeds": message.get("embeds"),
model_id=message.get("model"), "done": message.get("done", True),
files=message.get("files"), "status_history": message.get("statusHistory"),
sources=message.get("sources"), "error": message.get("error"),
embeds=message.get("embeds"), "usage": message.get("usage"),
done=message.get("done", True), "created_at": timestamp,
status_history=message.get("statusHistory"), "updated_at": timestamp,
error=message.get("error"), })
created_at=timestamp,
updated_at=timestamp, # Flush batch when full
) if len(messages_batch) >= BATCH_SIZE:
) inserted, failed = _flush_batch(conn, chat_message_table, messages_batch)
savepoint.commit() total_inserted += inserted
messages_inserted += 1 total_failed += failed
except Exception as e: if total_inserted % 50000 < BATCH_SIZE:
savepoint.rollback() log.info(f"Migration progress: {total_inserted} messages inserted...")
messages_failed += 1 messages_batch.clear()
log.warning(f"Failed to insert message {message_id}: {e}")
continue # Flush remaining messages
if messages_batch:
inserted, failed = _flush_batch(conn, chat_message_table, messages_batch)
total_inserted += inserted
total_failed += failed
log.info( log.info(
f"Backfilled {messages_inserted} messages into chat_message table ({messages_failed} failed)" f"Backfilled {total_inserted} messages into chat_message table ({total_failed} failed)"
) )