diff --git a/backend/open_webui/models/chats.py b/backend/open_webui/models/chats.py index 1397a254c6..d821985a4e 100644 --- a/backend/open_webui/models/chats.py +++ b/backend/open_webui/models/chats.py @@ -1337,5 +1337,17 @@ class ChatTable: except Exception: return False + def get_shared_chats_by_file_id(self, file_id: str) -> list[ChatModel]: + with get_db() as db: + # Join Chat and ChatFile tables to get shared chats associated with the file_id + all_chats = ( + db.query(Chat) + .join(ChatFile, Chat.id == ChatFile.chat_id) + .filter(ChatFile.file_id == file_id, Chat.share_id.isnot(None)) + .all() + ) + + return [ChatModel.model_validate(chat) for chat in all_chats] + Chats = ChatTable() diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py index 0986f5a76a..723a150197 100644 --- a/backend/open_webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -34,6 +34,7 @@ from open_webui.models.files import ( FileModelResponse, Files, ) +from open_webui.models.chats import Chats from open_webui.models.knowledge import Knowledges from open_webui.models.groups import Groups @@ -71,9 +72,9 @@ def has_access_to_file( detail=ERROR_MESSAGES.NOT_FOUND, ) + # Check if the file is associated with any knowledge bases the user has access to knowledge_bases = Knowledges.get_knowledges_by_file_id(file_id) user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)} - for knowledge_base in knowledge_bases: if knowledge_base.user_id == user.id or has_access( user.id, access_type, knowledge_base.access_control, user_group_ids @@ -89,10 +90,17 @@ def has_access_to_file( if knowledge_base.id == knowledge_base_id: return True + # Check if the file is associated with any channels the user has access to channels = Channels.get_channels_by_file_id_and_user_id(file_id, user.id) if access_type == "read" and channels: return True + # Check if the file is associated with any chats the user has access to + # TODO: Granular access control for chats + chats = Chats.get_shared_chats_by_file_id(file_id) + if chats: + return True + return False