mirror of
https://github.com/KohakuBlueleaf/KohakuHub.git
synced 2026-04-30 17:37:51 -05:00
better confirmation system
This commit is contained in:
186
scripts/db_migrations/015_confirmation_tokens.py
Normal file
186
scripts/db_migrations/015_confirmation_tokens.py
Normal file
@@ -0,0 +1,186 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Migration 015: Add ConfirmationToken table for two-step dangerous operations.
|
||||
|
||||
Provides general-purpose confirmation system for operations like:
|
||||
- S3 prefix deletion
|
||||
- Bulk repository deletion
|
||||
- Any operation requiring explicit user confirmation
|
||||
|
||||
Changes:
|
||||
- Add ConfirmationToken table with expiration and auto-cleanup
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add src to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "src"))
|
||||
# Add db_migrations to path (for _migration_utils)
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
from kohakuhub.config import cfg
|
||||
from kohakuhub.db import ConfirmationToken, db
|
||||
from _migration_utils import check_table_exists, should_skip_due_to_future_migrations
|
||||
|
||||
MIGRATION_NUMBER = 15
|
||||
|
||||
|
||||
def is_applied(db, cfg):
|
||||
"""Check if THIS migration has been applied.
|
||||
|
||||
Returns True if ConfirmationToken table exists.
|
||||
"""
|
||||
return check_table_exists(db, "confirmationtoken")
|
||||
|
||||
|
||||
def migrate_postgres():
|
||||
"""Create ConfirmationToken table in PostgreSQL."""
|
||||
cursor = db.cursor()
|
||||
|
||||
print("Creating ConfirmationToken table...")
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS confirmationtoken (
|
||||
id SERIAL PRIMARY KEY,
|
||||
token VARCHAR(255) UNIQUE NOT NULL,
|
||||
action_type VARCHAR(255) NOT NULL,
|
||||
action_data TEXT NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
expires_at TIMESTAMP NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
print(" ✓ Created ConfirmationToken table")
|
||||
|
||||
# Create indexes
|
||||
print("Creating indexes...")
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS confirmationtoken_token
|
||||
ON confirmationtoken(token)
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS confirmationtoken_action_type
|
||||
ON confirmationtoken(action_type)
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS confirmationtoken_expires_at
|
||||
ON confirmationtoken(expires_at)
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS confirmationtoken_action_type_expires_at
|
||||
ON confirmationtoken(action_type, expires_at)
|
||||
"""
|
||||
)
|
||||
print(" ✓ Created indexes")
|
||||
|
||||
|
||||
def migrate_sqlite():
|
||||
"""Create ConfirmationToken table in SQLite."""
|
||||
cursor = db.cursor()
|
||||
|
||||
print("Creating ConfirmationToken table...")
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS confirmationtoken (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
token VARCHAR(255) UNIQUE NOT NULL,
|
||||
action_type VARCHAR(255) NOT NULL,
|
||||
action_data TEXT NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
expires_at TIMESTAMP NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
print(" ✓ Created ConfirmationToken table")
|
||||
|
||||
# Create indexes
|
||||
print("Creating indexes...")
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS confirmationtoken_token
|
||||
ON confirmationtoken(token)
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS confirmationtoken_action_type
|
||||
ON confirmationtoken(action_type)
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS confirmationtoken_expires_at
|
||||
ON confirmationtoken(expires_at)
|
||||
"""
|
||||
)
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS confirmationtoken_action_type_expires_at
|
||||
ON confirmationtoken(action_type, expires_at)
|
||||
"""
|
||||
)
|
||||
print(" ✓ Created indexes")
|
||||
|
||||
|
||||
def run():
|
||||
"""Run migration 015.
|
||||
|
||||
Returns:
|
||||
True if successful or already applied, False otherwise
|
||||
"""
|
||||
db.connect(reuse_if_open=True)
|
||||
|
||||
try:
|
||||
# Check if should skip due to future migrations
|
||||
if should_skip_due_to_future_migrations(MIGRATION_NUMBER, db, cfg):
|
||||
print(
|
||||
f"Migration {MIGRATION_NUMBER}: Skipped (superseded by future migration)"
|
||||
)
|
||||
return True
|
||||
|
||||
# Check if already applied
|
||||
if is_applied(db, cfg):
|
||||
print(
|
||||
f"Migration {MIGRATION_NUMBER}: Already applied (ConfirmationToken table exists)"
|
||||
)
|
||||
return True
|
||||
|
||||
print("=" * 70)
|
||||
print(f"Migration {MIGRATION_NUMBER}: Add ConfirmationToken table")
|
||||
print("=" * 70)
|
||||
|
||||
# Run migration in transaction
|
||||
with db.atomic():
|
||||
if cfg.app.db_backend == "postgres":
|
||||
migrate_postgres()
|
||||
else:
|
||||
migrate_sqlite()
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print(f"Migration {MIGRATION_NUMBER}: ✓ Completed Successfully")
|
||||
print("=" * 70)
|
||||
print("\nSummary:")
|
||||
print(" • Added ConfirmationToken table for two-step confirmations")
|
||||
print(" • Supports: S3 deletion, bulk operations, dangerous actions")
|
||||
print(" • Auto-expiration with TTL")
|
||||
print(" • Works across multiple workers")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ Migration {MIGRATION_NUMBER} failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
@@ -396,21 +396,62 @@ async def prepare_delete_prefix(
|
||||
Confirmation token, prefix, estimated count, expiration
|
||||
"""
|
||||
|
||||
# Handle R2 path-in-endpoint (same logic as list_objects)
|
||||
parsed = urlparse(cfg.s3.endpoint)
|
||||
endpoint_path = parsed.path.strip("/")
|
||||
|
||||
if endpoint_path:
|
||||
# Endpoint has path - need to add base prefix
|
||||
path_parts = endpoint_path.split("/")
|
||||
actual_bucket = path_parts[0]
|
||||
base_prefix_parts = path_parts[1:] + [cfg.s3.bucket]
|
||||
base_prefix = "/".join(base_prefix_parts)
|
||||
actual_prefix = f"{base_prefix}/{prefix}" if prefix else f"{base_prefix}/"
|
||||
else:
|
||||
# Standard S3 - use prefix as-is
|
||||
actual_bucket = cfg.s3.bucket
|
||||
actual_prefix = prefix
|
||||
|
||||
logger.info(f"Counting objects: bucket={actual_bucket}, prefix={actual_prefix}")
|
||||
|
||||
# Count objects with prefix
|
||||
def _count():
|
||||
s3 = get_s3_client()
|
||||
# Parse endpoint for R2 support
|
||||
if endpoint_path:
|
||||
root_endpoint = f"{parsed.scheme}://{parsed.netloc}"
|
||||
s3_config = {}
|
||||
if cfg.s3.force_path_style:
|
||||
s3_config["addressing_style"] = "path"
|
||||
boto_config = BotoConfig(signature_version="s3v4", s3=s3_config)
|
||||
s3 = boto3.client(
|
||||
"s3",
|
||||
endpoint_url=root_endpoint,
|
||||
aws_access_key_id=cfg.s3.access_key,
|
||||
aws_secret_access_key=cfg.s3.secret_key,
|
||||
region_name=cfg.s3.region,
|
||||
config=boto_config,
|
||||
)
|
||||
else:
|
||||
s3 = get_s3_client()
|
||||
|
||||
paginator = s3.get_paginator("list_objects_v2")
|
||||
count = 0
|
||||
for page in paginator.paginate(Bucket=cfg.s3.bucket, Prefix=prefix):
|
||||
for page in paginator.paginate(Bucket=actual_bucket, Prefix=actual_prefix):
|
||||
count += len(page.get("Contents", []))
|
||||
return count
|
||||
|
||||
estimated = await run_in_s3_executor(_count)
|
||||
|
||||
# Create confirmation token in database (works across workers)
|
||||
# Store ACTUAL prefix (with base prefix if needed) for deletion
|
||||
conf_token = create_confirmation_token(
|
||||
action_type="delete_s3_prefix",
|
||||
action_data={"prefix": prefix, "estimated_count": estimated},
|
||||
action_data={
|
||||
"display_prefix": prefix, # What frontend sees
|
||||
"actual_prefix": actual_prefix, # What S3 sees
|
||||
"actual_bucket": actual_bucket, # Actual bucket name
|
||||
"estimated_count": estimated,
|
||||
},
|
||||
ttl_seconds=60,
|
||||
)
|
||||
|
||||
@@ -449,12 +490,18 @@ async def delete_s3_prefix(
|
||||
if not action_data:
|
||||
raise HTTPException(400, detail="Invalid or expired confirmation token")
|
||||
|
||||
# Verify action type and prefix match
|
||||
if action_data.get("prefix") != prefix:
|
||||
# Verify display prefix matches (what user requested)
|
||||
if action_data.get("display_prefix") != prefix:
|
||||
raise HTTPException(400, detail="Prefix mismatch with confirmation token")
|
||||
|
||||
# Delete objects
|
||||
deleted_count = await delete_objects_with_prefix(cfg.s3.bucket, prefix)
|
||||
# Use actual S3 prefix and bucket from token (handles R2 path-in-endpoint)
|
||||
actual_prefix = action_data.get("actual_prefix")
|
||||
actual_bucket = action_data.get("actual_bucket")
|
||||
|
||||
logger.info(f"Deleting: bucket={actual_bucket}, prefix={actual_prefix}")
|
||||
|
||||
# Delete objects using ACTUAL prefix and bucket
|
||||
deleted_count = await delete_objects_with_prefix(actual_bucket, actual_prefix)
|
||||
|
||||
logger.warning(f"Admin deleted S3 prefix: {prefix} ({deleted_count} objects)")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user