mirror of
https://github.com/open-webui/open-webui.git
synced 2026-05-04 03:16:03 -05:00
refac: async db
This commit is contained in:
@@ -11,7 +11,7 @@ from open_webui.models.access_grants import (
|
||||
)
|
||||
from open_webui.config import DEFAULT_USER_PERMISSIONS
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
def fill_missing_permissions(permissions: dict[str, Any], default_permissions: dict[str, Any]) -> dict[str, Any]:
|
||||
@@ -28,10 +28,10 @@ def fill_missing_permissions(permissions: dict[str, Any], default_permissions: d
|
||||
return permissions
|
||||
|
||||
|
||||
def get_permissions(
|
||||
async def get_permissions(
|
||||
user_id: str,
|
||||
default_permissions: dict[str, Any],
|
||||
db: Session | None = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get all permissions for a user by combining the permissions of all groups the user is a member of.
|
||||
@@ -53,7 +53,7 @@ def get_permissions(
|
||||
permissions[key] = permissions[key] or value # Use the most permissive value (True > False)
|
||||
return permissions
|
||||
|
||||
user_groups = Groups.get_groups_by_member_id(user_id, db=db)
|
||||
user_groups = await Groups.get_groups_by_member_id(user_id, db=db)
|
||||
|
||||
# Deep copy default permissions to avoid modifying the original dict
|
||||
permissions = json.loads(json.dumps(default_permissions))
|
||||
@@ -68,11 +68,11 @@ def get_permissions(
|
||||
return permissions
|
||||
|
||||
|
||||
def has_permission(
|
||||
async def has_permission(
|
||||
user_id: str,
|
||||
permission_key: str,
|
||||
default_permissions: dict[str, Any] = {},
|
||||
db: Session | None = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a user has a specific permission by checking the group permissions
|
||||
@@ -93,7 +93,7 @@ def has_permission(
|
||||
permission_hierarchy = permission_key.split('.')
|
||||
|
||||
# Retrieve user group permissions
|
||||
user_groups = Groups.get_groups_by_member_id(user_id, db=db)
|
||||
user_groups = await Groups.get_groups_by_member_id(user_id, db=db)
|
||||
|
||||
for group in user_groups:
|
||||
if get_permission(group.permissions or {}, permission_hierarchy):
|
||||
@@ -104,12 +104,12 @@ def has_permission(
|
||||
return get_permission(default_permissions, permission_hierarchy)
|
||||
|
||||
|
||||
def has_access(
|
||||
async def has_access(
|
||||
user_id: str,
|
||||
permission: str = 'read',
|
||||
access_grants: list | None = None,
|
||||
user_group_ids: set[str] | None = None,
|
||||
db: Session | None = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a user has the specified permission using an in-memory access_grants list.
|
||||
@@ -126,7 +126,7 @@ def has_access(
|
||||
return False
|
||||
|
||||
if user_group_ids is None:
|
||||
user_groups = Groups.get_groups_by_member_id(user_id, db=db)
|
||||
user_groups = await Groups.get_groups_by_member_id(user_id, db=db)
|
||||
user_group_ids = {group.id for group in user_groups}
|
||||
|
||||
for grant in access_grants:
|
||||
@@ -144,7 +144,7 @@ def has_access(
|
||||
return False
|
||||
|
||||
|
||||
def has_connection_access(
|
||||
async def has_connection_access(
|
||||
user: UserModel,
|
||||
connection: dict,
|
||||
user_group_ids: set[str] | None = None,
|
||||
@@ -163,10 +163,10 @@ def has_connection_access(
|
||||
return True
|
||||
|
||||
if user_group_ids is None:
|
||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)}
|
||||
user_group_ids = {group.id for group in await Groups.get_groups_by_member_id(user.id)}
|
||||
|
||||
access_grants = (connection.get('config') or {}).get('access_grants', [])
|
||||
return has_access(user.id, 'read', access_grants, user_group_ids)
|
||||
return await has_access(user.id, 'read', access_grants, user_group_ids)
|
||||
|
||||
|
||||
def migrate_access_control(data: dict, ac_key: str = 'access_control', grants_key: str = 'access_grants') -> None:
|
||||
@@ -210,13 +210,13 @@ def migrate_access_control(data: dict, ac_key: str = 'access_control', grants_ke
|
||||
data.pop(ac_key, None)
|
||||
|
||||
|
||||
def filter_allowed_access_grants(
|
||||
async def filter_allowed_access_grants(
|
||||
default_permissions: dict[str, Any],
|
||||
user_id: str,
|
||||
user_role: str,
|
||||
access_grants: list,
|
||||
public_permission_key: str,
|
||||
db: Session | None = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> list:
|
||||
"""
|
||||
Checks if the user has the required permissions to grant access to a resource.
|
||||
@@ -228,7 +228,7 @@ def filter_allowed_access_grants(
|
||||
# Check if user can share publicly
|
||||
if (
|
||||
has_public_read_access_grant(access_grants) or has_public_write_access_grant(access_grants)
|
||||
) and not has_permission(
|
||||
) and not await has_permission(
|
||||
user_id,
|
||||
public_permission_key,
|
||||
default_permissions,
|
||||
@@ -246,7 +246,7 @@ def filter_allowed_access_grants(
|
||||
]
|
||||
|
||||
# Strip individual user sharing if user lacks permission
|
||||
if has_user_access_grant(access_grants) and not has_permission(
|
||||
if has_user_access_grant(access_grants) and not await has_permission(
|
||||
user_id,
|
||||
'access_grants.allow_users',
|
||||
default_permissions,
|
||||
@@ -257,7 +257,7 @@ def filter_allowed_access_grants(
|
||||
return access_grants
|
||||
|
||||
|
||||
def check_model_access(
|
||||
async def check_model_access(
|
||||
user: UserModel,
|
||||
model_info,
|
||||
bypass_filter: bool = False,
|
||||
@@ -270,7 +270,7 @@ def check_model_access(
|
||||
|
||||
Args:
|
||||
user: The authenticated user.
|
||||
model_info: The model record from Models.get_model_by_id(),
|
||||
model_info: The model record from await Models.get_model_by_id(),
|
||||
or None if the model is not registered.
|
||||
bypass_filter: If True, skip all access checks (used by
|
||||
internal callers and BYPASS_MODEL_ACCESS_CONTROL).
|
||||
@@ -284,10 +284,10 @@ def check_model_access(
|
||||
if user.role == 'user':
|
||||
from open_webui.models.access_grants import AccessGrants
|
||||
|
||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)}
|
||||
user_group_ids = {group.id for group in await Groups.get_groups_by_member_id(user.id)}
|
||||
if not (
|
||||
user.id == model_info.user_id
|
||||
or AccessGrants.has_access(
|
||||
or await AccessGrants.has_access(
|
||||
user_id=user.id,
|
||||
resource_type='model',
|
||||
resource_id=model_info.id,
|
||||
|
||||
Reference in New Issue
Block a user