diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py index 68ef37afe4..6e3271d7ec 100644 --- a/backend/open_webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -57,64 +57,7 @@ log = logging.getLogger(__name__) router = APIRouter() -############################ -# Check if the current user has access to a file through any knowledge bases the user may be in. -############################ - - -# TODO: Optimize this function to use the knowledge_file table for faster lookups. -def has_access_to_file( - file_id: Optional[str], - access_type: str, - user=Depends(get_verified_user), - db: Optional[Session] = None, -) -> bool: - file = Files.get_file_by_id(file_id, db=db) - log.debug(f"Checking if user has {access_type} access to file") - if not file: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=ERROR_MESSAGES.NOT_FOUND, - ) - - # Check if the file is associated with any knowledge bases the user has access to - knowledge_bases = Knowledges.get_knowledges_by_file_id(file_id, db=db) - user_group_ids = { - group.id for group in Groups.get_groups_by_member_id(user.id, db=db) - } - for knowledge_base in knowledge_bases: - if knowledge_base.user_id == user.id or AccessGrants.has_access( - user_id=user.id, - resource_type="knowledge", - resource_id=knowledge_base.id, - permission=access_type, - user_group_ids=user_group_ids, - db=db, - ): - return True - - knowledge_base_id = file.meta.get("collection_name") if file.meta else None - if knowledge_base_id: - knowledge_bases = Knowledges.get_knowledge_bases_by_user_id( - user.id, access_type, db=db - ) - for knowledge_base in knowledge_bases: - if knowledge_base.id == knowledge_base_id: - return True - - # Check if the file is associated with any channels the user has access to - channels = Channels.get_channels_by_file_id_and_user_id(file_id, user.id, db=db) - if access_type == "read" and channels: - return True - - # Check if the file is associated with any chats the user has access to - # TODO: Granular access control for chats - chats = Chats.get_shared_chats_by_file_id(file_id, db=db) - if chats: - return True - - return False - +from open_webui.utils.access_control.files import has_access_to_file ############################ # Upload File diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 4c27ffea47..5292546285 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -37,6 +37,7 @@ from langchain_text_splitters import ( from langchain_core.documents import Document from open_webui.models.files import FileModel, FileUpdateForm, Files +from open_webui.utils.access_control.files import has_access_to_file from open_webui.models.knowledge import Knowledges from open_webui.storage.provider import Storage from open_webui.internal.db import get_session, get_db @@ -2579,6 +2580,30 @@ async def query_collection_handler( form_data: QueryCollectionsForm, user=Depends(get_verified_user), ): + # Ownership validation: prevent accessing other users' private collections + if user.role != "admin": + for collection_name in form_data.collection_names: + if collection_name.startswith("user-memory-"): + owner_id = collection_name[len("user-memory-") :] + if owner_id != user.id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + elif collection_name.startswith("file-"): + file_id = collection_name[len("file-") :] + file = Files.get_file_by_id(file_id) + if file and file.user_id != user.id: + if not has_access_to_file( + file_id=file_id, + access_type="read", + user=user, + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + ) + try: if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and ( form_data.hybrid is None or form_data.hybrid diff --git a/backend/open_webui/tools/builtin.py b/backend/open_webui/tools/builtin.py index 330baf1318..a59df3813f 100644 --- a/backend/open_webui/tools/builtin.py +++ b/backend/open_webui/tools/builtin.py @@ -1627,7 +1627,7 @@ async def view_file( try: from open_webui.models.files import Files - from open_webui.routers.files import has_access_to_file + from open_webui.utils.access_control.files import has_access_to_file user_id = __user__.get("id") user_role = __user__.get("role", "user") diff --git a/backend/open_webui/utils/access_control.py b/backend/open_webui/utils/access_control/__init__.py similarity index 100% rename from backend/open_webui/utils/access_control.py rename to backend/open_webui/utils/access_control/__init__.py diff --git a/backend/open_webui/utils/access_control/files.py b/backend/open_webui/utils/access_control/files.py new file mode 100644 index 0000000000..b863c404f3 --- /dev/null +++ b/backend/open_webui/utils/access_control/files.py @@ -0,0 +1,71 @@ +import logging +from typing import Optional, Any + +from open_webui.models.users import UserModel +from open_webui.models.files import Files +from open_webui.models.knowledge import Knowledges +from open_webui.models.channels import Channels +from open_webui.models.chats import Chats +from open_webui.models.groups import Groups +from open_webui.models.access_grants import AccessGrants + +log = logging.getLogger(__name__) + + +def has_access_to_file( + file_id: Optional[str], + access_type: str, + user: UserModel, + db: Optional[Any] = None, +) -> bool: + """ + Check if a user has the specified access to a file through any of: + - Knowledge bases (ownership or access grants) + - Channels the user is a member of + - Shared chats + + NOTE: This does NOT check direct file ownership — callers should check + file.user_id == user.id separately before calling this. + """ + file = Files.get_file_by_id(file_id, db=db) + log.debug(f"Checking if user has {access_type} access to file") + if not file: + return False + + # Check if the file is associated with any knowledge bases the user has access to + knowledge_bases = Knowledges.get_knowledges_by_file_id(file_id, db=db) + user_group_ids = { + group.id for group in Groups.get_groups_by_member_id(user.id, db=db) + } + for knowledge_base in knowledge_bases: + if knowledge_base.user_id == user.id or AccessGrants.has_access( + user_id=user.id, + resource_type="knowledge", + resource_id=knowledge_base.id, + permission=access_type, + user_group_ids=user_group_ids, + db=db, + ): + return True + + knowledge_base_id = file.meta.get("collection_name") if file.meta else None + if knowledge_base_id: + knowledge_bases = Knowledges.get_knowledge_bases_by_user_id( + user.id, access_type, db=db + ) + for knowledge_base in knowledge_bases: + if knowledge_base.id == knowledge_base_id: + return True + + # Check if the file is associated with any channels the user has access to + channels = Channels.get_channels_by_file_id_and_user_id(file_id, user.id, db=db) + if access_type == "read" and channels: + return True + + # Check if the file is associated with any chats the user has access to + # TODO: Granular access control for chats + chats = Chats.get_shared_chats_by_file_id(file_id, db=db) + if chats: + return True + + return False