mirror of
https://github.com/KohakuBlueleaf/KohakuHub.git
synced 2026-03-09 07:12:07 -05:00
Update migration script impl
This commit is contained in:
@@ -78,11 +78,19 @@ def check_migration_needed():
|
||||
|
||||
|
||||
def run():
|
||||
"""Run this migration."""
|
||||
"""Run this migration.
|
||||
|
||||
IMPORTANT: Do NOT call db.close() in finally block!
|
||||
The db connection is managed by run_migrations.py and should stay open
|
||||
across all migrations to avoid stdout/stderr closure issues on Windows.
|
||||
|
||||
NOTE: Migration 001 delegates to external script (migrate_repository_schema.py).
|
||||
Cannot easily add db.atomic() without modifying external script.
|
||||
"""
|
||||
db.connect(reuse_if_open=True)
|
||||
|
||||
try:
|
||||
# Check if any future migration has been applied
|
||||
# Pre-flight checks (outside transaction for performance)
|
||||
if should_skip_due_to_future_migrations(MIGRATION_NUMBER, db, cfg):
|
||||
print("Migration 001: Skipped (superseded by future migration)")
|
||||
return True
|
||||
@@ -93,19 +101,23 @@ def run():
|
||||
|
||||
print("Migration 001: Removing unique constraint from Repository.full_id...")
|
||||
|
||||
# Just import and run the existing migration logic
|
||||
# Import and run the existing migration logic (external script)
|
||||
from pathlib import Path
|
||||
import importlib.util
|
||||
|
||||
parent_dir = Path(__file__).parent.parent
|
||||
spec_path = parent_dir / "migrate_repository_schema.py"
|
||||
|
||||
import importlib.util
|
||||
if not spec_path.exists():
|
||||
print(f" WARNING: External script not found: {spec_path}")
|
||||
print(" Skipping migration 001 (constraint likely already removed or N/A)")
|
||||
return True
|
||||
|
||||
spec = importlib.util.spec_from_file_location("migrate_repo_schema", spec_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# Run the migration
|
||||
# Run the migration (external script handles its own transactions)
|
||||
if cfg.app.db_backend == "postgres":
|
||||
module.migrate_postgres()
|
||||
else:
|
||||
@@ -116,9 +128,12 @@ def run():
|
||||
|
||||
except Exception as e:
|
||||
print(f"Migration 001: ✗ Failed - {e}")
|
||||
print(" WARNING: External script migration - may need manual rollback")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
# NOTE: No finally block - db connection stays open
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -52,7 +52,11 @@ def check_migration_needed():
|
||||
|
||||
|
||||
def migrate_sqlite():
|
||||
"""Migrate SQLite database."""
|
||||
"""Migrate SQLite database.
|
||||
|
||||
Note: This function runs inside a transaction (db.atomic()).
|
||||
Do NOT call db.commit() or db.rollback() inside this function.
|
||||
"""
|
||||
cursor = db.cursor()
|
||||
|
||||
# User table
|
||||
@@ -111,11 +115,13 @@ def migrate_sqlite():
|
||||
else:
|
||||
raise
|
||||
|
||||
db.commit()
|
||||
|
||||
|
||||
def migrate_postgres():
|
||||
"""Migrate PostgreSQL database."""
|
||||
"""Migrate PostgreSQL database.
|
||||
|
||||
Note: This function runs inside a transaction (db.atomic()).
|
||||
Do NOT call db.commit() or db.rollback() inside this function.
|
||||
"""
|
||||
cursor = db.cursor()
|
||||
|
||||
# User table
|
||||
@@ -143,7 +149,6 @@ def migrate_postgres():
|
||||
except Exception as e:
|
||||
if "already exists" in str(e).lower():
|
||||
print(f" - User.{column} already exists")
|
||||
db.rollback()
|
||||
else:
|
||||
raise
|
||||
|
||||
@@ -172,19 +177,21 @@ def migrate_postgres():
|
||||
except Exception as e:
|
||||
if "already exists" in str(e).lower():
|
||||
print(f" - Organization.{column} already exists")
|
||||
db.rollback()
|
||||
else:
|
||||
raise
|
||||
|
||||
db.commit()
|
||||
|
||||
|
||||
def run():
|
||||
"""Run this migration."""
|
||||
"""Run this migration.
|
||||
|
||||
IMPORTANT: Do NOT call db.close() in finally block!
|
||||
The db connection is managed by run_migrations.py and should stay open
|
||||
across all migrations to avoid stdout/stderr closure issues on Windows.
|
||||
"""
|
||||
db.connect(reuse_if_open=True)
|
||||
|
||||
try:
|
||||
# Check if any future migration has been applied
|
||||
# Pre-flight checks (outside transaction for performance)
|
||||
if should_skip_due_to_future_migrations(MIGRATION_NUMBER, db, cfg):
|
||||
print("Migration 002: Skipped (superseded by future migration)")
|
||||
return True
|
||||
@@ -195,22 +202,25 @@ def run():
|
||||
|
||||
print("Migration 002: Adding User/Organization quota fields...")
|
||||
|
||||
if cfg.app.db_backend == "postgres":
|
||||
migrate_postgres()
|
||||
else:
|
||||
migrate_sqlite()
|
||||
# Run migration in a transaction - will auto-rollback on exception
|
||||
with db.atomic():
|
||||
if cfg.app.db_backend == "postgres":
|
||||
migrate_postgres()
|
||||
else:
|
||||
migrate_sqlite()
|
||||
|
||||
print("Migration 002: ✓ Completed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
# Transaction automatically rolled back if we reach here
|
||||
print(f"Migration 002: ✗ Failed - {e}")
|
||||
print(" All changes have been rolled back")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
# NOTE: No finally block - db connection stays open
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -31,11 +31,16 @@ def check_migration_needed():
|
||||
|
||||
|
||||
def run():
|
||||
"""Run this migration."""
|
||||
"""Run this migration.
|
||||
|
||||
IMPORTANT: Do NOT call db.close() in finally block!
|
||||
The db connection is managed by run_migrations.py and should stay open
|
||||
across all migrations to avoid stdout/stderr closure issues on Windows.
|
||||
"""
|
||||
db.connect(reuse_if_open=True)
|
||||
|
||||
try:
|
||||
# Check if any future migration has been applied
|
||||
# Pre-flight checks (outside transaction for performance)
|
||||
if should_skip_due_to_future_migrations(MIGRATION_NUMBER, db, cfg):
|
||||
print("Migration 003: Skipped (superseded by future migration)")
|
||||
return True
|
||||
@@ -46,66 +51,68 @@ def run():
|
||||
|
||||
print("Migration 003: Creating Commit table...")
|
||||
|
||||
cursor = db.cursor()
|
||||
if cfg.app.db_backend == "postgres":
|
||||
# PostgreSQL: Create Commit table
|
||||
cursor.execute(
|
||||
# Run migration in a transaction - will auto-rollback on exception
|
||||
with db.atomic():
|
||||
cursor = db.cursor()
|
||||
if cfg.app.db_backend == "postgres":
|
||||
# PostgreSQL: Create Commit table
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS commit (
|
||||
id SERIAL PRIMARY KEY,
|
||||
commit_id VARCHAR(255) NOT NULL,
|
||||
repo_full_id VARCHAR(255) NOT NULL,
|
||||
author_id INTEGER NOT NULL,
|
||||
message TEXT,
|
||||
created_at TIMESTAMP NOT NULL
|
||||
)
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS commit (
|
||||
id SERIAL PRIMARY KEY,
|
||||
commit_id VARCHAR(255) NOT NULL,
|
||||
repo_full_id VARCHAR(255) NOT NULL,
|
||||
author_id INTEGER NOT NULL,
|
||||
message TEXT,
|
||||
created_at TIMESTAMP NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS commit_commit_id ON commit(commit_id)"
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS commit_repo_full_id ON commit(repo_full_id)"
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS commit_author_id ON commit(author_id)"
|
||||
)
|
||||
else:
|
||||
# SQLite: Create Commit table
|
||||
cursor.execute(
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS commit_commit_id ON commit(commit_id)"
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS commit_repo_full_id ON commit(repo_full_id)"
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS commit_author_id ON commit(author_id)"
|
||||
)
|
||||
else:
|
||||
# SQLite: Create Commit table
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS commit (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
commit_id VARCHAR(255) NOT NULL,
|
||||
repo_full_id VARCHAR(255) NOT NULL,
|
||||
author_id INTEGER NOT NULL,
|
||||
message TEXT,
|
||||
created_at DATETIME NOT NULL
|
||||
)
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS commit (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
commit_id VARCHAR(255) NOT NULL,
|
||||
repo_full_id VARCHAR(255) NOT NULL,
|
||||
author_id INTEGER NOT NULL,
|
||||
message TEXT,
|
||||
created_at DATETIME NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS commit_commit_id ON commit(commit_id)"
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS commit_repo_full_id ON commit(repo_full_id)"
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS commit_author_id ON commit(author_id)"
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS commit_commit_id ON commit(commit_id)"
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS commit_repo_full_id ON commit(repo_full_id)"
|
||||
)
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS commit_author_id ON commit(author_id)"
|
||||
)
|
||||
|
||||
db.commit()
|
||||
print("Migration 003: ✓ Completed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
# Transaction automatically rolled back if we reach here
|
||||
print(f"Migration 003: ✗ Failed - {e}")
|
||||
print(" All changes have been rolled back")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
# NOTE: No finally block - db connection stays open
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -51,7 +51,11 @@ def check_migration_needed():
|
||||
|
||||
|
||||
def migrate_sqlite():
|
||||
"""Migrate SQLite database."""
|
||||
"""Migrate SQLite database.
|
||||
|
||||
Note: This function runs inside a transaction (db.atomic()).
|
||||
Do NOT call db.commit() or db.rollback() inside this function.
|
||||
"""
|
||||
cursor = db.cursor()
|
||||
|
||||
for column, sql in [
|
||||
@@ -73,11 +77,13 @@ def migrate_sqlite():
|
||||
else:
|
||||
raise
|
||||
|
||||
db.commit()
|
||||
|
||||
|
||||
def migrate_postgres():
|
||||
"""Migrate PostgreSQL database."""
|
||||
"""Migrate PostgreSQL database.
|
||||
|
||||
Note: This function runs inside a transaction (db.atomic()).
|
||||
Do NOT call db.commit() or db.rollback() inside this function.
|
||||
"""
|
||||
cursor = db.cursor()
|
||||
|
||||
for column, sql in [
|
||||
@@ -93,19 +99,21 @@ def migrate_postgres():
|
||||
except Exception as e:
|
||||
if "already exists" in str(e).lower():
|
||||
print(f" - Repository.{column} already exists")
|
||||
db.rollback()
|
||||
else:
|
||||
raise
|
||||
|
||||
db.commit()
|
||||
|
||||
|
||||
def run():
|
||||
"""Run this migration."""
|
||||
"""Run this migration.
|
||||
|
||||
IMPORTANT: Do NOT call db.close() in finally block!
|
||||
The db connection is managed by run_migrations.py and should stay open
|
||||
across all migrations to avoid stdout/stderr closure issues on Windows.
|
||||
"""
|
||||
db.connect(reuse_if_open=True)
|
||||
|
||||
try:
|
||||
# Check if any future migration has been applied
|
||||
# Pre-flight checks (outside transaction for performance)
|
||||
if should_skip_due_to_future_migrations(MIGRATION_NUMBER, db, cfg):
|
||||
print("Migration 004: Skipped (superseded by future migration)")
|
||||
return True
|
||||
@@ -116,22 +124,25 @@ def run():
|
||||
|
||||
print("Migration 004: Adding Repository quota fields...")
|
||||
|
||||
if cfg.app.db_backend == "postgres":
|
||||
migrate_postgres()
|
||||
else:
|
||||
migrate_sqlite()
|
||||
# Run migration in a transaction - will auto-rollback on exception
|
||||
with db.atomic():
|
||||
if cfg.app.db_backend == "postgres":
|
||||
migrate_postgres()
|
||||
else:
|
||||
migrate_sqlite()
|
||||
|
||||
print("Migration 004: ✓ Completed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
# Transaction automatically rolled back if we reach here
|
||||
print(f"Migration 004: ✗ Failed - {e}")
|
||||
print(" All changes have been rolled back")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
# NOTE: No finally block - db connection stays open
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -11,12 +11,6 @@ Adds the following:
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Fix Windows encoding issues
|
||||
if sys.platform == "win32":
|
||||
import io
|
||||
|
||||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")
|
||||
|
||||
# Add src to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "src"))
|
||||
# Add db_migrations to path (for _migration_utils)
|
||||
@@ -59,7 +53,11 @@ def check_migration_needed():
|
||||
|
||||
|
||||
def migrate_sqlite():
|
||||
"""Migrate SQLite database."""
|
||||
"""Migrate SQLite database.
|
||||
|
||||
Note: This function runs inside a transaction (db.atomic()).
|
||||
Do NOT call db.commit() or db.rollback() inside this function.
|
||||
"""
|
||||
cursor = db.cursor()
|
||||
|
||||
# User profile fields
|
||||
@@ -144,11 +142,13 @@ def migrate_sqlite():
|
||||
else:
|
||||
raise
|
||||
|
||||
db.commit()
|
||||
|
||||
|
||||
def migrate_postgres():
|
||||
"""Migrate PostgreSQL database."""
|
||||
"""Migrate PostgreSQL database.
|
||||
|
||||
Note: This function runs inside a transaction (db.atomic()).
|
||||
Do NOT call db.commit() or db.rollback() inside this function.
|
||||
"""
|
||||
cursor = db.cursor()
|
||||
|
||||
# User profile fields
|
||||
@@ -170,7 +170,6 @@ def migrate_postgres():
|
||||
except Exception as e:
|
||||
if "already exists" in str(e).lower():
|
||||
print(f" - User.{column} already exists")
|
||||
db.rollback()
|
||||
else:
|
||||
raise
|
||||
|
||||
@@ -192,7 +191,6 @@ def migrate_postgres():
|
||||
except Exception as e:
|
||||
if "already exists" in str(e).lower():
|
||||
print(f" - Organization.{column} already exists")
|
||||
db.rollback()
|
||||
else:
|
||||
raise
|
||||
|
||||
@@ -239,15 +237,18 @@ def migrate_postgres():
|
||||
else:
|
||||
raise
|
||||
|
||||
db.commit()
|
||||
|
||||
|
||||
def run():
|
||||
"""Run this migration."""
|
||||
"""Run this migration.
|
||||
|
||||
IMPORTANT: Do NOT call db.close() in finally block!
|
||||
The db connection is managed by run_migrations.py and should stay open
|
||||
across all migrations to avoid stdout/stderr closure issues on Windows.
|
||||
"""
|
||||
db.connect(reuse_if_open=True)
|
||||
|
||||
try:
|
||||
# Check if any future migration has been applied
|
||||
# Pre-flight checks (outside transaction for performance)
|
||||
if should_skip_due_to_future_migrations(MIGRATION_NUMBER, db, cfg):
|
||||
print("Migration 005: Skipped (superseded by future migration)")
|
||||
return True
|
||||
@@ -258,22 +259,25 @@ def run():
|
||||
|
||||
print("Migration 005: Adding profile fields and invitation system...")
|
||||
|
||||
if cfg.app.db_backend == "postgres":
|
||||
migrate_postgres()
|
||||
else:
|
||||
migrate_sqlite()
|
||||
# Run migration in a transaction - will auto-rollback on exception
|
||||
with db.atomic():
|
||||
if cfg.app.db_backend == "postgres":
|
||||
migrate_postgres()
|
||||
else:
|
||||
migrate_sqlite()
|
||||
|
||||
print("Migration 005: ✓ Completed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
# Transaction automatically rolled back if we reach here
|
||||
print(f"Migration 005: ✗ Failed - {e}")
|
||||
print(" All changes have been rolled back")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
# NOTE: No finally block - db connection stays open
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -10,12 +10,6 @@ Adds the following fields to existing Invitation table:
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Fix Windows encoding issues
|
||||
if sys.platform == "win32":
|
||||
import io
|
||||
|
||||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")
|
||||
|
||||
# Add src to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "src"))
|
||||
# Add db_migrations to path (for _migration_utils)
|
||||
@@ -86,7 +80,11 @@ def check_migration_needed():
|
||||
|
||||
|
||||
def migrate_sqlite():
|
||||
"""Migrate SQLite database."""
|
||||
"""Migrate SQLite database.
|
||||
|
||||
Note: This function runs inside a transaction (db.atomic()).
|
||||
Do NOT call db.commit() or db.rollback() inside this function.
|
||||
"""
|
||||
cursor = db.cursor()
|
||||
|
||||
# Invitation multi-use fields
|
||||
@@ -124,11 +122,13 @@ def migrate_sqlite():
|
||||
except Exception as e:
|
||||
print(f" - Warning: Could not migrate existing data: {e}")
|
||||
|
||||
db.commit()
|
||||
|
||||
|
||||
def migrate_postgres():
|
||||
"""Migrate PostgreSQL database."""
|
||||
"""Migrate PostgreSQL database.
|
||||
|
||||
Note: This function runs inside a transaction (db.atomic()).
|
||||
Do NOT call db.commit() or db.rollback() inside this function.
|
||||
"""
|
||||
cursor = db.cursor()
|
||||
|
||||
# Invitation multi-use fields
|
||||
@@ -148,7 +148,6 @@ def migrate_postgres():
|
||||
except Exception as e:
|
||||
if "already exists" in str(e).lower():
|
||||
print(f" - Invitation.{column} already exists")
|
||||
db.rollback()
|
||||
else:
|
||||
raise
|
||||
|
||||
@@ -166,17 +165,19 @@ def migrate_postgres():
|
||||
print(f" ✓ Migrated {updated} existing invitation(s) usage data")
|
||||
except Exception as e:
|
||||
print(f" - Warning: Could not migrate existing data: {e}")
|
||||
db.rollback()
|
||||
|
||||
db.commit()
|
||||
|
||||
|
||||
def run():
|
||||
"""Run this migration."""
|
||||
"""Run this migration.
|
||||
|
||||
IMPORTANT: Do NOT call db.close() in finally block!
|
||||
The db connection is managed by run_migrations.py and should stay open
|
||||
across all migrations to avoid stdout/stderr closure issues on Windows.
|
||||
"""
|
||||
db.connect(reuse_if_open=True)
|
||||
|
||||
try:
|
||||
# Check if any future migration has been applied
|
||||
# Pre-flight checks (outside transaction for performance)
|
||||
if should_skip_due_to_future_migrations(MIGRATION_NUMBER, db, cfg):
|
||||
print("Migration 006: Skipped (superseded by future migration)")
|
||||
return True
|
||||
@@ -206,22 +207,25 @@ def run():
|
||||
|
||||
print("Migration 006: Adding multi-use support to Invitation table...")
|
||||
|
||||
if cfg.app.db_backend == "postgres":
|
||||
migrate_postgres()
|
||||
else:
|
||||
migrate_sqlite()
|
||||
# Run migration in a transaction - will auto-rollback on exception
|
||||
with db.atomic():
|
||||
if cfg.app.db_backend == "postgres":
|
||||
migrate_postgres()
|
||||
else:
|
||||
migrate_sqlite()
|
||||
|
||||
print("Migration 006: ✓ Completed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
# Transaction automatically rolled back if we reach here
|
||||
print(f"Migration 006: ✗ Failed - {e}")
|
||||
print(" All changes have been rolled back")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
# NOTE: No finally block - db connection stays open
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -10,12 +10,6 @@ Adds the following fields:
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Fix Windows encoding issues
|
||||
if sys.platform == "win32":
|
||||
import io
|
||||
|
||||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")
|
||||
|
||||
# Add src to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "src"))
|
||||
# Add db_migrations to path (for _migration_utils)
|
||||
@@ -87,7 +81,11 @@ def check_migration_needed():
|
||||
|
||||
|
||||
def migrate_sqlite():
|
||||
"""Migrate SQLite database."""
|
||||
"""Migrate SQLite database.
|
||||
|
||||
Note: This function runs inside a transaction (db.atomic()).
|
||||
Do NOT call db.commit() or db.rollback() inside this function.
|
||||
"""
|
||||
cursor = db.cursor()
|
||||
|
||||
# User avatar fields
|
||||
@@ -124,11 +122,13 @@ def migrate_sqlite():
|
||||
else:
|
||||
raise
|
||||
|
||||
db.commit()
|
||||
|
||||
|
||||
def migrate_postgres():
|
||||
"""Migrate PostgreSQL database."""
|
||||
"""Migrate PostgreSQL database.
|
||||
|
||||
Note: This function runs inside a transaction (db.atomic()).
|
||||
Do NOT call db.commit() or db.rollback() inside this function.
|
||||
"""
|
||||
cursor = db.cursor()
|
||||
|
||||
# User avatar fields
|
||||
@@ -145,7 +145,6 @@ def migrate_postgres():
|
||||
except Exception as e:
|
||||
if "already exists" in str(e).lower():
|
||||
print(f" - User.{column} already exists")
|
||||
db.rollback()
|
||||
else:
|
||||
raise
|
||||
|
||||
@@ -163,19 +162,21 @@ def migrate_postgres():
|
||||
except Exception as e:
|
||||
if "already exists" in str(e).lower():
|
||||
print(f" - Organization.{column} already exists")
|
||||
db.rollback()
|
||||
else:
|
||||
raise
|
||||
|
||||
db.commit()
|
||||
|
||||
|
||||
def run():
|
||||
"""Run this migration."""
|
||||
"""Run this migration.
|
||||
|
||||
IMPORTANT: Do NOT call db.close() in finally block!
|
||||
The db connection is managed by run_migrations.py and should stay open
|
||||
across all migrations to avoid stdout/stderr closure issues on Windows.
|
||||
"""
|
||||
db.connect(reuse_if_open=True)
|
||||
|
||||
try:
|
||||
# Check if any future migration has been applied
|
||||
# Pre-flight checks (outside transaction for performance)
|
||||
if should_skip_due_to_future_migrations(MIGRATION_NUMBER, db, cfg):
|
||||
print("Migration 007: Skipped (superseded by future migration)")
|
||||
return True
|
||||
@@ -205,22 +206,25 @@ def run():
|
||||
|
||||
print("Migration 007: Adding avatar support to User and Organization tables...")
|
||||
|
||||
if cfg.app.db_backend == "postgres":
|
||||
migrate_postgres()
|
||||
else:
|
||||
migrate_sqlite()
|
||||
# Run migration in a transaction - will auto-rollback on exception
|
||||
with db.atomic():
|
||||
if cfg.app.db_backend == "postgres":
|
||||
migrate_postgres()
|
||||
else:
|
||||
migrate_sqlite()
|
||||
|
||||
print("Migration 007: ✓ Completed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
# Transaction automatically rolled back if we reach here
|
||||
print(f"Migration 007: ✗ Failed - {e}")
|
||||
print(" All changes have been rolled back")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
# NOTE: No finally block - db connection stays open
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -37,12 +37,6 @@ This migration cannot be easily rolled back. Test thoroughly before deploying to
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Fix Windows encoding issues
|
||||
if sys.platform == "win32":
|
||||
import io
|
||||
|
||||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")
|
||||
|
||||
# Add src to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "src"))
|
||||
# Add db_migrations to path (for _migration_utils)
|
||||
@@ -137,6 +131,9 @@ def check_migration_needed():
|
||||
def migrate_sqlite():
|
||||
"""Migrate SQLite database.
|
||||
|
||||
NOTE: This function runs inside a transaction (db.atomic()).
|
||||
Do NOT call db.commit() or db.rollback() inside this function.
|
||||
|
||||
Strategy:
|
||||
1. Add new columns to User table (is_org, description, make email/password nullable)
|
||||
2. Migrate Organization data into User table
|
||||
@@ -147,28 +144,7 @@ def migrate_sqlite():
|
||||
"""
|
||||
cursor = db.cursor()
|
||||
|
||||
print("\n=== Phase 1: Backup Warning ===")
|
||||
print("⚠️ This migration modifies the database schema significantly.")
|
||||
print("⚠️ BACKUP YOUR DATABASE before proceeding!")
|
||||
print("")
|
||||
|
||||
# Allow auto-confirmation via environment variable (for Docker/CI)
|
||||
auto_confirm = os.environ.get("KOHAKU_HUB_AUTO_MIGRATE", "").lower() in (
|
||||
"true",
|
||||
"1",
|
||||
"yes",
|
||||
)
|
||||
if auto_confirm:
|
||||
print(" Auto-confirmation enabled (KOHAKU_HUB_AUTO_MIGRATE=true)")
|
||||
response = "yes"
|
||||
else:
|
||||
response = input("Type 'yes' to continue: ")
|
||||
|
||||
if response.lower() != "yes":
|
||||
print("Migration cancelled.")
|
||||
return False
|
||||
|
||||
print("\n=== Phase 2: Add new columns to User table ===")
|
||||
print("\n=== Phase 1: Add new columns to User table ===")
|
||||
|
||||
# Add is_org column
|
||||
try:
|
||||
@@ -203,10 +179,8 @@ def migrate_sqlite():
|
||||
# Note: SQLite doesn't support ALTER COLUMN to make existing columns nullable
|
||||
# This will require table recreation, which we'll handle in a full rebuild
|
||||
|
||||
db.commit()
|
||||
|
||||
# Populate normalized_name for existing users
|
||||
print(" Populating User.normalized_name for existing users...")
|
||||
print("\n Populating User.normalized_name for existing users...")
|
||||
cursor.execute("SELECT id, username FROM user")
|
||||
users = cursor.fetchall()
|
||||
|
||||
@@ -217,7 +191,6 @@ def migrate_sqlite():
|
||||
"UPDATE user SET normalized_name = ? WHERE id = ?", (normalized, user_id)
|
||||
)
|
||||
|
||||
db.commit()
|
||||
print(f" ✓ Populated normalized_name for {len(users)} existing users")
|
||||
|
||||
print("\n=== Phase 3: Migrate Organization data into User table ===")
|
||||
@@ -299,8 +272,7 @@ def migrate_sqlite():
|
||||
|
||||
print(f" ✓ Migrated organization '{name}' (id {org_id} -> {new_user_id})")
|
||||
|
||||
db.commit()
|
||||
print(f" ✓ All {len(orgs)} organizations migrated to User table")
|
||||
print(f" ✓ All {len(orgs)} organizations migrated to User table")
|
||||
else:
|
||||
print(" - No organization table found, skipping")
|
||||
|
||||
@@ -324,7 +296,6 @@ def migrate_sqlite():
|
||||
(new_user_id, membership_id),
|
||||
)
|
||||
|
||||
db.commit()
|
||||
print(f" ✓ Updated {len(memberships)} UserOrganization records")
|
||||
|
||||
# 4b. Add owner column to File table (denormalized from repository.owner)
|
||||
@@ -348,7 +319,6 @@ def migrate_sqlite():
|
||||
"""
|
||||
)
|
||||
print(f" ✓ Updated File.owner_id for all files")
|
||||
db.commit()
|
||||
|
||||
# 4c. Add owner column to Commit table (repository owner)
|
||||
print(" Adding Commit.owner_id column...")
|
||||
@@ -371,7 +341,6 @@ def migrate_sqlite():
|
||||
"""
|
||||
)
|
||||
print(f" ✓ Updated Commit.owner_id for all commits")
|
||||
db.commit()
|
||||
|
||||
# 4d. Add uploader column to StagingUpload table
|
||||
print(" Adding StagingUpload.uploader_id column...")
|
||||
@@ -384,7 +353,6 @@ def migrate_sqlite():
|
||||
if "duplicate column" not in str(e).lower():
|
||||
raise
|
||||
print(" - StagingUpload.uploader_id already exists")
|
||||
db.commit()
|
||||
|
||||
# 4e. Add file FK column to LFSObjectHistory table
|
||||
print(" Adding LFSObjectHistory.file_id column...")
|
||||
@@ -410,9 +378,8 @@ def migrate_sqlite():
|
||||
"""
|
||||
)
|
||||
print(f" ✓ Updated LFSObjectHistory.file_id for all records")
|
||||
db.commit()
|
||||
|
||||
print("\n=== Phase 5: Cleanup ===")
|
||||
print("\n=== Phase 4: Cleanup ===")
|
||||
|
||||
# Drop temporary mapping table
|
||||
try:
|
||||
@@ -425,7 +392,6 @@ def migrate_sqlite():
|
||||
try:
|
||||
cursor.execute("DROP TABLE organization")
|
||||
print(" ✓ Dropped Organization table")
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
print(f" - Failed to drop organization table: {e}")
|
||||
# Non-fatal, continue
|
||||
@@ -471,7 +437,6 @@ def migrate_postgres():
|
||||
except Exception as e:
|
||||
if "already exists" in str(e).lower():
|
||||
print(" - User.is_org already exists")
|
||||
db.rollback()
|
||||
else:
|
||||
raise
|
||||
|
||||
@@ -482,7 +447,6 @@ def migrate_postgres():
|
||||
except Exception as e:
|
||||
if "already exists" in str(e).lower():
|
||||
print(" - User.description already exists")
|
||||
db.rollback()
|
||||
else:
|
||||
raise
|
||||
|
||||
@@ -493,7 +457,6 @@ def migrate_postgres():
|
||||
except Exception as e:
|
||||
if "already exists" in str(e).lower():
|
||||
print(" - User.normalized_name already exists")
|
||||
db.rollback()
|
||||
else:
|
||||
raise
|
||||
|
||||
@@ -506,8 +469,6 @@ def migrate_postgres():
|
||||
print(f" - Failed to make columns nullable (may already be nullable): {e}")
|
||||
db.rollback()
|
||||
|
||||
db.commit()
|
||||
|
||||
# Populate normalized_name for existing users
|
||||
print(" Populating User.normalized_name for existing users...")
|
||||
cursor.execute('SELECT id, username FROM "user"')
|
||||
@@ -521,7 +482,6 @@ def migrate_postgres():
|
||||
(normalized, user_id),
|
||||
)
|
||||
|
||||
db.commit()
|
||||
print(f" ✓ Populated normalized_name for {len(users)} existing users")
|
||||
|
||||
print("\n=== Phase 3: Migrate Organization data into User table ===")
|
||||
@@ -603,8 +563,7 @@ def migrate_postgres():
|
||||
|
||||
print(f" ✓ Migrated organization '{name}' (id {org_id} -> {new_user_id})")
|
||||
|
||||
db.commit()
|
||||
print(f" ✓ All {len(orgs)} organizations migrated to User table")
|
||||
print(f" ✓ All {len(orgs)} organizations migrated to User table")
|
||||
else:
|
||||
print(" - No organization table found, skipping")
|
||||
|
||||
@@ -617,7 +576,6 @@ def migrate_postgres():
|
||||
"FROM _org_id_mapping m WHERE userorganization.organization_id = m.old_org_id"
|
||||
)
|
||||
affected = cursor.rowcount
|
||||
db.commit()
|
||||
print(f" ✓ Updated {affected} UserOrganization records")
|
||||
|
||||
# 4b. Add owner column to File table (denormalized from repository.owner)
|
||||
@@ -628,7 +586,6 @@ def migrate_postgres():
|
||||
except Exception as e:
|
||||
if "already exists" in str(e).lower():
|
||||
print(" - File.owner_id already exists")
|
||||
db.rollback()
|
||||
else:
|
||||
raise
|
||||
|
||||
@@ -641,7 +598,6 @@ def migrate_postgres():
|
||||
"""
|
||||
)
|
||||
print(f" ✓ Updated File.owner_id for all files")
|
||||
db.commit()
|
||||
|
||||
# 4c. Add owner column to Commit table (repository owner)
|
||||
print(" Adding Commit.owner_id column...")
|
||||
@@ -651,7 +607,6 @@ def migrate_postgres():
|
||||
except Exception as e:
|
||||
if "already exists" in str(e).lower():
|
||||
print(" - Commit.owner_id already exists")
|
||||
db.rollback()
|
||||
else:
|
||||
raise
|
||||
|
||||
@@ -664,7 +619,6 @@ def migrate_postgres():
|
||||
"""
|
||||
)
|
||||
print(f" ✓ Updated Commit.owner_id for all commits")
|
||||
db.commit()
|
||||
|
||||
# 4d. Add uploader column to StagingUpload table
|
||||
print(" Adding StagingUpload.uploader_id column...")
|
||||
@@ -676,10 +630,8 @@ def migrate_postgres():
|
||||
except Exception as e:
|
||||
if "already exists" in str(e).lower():
|
||||
print(" - StagingUpload.uploader_id already exists")
|
||||
db.rollback()
|
||||
else:
|
||||
raise
|
||||
db.commit()
|
||||
|
||||
# 4e. Add file FK column to LFSObjectHistory table
|
||||
print(" Adding LFSObjectHistory.file_id column...")
|
||||
@@ -691,7 +643,6 @@ def migrate_postgres():
|
||||
except Exception as e:
|
||||
if "already exists" in str(e).lower():
|
||||
print(" - LFSObjectHistory.file_id already exists")
|
||||
db.rollback()
|
||||
else:
|
||||
raise
|
||||
|
||||
@@ -705,17 +656,14 @@ def migrate_postgres():
|
||||
"""
|
||||
)
|
||||
print(f" ✓ Updated LFSObjectHistory.file_id for all records")
|
||||
db.commit()
|
||||
|
||||
print("\n=== Phase 5: Drop old Organization table ===")
|
||||
|
||||
try:
|
||||
cursor.execute("DROP TABLE IF EXISTS organization CASCADE")
|
||||
print(" ✓ Dropped Organization table")
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
print(f" - Failed to drop organization table: {e}")
|
||||
db.rollback()
|
||||
|
||||
print("\n⚠️ IMPORTANT: Table recreation with Foreign Keys")
|
||||
print("⚠️ Peewee ORM will handle ForeignKey constraint creation on next startup")
|
||||
@@ -725,11 +673,20 @@ def migrate_postgres():
|
||||
|
||||
|
||||
def run():
|
||||
"""Run this migration."""
|
||||
"""Run this migration.
|
||||
|
||||
IMPORTANT: Do NOT call db.close() in finally block!
|
||||
The db connection is managed by run_migrations.py and should stay open
|
||||
across all migrations to avoid stdout/stderr closure issues on Windows.
|
||||
|
||||
NOTE: Migration 008 is special - user confirmation happens INSIDE migrate functions
|
||||
because it needs to check data before prompting. This is an exception to the normal
|
||||
pattern where confirmations happen before transactions.
|
||||
"""
|
||||
db.connect(reuse_if_open=True)
|
||||
|
||||
try:
|
||||
# Check if any future migration has been applied (for extensibility)
|
||||
# Pre-flight checks (outside transaction for performance)
|
||||
if should_skip_due_to_future_migrations(MIGRATION_NUMBER, db, cfg):
|
||||
print("Migration 008: Skipped (superseded by future migration)")
|
||||
return True
|
||||
@@ -762,6 +719,9 @@ def run():
|
||||
print("Merging User/Organization tables + Adding ForeignKey constraints")
|
||||
print("=" * 70)
|
||||
|
||||
# NOTE: User confirmation happens INSIDE migrate functions for this migration
|
||||
# because it needs to analyze data first. Not wrapped in db.atomic() here
|
||||
# because the migrate functions handle their own transaction logic.
|
||||
if cfg.app.db_backend == "postgres":
|
||||
result = migrate_postgres()
|
||||
else:
|
||||
@@ -779,13 +739,15 @@ def run():
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
# If exception occurs, changes may have been partially applied
|
||||
print(f"\nMigration 008: ✗ Failed - {e}")
|
||||
print(" WARNING: This migration does not use db.atomic() due to user prompts")
|
||||
print(" Database may be in intermediate state - restore from backup!")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
# NOTE: No finally block - db connection stays open
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
164
scripts/db_migrations/009_repo_lfs_settings.py
Normal file
164
scripts/db_migrations/009_repo_lfs_settings.py
Normal file
@@ -0,0 +1,164 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Migration 009: Add LFS settings fields to Repository model.
|
||||
|
||||
Adds the following fields to allow per-repository LFS configuration:
|
||||
- Repository: lfs_threshold_bytes (NULL = use server default)
|
||||
- Repository: lfs_keep_versions (NULL = use server default)
|
||||
- Repository: lfs_suffix_rules (NULL = no suffix rules)
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add src to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "src"))
|
||||
# Add db_migrations to path (for _migration_utils)
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
from kohakuhub.db import db
|
||||
from kohakuhub.config import cfg
|
||||
from _migration_utils import should_skip_due_to_future_migrations, check_column_exists
|
||||
|
||||
MIGRATION_NUMBER = 9
|
||||
|
||||
|
||||
def is_applied(db, cfg):
|
||||
"""Check if THIS migration has been applied.
|
||||
|
||||
Returns True if Repository.lfs_threshold_bytes column exists.
|
||||
"""
|
||||
return check_column_exists(db, cfg, "repository", "lfs_threshold_bytes")
|
||||
|
||||
|
||||
def check_migration_needed():
|
||||
"""Check if this migration needs to run by checking if columns exist."""
|
||||
cursor = db.cursor()
|
||||
|
||||
if cfg.app.db_backend == "postgres":
|
||||
# Check if Repository.lfs_threshold_bytes exists
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_name='repository' AND column_name='lfs_threshold_bytes'
|
||||
"""
|
||||
)
|
||||
return cursor.fetchone() is None
|
||||
else:
|
||||
# SQLite: Check via PRAGMA
|
||||
cursor.execute("PRAGMA table_info(repository)")
|
||||
columns = [row[1] for row in cursor.fetchall()]
|
||||
return "lfs_threshold_bytes" not in columns
|
||||
|
||||
|
||||
def migrate_sqlite():
|
||||
"""Migrate SQLite database.
|
||||
|
||||
Note: This function runs inside a transaction (db.atomic()).
|
||||
Do NOT call db.commit() or db.rollback() inside this function.
|
||||
"""
|
||||
cursor = db.cursor()
|
||||
|
||||
for column, sql in [
|
||||
(
|
||||
"lfs_threshold_bytes",
|
||||
"ALTER TABLE repository ADD COLUMN lfs_threshold_bytes INTEGER DEFAULT NULL",
|
||||
),
|
||||
(
|
||||
"lfs_keep_versions",
|
||||
"ALTER TABLE repository ADD COLUMN lfs_keep_versions INTEGER DEFAULT NULL",
|
||||
),
|
||||
(
|
||||
"lfs_suffix_rules",
|
||||
"ALTER TABLE repository ADD COLUMN lfs_suffix_rules TEXT DEFAULT NULL",
|
||||
),
|
||||
]:
|
||||
try:
|
||||
cursor.execute(sql)
|
||||
print(f" ✓ Added Repository.{column}")
|
||||
except Exception as e:
|
||||
if "duplicate column" in str(e).lower():
|
||||
print(f" - Repository.{column} already exists")
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def migrate_postgres():
|
||||
"""Migrate PostgreSQL database.
|
||||
|
||||
Note: This function runs inside a transaction (db.atomic()).
|
||||
Do NOT call db.commit() or db.rollback() inside this function.
|
||||
"""
|
||||
cursor = db.cursor()
|
||||
|
||||
for column, sql in [
|
||||
(
|
||||
"lfs_threshold_bytes",
|
||||
"ALTER TABLE repository ADD COLUMN lfs_threshold_bytes INTEGER DEFAULT NULL",
|
||||
),
|
||||
(
|
||||
"lfs_keep_versions",
|
||||
"ALTER TABLE repository ADD COLUMN lfs_keep_versions INTEGER DEFAULT NULL",
|
||||
),
|
||||
(
|
||||
"lfs_suffix_rules",
|
||||
"ALTER TABLE repository ADD COLUMN lfs_suffix_rules TEXT DEFAULT NULL",
|
||||
),
|
||||
]:
|
||||
try:
|
||||
cursor.execute(sql)
|
||||
print(f" ✓ Added Repository.{column}")
|
||||
except Exception as e:
|
||||
if "already exists" in str(e).lower():
|
||||
print(f" - Repository.{column} already exists")
|
||||
# Don't need to rollback - the exception will propagate and rollback the entire transaction
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def run():
|
||||
"""Run this migration.
|
||||
|
||||
IMPORTANT: Do NOT call db.close() in finally block!
|
||||
The db connection is managed by run_migrations.py and should stay open
|
||||
across all migrations to avoid stdout/stderr closure issues on Windows.
|
||||
"""
|
||||
db.connect(reuse_if_open=True)
|
||||
|
||||
try:
|
||||
# Pre-flight checks (outside transaction for performance)
|
||||
if should_skip_due_to_future_migrations(MIGRATION_NUMBER, db, cfg):
|
||||
print("Migration 009: Skipped (superseded by future migration)")
|
||||
return True
|
||||
|
||||
if not check_migration_needed():
|
||||
print("Migration 009: Already applied (columns exist)")
|
||||
return True
|
||||
|
||||
print("Migration 009: Adding Repository LFS settings fields...")
|
||||
|
||||
# Run migration in a transaction - will auto-rollback on exception
|
||||
with db.atomic():
|
||||
if cfg.app.db_backend == "postgres":
|
||||
migrate_postgres()
|
||||
else:
|
||||
migrate_sqlite()
|
||||
|
||||
print("Migration 009: ✓ Completed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
# Transaction automatically rolled back if we reach here
|
||||
print(f"Migration 009: ✗ Failed - {e}")
|
||||
print(" All changes have been rolled back")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
# NOTE: No finally block - db connection stays open
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = run()
|
||||
sys.exit(0 if success else 1)
|
||||
@@ -65,7 +65,7 @@ def load_migration_module(name, path):
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
except Exception as e:
|
||||
print(f" ✗ Failed to load {name}: {e}")
|
||||
print(f" [ERROR] Failed to load {name}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@@ -102,7 +102,7 @@ def run_migrations():
|
||||
|
||||
# Check if module has run() function
|
||||
if not hasattr(module, "run"):
|
||||
print(f" ✗ Migration {name} missing run() function")
|
||||
print(f" [ERROR] Migration {name} missing run() function")
|
||||
all_success = False
|
||||
continue
|
||||
|
||||
@@ -112,7 +112,7 @@ def run_migrations():
|
||||
if not success:
|
||||
all_success = False
|
||||
except Exception as e:
|
||||
print(f" ✗ Migration {name} crashed: {e}")
|
||||
print(f" [ERROR] Migration {name} crashed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
@@ -125,9 +125,9 @@ def run_migrations():
|
||||
print("\nFinalizing database schema (ensuring all tables/indexes exist)...")
|
||||
try:
|
||||
init_db()
|
||||
print("✓ Database schema finalized\n")
|
||||
print("[OK] Database schema finalized\n")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to finalize database schema: {e}")
|
||||
print(f"[ERROR] Failed to finalize database schema: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
@@ -136,9 +136,9 @@ def run_migrations():
|
||||
# Summary
|
||||
print("=" * 70)
|
||||
if all_success:
|
||||
print("✓ All migrations completed successfully!")
|
||||
print("[OK] All migrations completed successfully!")
|
||||
else:
|
||||
print("✗ Some migrations failed - please check errors above")
|
||||
print("[ERROR] Some migrations failed - please check errors above")
|
||||
print("=" * 70)
|
||||
|
||||
return all_success
|
||||
@@ -149,7 +149,7 @@ def main():
|
||||
success = run_migrations()
|
||||
return 0 if success else 1
|
||||
except Exception as e:
|
||||
print(f"\n✗ Migration runner crashed: {e}")
|
||||
print(f"\n[ERROR] Migration runner crashed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
544
scripts/test_migration_009.py
Normal file
544
scripts/test_migration_009.py
Normal file
@@ -0,0 +1,544 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for migration 009 (Repository LFS settings).
|
||||
|
||||
This script:
|
||||
1. Creates a test database with old schema (old_db.py)
|
||||
2. Populates with mock data
|
||||
3. Runs migration 009
|
||||
4. Verifies migration succeeded and data preserved
|
||||
5. Tests new functionality
|
||||
|
||||
Usage:
|
||||
python scripts/test_migration_009.py
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Fix Windows encoding
|
||||
if sys.platform == "win32":
|
||||
import io
|
||||
|
||||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")
|
||||
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8")
|
||||
|
||||
# Add src to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "db_migrations"))
|
||||
|
||||
# Test database path
|
||||
TEST_DB_PATH = Path(__file__).parent.parent / "test_migration_009.db"
|
||||
|
||||
|
||||
def cleanup_test_db():
|
||||
"""Remove test database if exists."""
|
||||
if TEST_DB_PATH.exists():
|
||||
TEST_DB_PATH.unlink()
|
||||
print(f"Cleaned up old test database: {TEST_DB_PATH}")
|
||||
|
||||
|
||||
def step1_create_old_schema():
|
||||
"""Step 1: Create database with old schema (before migration 009)."""
|
||||
print("\n" + "=" * 70)
|
||||
print("STEP 1: Create database with OLD schema (before migration 009)")
|
||||
print("=" * 70)
|
||||
|
||||
# Import old schema
|
||||
from kohakuhub import old_db
|
||||
from kohakuhub.config import cfg
|
||||
|
||||
# Override database to use test DB
|
||||
from peewee import SqliteDatabase
|
||||
|
||||
old_db.db = SqliteDatabase(str(TEST_DB_PATH), pragmas={"foreign_keys": 1})
|
||||
|
||||
# Reconnect all models to test DB
|
||||
for model in [
|
||||
old_db.User,
|
||||
old_db.EmailVerification,
|
||||
old_db.Session,
|
||||
old_db.Token,
|
||||
old_db.Repository,
|
||||
old_db.File,
|
||||
old_db.StagingUpload,
|
||||
old_db.UserOrganization,
|
||||
old_db.Commit,
|
||||
old_db.LFSObjectHistory,
|
||||
old_db.SSHKey,
|
||||
old_db.Invitation,
|
||||
]:
|
||||
model._meta.database = old_db.db
|
||||
|
||||
# Create tables
|
||||
old_db.db.connect(reuse_if_open=True)
|
||||
old_db.db.create_tables(
|
||||
[
|
||||
old_db.User,
|
||||
old_db.EmailVerification,
|
||||
old_db.Session,
|
||||
old_db.Token,
|
||||
old_db.Repository,
|
||||
old_db.File,
|
||||
old_db.StagingUpload,
|
||||
old_db.UserOrganization,
|
||||
old_db.Commit,
|
||||
old_db.LFSObjectHistory,
|
||||
old_db.SSHKey,
|
||||
old_db.Invitation,
|
||||
],
|
||||
safe=True,
|
||||
)
|
||||
|
||||
# Verify Repository table structure (should NOT have LFS fields)
|
||||
cursor = old_db.db.cursor()
|
||||
cursor.execute("PRAGMA table_info(repository)")
|
||||
columns = [row[1] for row in cursor.fetchall()]
|
||||
|
||||
print(f" Created tables with old schema")
|
||||
print(f" Repository columns: {len(columns)}")
|
||||
print(f" Has quota_bytes: {'quota_bytes' in columns}")
|
||||
print(f" Has lfs_threshold_bytes: {'lfs_threshold_bytes' in columns}")
|
||||
print(f" Has lfs_keep_versions: {'lfs_keep_versions' in columns}")
|
||||
print(f" Has lfs_suffix_rules: {'lfs_suffix_rules' in columns}")
|
||||
|
||||
if "lfs_threshold_bytes" in columns:
|
||||
print(" ERROR: Old schema should NOT have LFS fields!")
|
||||
return False
|
||||
|
||||
print(" ✓ Old schema created correctly (no LFS fields)")
|
||||
return True
|
||||
|
||||
|
||||
def step2_populate_mock_data():
|
||||
"""Step 2: Populate database with mock data."""
|
||||
print("\n" + "=" * 70)
|
||||
print("STEP 2: Populate with mock data")
|
||||
print("=" * 70)
|
||||
|
||||
from kohakuhub import old_db
|
||||
|
||||
# Create test org
|
||||
org = old_db.User.create(
|
||||
username="test-org",
|
||||
normalized_name="testorg",
|
||||
is_org=True,
|
||||
email=None,
|
||||
password_hash=None,
|
||||
private_quota_bytes=10 * 1024 * 1024 * 1024, # 10GB
|
||||
public_quota_bytes=50 * 1024 * 1024 * 1024, # 50GB
|
||||
)
|
||||
print(f" Created organization: {org.username}")
|
||||
|
||||
# Create test user
|
||||
user = old_db.User.create(
|
||||
username="test-user",
|
||||
normalized_name="testuser",
|
||||
is_org=False,
|
||||
email="test@example.com",
|
||||
password_hash="dummy_hash",
|
||||
email_verified=True,
|
||||
is_active=True,
|
||||
private_quota_bytes=5 * 1024 * 1024 * 1024, # 5GB
|
||||
public_quota_bytes=20 * 1024 * 1024 * 1024, # 20GB
|
||||
)
|
||||
print(f" Created user: {user.username}")
|
||||
|
||||
# Create test repositories
|
||||
repo1 = old_db.Repository.create(
|
||||
repo_type="model",
|
||||
namespace="test-org",
|
||||
name="test-model",
|
||||
full_id="test-org/test-model",
|
||||
private=False,
|
||||
owner=org,
|
||||
quota_bytes=None, # Inherit from org
|
||||
used_bytes=1024 * 1024 * 100, # 100MB
|
||||
)
|
||||
print(f" Created repository: {repo1.full_id}")
|
||||
|
||||
repo2 = old_db.Repository.create(
|
||||
repo_type="dataset",
|
||||
namespace="test-user",
|
||||
name="test-dataset",
|
||||
full_id="test-user/test-dataset",
|
||||
private=True,
|
||||
owner=user,
|
||||
quota_bytes=2 * 1024 * 1024 * 1024, # Custom 2GB quota
|
||||
used_bytes=500 * 1024 * 1024, # 500MB used
|
||||
)
|
||||
print(f" Created repository: {repo2.full_id}")
|
||||
|
||||
# Verify data
|
||||
repo_count = old_db.Repository.select().count()
|
||||
user_count = old_db.User.select().count()
|
||||
|
||||
print(f"\n ✓ Mock data created:")
|
||||
print(f" Users: {user_count}")
|
||||
print(f" Repositories: {repo_count}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def step3_run_migration():
|
||||
"""Step 3: Run migration 009."""
|
||||
print("\n" + "=" * 70)
|
||||
print("STEP 3: Run migration 009")
|
||||
print("=" * 70)
|
||||
|
||||
# Migration script will use kohakuhub.db which is already connected to our test DB
|
||||
# We just need to temporarily override the database path
|
||||
import kohakuhub.db as db_module
|
||||
from peewee import SqliteDatabase
|
||||
|
||||
# Close any existing connection
|
||||
if not db_module.db.is_closed():
|
||||
db_module.db.close()
|
||||
|
||||
# Replace with test database
|
||||
old_db = db_module.db
|
||||
db_module.db = SqliteDatabase(str(TEST_DB_PATH), pragmas={"foreign_keys": 1})
|
||||
db_module.db.connect(reuse_if_open=True)
|
||||
|
||||
# Also update all model references to use the test db
|
||||
from kohakuhub.db import (
|
||||
User,
|
||||
EmailVerification,
|
||||
Session,
|
||||
Token,
|
||||
Repository,
|
||||
File,
|
||||
StagingUpload,
|
||||
UserOrganization,
|
||||
Commit,
|
||||
LFSObjectHistory,
|
||||
SSHKey,
|
||||
Invitation,
|
||||
)
|
||||
|
||||
for model in [
|
||||
User,
|
||||
EmailVerification,
|
||||
Session,
|
||||
Token,
|
||||
Repository,
|
||||
File,
|
||||
StagingUpload,
|
||||
UserOrganization,
|
||||
Commit,
|
||||
LFSObjectHistory,
|
||||
SSHKey,
|
||||
Invitation,
|
||||
]:
|
||||
model._meta.database = db_module.db
|
||||
|
||||
try:
|
||||
# Load and run migration
|
||||
migration_path = (
|
||||
Path(__file__).parent / "db_migrations" / "009_repo_lfs_settings.py"
|
||||
)
|
||||
|
||||
import importlib.util
|
||||
|
||||
spec = importlib.util.spec_from_file_location("migration_009", migration_path)
|
||||
migration_module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(migration_module)
|
||||
|
||||
# Run migration
|
||||
success = migration_module.run()
|
||||
|
||||
if not success:
|
||||
print(" ✗ Migration failed!")
|
||||
return False
|
||||
|
||||
print(" ✓ Migration completed")
|
||||
return True
|
||||
|
||||
finally:
|
||||
# Don't restore old_db, keep using test DB for verification
|
||||
pass
|
||||
|
||||
|
||||
def step4_verify_migration():
|
||||
"""Step 4: Verify migration succeeded and data preserved."""
|
||||
print("\n" + "=" * 70)
|
||||
print("STEP 4: Verify migration results")
|
||||
print("=" * 70)
|
||||
|
||||
# DB is already connected to test DB from step3
|
||||
from kohakuhub.db import Repository, User, db
|
||||
|
||||
db.connect(reuse_if_open=True)
|
||||
|
||||
# Verify schema
|
||||
cursor = db.cursor()
|
||||
cursor.execute("PRAGMA table_info(repository)")
|
||||
columns = [row[1] for row in cursor.fetchall()]
|
||||
|
||||
print(f" Repository columns after migration: {len(columns)}")
|
||||
print(f" Has lfs_threshold_bytes: {'lfs_threshold_bytes' in columns}")
|
||||
print(f" Has lfs_keep_versions: {'lfs_keep_versions' in columns}")
|
||||
print(f" Has lfs_suffix_rules: {'lfs_suffix_rules' in columns}")
|
||||
|
||||
if not all(
|
||||
col in columns
|
||||
for col in ["lfs_threshold_bytes", "lfs_keep_versions", "lfs_suffix_rules"]
|
||||
):
|
||||
print(" ✗ Migration failed - LFS columns not found!")
|
||||
return False
|
||||
|
||||
# Verify data preserved
|
||||
repos = list(Repository.select())
|
||||
users = list(User.select())
|
||||
|
||||
print(f"\n Data preservation check:")
|
||||
print(f" Users: {len(users)}")
|
||||
print(f" Repositories: {len(repos)}")
|
||||
|
||||
if len(repos) != 2 or len(users) != 2:
|
||||
print(" ✗ Data loss detected!")
|
||||
return False
|
||||
|
||||
# Verify specific data
|
||||
repo1 = Repository.get_or_none(
|
||||
Repository.repo_type == "model", Repository.namespace == "test-org"
|
||||
)
|
||||
repo2 = Repository.get_or_none(
|
||||
Repository.repo_type == "dataset", Repository.namespace == "test-user"
|
||||
)
|
||||
|
||||
if not repo1 or not repo2:
|
||||
print(" ✗ Cannot find test repositories!")
|
||||
return False
|
||||
|
||||
print(f"\n Repository 1 (test-org/test-model):")
|
||||
print(f" quota_bytes: {repo1.quota_bytes}")
|
||||
print(f" used_bytes: {repo1.used_bytes}")
|
||||
print(
|
||||
f" lfs_threshold_bytes: {repo1.lfs_threshold_bytes} (NULL = server default)"
|
||||
)
|
||||
print(f" lfs_keep_versions: {repo1.lfs_keep_versions} (NULL = server default)")
|
||||
print(f" lfs_suffix_rules: {repo1.lfs_suffix_rules} (NULL = no rules)")
|
||||
|
||||
print(f"\n Repository 2 (test-user/test-dataset):")
|
||||
print(f" quota_bytes: {repo2.quota_bytes}")
|
||||
print(f" used_bytes: {repo2.used_bytes}")
|
||||
print(f" lfs_threshold_bytes: {repo2.lfs_threshold_bytes}")
|
||||
print(f" lfs_keep_versions: {repo2.lfs_keep_versions}")
|
||||
print(f" lfs_suffix_rules: {repo2.lfs_suffix_rules}")
|
||||
|
||||
# Verify NULL values (new fields should be NULL after migration)
|
||||
if repo1.lfs_threshold_bytes is not None:
|
||||
print(" ✗ lfs_threshold_bytes should be NULL after migration!")
|
||||
return False
|
||||
|
||||
if repo1.lfs_keep_versions is not None:
|
||||
print(" ✗ lfs_keep_versions should be NULL after migration!")
|
||||
return False
|
||||
|
||||
if repo1.lfs_suffix_rules is not None:
|
||||
print(" ✗ lfs_suffix_rules should be NULL after migration!")
|
||||
return False
|
||||
|
||||
# Verify old data preserved
|
||||
if repo1.quota_bytes is not None:
|
||||
print(" ✗ quota_bytes should still be NULL!")
|
||||
return False
|
||||
|
||||
if repo1.used_bytes != 1024 * 1024 * 100:
|
||||
print(f" ✗ used_bytes changed! Expected 104857600, got {repo1.used_bytes}")
|
||||
return False
|
||||
|
||||
if repo2.quota_bytes != 2 * 1024 * 1024 * 1024:
|
||||
print(f" ✗ quota_bytes changed! Expected 2147483648, got {repo2.quota_bytes}")
|
||||
return False
|
||||
|
||||
print("\n ✓ All data preserved correctly")
|
||||
print(" ✓ New LFS fields added as NULL")
|
||||
return True
|
||||
|
||||
|
||||
def step5_test_new_functionality():
|
||||
"""Step 5: Test new LFS functionality."""
|
||||
print("\n" + "=" * 70)
|
||||
print("STEP 5: Test new LFS functionality")
|
||||
print("=" * 70)
|
||||
|
||||
# DB is already connected to test DB
|
||||
from kohakuhub.db import Repository, db
|
||||
from kohakuhub.db_operations import (
|
||||
get_effective_lfs_keep_versions,
|
||||
get_effective_lfs_suffix_rules,
|
||||
get_effective_lfs_threshold,
|
||||
should_use_lfs,
|
||||
)
|
||||
from kohakuhub.config import cfg
|
||||
import json
|
||||
|
||||
repo = Repository.get_or_none(
|
||||
Repository.repo_type == "model", Repository.namespace == "test-org"
|
||||
)
|
||||
|
||||
if not repo:
|
||||
print(" ✗ Test repository not found!")
|
||||
return False
|
||||
|
||||
# Test 1: Default values (NULL in DB)
|
||||
print(" Test 1: Default values (NULL in database)")
|
||||
threshold = get_effective_lfs_threshold(repo)
|
||||
keep_versions = get_effective_lfs_keep_versions(repo)
|
||||
suffix_rules = get_effective_lfs_suffix_rules(repo)
|
||||
|
||||
print(f" Effective threshold: {threshold / (1024*1024):.1f} MB")
|
||||
print(f" Effective keep_versions: {keep_versions}")
|
||||
print(f" Suffix rules: {suffix_rules}")
|
||||
|
||||
if threshold != cfg.app.lfs_threshold_bytes:
|
||||
print(
|
||||
f" ✗ Wrong threshold! Expected {cfg.app.lfs_threshold_bytes}, got {threshold}"
|
||||
)
|
||||
return False
|
||||
|
||||
if keep_versions != cfg.app.lfs_keep_versions:
|
||||
print(
|
||||
f" ✗ Wrong keep_versions! Expected {cfg.app.lfs_keep_versions}, got {keep_versions}"
|
||||
)
|
||||
return False
|
||||
|
||||
if len(suffix_rules) != 0:
|
||||
print(f" ✗ Suffix rules should be empty! Got {suffix_rules}")
|
||||
return False
|
||||
|
||||
# Test 2: should_use_lfs with defaults
|
||||
test_small = should_use_lfs(repo, "config.json", 1024) # 1KB
|
||||
test_large = should_use_lfs(repo, "model.bin", 10 * 1024 * 1024) # 10MB
|
||||
|
||||
print(f" config.json (1KB) uses LFS: {test_small}")
|
||||
print(f" model.bin (10MB) uses LFS: {test_large}")
|
||||
|
||||
if test_small or not test_large:
|
||||
print(" ✗ LFS detection failed with defaults!")
|
||||
return False
|
||||
|
||||
print(" ✓ Default values work correctly")
|
||||
|
||||
# Test 3: Custom threshold
|
||||
print("\n Test 2: Custom threshold (1MB)")
|
||||
repo.lfs_threshold_bytes = 1024 * 1024
|
||||
repo.save()
|
||||
|
||||
threshold = get_effective_lfs_threshold(repo)
|
||||
test_500kb = should_use_lfs(repo, "file.bin", 500 * 1024)
|
||||
test_2mb = should_use_lfs(repo, "file.bin", 2 * 1024 * 1024)
|
||||
|
||||
print(f" Effective threshold: {threshold / (1024*1024):.1f} MB")
|
||||
print(f" file.bin (500KB) uses LFS: {test_500kb}")
|
||||
print(f" file.bin (2MB) uses LFS: {test_2mb}")
|
||||
|
||||
if test_500kb or not test_2mb:
|
||||
print(" ✗ Custom threshold not working!")
|
||||
return False
|
||||
|
||||
print(" ✓ Custom threshold works correctly")
|
||||
|
||||
# Test 4: Suffix rules
|
||||
print("\n Test 3: Suffix rules (.safetensors, .gguf)")
|
||||
repo.lfs_suffix_rules = json.dumps([".safetensors", ".gguf"])
|
||||
repo.save()
|
||||
|
||||
suffix_rules = get_effective_lfs_suffix_rules(repo)
|
||||
test_safetensors = should_use_lfs(repo, "model.safetensors", 100) # 100 bytes
|
||||
test_gguf = should_use_lfs(repo, "model.gguf", 500) # 500 bytes
|
||||
test_json = should_use_lfs(repo, "config.json", 100) # 100 bytes
|
||||
|
||||
print(f" Suffix rules: {suffix_rules}")
|
||||
print(f" model.safetensors (100B) uses LFS: {test_safetensors}")
|
||||
print(f" model.gguf (500B) uses LFS: {test_gguf}")
|
||||
print(f" config.json (100B) uses LFS: {test_json}")
|
||||
|
||||
if not test_safetensors or not test_gguf or test_json:
|
||||
print(" ✗ Suffix rules not working!")
|
||||
return False
|
||||
|
||||
print(" ✓ Suffix rules work correctly")
|
||||
|
||||
# Test 5: Custom keep_versions
|
||||
print("\n Test 4: Custom keep_versions (10)")
|
||||
repo.lfs_keep_versions = 10
|
||||
repo.save()
|
||||
|
||||
keep_versions = get_effective_lfs_keep_versions(repo)
|
||||
print(f" Effective keep_versions: {keep_versions}")
|
||||
|
||||
if keep_versions != 10:
|
||||
print(f" ✗ Wrong keep_versions! Expected 10, got {keep_versions}")
|
||||
return False
|
||||
|
||||
print(" ✓ Custom keep_versions works correctly")
|
||||
|
||||
print("\n ✓ All new LFS functionality works!")
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all test steps."""
|
||||
print("=" * 70)
|
||||
print("MIGRATION 009 TEST SUITE")
|
||||
print("Testing: Repository LFS Settings")
|
||||
print("=" * 70)
|
||||
|
||||
# Cleanup old test database
|
||||
cleanup_test_db()
|
||||
|
||||
# Run test steps
|
||||
steps = [
|
||||
step1_create_old_schema,
|
||||
step2_populate_mock_data,
|
||||
step3_run_migration,
|
||||
step4_verify_migration,
|
||||
step5_test_new_functionality,
|
||||
]
|
||||
|
||||
for i, step in enumerate(steps, 1):
|
||||
try:
|
||||
success = step()
|
||||
if not success:
|
||||
print(f"\n✗ Step {i} failed!")
|
||||
print(f"\nTest database preserved at: {TEST_DB_PATH}")
|
||||
print("You can inspect it with: sqlite3 test_migration_009.db")
|
||||
return 1
|
||||
except Exception as e:
|
||||
print(f"\n✗ Step {i} crashed with exception: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
print(f"\nTest database preserved at: {TEST_DB_PATH}")
|
||||
return 1
|
||||
|
||||
# Cleanup on success
|
||||
print("\n" + "=" * 70)
|
||||
print("✓ ALL TESTS PASSED!")
|
||||
print("=" * 70)
|
||||
print("\nMigration 009 is working correctly:")
|
||||
print(" ✓ Schema updated without data loss")
|
||||
print(" ✓ NULL values default to server settings")
|
||||
print(" ✓ Custom thresholds work")
|
||||
print(" ✓ Suffix rules work")
|
||||
print(" ✓ Keep versions work")
|
||||
print("\nCleaning up test database...")
|
||||
|
||||
# Close database before cleanup
|
||||
from kohakuhub.db import db
|
||||
|
||||
if not db.is_closed():
|
||||
db.close()
|
||||
|
||||
cleanup_test_db()
|
||||
print("✓ Cleanup complete\n")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user