refac: async db

This commit is contained in:
Timothy Jaeryang Baek
2026-04-12 14:22:11 -05:00
parent b618d84065
commit 27169124f2
74 changed files with 4831 additions and 4479 deletions

View File

@@ -11,7 +11,7 @@ from open_webui.models.access_grants import (
)
from open_webui.config import DEFAULT_USER_PERMISSIONS
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
def fill_missing_permissions(permissions: dict[str, Any], default_permissions: dict[str, Any]) -> dict[str, Any]:
@@ -28,10 +28,10 @@ def fill_missing_permissions(permissions: dict[str, Any], default_permissions: d
return permissions
def get_permissions(
async def get_permissions(
user_id: str,
default_permissions: dict[str, Any],
db: Session | None = None,
db: AsyncSession | None = None,
) -> dict[str, Any]:
"""
Get all permissions for a user by combining the permissions of all groups the user is a member of.
@@ -53,7 +53,7 @@ def get_permissions(
permissions[key] = permissions[key] or value # Use the most permissive value (True > False)
return permissions
user_groups = Groups.get_groups_by_member_id(user_id, db=db)
user_groups = await Groups.get_groups_by_member_id(user_id, db=db)
# Deep copy default permissions to avoid modifying the original dict
permissions = json.loads(json.dumps(default_permissions))
@@ -68,11 +68,11 @@ def get_permissions(
return permissions
def has_permission(
async def has_permission(
user_id: str,
permission_key: str,
default_permissions: dict[str, Any] = {},
db: Session | None = None,
db: AsyncSession | None = None,
) -> bool:
"""
Check if a user has a specific permission by checking the group permissions
@@ -93,7 +93,7 @@ def has_permission(
permission_hierarchy = permission_key.split('.')
# Retrieve user group permissions
user_groups = Groups.get_groups_by_member_id(user_id, db=db)
user_groups = await Groups.get_groups_by_member_id(user_id, db=db)
for group in user_groups:
if get_permission(group.permissions or {}, permission_hierarchy):
@@ -104,12 +104,12 @@ def has_permission(
return get_permission(default_permissions, permission_hierarchy)
def has_access(
async def has_access(
user_id: str,
permission: str = 'read',
access_grants: list | None = None,
user_group_ids: set[str] | None = None,
db: Session | None = None,
db: AsyncSession | None = None,
) -> bool:
"""
Check if a user has the specified permission using an in-memory access_grants list.
@@ -126,7 +126,7 @@ def has_access(
return False
if user_group_ids is None:
user_groups = Groups.get_groups_by_member_id(user_id, db=db)
user_groups = await Groups.get_groups_by_member_id(user_id, db=db)
user_group_ids = {group.id for group in user_groups}
for grant in access_grants:
@@ -144,7 +144,7 @@ def has_access(
return False
def has_connection_access(
async def has_connection_access(
user: UserModel,
connection: dict,
user_group_ids: set[str] | None = None,
@@ -163,10 +163,10 @@ def has_connection_access(
return True
if user_group_ids is None:
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)}
user_group_ids = {group.id for group in await Groups.get_groups_by_member_id(user.id)}
access_grants = (connection.get('config') or {}).get('access_grants', [])
return has_access(user.id, 'read', access_grants, user_group_ids)
return await has_access(user.id, 'read', access_grants, user_group_ids)
def migrate_access_control(data: dict, ac_key: str = 'access_control', grants_key: str = 'access_grants') -> None:
@@ -210,13 +210,13 @@ def migrate_access_control(data: dict, ac_key: str = 'access_control', grants_ke
data.pop(ac_key, None)
def filter_allowed_access_grants(
async def filter_allowed_access_grants(
default_permissions: dict[str, Any],
user_id: str,
user_role: str,
access_grants: list,
public_permission_key: str,
db: Session | None = None,
db: AsyncSession | None = None,
) -> list:
"""
Checks if the user has the required permissions to grant access to a resource.
@@ -228,7 +228,7 @@ def filter_allowed_access_grants(
# Check if user can share publicly
if (
has_public_read_access_grant(access_grants) or has_public_write_access_grant(access_grants)
) and not has_permission(
) and not await has_permission(
user_id,
public_permission_key,
default_permissions,
@@ -246,7 +246,7 @@ def filter_allowed_access_grants(
]
# Strip individual user sharing if user lacks permission
if has_user_access_grant(access_grants) and not has_permission(
if has_user_access_grant(access_grants) and not await has_permission(
user_id,
'access_grants.allow_users',
default_permissions,
@@ -257,7 +257,7 @@ def filter_allowed_access_grants(
return access_grants
def check_model_access(
async def check_model_access(
user: UserModel,
model_info,
bypass_filter: bool = False,
@@ -270,7 +270,7 @@ def check_model_access(
Args:
user: The authenticated user.
model_info: The model record from Models.get_model_by_id(),
model_info: The model record from await Models.get_model_by_id(),
or None if the model is not registered.
bypass_filter: If True, skip all access checks (used by
internal callers and BYPASS_MODEL_ACCESS_CONTROL).
@@ -284,10 +284,10 @@ def check_model_access(
if user.role == 'user':
from open_webui.models.access_grants import AccessGrants
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)}
user_group_ids = {group.id for group in await Groups.get_groups_by_member_id(user.id)}
if not (
user.id == model_info.user_id
or AccessGrants.has_access(
or await AccessGrants.has_access(
user_id=user.id,
resource_type='model',
resource_id=model_info.id,