This commit is contained in:
Timothy Jaeryang Baek
2026-02-21 16:27:25 -06:00
parent 74e771fec6
commit b48594a166
2 changed files with 38 additions and 25 deletions

View File

@@ -2,7 +2,7 @@ import logging
import time
from typing import Optional
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, defer
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
from open_webui.models.users import Users, UserResponse
from open_webui.models.groups import Groups
@@ -156,9 +156,14 @@ class ToolsTable:
except Exception:
return None
def get_tools(self, db: Optional[Session] = None) -> list[ToolUserModel]:
def get_tools(
self, defer_content: bool = False, db: Optional[Session] = None
) -> list[ToolUserModel]:
with get_db_context(db) as db:
all_tools = db.query(Tool).order_by(Tool.updated_at.desc()).all()
query = db.query(Tool).order_by(Tool.updated_at.desc())
if defer_content:
query = query.options(defer(Tool.content), defer(Tool.specs))
all_tools = query.all()
user_ids = list(set(tool.user_id for tool in all_tools))
tool_ids = [tool.id for tool in all_tools]
@@ -185,9 +190,9 @@ class ToolsTable:
return tools
def get_tools_by_user_id(
self, user_id: str, permission: str = "write", db: Optional[Session] = None
self, user_id: str, permission: str = "write", defer_content: bool = False, db: Optional[Session] = None
) -> list[ToolUserModel]:
tools = self.get_tools(db=db)
tools = self.get_tools(defer_content=defer_content, db=db)
user_group_ids = {
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
}

View File

@@ -64,13 +64,13 @@ async def get_tools(
tools = []
# Local Tools
for tool in Tools.get_tools(db=db):
tool_module = get_tool_module(request, tool.id)
for tool in Tools.get_tools(defer_content=True, db=db):
tool_module = request.app.state.TOOLS.get(tool.id) if hasattr(request.app.state, 'TOOLS') else None
tools.append(
ToolUserResponse(
**{
**tool.model_dump(),
"has_user_valves": hasattr(tool_module, "UserValves"),
"has_user_valves": hasattr(tool_module, "UserValves") if tool_module else False,
}
)
)
@@ -196,27 +196,35 @@ async def get_tool_list(
user=Depends(get_verified_user), db: Session = Depends(get_session)
):
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
tools = Tools.get_tools(db=db)
tools = Tools.get_tools(defer_content=True, db=db)
else:
tools = Tools.get_tools_by_user_id(user.id, "read", db=db)
tools = Tools.get_tools_by_user_id(user.id, "read", defer_content=True, db=db)
return [
ToolAccessResponse(
**tool.model_dump(),
write_access=(
(user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
or user.id == tool.user_id
or AccessGrants.has_access(
user_id=user.id,
resource_type="tool",
resource_id=tool.id,
permission="write",
db=db,
user_group_ids = {
group.id for group in Groups.get_groups_by_member_id(user.id, db=db)
}
result = []
for tool in tools:
has_write = (
(user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
or user.id == tool.user_id
or any(
g.permission == "write"
and (
(g.principal_type == "user" and (g.principal_id == user.id or g.principal_id == "*"))
or (g.principal_type == "group" and g.principal_id in user_group_ids)
)
),
for g in tool.access_grants
)
)
for tool in tools
]
result.append(
ToolAccessResponse(
**tool.model_dump(),
write_access=has_write,
)
)
return result
############################