This commit is contained in:
Timothy Jaeryang Baek
2026-03-01 13:49:36 -06:00
parent 597883a179
commit 259d5ca596
5 changed files with 98 additions and 59 deletions

View File

@@ -57,64 +57,7 @@ log = logging.getLogger(__name__)
router = APIRouter()
############################
# Check if the current user has access to a file through any knowledge bases the user may be in.
############################
# TODO: Optimize this function to use the knowledge_file table for faster lookups.
def has_access_to_file(
file_id: Optional[str],
access_type: str,
user=Depends(get_verified_user),
db: Optional[Session] = None,
) -> bool:
file = Files.get_file_by_id(file_id, db=db)
log.debug(f"Checking if user has {access_type} access to file")
if not file:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
# Check if the file is associated with any knowledge bases the user has access to
knowledge_bases = Knowledges.get_knowledges_by_file_id(file_id, db=db)
user_group_ids = {
group.id for group in Groups.get_groups_by_member_id(user.id, db=db)
}
for knowledge_base in knowledge_bases:
if knowledge_base.user_id == user.id or AccessGrants.has_access(
user_id=user.id,
resource_type="knowledge",
resource_id=knowledge_base.id,
permission=access_type,
user_group_ids=user_group_ids,
db=db,
):
return True
knowledge_base_id = file.meta.get("collection_name") if file.meta else None
if knowledge_base_id:
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(
user.id, access_type, db=db
)
for knowledge_base in knowledge_bases:
if knowledge_base.id == knowledge_base_id:
return True
# Check if the file is associated with any channels the user has access to
channels = Channels.get_channels_by_file_id_and_user_id(file_id, user.id, db=db)
if access_type == "read" and channels:
return True
# Check if the file is associated with any chats the user has access to
# TODO: Granular access control for chats
chats = Chats.get_shared_chats_by_file_id(file_id, db=db)
if chats:
return True
return False
from open_webui.utils.access_control.files import has_access_to_file
############################
# Upload File

View File

@@ -37,6 +37,7 @@ from langchain_text_splitters import (
from langchain_core.documents import Document
from open_webui.models.files import FileModel, FileUpdateForm, Files
from open_webui.utils.access_control.files import has_access_to_file
from open_webui.models.knowledge import Knowledges
from open_webui.storage.provider import Storage
from open_webui.internal.db import get_session, get_db
@@ -2579,6 +2580,30 @@ async def query_collection_handler(
form_data: QueryCollectionsForm,
user=Depends(get_verified_user),
):
# Ownership validation: prevent accessing other users' private collections
if user.role != "admin":
for collection_name in form_data.collection_names:
if collection_name.startswith("user-memory-"):
owner_id = collection_name[len("user-memory-") :]
if owner_id != user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
elif collection_name.startswith("file-"):
file_id = collection_name[len("file-") :]
file = Files.get_file_by_id(file_id)
if file and file.user_id != user.id:
if not has_access_to_file(
file_id=file_id,
access_type="read",
user=user,
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
try:
if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and (
form_data.hybrid is None or form_data.hybrid

View File

@@ -1627,7 +1627,7 @@ async def view_file(
try:
from open_webui.models.files import Files
from open_webui.routers.files import has_access_to_file
from open_webui.utils.access_control.files import has_access_to_file
user_id = __user__.get("id")
user_role = __user__.get("role", "user")

View File

@@ -0,0 +1,71 @@
import logging
from typing import Optional, Any
from open_webui.models.users import UserModel
from open_webui.models.files import Files
from open_webui.models.knowledge import Knowledges
from open_webui.models.channels import Channels
from open_webui.models.chats import Chats
from open_webui.models.groups import Groups
from open_webui.models.access_grants import AccessGrants
log = logging.getLogger(__name__)
def has_access_to_file(
file_id: Optional[str],
access_type: str,
user: UserModel,
db: Optional[Any] = None,
) -> bool:
"""
Check if a user has the specified access to a file through any of:
- Knowledge bases (ownership or access grants)
- Channels the user is a member of
- Shared chats
NOTE: This does NOT check direct file ownership — callers should check
file.user_id == user.id separately before calling this.
"""
file = Files.get_file_by_id(file_id, db=db)
log.debug(f"Checking if user has {access_type} access to file")
if not file:
return False
# Check if the file is associated with any knowledge bases the user has access to
knowledge_bases = Knowledges.get_knowledges_by_file_id(file_id, db=db)
user_group_ids = {
group.id for group in Groups.get_groups_by_member_id(user.id, db=db)
}
for knowledge_base in knowledge_bases:
if knowledge_base.user_id == user.id or AccessGrants.has_access(
user_id=user.id,
resource_type="knowledge",
resource_id=knowledge_base.id,
permission=access_type,
user_group_ids=user_group_ids,
db=db,
):
return True
knowledge_base_id = file.meta.get("collection_name") if file.meta else None
if knowledge_base_id:
knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(
user.id, access_type, db=db
)
for knowledge_base in knowledge_bases:
if knowledge_base.id == knowledge_base_id:
return True
# Check if the file is associated with any channels the user has access to
channels = Channels.get_channels_by_file_id_and_user_id(file_id, user.id, db=db)
if access_type == "read" and channels:
return True
# Check if the file is associated with any chats the user has access to
# TODO: Granular access control for chats
chats = Chats.get_shared_chats_by_file_id(file_id, db=db)
if chats:
return True
return False