mirror of
https://github.com/KohakuBlueleaf/KohakuHub.git
synced 2026-04-28 18:38:17 -05:00
update config system for better fallback mech
This commit is contained in:
@@ -31,25 +31,14 @@ use_tls = true
|
||||
|
||||
[auth]
|
||||
require_email_verification = false
|
||||
invitation_only = false # Set to true to require invitation for registration
|
||||
session_secret = "CHANGE-ME-GENERATE-RANDOM" # Use: python -c "import secrets; print(secrets.token_urlsafe(48))"
|
||||
session_expire_hours = 168 # 7 days
|
||||
token_expire_days = 365
|
||||
# Admin secret token (for /admin/* endpoints)
|
||||
admin_secret_token = "CHANGE-ME-GENERATE-RANDOM" # Use: python -c "import secrets; print(secrets.token_urlsafe(48))"
|
||||
|
||||
[app]
|
||||
base_url = "http://127.0.0.1:28080" # Use nginx port (28080), not direct backend (48888)
|
||||
api_base = "/api"
|
||||
db_backend = "postgres" # or "sqlite"
|
||||
database_url = "postgresql://hub:hubpass@127.0.0.1:25432/hubdb"
|
||||
auto_migrate = false # Auto-confirm database migrations (set true for Docker)
|
||||
lfs_threshold_bytes = 10485760 # 10MB
|
||||
debug_log_payloads = false # Log commit payloads (development only)
|
||||
# LFS Garbage Collection settings
|
||||
lfs_keep_versions = 5 # Keep last K versions of each LFS file
|
||||
lfs_auto_gc = true # Automatically delete old LFS objects on commit
|
||||
# Site identification
|
||||
site_name = "KohakuHub" # Customizable site name (e.g., "MyCompany Hub")
|
||||
[admin]
|
||||
enabled = true
|
||||
secret_token = "CHANGE-ME-GENERATE-RANDOM" # Use: python -c "import secrets; print(secrets.token_urlsafe(48))"
|
||||
|
||||
[quota]
|
||||
# Default storage quotas for new users/organizations (in bytes, null = unlimited)
|
||||
@@ -57,3 +46,30 @@ default_user_private_quota_bytes = 10_000_000_000 # 10GB for private repos
|
||||
default_user_public_quota_bytes = 100_000_000_000 # 100GB for public repos
|
||||
default_org_private_quota_bytes = 10_000_000_000 # 10GB for private repos
|
||||
default_org_public_quota_bytes = 100_000_000_000 # 100GB for public repos
|
||||
|
||||
[fallback]
|
||||
enabled = true
|
||||
cache_ttl_seconds = 300 # 5 minutes cache for repo→source mappings
|
||||
timeout_seconds = 10
|
||||
max_concurrent_requests = 5
|
||||
# sources = [] # JSON list of fallback sources (optional)
|
||||
|
||||
[app]
|
||||
base_url = "http://127.0.0.1:28080" # Use nginx port (28080), not direct backend (48888)
|
||||
api_base = "/api"
|
||||
db_backend = "postgres" # or "sqlite"
|
||||
database_url = "postgresql://hub:hubpass@127.0.0.1:25432/hubdb"
|
||||
auto_migrate = false # Auto-confirm database migrations (set true for Docker)
|
||||
# LFS Configuration (sizes in decimal: 1MB = 1,000,000 bytes)
|
||||
lfs_threshold_bytes = 5_000_000 # 5MB - files larger use LFS
|
||||
lfs_multipart_threshold_bytes = 100_000_000 # 100MB - files larger use multipart upload
|
||||
lfs_multipart_chunk_size_bytes = 50_000_000 # 50MB - size of each part (min 5MB except last)
|
||||
lfs_keep_versions = 5 # Keep last K versions of each LFS file
|
||||
lfs_auto_gc = true # Automatically delete old LFS objects on commit
|
||||
# Download tracking settings
|
||||
download_time_bucket_seconds = 900 # 15 minutes - session deduplication window
|
||||
download_session_cleanup_threshold = 100 # Trigger cleanup when sessions > this
|
||||
download_keep_sessions_days = 30 # Keep sessions from last N days
|
||||
debug_log_payloads = false # Log commit payloads (development only)
|
||||
# Site identification
|
||||
site_name = "KohakuHub" # Customizable site name (e.g., "MyCompany Hub")
|
||||
|
||||
@@ -756,17 +756,6 @@ def generate_config_toml(config: dict) -> str:
|
||||
# Generated by KohakuHub docker-compose generator
|
||||
# Use this for local development server
|
||||
|
||||
[app]
|
||||
base_url = "http://localhost:48888" # Dev server URL
|
||||
api_base = "/api"
|
||||
site_name = "KohakuHub"
|
||||
workers = 1 # Single worker for dev
|
||||
|
||||
[database]
|
||||
backend = "postgres"
|
||||
url = "{db_url}"
|
||||
auto_migrate = true # Auto-confirm migrations
|
||||
|
||||
[s3]
|
||||
endpoint = "{s3_endpoint_internal}"
|
||||
public_endpoint = "{s3_endpoint_public}"
|
||||
@@ -774,6 +763,7 @@ access_key = "{config['s3_access_key']}"
|
||||
secret_key = "{config['s3_secret_key']}"
|
||||
bucket = "hub-storage"
|
||||
region = "{s3_region}"
|
||||
force_path_style = true
|
||||
"""
|
||||
|
||||
# Add signature_version only if set (for external S3)
|
||||
@@ -786,36 +776,55 @@ endpoint = "http://localhost:28000"
|
||||
repo_namespace = "hf"
|
||||
# Credentials auto-generated on first start
|
||||
|
||||
[lfs]
|
||||
threshold_bytes = 1_000_000 # 1MB
|
||||
keep_versions = 5
|
||||
auto_gc = true
|
||||
|
||||
[auth]
|
||||
session_secret = "{config['session_secret']}"
|
||||
session_expire_hours = 168 # 7 days
|
||||
token_expire_days = 365
|
||||
require_email_verification = false
|
||||
invitation_only = false
|
||||
|
||||
[admin]
|
||||
enabled = true
|
||||
secret_token = "{config['admin_secret']}"
|
||||
|
||||
[smtp]
|
||||
enabled = false
|
||||
host = "smtp.gmail.com"
|
||||
port = 587
|
||||
username = ""
|
||||
password = ""
|
||||
from = "noreply@kohakuhub.local"
|
||||
tls = true
|
||||
from_email = "noreply@kohakuhub.local"
|
||||
use_tls = true
|
||||
|
||||
[auth]
|
||||
require_email_verification = false
|
||||
invitation_only = false
|
||||
session_secret = "{config['session_secret']}"
|
||||
session_expire_hours = 168 # 7 days
|
||||
token_expire_days = 365
|
||||
|
||||
[admin]
|
||||
enabled = true
|
||||
secret_token = "{config['admin_secret']}"
|
||||
|
||||
[quota]
|
||||
default_user_private_bytes = 10_000_000 # 10MB
|
||||
default_user_public_bytes = 100_000_000 # 100MB
|
||||
default_org_private_bytes = 10_000_000 # 10MB
|
||||
default_org_public_bytes = 100_000_000 # 100MB
|
||||
default_user_private_quota_bytes = 10_000_000 # 10MB
|
||||
default_user_public_quota_bytes = 100_000_000 # 100MB
|
||||
default_org_private_quota_bytes = 10_000_000 # 10MB
|
||||
default_org_public_quota_bytes = 100_000_000 # 100MB
|
||||
|
||||
[fallback]
|
||||
enabled = true
|
||||
cache_ttl_seconds = 300
|
||||
timeout_seconds = 10
|
||||
max_concurrent_requests = 5
|
||||
|
||||
[app]
|
||||
base_url = "http://localhost:48888" # Dev server URL
|
||||
api_base = "/api"
|
||||
db_backend = "postgres"
|
||||
database_url = "{db_url}"
|
||||
# LFS Configuration (sizes in decimal: 1MB = 1,000,000 bytes)
|
||||
lfs_threshold_bytes = 5_000_000 # 5MB - files larger use LFS
|
||||
lfs_multipart_threshold_bytes = 100_000_000 # 100MB - files larger use multipart upload
|
||||
lfs_multipart_chunk_size_bytes = 50_000_000 # 50MB - size of each part (min 5MB except last)
|
||||
lfs_keep_versions = 5 # Keep last K versions of each LFS file
|
||||
lfs_auto_gc = true # Automatically delete old LFS objects on commit
|
||||
# Download tracking settings
|
||||
download_time_bucket_seconds = 900 # 15 minutes - session deduplication window
|
||||
download_session_cleanup_threshold = 100 # Trigger cleanup when sessions > this
|
||||
download_keep_sessions_days = 30 # Keep sessions from last N days
|
||||
debug_log_payloads = false
|
||||
site_name = "KohakuHub"
|
||||
"""
|
||||
|
||||
return toml_content
|
||||
|
||||
@@ -203,164 +203,236 @@ class Config(BaseModel):
|
||||
return warnings
|
||||
|
||||
|
||||
def update_recursive(d: dict, u: dict) -> dict:
|
||||
"""Recursively update a dictionary."""
|
||||
for k, v in u.items():
|
||||
if isinstance(v, dict):
|
||||
# get node or create one
|
||||
d[k] = update_recursive(d.get(k, {}), v)
|
||||
else:
|
||||
d[k] = v
|
||||
return d
|
||||
|
||||
|
||||
def _parse_quota(value: str | None) -> int | None:
|
||||
"""Parse quota value from environment variable."""
|
||||
if value is None or value.lower() in ("", "none", "unlimited"):
|
||||
return None
|
||||
return int(value)
|
||||
|
||||
|
||||
def _parse_fallback_sources(value: str | None) -> list[dict]:
|
||||
"""Parse fallback sources from JSON environment variable."""
|
||||
import json
|
||||
|
||||
if not value:
|
||||
return []
|
||||
try:
|
||||
sources = json.loads(value)
|
||||
if not isinstance(sources, list):
|
||||
return []
|
||||
return sources
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def load_config(path: str = None) -> Config:
|
||||
path = path or os.environ.get("HUB_CONFIG", None)
|
||||
if path is None:
|
||||
s3_config = S3Config(
|
||||
public_endpoint=os.environ.get(
|
||||
"KOHAKU_HUB_S3_PUBLIC_ENDPOINT", _DEFAULT_S3_ENDPOINT
|
||||
),
|
||||
endpoint=os.environ.get("KOHAKU_HUB_S3_ENDPOINT", _DEFAULT_S3_ENDPOINT),
|
||||
access_key=os.environ.get("KOHAKU_HUB_S3_ACCESS_KEY", "test-access-key"),
|
||||
secret_key=os.environ.get("KOHAKU_HUB_S3_SECRET_KEY", "test-secret-key"),
|
||||
bucket=os.environ.get("KOHAKU_HUB_S3_BUCKET", "test-bucket"),
|
||||
region=os.environ.get("KOHAKU_HUB_S3_REGION", "us-east-1"),
|
||||
signature_version=os.environ.get("KOHAKU_HUB_S3_SIGNATURE_VERSION", None),
|
||||
# 1. Determine config file path: explicit path, HUB_CONFIG env, or default "config.toml"
|
||||
config_path = path or os.environ.get("HUB_CONFIG") or "config.toml"
|
||||
|
||||
# 2. Load from TOML file if it exists
|
||||
config_from_file = {}
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, "rb") as f:
|
||||
config_from_file = tomllib.load(f)
|
||||
|
||||
# 3. Load from environment variables, building a nested dict
|
||||
config_from_env = {}
|
||||
|
||||
# S3
|
||||
s3_env = {}
|
||||
if "KOHAKU_HUB_S3_PUBLIC_ENDPOINT" in os.environ:
|
||||
s3_env["public_endpoint"] = os.environ["KOHAKU_HUB_S3_PUBLIC_ENDPOINT"]
|
||||
if "KOHAKU_HUB_S3_ENDPOINT" in os.environ:
|
||||
s3_env["endpoint"] = os.environ["KOHAKU_HUB_S3_ENDPOINT"]
|
||||
if "KOHAKU_HUB_S3_ACCESS_KEY" in os.environ:
|
||||
s3_env["access_key"] = os.environ["KOHAKU_HUB_S3_ACCESS_KEY"]
|
||||
if "KOHAKU_HUB_S3_SECRET_KEY" in os.environ:
|
||||
s3_env["secret_key"] = os.environ["KOHAKU_HUB_S3_SECRET_KEY"]
|
||||
if "KOHAKU_HUB_S3_BUCKET" in os.environ:
|
||||
s3_env["bucket"] = os.environ["KOHAKU_HUB_S3_BUCKET"]
|
||||
if "KOHAKU_HUB_S3_REGION" in os.environ:
|
||||
s3_env["region"] = os.environ["KOHAKU_HUB_S3_REGION"]
|
||||
if "KOHAKU_HUB_S3_SIGNATURE_VERSION" in os.environ:
|
||||
s3_env["signature_version"] = os.environ["KOHAKU_HUB_S3_SIGNATURE_VERSION"]
|
||||
if s3_env:
|
||||
config_from_env["s3"] = s3_env
|
||||
|
||||
# LakeFS
|
||||
lakefs_env = {}
|
||||
if "KOHAKU_HUB_LAKEFS_ENDPOINT" in os.environ:
|
||||
lakefs_env["endpoint"] = os.environ["KOHAKU_HUB_LAKEFS_ENDPOINT"]
|
||||
if "KOHAKU_HUB_LAKEFS_ACCESS_KEY" in os.environ:
|
||||
lakefs_env["access_key"] = os.environ["KOHAKU_HUB_LAKEFS_ACCESS_KEY"]
|
||||
if "KOHAKU_HUB_LAKEFS_SECRET_KEY" in os.environ:
|
||||
lakefs_env["secret_key"] = os.environ["KOHAKU_HUB_LAKEFS_SECRET_KEY"]
|
||||
if "KOHAKU_HUB_LAKEFS_REPO_NAMESPACE" in os.environ:
|
||||
lakefs_env["repo_namespace"] = os.environ["KOHAKU_HUB_LAKEFS_REPO_NAMESPACE"]
|
||||
if lakefs_env:
|
||||
config_from_env["lakefs"] = lakefs_env
|
||||
|
||||
# SMTP
|
||||
smtp_env = {}
|
||||
if "KOHAKU_HUB_SMTP_ENABLED" in os.environ:
|
||||
smtp_env["enabled"] = os.environ["KOHAKU_HUB_SMTP_ENABLED"].lower() == "true"
|
||||
if "KOHAKU_HUB_SMTP_HOST" in os.environ:
|
||||
smtp_env["host"] = os.environ["KOHAKU_HUB_SMTP_HOST"]
|
||||
if "KOHAKU_HUB_SMTP_PORT" in os.environ:
|
||||
smtp_env["port"] = int(os.environ["KOHAKU_HUB_SMTP_PORT"])
|
||||
if "KOHAKU_HUB_SMTP_USERNAME" in os.environ:
|
||||
smtp_env["username"] = os.environ["KOHAKU_HUB_SMTP_USERNAME"]
|
||||
if "KOHAKU_HUB_SMTP_PASSWORD" in os.environ:
|
||||
smtp_env["password"] = os.environ["KOHAKU_HUB_SMTP_PASSWORD"]
|
||||
if "KOHAKU_HUB_SMTP_FROM" in os.environ:
|
||||
smtp_env["from_email"] = os.environ["KOHAKU_HUB_SMTP_FROM"]
|
||||
if "KOHAKU_HUB_SMTP_TLS" in os.environ:
|
||||
smtp_env["use_tls"] = os.environ["KOHAKU_HUB_SMTP_TLS"].lower() == "true"
|
||||
if smtp_env:
|
||||
config_from_env["smtp"] = smtp_env
|
||||
|
||||
# Auth
|
||||
auth_env = {}
|
||||
if "KOHAKU_HUB_REQUIRE_EMAIL_VERIFICATION" in os.environ:
|
||||
auth_env["require_email_verification"] = (
|
||||
os.environ["KOHAKU_HUB_REQUIRE_EMAIL_VERIFICATION"].lower() == "true"
|
||||
)
|
||||
|
||||
lakefs_config = LakeFSConfig(
|
||||
endpoint=os.environ.get(
|
||||
"KOHAKU_HUB_LAKEFS_ENDPOINT", "http://localhost:8000"
|
||||
),
|
||||
access_key=os.environ.get(
|
||||
"KOHAKU_HUB_LAKEFS_ACCESS_KEY", "test-access-key"
|
||||
),
|
||||
secret_key=os.environ.get(
|
||||
"KOHAKU_HUB_LAKEFS_SECRET_KEY", "test-secret-key"
|
||||
),
|
||||
repo_namespace=os.environ.get("KOHAKU_HUB_LAKEFS_REPO_NAMESPACE", "hf"),
|
||||
if "KOHAKU_HUB_INVITATION_ONLY" in os.environ:
|
||||
auth_env["invitation_only"] = (
|
||||
os.environ["KOHAKU_HUB_INVITATION_ONLY"].lower() == "true"
|
||||
)
|
||||
|
||||
smtp_config = SMTPConfig(
|
||||
enabled=os.environ.get("KOHAKU_HUB_SMTP_ENABLED", "false").lower()
|
||||
== "true",
|
||||
host=os.environ.get("KOHAKU_HUB_SMTP_HOST", "localhost"),
|
||||
port=int(os.environ.get("KOHAKU_HUB_SMTP_PORT", "587")),
|
||||
username=os.environ.get("KOHAKU_HUB_SMTP_USERNAME", ""),
|
||||
password=os.environ.get("KOHAKU_HUB_SMTP_PASSWORD", ""),
|
||||
from_email=os.environ.get("KOHAKU_HUB_SMTP_FROM", "noreply@localhost"),
|
||||
use_tls=os.environ.get("KOHAKU_HUB_SMTP_TLS", "true").lower() == "true",
|
||||
if "KOHAKU_HUB_SESSION_SECRET" in os.environ:
|
||||
auth_env["session_secret"] = os.environ["KOHAKU_HUB_SESSION_SECRET"]
|
||||
if "KOHAKU_HUB_SESSION_EXPIRE_HOURS" in os.environ:
|
||||
auth_env["session_expire_hours"] = int(
|
||||
os.environ["KOHAKU_HUB_SESSION_EXPIRE_HOURS"]
|
||||
)
|
||||
if "KOHAKU_HUB_TOKEN_EXPIRE_DAYS" in os.environ:
|
||||
auth_env["token_expire_days"] = int(os.environ["KOHAKU_HUB_TOKEN_EXPIRE_DAYS"])
|
||||
if auth_env:
|
||||
config_from_env["auth"] = auth_env
|
||||
|
||||
auth_config = AuthConfig(
|
||||
require_email_verification=os.environ.get(
|
||||
"KOHAKU_HUB_REQUIRE_EMAIL_VERIFICATION", "false"
|
||||
).lower()
|
||||
== "true",
|
||||
invitation_only=os.environ.get(
|
||||
"KOHAKU_HUB_INVITATION_ONLY", "false"
|
||||
).lower()
|
||||
== "true",
|
||||
session_secret=os.environ.get(
|
||||
"KOHAKU_HUB_SESSION_SECRET", "change-me-in-production"
|
||||
),
|
||||
session_expire_hours=int(
|
||||
os.environ.get("KOHAKU_HUB_SESSION_EXPIRE_HOURS", "168")
|
||||
),
|
||||
token_expire_days=int(
|
||||
os.environ.get("KOHAKU_HUB_TOKEN_EXPIRE_DAYS", "365")
|
||||
),
|
||||
# Admin
|
||||
admin_env = {}
|
||||
if "KOHAKU_HUB_ADMIN_ENABLED" in os.environ:
|
||||
admin_env["enabled"] = os.environ["KOHAKU_HUB_ADMIN_ENABLED"].lower() == "true"
|
||||
if "KOHAKU_HUB_ADMIN_SECRET_TOKEN" in os.environ:
|
||||
admin_env["secret_token"] = os.environ["KOHAKU_HUB_ADMIN_SECRET_TOKEN"]
|
||||
if admin_env:
|
||||
config_from_env["admin"] = admin_env
|
||||
|
||||
# Quota
|
||||
quota_env = {}
|
||||
if "KOHAKU_HUB_DEFAULT_USER_PRIVATE_QUOTA_BYTES" in os.environ:
|
||||
quota_env["default_user_private_quota_bytes"] = _parse_quota(
|
||||
os.environ.get("KOHAKU_HUB_DEFAULT_USER_PRIVATE_QUOTA_BYTES")
|
||||
)
|
||||
|
||||
admin_config = AdminConfig(
|
||||
enabled=os.environ.get("KOHAKU_HUB_ADMIN_ENABLED", "true").lower()
|
||||
== "true",
|
||||
secret_token=os.environ.get(
|
||||
"KOHAKU_HUB_ADMIN_SECRET_TOKEN", "change-me-in-production"
|
||||
),
|
||||
if "KOHAKU_HUB_DEFAULT_USER_PUBLIC_QUOTA_BYTES" in os.environ:
|
||||
quota_env["default_user_public_quota_bytes"] = _parse_quota(
|
||||
os.environ.get("KOHAKU_HUB_DEFAULT_USER_PUBLIC_QUOTA_BYTES")
|
||||
)
|
||||
|
||||
def _parse_quota(value: str | None) -> int | None:
|
||||
"""Parse quota value from environment variable."""
|
||||
if value is None or value.lower() in ("", "none", "unlimited"):
|
||||
return None
|
||||
return int(value)
|
||||
|
||||
quota_config = QuotaConfig(
|
||||
default_user_private_quota_bytes=_parse_quota(
|
||||
os.environ.get("KOHAKU_HUB_DEFAULT_USER_PRIVATE_QUOTA_BYTES")
|
||||
),
|
||||
default_user_public_quota_bytes=_parse_quota(
|
||||
os.environ.get("KOHAKU_HUB_DEFAULT_USER_PUBLIC_QUOTA_BYTES")
|
||||
),
|
||||
default_org_private_quota_bytes=_parse_quota(
|
||||
os.environ.get("KOHAKU_HUB_DEFAULT_ORG_PRIVATE_QUOTA_BYTES")
|
||||
),
|
||||
default_org_public_quota_bytes=_parse_quota(
|
||||
os.environ.get("KOHAKU_HUB_DEFAULT_ORG_PUBLIC_QUOTA_BYTES")
|
||||
),
|
||||
if "KOHAKU_HUB_DEFAULT_ORG_PRIVATE_QUOTA_BYTES" in os.environ:
|
||||
quota_env["default_org_private_quota_bytes"] = _parse_quota(
|
||||
os.environ.get("KOHAKU_HUB_DEFAULT_ORG_PRIVATE_QUOTA_BYTES")
|
||||
)
|
||||
|
||||
def _parse_fallback_sources(value: str | None) -> list[dict]:
|
||||
"""Parse fallback sources from JSON environment variable."""
|
||||
import json
|
||||
|
||||
if not value:
|
||||
return []
|
||||
try:
|
||||
sources = json.loads(value)
|
||||
if not isinstance(sources, list):
|
||||
return []
|
||||
return sources
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
|
||||
fallback_config = FallbackConfig(
|
||||
enabled=os.environ.get("KOHAKU_HUB_FALLBACK_ENABLED", "true").lower()
|
||||
== "true",
|
||||
cache_ttl_seconds=int(
|
||||
os.environ.get("KOHAKU_HUB_FALLBACK_CACHE_TTL", "300")
|
||||
),
|
||||
timeout_seconds=int(os.environ.get("KOHAKU_HUB_FALLBACK_TIMEOUT", "10")),
|
||||
max_concurrent_requests=int(
|
||||
os.environ.get("KOHAKU_HUB_FALLBACK_MAX_CONCURRENT", "5")
|
||||
),
|
||||
sources=_parse_fallback_sources(
|
||||
os.environ.get("KOHAKU_HUB_FALLBACK_SOURCES")
|
||||
),
|
||||
if "KOHAKU_HUB_DEFAULT_ORG_PUBLIC_QUOTA_BYTES" in os.environ:
|
||||
quota_env["default_org_public_quota_bytes"] = _parse_quota(
|
||||
os.environ.get("KOHAKU_HUB_DEFAULT_ORG_PUBLIC_QUOTA_BYTES")
|
||||
)
|
||||
if quota_env:
|
||||
config_from_env["quota"] = quota_env
|
||||
|
||||
app_config = AppConfig(
|
||||
base_url=os.environ.get("KOHAKU_HUB_BASE_URL", "http://localhost:48888"),
|
||||
api_base=os.environ.get("KOHAKU_HUB_API_BASE", "/api"),
|
||||
db_backend=os.environ.get("KOHAKU_HUB_DB_BACKEND", "sqlite"),
|
||||
database_url=os.environ.get(
|
||||
"KOHAKU_HUB_DATABASE_URL", "sqlite:///./hub.db"
|
||||
),
|
||||
lfs_threshold_bytes=int(
|
||||
os.environ.get("KOHAKU_HUB_LFS_THRESHOLD_BYTES", "5242880")
|
||||
),
|
||||
lfs_multipart_threshold_bytes=int(
|
||||
os.environ.get("KOHAKU_HUB_LFS_MULTIPART_THRESHOLD_BYTES", "104857600")
|
||||
),
|
||||
lfs_multipart_chunk_size_bytes=int(
|
||||
os.environ.get("KOHAKU_HUB_LFS_MULTIPART_CHUNK_SIZE_BYTES", "52428800")
|
||||
),
|
||||
lfs_keep_versions=int(os.environ.get("KOHAKU_HUB_LFS_KEEP_VERSIONS", "5")),
|
||||
lfs_auto_gc=os.environ.get("KOHAKU_HUB_LFS_AUTO_GC", "false").lower()
|
||||
== "true",
|
||||
site_name=os.environ.get("KOHAKU_HUB_SITE_NAME", "KohakuHub"),
|
||||
debug_log_payloads=os.environ.get(
|
||||
"KOHAKU_HUB_DEBUG_LOG_PAYLOADS", "false"
|
||||
).lower()
|
||||
== "true",
|
||||
# Fallback
|
||||
fallback_env = {}
|
||||
if "KOHAKU_HUB_FALLBACK_ENABLED" in os.environ:
|
||||
fallback_env["enabled"] = (
|
||||
os.environ["KOHAKU_HUB_FALLBACK_ENABLED"].lower() == "true"
|
||||
)
|
||||
if "KOHAKU_HUB_FALLBACK_CACHE_TTL" in os.environ:
|
||||
fallback_env["cache_ttl_seconds"] = int(
|
||||
os.environ["KOHAKU_HUB_FALLBACK_CACHE_TTL"]
|
||||
)
|
||||
if "KOHAKU_HUB_FALLBACK_TIMEOUT" in os.environ:
|
||||
fallback_env["timeout_seconds"] = int(os.environ["KOHAKU_HUB_FALLBACK_TIMEOUT"])
|
||||
if "KOHAKU_HUB_FALLBACK_MAX_CONCURRENT" in os.environ:
|
||||
fallback_env["max_concurrent_requests"] = int(
|
||||
os.environ["KOHAKU_HUB_FALLBACK_MAX_CONCURRENT"]
|
||||
)
|
||||
if "KOHAKU_HUB_FALLBACK_SOURCES" in os.environ:
|
||||
fallback_env["sources"] = _parse_fallback_sources(
|
||||
os.environ.get("KOHAKU_HUB_FALLBACK_SOURCES")
|
||||
)
|
||||
if fallback_env:
|
||||
config_from_env["fallback"] = fallback_env
|
||||
|
||||
return Config(
|
||||
s3=s3_config,
|
||||
lakefs=lakefs_config,
|
||||
smtp=smtp_config,
|
||||
auth=auth_config,
|
||||
admin=admin_config,
|
||||
quota=quota_config,
|
||||
fallback=fallback_config,
|
||||
app=app_config,
|
||||
# App
|
||||
app_env = {}
|
||||
if "KOHAKU_HUB_BASE_URL" in os.environ:
|
||||
app_env["base_url"] = os.environ["KOHAKU_HUB_BASE_URL"]
|
||||
if "KOHAKU_HUB_API_BASE" in os.environ:
|
||||
app_env["api_base"] = os.environ["KOHAKU_HUB_API_BASE"]
|
||||
if "KOHAKU_HUB_DB_BACKEND" in os.environ:
|
||||
app_env["db_backend"] = os.environ["KOHAKU_HUB_DB_BACKEND"]
|
||||
if "KOHAKU_HUB_DATABASE_URL" in os.environ:
|
||||
app_env["database_url"] = os.environ["KOHAKU_HUB_DATABASE_URL"]
|
||||
if "KOHAKU_HUB_LFS_THRESHOLD_BYTES" in os.environ:
|
||||
app_env["lfs_threshold_bytes"] = int(
|
||||
os.environ["KOHAKU_HUB_LFS_THRESHOLD_BYTES"]
|
||||
)
|
||||
else:
|
||||
with open(path, "rb") as f:
|
||||
raw = tomllib.load(f)
|
||||
return Config(**raw)
|
||||
if "KOHAKU_HUB_LFS_MULTIPART_THRESHOLD_BYTES" in os.environ:
|
||||
app_env["lfs_multipart_threshold_bytes"] = int(
|
||||
os.environ["KOHAKU_HUB_LFS_MULTIPART_THRESHOLD_BYTES"]
|
||||
)
|
||||
if "KOHAKU_HUB_LFS_MULTIPART_CHUNK_SIZE_BYTES" in os.environ:
|
||||
app_env["lfs_multipart_chunk_size_bytes"] = int(
|
||||
os.environ["KOHAKU_HUB_LFS_MULTIPART_CHUNK_SIZE_BYTES"]
|
||||
)
|
||||
if "KOHAKU_HUB_LFS_KEEP_VERSIONS" in os.environ:
|
||||
app_env["lfs_keep_versions"] = int(os.environ["KOHAKU_HUB_LFS_KEEP_VERSIONS"])
|
||||
if "KOHAKU_HUB_LFS_AUTO_GC" in os.environ:
|
||||
app_env["lfs_auto_gc"] = os.environ["KOHAKU_HUB_LFS_AUTO_GC"].lower() == "true"
|
||||
if "KOHAKU_HUB_SITE_NAME" in os.environ:
|
||||
app_env["site_name"] = os.environ["KOHAKU_HUB_SITE_NAME"]
|
||||
if "KOHAKU_HUB_DEBUG_LOG_PAYLOADS" in os.environ:
|
||||
app_env["debug_log_payloads"] = (
|
||||
os.environ["KOHAKU_HUB_DEBUG_LOG_PAYLOADS"].lower() == "true"
|
||||
)
|
||||
if app_env:
|
||||
config_from_env["app"] = app_env
|
||||
|
||||
# 4. Merge: Start with file config, then recursively update with env config
|
||||
merged_config = update_recursive(config_from_file, config_from_env)
|
||||
|
||||
# 5. Instantiate config models, allowing Pydantic to handle defaults
|
||||
s3_config = S3Config(**merged_config.get("s3", {}))
|
||||
lakefs_config = LakeFSConfig(**merged_config.get("lakefs", {}))
|
||||
smtp_config = SMTPConfig(**merged_config.get("smtp", {}))
|
||||
auth_config = AuthConfig(**merged_config.get("auth", {}))
|
||||
admin_config = AdminConfig(**merged_config.get("admin", {}))
|
||||
quota_config = QuotaConfig(**merged_config.get("quota", {}))
|
||||
fallback_config = FallbackConfig(**merged_config.get("fallback", {}))
|
||||
app_config = AppConfig(**merged_config.get("app", {}))
|
||||
|
||||
return Config(
|
||||
s3=s3_config,
|
||||
lakefs=lakefs_config,
|
||||
smtp=smtp_config,
|
||||
auth=auth_config,
|
||||
admin=admin_config,
|
||||
quota=quota_config,
|
||||
fallback=fallback_config,
|
||||
app=app_config,
|
||||
)
|
||||
|
||||
|
||||
cfg = load_config()
|
||||
|
||||
Reference in New Issue
Block a user