mirror of
https://github.com/open-webui/open-webui.git
synced 2026-05-03 18:59:38 -05:00
refac
This commit is contained in:
@@ -40,12 +40,12 @@ import datetime
|
||||
|
||||
class UserSettings(BaseModel):
|
||||
ui: Optional[dict] = {}
|
||||
model_config = ConfigDict(extra="allow")
|
||||
model_config = ConfigDict(extra='allow')
|
||||
pass
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "user"
|
||||
__tablename__ = 'user'
|
||||
|
||||
id = Column(String, primary_key=True, unique=True)
|
||||
email = Column(String)
|
||||
@@ -83,7 +83,7 @@ class UserModel(BaseModel):
|
||||
|
||||
email: str
|
||||
username: Optional[str] = None
|
||||
role: str = "pending"
|
||||
role: str = 'pending'
|
||||
|
||||
name: str
|
||||
|
||||
@@ -112,10 +112,10 @@ class UserModel(BaseModel):
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
@model_validator(mode='after')
|
||||
def set_profile_image_url(self):
|
||||
if not self.profile_image_url:
|
||||
self.profile_image_url = f"/api/v1/users/{self.id}/profile/image"
|
||||
self.profile_image_url = f'/api/v1/users/{self.id}/profile/image'
|
||||
return self
|
||||
|
||||
|
||||
@@ -126,7 +126,7 @@ class UserStatusModel(UserModel):
|
||||
|
||||
|
||||
class ApiKey(Base):
|
||||
__tablename__ = "api_key"
|
||||
__tablename__ = 'api_key'
|
||||
|
||||
id = Column(Text, primary_key=True, unique=True)
|
||||
user_id = Column(Text, nullable=False)
|
||||
@@ -163,7 +163,7 @@ class UpdateProfileForm(BaseModel):
|
||||
gender: Optional[str] = None
|
||||
date_of_birth: Optional[datetime.date] = None
|
||||
|
||||
@field_validator("profile_image_url")
|
||||
@field_validator('profile_image_url')
|
||||
@classmethod
|
||||
def check_profile_image_url(cls, v: str) -> str:
|
||||
return validate_profile_image_url(v)
|
||||
@@ -174,7 +174,7 @@ class UserGroupIdsModel(UserModel):
|
||||
|
||||
|
||||
class UserModelResponse(UserModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
model_config = ConfigDict(extra='allow')
|
||||
|
||||
|
||||
class UserListResponse(BaseModel):
|
||||
@@ -251,7 +251,7 @@ class UserUpdateForm(BaseModel):
|
||||
profile_image_url: str
|
||||
password: Optional[str] = None
|
||||
|
||||
@field_validator("profile_image_url")
|
||||
@field_validator('profile_image_url')
|
||||
@classmethod
|
||||
def check_profile_image_url(cls, v: str) -> str:
|
||||
return validate_profile_image_url(v)
|
||||
@@ -263,8 +263,8 @@ class UsersTable:
|
||||
id: str,
|
||||
name: str,
|
||||
email: str,
|
||||
profile_image_url: str = "/user.png",
|
||||
role: str = "pending",
|
||||
profile_image_url: str = '/user.png',
|
||||
role: str = 'pending',
|
||||
username: Optional[str] = None,
|
||||
oauth: Optional[dict] = None,
|
||||
db: Optional[Session] = None,
|
||||
@@ -272,16 +272,16 @@ class UsersTable:
|
||||
with get_db_context(db) as db:
|
||||
user = UserModel(
|
||||
**{
|
||||
"id": id,
|
||||
"email": email,
|
||||
"name": name,
|
||||
"role": role,
|
||||
"profile_image_url": profile_image_url,
|
||||
"last_active_at": int(time.time()),
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
"username": username,
|
||||
"oauth": oauth,
|
||||
'id': id,
|
||||
'email': email,
|
||||
'name': name,
|
||||
'role': role,
|
||||
'profile_image_url': profile_image_url,
|
||||
'last_active_at': int(time.time()),
|
||||
'created_at': int(time.time()),
|
||||
'updated_at': int(time.time()),
|
||||
'username': username,
|
||||
'oauth': oauth,
|
||||
}
|
||||
)
|
||||
result = User(**user.model_dump())
|
||||
@@ -293,9 +293,7 @@ class UsersTable:
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_user_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> Optional[UserModel]:
|
||||
def get_user_by_id(self, id: str, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
@@ -303,49 +301,32 @@ class UsersTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_user_by_api_key(
|
||||
self, api_key: str, db: Optional[Session] = None
|
||||
) -> Optional[UserModel]:
|
||||
def get_user_by_api_key(self, api_key: str, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
user = (
|
||||
db.query(User)
|
||||
.join(ApiKey, User.id == ApiKey.user_id)
|
||||
.filter(ApiKey.key == api_key)
|
||||
.first()
|
||||
)
|
||||
user = db.query(User).join(ApiKey, User.id == ApiKey.user_id).filter(ApiKey.key == api_key).first()
|
||||
return UserModel.model_validate(user) if user else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_user_by_email(
|
||||
self, email: str, db: Optional[Session] = None
|
||||
) -> Optional[UserModel]:
|
||||
def get_user_by_email(self, email: str, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
user = (
|
||||
db.query(User)
|
||||
.filter(func.lower(User.email) == email.lower())
|
||||
.first()
|
||||
)
|
||||
user = db.query(User).filter(func.lower(User.email) == email.lower()).first()
|
||||
return UserModel.model_validate(user) if user else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_user_by_oauth_sub(
|
||||
self, provider: str, sub: str, db: Optional[Session] = None
|
||||
) -> Optional[UserModel]:
|
||||
def get_user_by_oauth_sub(self, provider: str, sub: str, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db_context(db) as db: # type: Session
|
||||
dialect_name = db.bind.dialect.name
|
||||
|
||||
query = db.query(User)
|
||||
if dialect_name == "sqlite":
|
||||
query = query.filter(User.oauth.contains({provider: {"sub": sub}}))
|
||||
elif dialect_name == "postgresql":
|
||||
query = query.filter(
|
||||
User.oauth[provider].cast(JSONB)["sub"].astext == sub
|
||||
)
|
||||
if dialect_name == 'sqlite':
|
||||
query = query.filter(User.oauth.contains({provider: {'sub': sub}}))
|
||||
elif dialect_name == 'postgresql':
|
||||
query = query.filter(User.oauth[provider].cast(JSONB)['sub'].astext == sub)
|
||||
|
||||
user = query.first()
|
||||
return UserModel.model_validate(user) if user else None
|
||||
@@ -361,15 +342,10 @@ class UsersTable:
|
||||
dialect_name = db.bind.dialect.name
|
||||
|
||||
query = db.query(User)
|
||||
if dialect_name == "sqlite":
|
||||
query = query.filter(
|
||||
User.scim.contains({provider: {"external_id": external_id}})
|
||||
)
|
||||
elif dialect_name == "postgresql":
|
||||
query = query.filter(
|
||||
User.scim[provider].cast(JSONB)["external_id"].astext
|
||||
== external_id
|
||||
)
|
||||
if dialect_name == 'sqlite':
|
||||
query = query.filter(User.scim.contains({provider: {'external_id': external_id}}))
|
||||
elif dialect_name == 'postgresql':
|
||||
query = query.filter(User.scim[provider].cast(JSONB)['external_id'].astext == external_id)
|
||||
|
||||
user = query.first()
|
||||
return UserModel.model_validate(user) if user else None
|
||||
@@ -388,16 +364,16 @@ class UsersTable:
|
||||
query = db.query(User).options(defer(User.profile_image_url))
|
||||
|
||||
if filter:
|
||||
query_key = filter.get("query")
|
||||
query_key = filter.get('query')
|
||||
if query_key:
|
||||
query = query.filter(
|
||||
or_(
|
||||
User.name.ilike(f"%{query_key}%"),
|
||||
User.email.ilike(f"%{query_key}%"),
|
||||
User.name.ilike(f'%{query_key}%'),
|
||||
User.email.ilike(f'%{query_key}%'),
|
||||
)
|
||||
)
|
||||
|
||||
channel_id = filter.get("channel_id")
|
||||
channel_id = filter.get('channel_id')
|
||||
if channel_id:
|
||||
query = query.filter(
|
||||
exists(
|
||||
@@ -408,13 +384,13 @@ class UsersTable:
|
||||
)
|
||||
)
|
||||
|
||||
user_ids = filter.get("user_ids")
|
||||
group_ids = filter.get("group_ids")
|
||||
user_ids = filter.get('user_ids')
|
||||
group_ids = filter.get('group_ids')
|
||||
|
||||
if isinstance(user_ids, list) and isinstance(group_ids, list):
|
||||
# If both are empty lists, return no users
|
||||
if not user_ids and not group_ids:
|
||||
return {"users": [], "total": 0}
|
||||
return {'users': [], 'total': 0}
|
||||
|
||||
if user_ids:
|
||||
query = query.filter(User.id.in_(user_ids))
|
||||
@@ -429,21 +405,21 @@ class UsersTable:
|
||||
)
|
||||
)
|
||||
|
||||
roles = filter.get("roles")
|
||||
roles = filter.get('roles')
|
||||
if roles:
|
||||
include_roles = [role for role in roles if not role.startswith("!")]
|
||||
exclude_roles = [role[1:] for role in roles if role.startswith("!")]
|
||||
include_roles = [role for role in roles if not role.startswith('!')]
|
||||
exclude_roles = [role[1:] for role in roles if role.startswith('!')]
|
||||
|
||||
if include_roles:
|
||||
query = query.filter(User.role.in_(include_roles))
|
||||
if exclude_roles:
|
||||
query = query.filter(~User.role.in_(exclude_roles))
|
||||
|
||||
order_by = filter.get("order_by")
|
||||
direction = filter.get("direction")
|
||||
order_by = filter.get('order_by')
|
||||
direction = filter.get('direction')
|
||||
|
||||
if order_by and order_by.startswith("group_id:"):
|
||||
group_id = order_by.split(":", 1)[1]
|
||||
if order_by and order_by.startswith('group_id:'):
|
||||
group_id = order_by.split(':', 1)[1]
|
||||
|
||||
# Subquery that checks if the user belongs to the group
|
||||
membership_exists = exists(
|
||||
@@ -456,42 +432,42 @@ class UsersTable:
|
||||
# CASE: user in group → 1, user not in group → 0
|
||||
group_sort = case((membership_exists, 1), else_=0)
|
||||
|
||||
if direction == "asc":
|
||||
if direction == 'asc':
|
||||
query = query.order_by(group_sort.asc(), User.name.asc())
|
||||
else:
|
||||
query = query.order_by(group_sort.desc(), User.name.asc())
|
||||
|
||||
elif order_by == "name":
|
||||
if direction == "asc":
|
||||
elif order_by == 'name':
|
||||
if direction == 'asc':
|
||||
query = query.order_by(User.name.asc())
|
||||
else:
|
||||
query = query.order_by(User.name.desc())
|
||||
|
||||
elif order_by == "email":
|
||||
if direction == "asc":
|
||||
elif order_by == 'email':
|
||||
if direction == 'asc':
|
||||
query = query.order_by(User.email.asc())
|
||||
else:
|
||||
query = query.order_by(User.email.desc())
|
||||
|
||||
elif order_by == "created_at":
|
||||
if direction == "asc":
|
||||
elif order_by == 'created_at':
|
||||
if direction == 'asc':
|
||||
query = query.order_by(User.created_at.asc())
|
||||
else:
|
||||
query = query.order_by(User.created_at.desc())
|
||||
|
||||
elif order_by == "last_active_at":
|
||||
if direction == "asc":
|
||||
elif order_by == 'last_active_at':
|
||||
if direction == 'asc':
|
||||
query = query.order_by(User.last_active_at.asc())
|
||||
else:
|
||||
query = query.order_by(User.last_active_at.desc())
|
||||
|
||||
elif order_by == "updated_at":
|
||||
if direction == "asc":
|
||||
elif order_by == 'updated_at':
|
||||
if direction == 'asc':
|
||||
query = query.order_by(User.updated_at.asc())
|
||||
else:
|
||||
query = query.order_by(User.updated_at.desc())
|
||||
elif order_by == "role":
|
||||
if direction == "asc":
|
||||
elif order_by == 'role':
|
||||
if direction == 'asc':
|
||||
query = query.order_by(User.role.asc())
|
||||
else:
|
||||
query = query.order_by(User.role.desc())
|
||||
@@ -510,13 +486,11 @@ class UsersTable:
|
||||
|
||||
users = query.all()
|
||||
return {
|
||||
"users": [UserModel.model_validate(user) for user in users],
|
||||
"total": total,
|
||||
'users': [UserModel.model_validate(user) for user in users],
|
||||
'total': total,
|
||||
}
|
||||
|
||||
def get_users_by_group_id(
|
||||
self, group_id: str, db: Optional[Session] = None
|
||||
) -> list[UserModel]:
|
||||
def get_users_by_group_id(self, group_id: str, db: Optional[Session] = None) -> list[UserModel]:
|
||||
with get_db_context(db) as db:
|
||||
users = (
|
||||
db.query(User)
|
||||
@@ -527,16 +501,9 @@ class UsersTable:
|
||||
)
|
||||
return [UserModel.model_validate(user) for user in users]
|
||||
|
||||
def get_users_by_user_ids(
|
||||
self, user_ids: list[str], db: Optional[Session] = None
|
||||
) -> list[UserStatusModel]:
|
||||
def get_users_by_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> list[UserStatusModel]:
|
||||
with get_db_context(db) as db:
|
||||
users = (
|
||||
db.query(User)
|
||||
.options(defer(User.profile_image_url))
|
||||
.filter(User.id.in_(user_ids))
|
||||
.all()
|
||||
)
|
||||
users = db.query(User).options(defer(User.profile_image_url)).filter(User.id.in_(user_ids)).all()
|
||||
return [UserModel.model_validate(user) for user in users]
|
||||
|
||||
def get_num_users(self, db: Optional[Session] = None) -> Optional[int]:
|
||||
@@ -555,9 +522,7 @@ class UsersTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_user_webhook_url_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> Optional[str]:
|
||||
def get_user_webhook_url_by_id(self, id: str, db: Optional[Session] = None) -> Optional[str]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
@@ -565,11 +530,7 @@ class UsersTable:
|
||||
if user.settings is None:
|
||||
return None
|
||||
else:
|
||||
return (
|
||||
user.settings.get("ui", {})
|
||||
.get("notifications", {})
|
||||
.get("webhook_url", None)
|
||||
)
|
||||
return user.settings.get('ui', {}).get('notifications', {}).get('webhook_url', None)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@@ -577,14 +538,10 @@ class UsersTable:
|
||||
with get_db_context(db) as db:
|
||||
current_timestamp = int(datetime.datetime.now().timestamp())
|
||||
today_midnight_timestamp = current_timestamp - (current_timestamp % 86400)
|
||||
query = db.query(User).filter(
|
||||
User.last_active_at > today_midnight_timestamp
|
||||
)
|
||||
query = db.query(User).filter(User.last_active_at > today_midnight_timestamp)
|
||||
return query.count()
|
||||
|
||||
def update_user_role_by_id(
|
||||
self, id: str, role: str, db: Optional[Session] = None
|
||||
) -> Optional[UserModel]:
|
||||
def update_user_role_by_id(self, id: str, role: str, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
@@ -629,9 +586,7 @@ class UsersTable:
|
||||
return None
|
||||
|
||||
@throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL)
|
||||
def update_last_active_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> Optional[UserModel]:
|
||||
def update_last_active_by_id(self, id: str, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
@@ -665,10 +620,10 @@ class UsersTable:
|
||||
oauth = user.oauth or {}
|
||||
|
||||
# Update or insert provider entry
|
||||
oauth[provider] = {"sub": sub}
|
||||
oauth[provider] = {'sub': sub}
|
||||
|
||||
# Persist updated JSON
|
||||
db.query(User).filter_by(id=id).update({"oauth": oauth})
|
||||
db.query(User).filter_by(id=id).update({'oauth': oauth})
|
||||
db.commit()
|
||||
|
||||
return UserModel.model_validate(user)
|
||||
@@ -698,9 +653,9 @@ class UsersTable:
|
||||
return None
|
||||
|
||||
scim = user.scim or {}
|
||||
scim[provider] = {"external_id": external_id}
|
||||
scim[provider] = {'external_id': external_id}
|
||||
|
||||
db.query(User).filter_by(id=id).update({"scim": scim})
|
||||
db.query(User).filter_by(id=id).update({'scim': scim})
|
||||
db.commit()
|
||||
|
||||
return UserModel.model_validate(user)
|
||||
@@ -708,9 +663,7 @@ class UsersTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_user_by_id(
|
||||
self, id: str, updated: dict, db: Optional[Session] = None
|
||||
) -> Optional[UserModel]:
|
||||
def update_user_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
@@ -725,9 +678,7 @@ class UsersTable:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
def update_user_settings_by_id(
|
||||
self, id: str, updated: dict, db: Optional[Session] = None
|
||||
) -> Optional[UserModel]:
|
||||
def update_user_settings_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
@@ -741,7 +692,7 @@ class UsersTable:
|
||||
|
||||
user_settings.update(updated)
|
||||
|
||||
db.query(User).filter_by(id=id).update({"settings": user_settings})
|
||||
db.query(User).filter_by(id=id).update({'settings': user_settings})
|
||||
db.commit()
|
||||
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
@@ -768,9 +719,7 @@ class UsersTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_user_api_key_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> Optional[str]:
|
||||
def get_user_api_key_by_id(self, id: str, db: Optional[Session] = None) -> Optional[str]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
api_key = db.query(ApiKey).filter_by(user_id=id).first()
|
||||
@@ -778,9 +727,7 @@ class UsersTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_user_api_key_by_id(
|
||||
self, id: str, api_key: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
def update_user_api_key_by_id(self, id: str, api_key: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
db.query(ApiKey).filter_by(user_id=id).delete()
|
||||
@@ -788,7 +735,7 @@ class UsersTable:
|
||||
|
||||
now = int(time.time())
|
||||
new_api_key = ApiKey(
|
||||
id=f"key_{id}",
|
||||
id=f'key_{id}',
|
||||
user_id=id,
|
||||
key=api_key,
|
||||
created_at=now,
|
||||
@@ -811,16 +758,14 @@ class UsersTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_valid_user_ids(
|
||||
self, user_ids: list[str], db: Optional[Session] = None
|
||||
) -> list[str]:
|
||||
def get_valid_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> list[str]:
|
||||
with get_db_context(db) as db:
|
||||
users = db.query(User).filter(User.id.in_(user_ids)).all()
|
||||
return [user.id for user in users]
|
||||
|
||||
def get_super_admin_user(self, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||
with get_db_context(db) as db:
|
||||
user = db.query(User).filter_by(role="admin").first()
|
||||
user = db.query(User).filter_by(role='admin').first()
|
||||
if user:
|
||||
return UserModel.model_validate(user)
|
||||
else:
|
||||
@@ -830,9 +775,7 @@ class UsersTable:
|
||||
with get_db_context(db) as db:
|
||||
# Consider user active if last_active_at within the last 3 minutes
|
||||
three_minutes_ago = int(time.time()) - 180
|
||||
count = (
|
||||
db.query(User).filter(User.last_active_at >= three_minutes_ago).count()
|
||||
)
|
||||
count = db.query(User).filter(User.last_active_at >= three_minutes_ago).count()
|
||||
return count
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user