refac: defer profile

This commit is contained in:
Timothy Jaeryang Baek
2026-02-13 14:08:07 -06:00
parent 589c4e64c1
commit b7549d2f6c
4 changed files with 30 additions and 13 deletions

View File

@@ -1,7 +1,7 @@
import time
from typing import Optional
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, defer
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
@@ -15,7 +15,7 @@ from open_webui.utils.misc import throttle
from open_webui.utils.validate import validate_profile_image_url
from pydantic import BaseModel, ConfigDict, field_validator
from pydantic import BaseModel, ConfigDict, field_validator, model_validator
from sqlalchemy import (
BigInteger,
JSON,
@@ -28,7 +28,7 @@ from sqlalchemy import (
select,
cast,
)
from sqlalchemy import or_, case
from sqlalchemy import or_, case, func
from sqlalchemy.dialects.postgresql import JSONB
import datetime
@@ -86,7 +86,7 @@ class UserModel(BaseModel):
name: str
profile_image_url: str
profile_image_url: Optional[str] = None
profile_banner_image_url: Optional[str] = None
bio: Optional[str] = None
@@ -110,6 +110,12 @@ class UserModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
@model_validator(mode="after")
def set_profile_image_url(self):
if not self.profile_image_url:
self.profile_image_url = f"/api/v1/users/{self.id}/profile/image"
return self
class UserStatusModel(UserModel):
is_active: bool = False
@@ -315,8 +321,12 @@ class UsersTable:
) -> Optional[UserModel]:
try:
with get_db_context(db) as db:
user = db.query(User).filter_by(email=email).first()
return UserModel.model_validate(user)
user = (
db.query(User)
.filter(func.lower(User.email) == email.lower())
.first()
)
return UserModel.model_validate(user) if user else None
except Exception:
return None
@@ -350,7 +360,7 @@ class UsersTable:
) -> dict:
with get_db_context(db) as db:
# Join GroupMember so we can order by group_id when requested
query = db.query(User)
query = db.query(User).options(defer(User.profile_image_url))
if filter:
query_key = filter.get("query")
@@ -485,6 +495,7 @@ class UsersTable:
with get_db_context(db) as db:
users = (
db.query(User)
.options(defer(User.profile_image_url))
.join(GroupMember, User.id == GroupMember.user_id)
.filter(GroupMember.group_id == group_id)
.all()
@@ -495,7 +506,7 @@ class UsersTable:
self, user_ids: list[str], db: Optional[Session] = None
) -> list[UserStatusModel]:
with get_db_context(db) as db:
users = db.query(User).filter(User.id.in_(user_ids)).all()
users = db.query(User).options(defer(User.profile_image_url)).filter(User.id.in_(user_ids)).all()
return [UserModel.model_validate(user) for user in users]
def get_num_users(self, db: Optional[Session] = None) -> Optional[int]: