This commit is contained in:
Timothy Jaeryang Baek
2026-03-17 17:58:01 -05:00
parent fcf7208352
commit de3317e26b
220 changed files with 17200 additions and 22836 deletions

View File

@@ -49,7 +49,7 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
log = logging.getLogger(__name__)
SESSION_SECRET = WEBUI_SECRET_KEY
ALGORITHM = "HS256"
ALGORITHM = 'HS256'
##############
# Auth Utils
@@ -74,62 +74,60 @@ def verify_signature(payload: str, signature: str) -> bool:
def override_static(path: str, content: str):
# Ensure path is safe
if "/" in path or ".." in path:
log.error(f"Invalid path: {path}")
if '/' in path or '..' in path:
log.error(f'Invalid path: {path}')
return
file_path = os.path.join(STATIC_DIR, path)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "wb") as f:
with open(file_path, 'wb') as f:
f.write(base64.b64decode(content)) # Convert Base64 back to raw binary
def get_license_data(app, key):
def data_handler(data):
for k, v in data.items():
if k == "resources":
if k == 'resources':
for p, c in v.items():
globals().get("override_static", lambda a, b: None)(p, c)
elif k == "count":
setattr(app.state, "USER_COUNT", v)
elif k == "name":
setattr(app.state, "WEBUI_NAME", v)
elif k == "metadata":
setattr(app.state, "LICENSE_METADATA", v)
globals().get('override_static', lambda a, b: None)(p, c)
elif k == 'count':
setattr(app.state, 'USER_COUNT', v)
elif k == 'name':
setattr(app.state, 'WEBUI_NAME', v)
elif k == 'metadata':
setattr(app.state, 'LICENSE_METADATA', v)
def handler(u):
res = requests.post(
f"{u}/api/v1/license/",
json={"key": key, "version": "1"},
f'{u}/api/v1/license/',
json={'key': key, 'version': '1'},
timeout=5,
)
if getattr(res, "ok", False):
payload = getattr(res, "json", lambda: {})()
if getattr(res, 'ok', False):
payload = getattr(res, 'json', lambda: {})()
data_handler(payload)
return True
else:
log.error(
f"License: retrieval issue: {getattr(res, 'text', 'unknown error')}"
)
log.error(f'License: retrieval issue: {getattr(res, "text", "unknown error")}')
if key:
us = [
"https://api.openwebui.com",
"https://licenses.api.openwebui.com",
'https://api.openwebui.com',
'https://licenses.api.openwebui.com',
]
try:
for u in us:
if handler(u):
return True
except Exception as ex:
log.exception(f"License: Uncaught Exception: {ex}")
log.exception(f'License: Uncaught Exception: {ex}')
try:
if LICENSE_BLOB:
nl = 12
kb = hashlib.sha256((key.replace("-", "").upper()).encode()).digest()
kb = hashlib.sha256((key.replace('-', '').upper()).encode()).digest()
def nt(b):
return b[:nl], b[nl:]
@@ -139,19 +137,19 @@ def get_license_data(app, key):
aesgcm = AESGCM(kb)
p = json.loads(aesgcm.decrypt(ln, lt, None))
pk.verify(base64.b64decode(p["s"]), p["p"].encode())
pk.verify(base64.b64decode(p['s']), p['p'].encode())
pb = base64.b64decode(p["p"])
pb = base64.b64decode(p['p'])
pn, pt = nt(pb)
data = json.loads(aesgcm.decrypt(pn, pt, None).decode())
if not data.get("exp") and data.get("exp") < datetime.now().date():
if not data.get('exp') and data.get('exp') < datetime.now().date():
return False
data_handler(data)
return True
except Exception as e:
log.error(f"License: {e}")
log.error(f'License: {e}')
return False
@@ -161,12 +159,12 @@ bearer_security = HTTPBearer(auto_error=False)
def get_password_hash(password: str) -> str:
"""Hash a password using bcrypt"""
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
return bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8')
def validate_password(password: str) -> bool:
# The password passed to bcrypt must be 72 bytes or fewer. If it is longer, it will be truncated before hashing.
if len(password.encode("utf-8")) > 72:
if len(password.encode('utf-8')) > 72:
raise Exception(
ERROR_MESSAGES.PASSWORD_TOO_LONG,
)
@@ -182,8 +180,8 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash"""
return (
bcrypt.checkpw(
plain_password.encode("utf-8"),
hashed_password.encode("utf-8"),
plain_password.encode('utf-8'),
hashed_password.encode('utf-8'),
)
if hashed_password
else None
@@ -195,10 +193,10 @@ def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> st
if expires_delta:
expire = datetime.now(UTC) + expires_delta
payload.update({"exp": expire})
payload.update({'exp': expire})
jti = str(uuid.uuid4())
payload.update({"jti": jti})
payload.update({'jti': jti})
encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM)
return encoded_jwt
@@ -215,12 +213,10 @@ def decode_token(token: str) -> Optional[dict]:
async def is_valid_token(request, decoded) -> bool:
# Require Redis to check revoked tokens
if request.app.state.redis:
jti = decoded.get("jti")
jti = decoded.get('jti')
if jti:
revoked = await request.app.state.redis.get(
f"{REDIS_KEY_PREFIX}:auth:token:{jti}:revoked"
)
revoked = await request.app.state.redis.get(f'{REDIS_KEY_PREFIX}:auth:token:{jti}:revoked')
if revoked:
return False
@@ -236,37 +232,35 @@ async def invalidate_token(request, token):
# Require Redis to store revoked tokens
if request.app.state.redis:
jti = decoded.get("jti")
exp = decoded.get("exp")
jti = decoded.get('jti')
exp = decoded.get('exp')
if jti and exp:
ttl = exp - int(
datetime.now(UTC).timestamp()
) # Calculate time-to-live for the token
ttl = exp - int(datetime.now(UTC).timestamp()) # Calculate time-to-live for the token
if ttl > 0:
# Store the revoked token in Redis with an expiration time
await request.app.state.redis.set(
f"{REDIS_KEY_PREFIX}:auth:token:{jti}:revoked",
"1",
f'{REDIS_KEY_PREFIX}:auth:token:{jti}:revoked',
'1',
ex=ttl,
)
def extract_token_from_auth_header(auth_header: str):
return auth_header[len("Bearer ") :]
return auth_header[len('Bearer ') :]
def create_api_key():
key = str(uuid.uuid4()).replace("-", "")
return f"sk-{key}"
key = str(uuid.uuid4()).replace('-', '')
return f'sk-{key}'
def get_http_authorization_cred(auth_header: Optional[str]):
if not auth_header:
return None
try:
scheme, credentials = auth_header.split(" ")
scheme, credentials = auth_header.split(' ')
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
except Exception:
return None
@@ -287,27 +281,27 @@ async def get_current_user(
if auth_token is not None:
token = auth_token.credentials
if token is None and "token" in request.cookies:
token = request.cookies.get("token")
if token is None and 'token' in request.cookies:
token = request.cookies.get('token')
# Fallback to request.state.token (set by middleware, e.g. for x-api-key)
if token is None and hasattr(request.state, "token") and request.state.token:
if token is None and hasattr(request.state, 'token') and request.state.token:
token = request.state.token.credentials
if token is None:
raise HTTPException(status_code=401, detail="Not authenticated")
raise HTTPException(status_code=401, detail='Not authenticated')
# auth by api key
if token.startswith("sk-"):
if token.startswith('sk-'):
user = get_current_user_by_api_key(request, token)
# Add user info to current span
current_span = trace.get_current_span()
if current_span:
current_span.set_attribute("client.user.id", user.id)
current_span.set_attribute("client.user.email", user.email)
current_span.set_attribute("client.user.role", user.role)
current_span.set_attribute("client.auth.type", "api_key")
current_span.set_attribute('client.user.id', user.id)
current_span.set_attribute('client.user.email', user.email)
current_span.set_attribute('client.user.role', user.role)
current_span.set_attribute('client.auth.type', 'api_key')
return user
@@ -318,17 +312,17 @@ async def get_current_user(
except Exception as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token",
detail='Invalid token',
)
if data is not None and "id" in data:
if data.get("jti") and not await is_valid_token(request, data):
if data is not None and 'id' in data:
if data.get('jti') and not await is_valid_token(request, data):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token",
detail='Invalid token',
)
user = Users.get_user_by_id(data["id"])
user = Users.get_user_by_id(data['id'])
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -336,22 +330,20 @@ async def get_current_user(
)
else:
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
trusted_email = request.headers.get(
WEBUI_AUTH_TRUSTED_EMAIL_HEADER, ""
).lower()
trusted_email = request.headers.get(WEBUI_AUTH_TRUSTED_EMAIL_HEADER, '').lower()
if trusted_email and user.email != trusted_email:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User mismatch. Please sign in again.",
detail='User mismatch. Please sign in again.',
)
# Add user info to current span
current_span = trace.get_current_span()
if current_span:
current_span.set_attribute("client.user.id", user.id)
current_span.set_attribute("client.user.email", user.email)
current_span.set_attribute("client.user.role", user.role)
current_span.set_attribute("client.auth.type", "jwt")
current_span.set_attribute('client.user.id', user.id)
current_span.set_attribute('client.user.email', user.email)
current_span.set_attribute('client.user.role', user.role)
current_span.set_attribute('client.auth.type', 'jwt')
# Refresh the user's last active timestamp asynchronously
# to prevent blocking the request
@@ -365,15 +357,15 @@ async def get_current_user(
)
except Exception as e:
# Delete the token cookie
if request.cookies.get("token"):
response.delete_cookie("token")
if request.cookies.get('token'):
response.delete_cookie('token')
if request.cookies.get("oauth_id_token"):
response.delete_cookie("oauth_id_token")
if request.cookies.get('oauth_id_token'):
response.delete_cookie('oauth_id_token')
# Delete OAuth session if present
if request.cookies.get("oauth_session_id"):
response.delete_cookie("oauth_session_id")
if request.cookies.get('oauth_session_id'):
response.delete_cookie('oauth_session_id')
raise e
@@ -389,31 +381,29 @@ def get_current_user_by_api_key(request, api_key: str):
)
if not request.state.enable_api_keys or (
user.role != "admin"
user.role != 'admin'
and not has_permission(
user.id,
"features.api_keys",
'features.api_keys',
request.app.state.config.USER_PERMISSIONS,
)
):
raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
)
raise HTTPException(status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED)
# Add user info to current span
current_span = trace.get_current_span()
if current_span:
current_span.set_attribute("client.user.id", user.id)
current_span.set_attribute("client.user.email", user.email)
current_span.set_attribute("client.user.role", user.role)
current_span.set_attribute("client.auth.type", "api_key")
current_span.set_attribute('client.user.id', user.id)
current_span.set_attribute('client.user.email', user.email)
current_span.set_attribute('client.user.role', user.role)
current_span.set_attribute('client.auth.type', 'api_key')
Users.update_last_active_by_id(user.id)
return user
def get_verified_user(user=Depends(get_current_user)):
if user.role not in {"user", "admin"}:
if user.role not in {'user', 'admin'}:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
@@ -422,7 +412,7 @@ def get_verified_user(user=Depends(get_current_user)):
def get_admin_user(user=Depends(get_current_user)):
if user.role != "admin":
if user.role != 'admin':
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
@@ -430,7 +420,7 @@ def get_admin_user(user=Depends(get_current_user)):
return user
def create_admin_user(email: str, password: str, name: str = "Admin"):
def create_admin_user(email: str, password: str, name: str = 'Admin'):
"""
Create an admin user from environment variables.
Used for headless/automated deployments.
@@ -441,24 +431,24 @@ def create_admin_user(email: str, password: str, name: str = "Admin"):
return None
if Users.has_users():
log.debug("Users already exist, skipping admin creation")
log.debug('Users already exist, skipping admin creation')
return None
log.info(f"Creating admin account from environment variables: {email}")
log.info(f'Creating admin account from environment variables: {email}')
try:
hashed = get_password_hash(password)
user = Auths.insert_new_auth(
email=email.lower(),
password=hashed,
name=name,
role="admin",
role='admin',
)
if user:
log.info(f"Admin account created successfully: {email}")
log.info(f'Admin account created successfully: {email}')
return user
else:
log.error("Failed to create admin account from environment variables")
log.error('Failed to create admin account from environment variables')
return None
except Exception as e:
log.error(f"Error creating admin account: {e}")
log.error(f'Error creating admin account: {e}')
return None