mirror of
https://github.com/open-webui/open-webui.git
synced 2026-05-01 17:59:28 -05:00
refac/enh: db session sharing
This commit is contained in:
@@ -28,6 +28,7 @@ def fill_missing_permissions(
|
||||
def get_permissions(
|
||||
user_id: str,
|
||||
default_permissions: Dict[str, Any],
|
||||
db: Optional[Any] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get all permissions for a user by combining the permissions of all groups the user is a member of.
|
||||
@@ -53,7 +54,7 @@ def get_permissions(
|
||||
) # Use the most permissive value (True > False)
|
||||
return permissions
|
||||
|
||||
user_groups = Groups.get_groups_by_member_id(user_id)
|
||||
user_groups = Groups.get_groups_by_member_id(user_id, db=db)
|
||||
|
||||
# Deep copy default permissions to avoid modifying the original dict
|
||||
permissions = json.loads(json.dumps(default_permissions))
|
||||
@@ -72,6 +73,7 @@ def has_permission(
|
||||
user_id: str,
|
||||
permission_key: str,
|
||||
default_permissions: Dict[str, Any] = {},
|
||||
db: Optional[Any] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a user has a specific permission by checking the group permissions
|
||||
@@ -92,7 +94,7 @@ def has_permission(
|
||||
permission_hierarchy = permission_key.split(".")
|
||||
|
||||
# Retrieve user group permissions
|
||||
user_groups = Groups.get_groups_by_member_id(user_id)
|
||||
user_groups = Groups.get_groups_by_member_id(user_id, db=db)
|
||||
|
||||
for group in user_groups:
|
||||
if get_permission(group.permissions or {}, permission_hierarchy):
|
||||
@@ -127,6 +129,7 @@ def has_access(
|
||||
access_control: Optional[dict] = None,
|
||||
user_group_ids: Optional[Set[str]] = None,
|
||||
strict: bool = True,
|
||||
db: Optional[Any] = None,
|
||||
) -> bool:
|
||||
if access_control is None:
|
||||
if strict:
|
||||
@@ -135,7 +138,7 @@ def has_access(
|
||||
return True
|
||||
|
||||
if user_group_ids is None:
|
||||
user_groups = Groups.get_groups_by_member_id(user_id)
|
||||
user_groups = Groups.get_groups_by_member_id(user_id, db=db)
|
||||
user_group_ids = {group.id for group in user_groups}
|
||||
|
||||
permitted_ids = get_permitted_group_and_user_ids(type, access_control)
|
||||
@@ -152,10 +155,10 @@ def has_access(
|
||||
|
||||
# Get all users with access to a resource
|
||||
def get_users_with_access(
|
||||
type: str = "write", access_control: Optional[dict] = None
|
||||
type: str = "write", access_control: Optional[dict] = None, db: Optional[Any] = None
|
||||
) -> list[UserModel]:
|
||||
if access_control is None:
|
||||
result = Users.get_users(filter={"roles": ["!pending"]})
|
||||
result = Users.get_users(filter={"roles": ["!pending"]}, db=db)
|
||||
return result.get("users", [])
|
||||
|
||||
permitted_ids = get_permitted_group_and_user_ids(type, access_control)
|
||||
@@ -167,8 +170,8 @@ def get_users_with_access(
|
||||
|
||||
user_ids_with_access = set(permitted_user_ids)
|
||||
|
||||
group_user_ids_map = Groups.get_group_user_ids_by_ids(permitted_group_ids)
|
||||
group_user_ids_map = Groups.get_group_user_ids_by_ids(permitted_group_ids, db=db)
|
||||
for user_ids in group_user_ids_map.values():
|
||||
user_ids_with_access.update(user_ids)
|
||||
|
||||
return Users.get_users_by_user_ids(list(user_ids_with_access))
|
||||
return Users.get_users_by_user_ids(list(user_ids_with_access), db=db)
|
||||
|
||||
Reference in New Issue
Block a user