mirror of
https://github.com/open-webui/open-webui.git
synced 2026-05-05 10:28:06 -05:00
refac: defer profile
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, defer
|
||||
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ from open_webui.utils.misc import throttle
|
||||
from open_webui.utils.validate import validate_profile_image_url
|
||||
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, field_validator, model_validator
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
JSON,
|
||||
@@ -28,7 +28,7 @@ from sqlalchemy import (
|
||||
select,
|
||||
cast,
|
||||
)
|
||||
from sqlalchemy import or_, case
|
||||
from sqlalchemy import or_, case, func
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
import datetime
|
||||
@@ -86,7 +86,7 @@ class UserModel(BaseModel):
|
||||
|
||||
name: str
|
||||
|
||||
profile_image_url: str
|
||||
profile_image_url: Optional[str] = None
|
||||
profile_banner_image_url: Optional[str] = None
|
||||
|
||||
bio: Optional[str] = None
|
||||
@@ -110,6 +110,12 @@ class UserModel(BaseModel):
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_profile_image_url(self):
|
||||
if not self.profile_image_url:
|
||||
self.profile_image_url = f"/api/v1/users/{self.id}/profile/image"
|
||||
return self
|
||||
|
||||
|
||||
class UserStatusModel(UserModel):
|
||||
is_active: bool = False
|
||||
@@ -315,8 +321,12 @@ class UsersTable:
|
||||
) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
user = db.query(User).filter_by(email=email).first()
|
||||
return UserModel.model_validate(user)
|
||||
user = (
|
||||
db.query(User)
|
||||
.filter(func.lower(User.email) == email.lower())
|
||||
.first()
|
||||
)
|
||||
return UserModel.model_validate(user) if user else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@@ -350,7 +360,7 @@ class UsersTable:
|
||||
) -> dict:
|
||||
with get_db_context(db) as db:
|
||||
# Join GroupMember so we can order by group_id when requested
|
||||
query = db.query(User)
|
||||
query = db.query(User).options(defer(User.profile_image_url))
|
||||
|
||||
if filter:
|
||||
query_key = filter.get("query")
|
||||
@@ -485,6 +495,7 @@ class UsersTable:
|
||||
with get_db_context(db) as db:
|
||||
users = (
|
||||
db.query(User)
|
||||
.options(defer(User.profile_image_url))
|
||||
.join(GroupMember, User.id == GroupMember.user_id)
|
||||
.filter(GroupMember.group_id == group_id)
|
||||
.all()
|
||||
@@ -495,7 +506,7 @@ class UsersTable:
|
||||
self, user_ids: list[str], db: Optional[Session] = None
|
||||
) -> list[UserStatusModel]:
|
||||
with get_db_context(db) as db:
|
||||
users = db.query(User).filter(User.id.in_(user_ids)).all()
|
||||
users = db.query(User).options(defer(User.profile_image_url)).filter(User.id.in_(user_ids)).all()
|
||||
return [UserModel.model_validate(user) for user in users]
|
||||
|
||||
def get_num_users(self, db: Optional[Session] = None) -> Optional[int]:
|
||||
|
||||
Reference in New Issue
Block a user