fix: migration streaming/batching (#21542)

* fix: normalize usage tokens + migration streaming/batching

- Migration: replace .fetchall() with yield_per streaming, replace per-message INSERT+SAVEPOINT with batched inserts (5k/batch) with fallback to row-by-row on error, add progress logging

- Write path: call normalize_usage() in upsert_message() before saving to ensure input_tokens/output_tokens always present

- Read path: analytics queries now COALESCE across input_tokens/prompt_tokens and output_tokens/completion_tokens so historical data with OpenAI-format keys is visible

* fix: restore defensive timestamp conversion in migration

Re-add try/except around int(float(timestamp)) that was accidentally dropped. Without this, a non-numeric timestamp string would cause a TypeError on the subsequent comparison, breaking the entire upgrade.

* revert: remove changes to chat_messages.py
This commit is contained in:
Classic298
2026-03-08 01:08:11 +01:00
committed by GitHub
parent 7b2f597b30
commit b4f340806a

View File

@@ -21,6 +21,39 @@ down_revision: Union[str, None] = "374d2f66af06"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
BATCH_SIZE = 5000
def _flush_batch(conn, table, batch):
"""
Insert a batch of messages, falling back to row-by-row on error.
Tries a single bulk insert first (fast path). If that fails (e.g. due to
a duplicate key), falls back to individual inserts wrapped in savepoints
so the rest of the batch can still succeed.
"""
savepoint = conn.begin_nested()
try:
conn.execute(sa.insert(table), batch)
savepoint.commit()
return len(batch), 0
except Exception:
savepoint.rollback()
# Batch failed - insert one-by-one to isolate the bad row(s)
inserted = 0
failed = 0
for msg in batch:
sp = conn.begin_nested()
try:
conn.execute(sa.insert(table).values(**msg))
sp.commit()
inserted += 1
except Exception as e:
sp.rollback()
failed += 1
log.warning(f"Failed to insert message {msg['id']}: {e}")
return inserted, failed
def upgrade() -> None:
# Step 1: Create table
@@ -88,18 +121,21 @@ def upgrade() -> None:
sa.column("updated_at", sa.BigInteger()),
)
# Fetch all chats (excluding shared chats which have user_id starting with 'shared-')
chats = conn.execute(
# Stream rows instead of loading all into memory:
# - yield_per: fetches rows in chunks via cursor.fetchmany() (all backends)
# - stream_results: enables server-side cursors on PostgreSQL (no-op on SQLite)
result = conn.execute(
sa.select(chat_table.c.id, chat_table.c.user_id, chat_table.c.chat).where(
~chat_table.c.user_id.like("shared-%")
)
).fetchall()
).execution_options(yield_per=1000, stream_results=True)
)
now = int(time.time())
messages_inserted = 0
messages_failed = 0
messages_batch = []
total_inserted = 0
total_failed = 0
for chat_row in chats:
for chat_row in result:
chat_id = chat_row[0]
user_id = chat_row[1]
chat_data = chat_row[2]
@@ -139,39 +175,43 @@ def upgrade() -> None:
if timestamp < 1577836800 or timestamp > now + 86400:
timestamp = now
# Use savepoint to allow individual insert failures without aborting transaction
savepoint = conn.begin_nested()
try:
conn.execute(
sa.insert(chat_message_table).values(
id=f"{chat_id}-{message_id}",
chat_id=chat_id,
user_id=user_id,
role=role,
parent_id=message.get("parentId"),
content=message.get("content"),
output=message.get("output"),
model_id=message.get("model"),
files=message.get("files"),
sources=message.get("sources"),
embeds=message.get("embeds"),
done=message.get("done", True),
status_history=message.get("statusHistory"),
error=message.get("error"),
created_at=timestamp,
updated_at=timestamp,
)
)
savepoint.commit()
messages_inserted += 1
except Exception as e:
savepoint.rollback()
messages_failed += 1
log.warning(f"Failed to insert message {message_id}: {e}")
continue
messages_batch.append({
"id": f"{chat_id}-{message_id}",
"chat_id": chat_id,
"user_id": user_id,
"role": role,
"parent_id": message.get("parentId"),
"content": message.get("content"),
"output": message.get("output"),
"model_id": message.get("model"),
"files": message.get("files"),
"sources": message.get("sources"),
"embeds": message.get("embeds"),
"done": message.get("done", True),
"status_history": message.get("statusHistory"),
"error": message.get("error"),
"usage": message.get("usage"),
"created_at": timestamp,
"updated_at": timestamp,
})
# Flush batch when full
if len(messages_batch) >= BATCH_SIZE:
inserted, failed = _flush_batch(conn, chat_message_table, messages_batch)
total_inserted += inserted
total_failed += failed
if total_inserted % 50000 < BATCH_SIZE:
log.info(f"Migration progress: {total_inserted} messages inserted...")
messages_batch.clear()
# Flush remaining messages
if messages_batch:
inserted, failed = _flush_batch(conn, chat_message_table, messages_batch)
total_inserted += inserted
total_failed += failed
log.info(
f"Backfilled {messages_inserted} messages into chat_message table ({messages_failed} failed)"
f"Backfilled {total_inserted} messages into chat_message table ({total_failed} failed)"
)