mirror of
https://github.com/open-webui/open-webui.git
synced 2026-03-11 17:47:44 -05:00
refac
This commit is contained in:
@@ -2,7 +2,7 @@ import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, defer
|
||||
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
|
||||
from open_webui.models.users import Users, UserResponse
|
||||
from open_webui.models.groups import Groups
|
||||
@@ -156,9 +156,14 @@ class ToolsTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_tools(self, db: Optional[Session] = None) -> list[ToolUserModel]:
|
||||
def get_tools(
|
||||
self, defer_content: bool = False, db: Optional[Session] = None
|
||||
) -> list[ToolUserModel]:
|
||||
with get_db_context(db) as db:
|
||||
all_tools = db.query(Tool).order_by(Tool.updated_at.desc()).all()
|
||||
query = db.query(Tool).order_by(Tool.updated_at.desc())
|
||||
if defer_content:
|
||||
query = query.options(defer(Tool.content), defer(Tool.specs))
|
||||
all_tools = query.all()
|
||||
|
||||
user_ids = list(set(tool.user_id for tool in all_tools))
|
||||
tool_ids = [tool.id for tool in all_tools]
|
||||
@@ -185,9 +190,9 @@ class ToolsTable:
|
||||
return tools
|
||||
|
||||
def get_tools_by_user_id(
|
||||
self, user_id: str, permission: str = "write", db: Optional[Session] = None
|
||||
self, user_id: str, permission: str = "write", defer_content: bool = False, db: Optional[Session] = None
|
||||
) -> list[ToolUserModel]:
|
||||
tools = self.get_tools(db=db)
|
||||
tools = self.get_tools(defer_content=defer_content, db=db)
|
||||
user_group_ids = {
|
||||
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
|
||||
}
|
||||
|
||||
@@ -64,13 +64,13 @@ async def get_tools(
|
||||
tools = []
|
||||
|
||||
# Local Tools
|
||||
for tool in Tools.get_tools(db=db):
|
||||
tool_module = get_tool_module(request, tool.id)
|
||||
for tool in Tools.get_tools(defer_content=True, db=db):
|
||||
tool_module = request.app.state.TOOLS.get(tool.id) if hasattr(request.app.state, 'TOOLS') else None
|
||||
tools.append(
|
||||
ToolUserResponse(
|
||||
**{
|
||||
**tool.model_dump(),
|
||||
"has_user_valves": hasattr(tool_module, "UserValves"),
|
||||
"has_user_valves": hasattr(tool_module, "UserValves") if tool_module else False,
|
||||
}
|
||||
)
|
||||
)
|
||||
@@ -196,27 +196,35 @@ async def get_tool_list(
|
||||
user=Depends(get_verified_user), db: Session = Depends(get_session)
|
||||
):
|
||||
if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
|
||||
tools = Tools.get_tools(db=db)
|
||||
tools = Tools.get_tools(defer_content=True, db=db)
|
||||
else:
|
||||
tools = Tools.get_tools_by_user_id(user.id, "read", db=db)
|
||||
tools = Tools.get_tools_by_user_id(user.id, "read", defer_content=True, db=db)
|
||||
|
||||
return [
|
||||
ToolAccessResponse(
|
||||
**tool.model_dump(),
|
||||
write_access=(
|
||||
(user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
|
||||
or user.id == tool.user_id
|
||||
or AccessGrants.has_access(
|
||||
user_id=user.id,
|
||||
resource_type="tool",
|
||||
resource_id=tool.id,
|
||||
permission="write",
|
||||
db=db,
|
||||
user_group_ids = {
|
||||
group.id for group in Groups.get_groups_by_member_id(user.id, db=db)
|
||||
}
|
||||
|
||||
result = []
|
||||
for tool in tools:
|
||||
has_write = (
|
||||
(user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL)
|
||||
or user.id == tool.user_id
|
||||
or any(
|
||||
g.permission == "write"
|
||||
and (
|
||||
(g.principal_type == "user" and (g.principal_id == user.id or g.principal_id == "*"))
|
||||
or (g.principal_type == "group" and g.principal_id in user_group_ids)
|
||||
)
|
||||
),
|
||||
for g in tool.access_grants
|
||||
)
|
||||
)
|
||||
for tool in tools
|
||||
]
|
||||
result.append(
|
||||
ToolAccessResponse(
|
||||
**tool.model_dump(),
|
||||
write_access=has_write,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
############################
|
||||
|
||||
Reference in New Issue
Block a user