mirror of
https://github.com/open-webui/open-webui.git
synced 2026-05-01 17:59:28 -05:00
refac
This commit is contained in:
@@ -49,7 +49,7 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
SESSION_SECRET = WEBUI_SECRET_KEY
|
||||
ALGORITHM = "HS256"
|
||||
ALGORITHM = 'HS256'
|
||||
|
||||
##############
|
||||
# Auth Utils
|
||||
@@ -74,62 +74,60 @@ def verify_signature(payload: str, signature: str) -> bool:
|
||||
|
||||
def override_static(path: str, content: str):
|
||||
# Ensure path is safe
|
||||
if "/" in path or ".." in path:
|
||||
log.error(f"Invalid path: {path}")
|
||||
if '/' in path or '..' in path:
|
||||
log.error(f'Invalid path: {path}')
|
||||
return
|
||||
|
||||
file_path = os.path.join(STATIC_DIR, path)
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(base64.b64decode(content)) # Convert Base64 back to raw binary
|
||||
|
||||
|
||||
def get_license_data(app, key):
|
||||
def data_handler(data):
|
||||
for k, v in data.items():
|
||||
if k == "resources":
|
||||
if k == 'resources':
|
||||
for p, c in v.items():
|
||||
globals().get("override_static", lambda a, b: None)(p, c)
|
||||
elif k == "count":
|
||||
setattr(app.state, "USER_COUNT", v)
|
||||
elif k == "name":
|
||||
setattr(app.state, "WEBUI_NAME", v)
|
||||
elif k == "metadata":
|
||||
setattr(app.state, "LICENSE_METADATA", v)
|
||||
globals().get('override_static', lambda a, b: None)(p, c)
|
||||
elif k == 'count':
|
||||
setattr(app.state, 'USER_COUNT', v)
|
||||
elif k == 'name':
|
||||
setattr(app.state, 'WEBUI_NAME', v)
|
||||
elif k == 'metadata':
|
||||
setattr(app.state, 'LICENSE_METADATA', v)
|
||||
|
||||
def handler(u):
|
||||
res = requests.post(
|
||||
f"{u}/api/v1/license/",
|
||||
json={"key": key, "version": "1"},
|
||||
f'{u}/api/v1/license/',
|
||||
json={'key': key, 'version': '1'},
|
||||
timeout=5,
|
||||
)
|
||||
|
||||
if getattr(res, "ok", False):
|
||||
payload = getattr(res, "json", lambda: {})()
|
||||
if getattr(res, 'ok', False):
|
||||
payload = getattr(res, 'json', lambda: {})()
|
||||
data_handler(payload)
|
||||
return True
|
||||
else:
|
||||
log.error(
|
||||
f"License: retrieval issue: {getattr(res, 'text', 'unknown error')}"
|
||||
)
|
||||
log.error(f'License: retrieval issue: {getattr(res, "text", "unknown error")}')
|
||||
|
||||
if key:
|
||||
us = [
|
||||
"https://api.openwebui.com",
|
||||
"https://licenses.api.openwebui.com",
|
||||
'https://api.openwebui.com',
|
||||
'https://licenses.api.openwebui.com',
|
||||
]
|
||||
try:
|
||||
for u in us:
|
||||
if handler(u):
|
||||
return True
|
||||
except Exception as ex:
|
||||
log.exception(f"License: Uncaught Exception: {ex}")
|
||||
log.exception(f'License: Uncaught Exception: {ex}')
|
||||
|
||||
try:
|
||||
if LICENSE_BLOB:
|
||||
nl = 12
|
||||
kb = hashlib.sha256((key.replace("-", "").upper()).encode()).digest()
|
||||
kb = hashlib.sha256((key.replace('-', '').upper()).encode()).digest()
|
||||
|
||||
def nt(b):
|
||||
return b[:nl], b[nl:]
|
||||
@@ -139,19 +137,19 @@ def get_license_data(app, key):
|
||||
|
||||
aesgcm = AESGCM(kb)
|
||||
p = json.loads(aesgcm.decrypt(ln, lt, None))
|
||||
pk.verify(base64.b64decode(p["s"]), p["p"].encode())
|
||||
pk.verify(base64.b64decode(p['s']), p['p'].encode())
|
||||
|
||||
pb = base64.b64decode(p["p"])
|
||||
pb = base64.b64decode(p['p'])
|
||||
pn, pt = nt(pb)
|
||||
|
||||
data = json.loads(aesgcm.decrypt(pn, pt, None).decode())
|
||||
if not data.get("exp") and data.get("exp") < datetime.now().date():
|
||||
if not data.get('exp') and data.get('exp') < datetime.now().date():
|
||||
return False
|
||||
|
||||
data_handler(data)
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"License: {e}")
|
||||
log.error(f'License: {e}')
|
||||
|
||||
return False
|
||||
|
||||
@@ -161,12 +159,12 @@ bearer_security = HTTPBearer(auto_error=False)
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""Hash a password using bcrypt"""
|
||||
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
||||
return bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8')
|
||||
|
||||
|
||||
def validate_password(password: str) -> bool:
|
||||
# The password passed to bcrypt must be 72 bytes or fewer. If it is longer, it will be truncated before hashing.
|
||||
if len(password.encode("utf-8")) > 72:
|
||||
if len(password.encode('utf-8')) > 72:
|
||||
raise Exception(
|
||||
ERROR_MESSAGES.PASSWORD_TOO_LONG,
|
||||
)
|
||||
@@ -182,8 +180,8 @@ def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against its hash"""
|
||||
return (
|
||||
bcrypt.checkpw(
|
||||
plain_password.encode("utf-8"),
|
||||
hashed_password.encode("utf-8"),
|
||||
plain_password.encode('utf-8'),
|
||||
hashed_password.encode('utf-8'),
|
||||
)
|
||||
if hashed_password
|
||||
else None
|
||||
@@ -195,10 +193,10 @@ def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> st
|
||||
|
||||
if expires_delta:
|
||||
expire = datetime.now(UTC) + expires_delta
|
||||
payload.update({"exp": expire})
|
||||
payload.update({'exp': expire})
|
||||
|
||||
jti = str(uuid.uuid4())
|
||||
payload.update({"jti": jti})
|
||||
payload.update({'jti': jti})
|
||||
|
||||
encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
@@ -215,12 +213,10 @@ def decode_token(token: str) -> Optional[dict]:
|
||||
async def is_valid_token(request, decoded) -> bool:
|
||||
# Require Redis to check revoked tokens
|
||||
if request.app.state.redis:
|
||||
jti = decoded.get("jti")
|
||||
jti = decoded.get('jti')
|
||||
|
||||
if jti:
|
||||
revoked = await request.app.state.redis.get(
|
||||
f"{REDIS_KEY_PREFIX}:auth:token:{jti}:revoked"
|
||||
)
|
||||
revoked = await request.app.state.redis.get(f'{REDIS_KEY_PREFIX}:auth:token:{jti}:revoked')
|
||||
if revoked:
|
||||
return False
|
||||
|
||||
@@ -236,37 +232,35 @@ async def invalidate_token(request, token):
|
||||
|
||||
# Require Redis to store revoked tokens
|
||||
if request.app.state.redis:
|
||||
jti = decoded.get("jti")
|
||||
exp = decoded.get("exp")
|
||||
jti = decoded.get('jti')
|
||||
exp = decoded.get('exp')
|
||||
|
||||
if jti and exp:
|
||||
ttl = exp - int(
|
||||
datetime.now(UTC).timestamp()
|
||||
) # Calculate time-to-live for the token
|
||||
ttl = exp - int(datetime.now(UTC).timestamp()) # Calculate time-to-live for the token
|
||||
|
||||
if ttl > 0:
|
||||
# Store the revoked token in Redis with an expiration time
|
||||
await request.app.state.redis.set(
|
||||
f"{REDIS_KEY_PREFIX}:auth:token:{jti}:revoked",
|
||||
"1",
|
||||
f'{REDIS_KEY_PREFIX}:auth:token:{jti}:revoked',
|
||||
'1',
|
||||
ex=ttl,
|
||||
)
|
||||
|
||||
|
||||
def extract_token_from_auth_header(auth_header: str):
|
||||
return auth_header[len("Bearer ") :]
|
||||
return auth_header[len('Bearer ') :]
|
||||
|
||||
|
||||
def create_api_key():
|
||||
key = str(uuid.uuid4()).replace("-", "")
|
||||
return f"sk-{key}"
|
||||
key = str(uuid.uuid4()).replace('-', '')
|
||||
return f'sk-{key}'
|
||||
|
||||
|
||||
def get_http_authorization_cred(auth_header: Optional[str]):
|
||||
if not auth_header:
|
||||
return None
|
||||
try:
|
||||
scheme, credentials = auth_header.split(" ")
|
||||
scheme, credentials = auth_header.split(' ')
|
||||
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
|
||||
except Exception:
|
||||
return None
|
||||
@@ -287,27 +281,27 @@ async def get_current_user(
|
||||
if auth_token is not None:
|
||||
token = auth_token.credentials
|
||||
|
||||
if token is None and "token" in request.cookies:
|
||||
token = request.cookies.get("token")
|
||||
if token is None and 'token' in request.cookies:
|
||||
token = request.cookies.get('token')
|
||||
|
||||
# Fallback to request.state.token (set by middleware, e.g. for x-api-key)
|
||||
if token is None and hasattr(request.state, "token") and request.state.token:
|
||||
if token is None and hasattr(request.state, 'token') and request.state.token:
|
||||
token = request.state.token.credentials
|
||||
|
||||
if token is None:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
raise HTTPException(status_code=401, detail='Not authenticated')
|
||||
|
||||
# auth by api key
|
||||
if token.startswith("sk-"):
|
||||
if token.startswith('sk-'):
|
||||
user = get_current_user_by_api_key(request, token)
|
||||
|
||||
# Add user info to current span
|
||||
current_span = trace.get_current_span()
|
||||
if current_span:
|
||||
current_span.set_attribute("client.user.id", user.id)
|
||||
current_span.set_attribute("client.user.email", user.email)
|
||||
current_span.set_attribute("client.user.role", user.role)
|
||||
current_span.set_attribute("client.auth.type", "api_key")
|
||||
current_span.set_attribute('client.user.id', user.id)
|
||||
current_span.set_attribute('client.user.email', user.email)
|
||||
current_span.set_attribute('client.user.role', user.role)
|
||||
current_span.set_attribute('client.auth.type', 'api_key')
|
||||
|
||||
return user
|
||||
|
||||
@@ -318,17 +312,17 @@ async def get_current_user(
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token",
|
||||
detail='Invalid token',
|
||||
)
|
||||
|
||||
if data is not None and "id" in data:
|
||||
if data.get("jti") and not await is_valid_token(request, data):
|
||||
if data is not None and 'id' in data:
|
||||
if data.get('jti') and not await is_valid_token(request, data):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token",
|
||||
detail='Invalid token',
|
||||
)
|
||||
|
||||
user = Users.get_user_by_id(data["id"])
|
||||
user = Users.get_user_by_id(data['id'])
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
@@ -336,22 +330,20 @@ async def get_current_user(
|
||||
)
|
||||
else:
|
||||
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
|
||||
trusted_email = request.headers.get(
|
||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER, ""
|
||||
).lower()
|
||||
trusted_email = request.headers.get(WEBUI_AUTH_TRUSTED_EMAIL_HEADER, '').lower()
|
||||
if trusted_email and user.email != trusted_email:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User mismatch. Please sign in again.",
|
||||
detail='User mismatch. Please sign in again.',
|
||||
)
|
||||
|
||||
# Add user info to current span
|
||||
current_span = trace.get_current_span()
|
||||
if current_span:
|
||||
current_span.set_attribute("client.user.id", user.id)
|
||||
current_span.set_attribute("client.user.email", user.email)
|
||||
current_span.set_attribute("client.user.role", user.role)
|
||||
current_span.set_attribute("client.auth.type", "jwt")
|
||||
current_span.set_attribute('client.user.id', user.id)
|
||||
current_span.set_attribute('client.user.email', user.email)
|
||||
current_span.set_attribute('client.user.role', user.role)
|
||||
current_span.set_attribute('client.auth.type', 'jwt')
|
||||
|
||||
# Refresh the user's last active timestamp asynchronously
|
||||
# to prevent blocking the request
|
||||
@@ -365,15 +357,15 @@ async def get_current_user(
|
||||
)
|
||||
except Exception as e:
|
||||
# Delete the token cookie
|
||||
if request.cookies.get("token"):
|
||||
response.delete_cookie("token")
|
||||
if request.cookies.get('token'):
|
||||
response.delete_cookie('token')
|
||||
|
||||
if request.cookies.get("oauth_id_token"):
|
||||
response.delete_cookie("oauth_id_token")
|
||||
if request.cookies.get('oauth_id_token'):
|
||||
response.delete_cookie('oauth_id_token')
|
||||
|
||||
# Delete OAuth session if present
|
||||
if request.cookies.get("oauth_session_id"):
|
||||
response.delete_cookie("oauth_session_id")
|
||||
if request.cookies.get('oauth_session_id'):
|
||||
response.delete_cookie('oauth_session_id')
|
||||
|
||||
raise e
|
||||
|
||||
@@ -389,31 +381,29 @@ def get_current_user_by_api_key(request, api_key: str):
|
||||
)
|
||||
|
||||
if not request.state.enable_api_keys or (
|
||||
user.role != "admin"
|
||||
user.role != 'admin'
|
||||
and not has_permission(
|
||||
user.id,
|
||||
"features.api_keys",
|
||||
'features.api_keys',
|
||||
request.app.state.config.USER_PERMISSIONS,
|
||||
)
|
||||
):
|
||||
raise HTTPException(
|
||||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED
|
||||
)
|
||||
raise HTTPException(status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.API_KEY_NOT_ALLOWED)
|
||||
|
||||
# Add user info to current span
|
||||
current_span = trace.get_current_span()
|
||||
if current_span:
|
||||
current_span.set_attribute("client.user.id", user.id)
|
||||
current_span.set_attribute("client.user.email", user.email)
|
||||
current_span.set_attribute("client.user.role", user.role)
|
||||
current_span.set_attribute("client.auth.type", "api_key")
|
||||
current_span.set_attribute('client.user.id', user.id)
|
||||
current_span.set_attribute('client.user.email', user.email)
|
||||
current_span.set_attribute('client.user.role', user.role)
|
||||
current_span.set_attribute('client.auth.type', 'api_key')
|
||||
|
||||
Users.update_last_active_by_id(user.id)
|
||||
return user
|
||||
|
||||
|
||||
def get_verified_user(user=Depends(get_current_user)):
|
||||
if user.role not in {"user", "admin"}:
|
||||
if user.role not in {'user', 'admin'}:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
@@ -422,7 +412,7 @@ def get_verified_user(user=Depends(get_current_user)):
|
||||
|
||||
|
||||
def get_admin_user(user=Depends(get_current_user)):
|
||||
if user.role != "admin":
|
||||
if user.role != 'admin':
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
@@ -430,7 +420,7 @@ def get_admin_user(user=Depends(get_current_user)):
|
||||
return user
|
||||
|
||||
|
||||
def create_admin_user(email: str, password: str, name: str = "Admin"):
|
||||
def create_admin_user(email: str, password: str, name: str = 'Admin'):
|
||||
"""
|
||||
Create an admin user from environment variables.
|
||||
Used for headless/automated deployments.
|
||||
@@ -441,24 +431,24 @@ def create_admin_user(email: str, password: str, name: str = "Admin"):
|
||||
return None
|
||||
|
||||
if Users.has_users():
|
||||
log.debug("Users already exist, skipping admin creation")
|
||||
log.debug('Users already exist, skipping admin creation')
|
||||
return None
|
||||
|
||||
log.info(f"Creating admin account from environment variables: {email}")
|
||||
log.info(f'Creating admin account from environment variables: {email}')
|
||||
try:
|
||||
hashed = get_password_hash(password)
|
||||
user = Auths.insert_new_auth(
|
||||
email=email.lower(),
|
||||
password=hashed,
|
||||
name=name,
|
||||
role="admin",
|
||||
role='admin',
|
||||
)
|
||||
if user:
|
||||
log.info(f"Admin account created successfully: {email}")
|
||||
log.info(f'Admin account created successfully: {email}')
|
||||
return user
|
||||
else:
|
||||
log.error("Failed to create admin account from environment variables")
|
||||
log.error('Failed to create admin account from environment variables')
|
||||
return None
|
||||
except Exception as e:
|
||||
log.error(f"Error creating admin account: {e}")
|
||||
log.error(f'Error creating admin account: {e}')
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user