mirror of
https://github.com/open-webui/open-webui.git
synced 2026-05-08 12:58:11 -05:00
feat: builtin native tools
This commit is contained in:
@@ -521,5 +521,34 @@ class MessageTable:
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
def search_messages_by_channel_ids(
|
||||
self,
|
||||
channel_ids: list[str],
|
||||
query: str,
|
||||
start_timestamp: Optional[int] = None,
|
||||
end_timestamp: Optional[int] = None,
|
||||
limit: int = 10,
|
||||
db: Optional[Session] = None,
|
||||
) -> list[MessageModel]:
|
||||
"""Search messages in specified channels by content."""
|
||||
with get_db_context(db) as db:
|
||||
query_builder = db.query(Message).filter(
|
||||
Message.channel_id.in_(channel_ids),
|
||||
Message.content.ilike(f"%{query}%"),
|
||||
)
|
||||
|
||||
if start_timestamp:
|
||||
query_builder = query_builder.filter(Message.created_at >= start_timestamp)
|
||||
if end_timestamp:
|
||||
query_builder = query_builder.filter(Message.created_at <= end_timestamp)
|
||||
|
||||
messages = (
|
||||
query_builder
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
return [MessageModel.model_validate(msg) for msg in messages]
|
||||
|
||||
|
||||
Messages = MessageTable()
|
||||
|
||||
@@ -14,11 +14,128 @@ from fastapi import Request
|
||||
from open_webui.models.users import UserModel
|
||||
from open_webui.routers.retrieval import search_web
|
||||
from open_webui.retrieval.utils import get_content_from_url
|
||||
from open_webui.routers.images import image_generations, image_edits, CreateImageForm, EditImageForm
|
||||
from open_webui.routers.memories import query_memory, add_memory, QueryMemoryForm, AddMemoryForm
|
||||
from open_webui.routers.images import (
|
||||
image_generations,
|
||||
image_edits,
|
||||
CreateImageForm,
|
||||
EditImageForm,
|
||||
)
|
||||
from open_webui.routers.memories import (
|
||||
query_memory,
|
||||
add_memory,
|
||||
QueryMemoryForm,
|
||||
AddMemoryForm,
|
||||
)
|
||||
from open_webui.models.notes import Notes
|
||||
from open_webui.models.chats import Chats
|
||||
from open_webui.models.channels import Channels, ChannelMember, Channel
|
||||
from open_webui.models.messages import Messages, Message
|
||||
from open_webui.models.groups import Groups
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# =============================================================================
|
||||
# TIME UTILITIES
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def get_current_timestamp(
|
||||
__request__: Request = None,
|
||||
__user__: dict = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the current Unix timestamp in seconds.
|
||||
|
||||
:return: JSON with current_timestamp (seconds) and current_iso (ISO format)
|
||||
"""
|
||||
try:
|
||||
import datetime
|
||||
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
return json.dumps(
|
||||
{
|
||||
"current_timestamp": int(now.timestamp()),
|
||||
"current_iso": now.isoformat(),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(f"get_current_timestamp error: {e}")
|
||||
return json.dumps({"error": str(e)})
|
||||
|
||||
|
||||
async def calculate_timestamp(
|
||||
days_ago: int = 0,
|
||||
weeks_ago: int = 0,
|
||||
months_ago: int = 0,
|
||||
years_ago: int = 0,
|
||||
__request__: Request = None,
|
||||
__user__: dict = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the current Unix timestamp, optionally adjusted by days, weeks, months, or years.
|
||||
Use this to calculate timestamps for date filtering in search functions.
|
||||
Examples: "last week" = weeks_ago=1, "3 days ago" = days_ago=3, "a year ago" = years_ago=1
|
||||
|
||||
:param days_ago: Number of days to subtract from current time (default: 0)
|
||||
:param weeks_ago: Number of weeks to subtract from current time (default: 0)
|
||||
:param months_ago: Number of months to subtract from current time (default: 0)
|
||||
:param years_ago: Number of years to subtract from current time (default: 0)
|
||||
:return: JSON with current_timestamp and calculated_timestamp (both in seconds)
|
||||
"""
|
||||
try:
|
||||
import datetime
|
||||
from dateutil.relativedelta import relativedelta
|
||||
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
current_ts = int(now.timestamp())
|
||||
|
||||
# Calculate the adjusted time
|
||||
total_days = days_ago + (weeks_ago * 7)
|
||||
adjusted = now - datetime.timedelta(days=total_days)
|
||||
|
||||
# Handle months and years separately (variable length)
|
||||
if months_ago > 0 or years_ago > 0:
|
||||
adjusted = adjusted - relativedelta(months=months_ago, years=years_ago)
|
||||
|
||||
adjusted_ts = int(adjusted.timestamp())
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"current_timestamp": current_ts,
|
||||
"current_iso": now.isoformat(),
|
||||
"calculated_timestamp": adjusted_ts,
|
||||
"calculated_iso": adjusted.isoformat(),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
except ImportError:
|
||||
# Fallback without dateutil
|
||||
import datetime
|
||||
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
current_ts = int(now.timestamp())
|
||||
total_days = days_ago + (weeks_ago * 7) + (months_ago * 30) + (years_ago * 365)
|
||||
adjusted = now - datetime.timedelta(days=total_days)
|
||||
adjusted_ts = int(adjusted.timestamp())
|
||||
return json.dumps(
|
||||
{
|
||||
"current_timestamp": current_ts,
|
||||
"current_iso": now.isoformat(),
|
||||
"calculated_timestamp": adjusted_ts,
|
||||
"calculated_iso": adjusted.isoformat(),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(f"calculate_timestamp error: {e}")
|
||||
return json.dumps({"error": str(e)})
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# WEB SEARCH TOOLS
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def web_search(
|
||||
query: str,
|
||||
@@ -70,18 +187,23 @@ async def fetch_url(
|
||||
|
||||
try:
|
||||
content, _ = get_content_from_url(__request__, url)
|
||||
|
||||
|
||||
# Truncate if too long (avoid overwhelming context)
|
||||
max_length = 50000
|
||||
if len(content) > max_length:
|
||||
content = content[:max_length] + "\n\n[Content truncated...]"
|
||||
|
||||
|
||||
return content
|
||||
except Exception as e:
|
||||
log.exception(f"fetch_url error: {e}")
|
||||
return json.dumps({"error": str(e)})
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# IMAGE GENERATION TOOLS
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def generate_image(
|
||||
prompt: str,
|
||||
__request__: Request = None,
|
||||
@@ -113,8 +235,7 @@ async def generate_image(
|
||||
"type": "files",
|
||||
"data": {
|
||||
"files": [
|
||||
{"type": "image", "url": img["url"]}
|
||||
for img in images
|
||||
{"type": "image", "url": img["url"]} for img in images
|
||||
]
|
||||
},
|
||||
}
|
||||
@@ -159,8 +280,7 @@ async def edit_image(
|
||||
"type": "files",
|
||||
"data": {
|
||||
"files": [
|
||||
{"type": "image", "url": img["url"]}
|
||||
for img in images
|
||||
{"type": "image", "url": img["url"]} for img in images
|
||||
]
|
||||
},
|
||||
}
|
||||
@@ -172,6 +292,11 @@ async def edit_image(
|
||||
return json.dumps({"error": str(e)})
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MEMORY TOOLS
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def memory_query(
|
||||
query: str,
|
||||
__request__: Request = None,
|
||||
@@ -199,9 +324,12 @@ async def memory_query(
|
||||
memories = []
|
||||
for doc_idx, doc in enumerate(results.documents[0]):
|
||||
created_at = "Unknown"
|
||||
if results.metadatas and results.metadatas[0][doc_idx].get("created_at"):
|
||||
if results.metadatas and results.metadatas[0][doc_idx].get(
|
||||
"created_at"
|
||||
):
|
||||
created_at = time.strftime(
|
||||
"%Y-%m-%d", time.localtime(results.metadatas[0][doc_idx]["created_at"])
|
||||
"%Y-%m-%d",
|
||||
time.localtime(results.metadatas[0][doc_idx]["created_at"]),
|
||||
)
|
||||
memories.append({"date": created_at, "content": doc})
|
||||
return json.dumps(memories, ensure_ascii=False)
|
||||
@@ -239,3 +367,701 @@ async def memory_add(
|
||||
except Exception as e:
|
||||
log.exception(f"memory_add error: {e}")
|
||||
return json.dumps({"error": str(e)})
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# NOTES TOOLS
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def search_notes(
|
||||
query: str,
|
||||
count: int = 5,
|
||||
start_timestamp: Optional[int] = None,
|
||||
end_timestamp: Optional[int] = None,
|
||||
__request__: Request = None,
|
||||
__user__: dict = None,
|
||||
) -> str:
|
||||
"""
|
||||
Search the user's notes by title and content.
|
||||
|
||||
:param query: The search query to find matching notes
|
||||
:param count: Maximum number of results to return (default: 5)
|
||||
:param start_timestamp: Only include notes updated after this Unix timestamp (seconds)
|
||||
:param end_timestamp: Only include notes updated before this Unix timestamp (seconds)
|
||||
:return: JSON with matching notes containing id, title, and content snippet
|
||||
"""
|
||||
if __request__ is None:
|
||||
return json.dumps({"error": "Request context not available"})
|
||||
|
||||
if not __user__:
|
||||
return json.dumps({"error": "User context not available"})
|
||||
|
||||
try:
|
||||
user_id = __user__.get("id")
|
||||
user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id)]
|
||||
|
||||
result = Notes.search_notes(
|
||||
user_id=user_id,
|
||||
filter={
|
||||
"query": query,
|
||||
"user_id": user_id,
|
||||
"group_ids": user_group_ids,
|
||||
"permission": "read",
|
||||
},
|
||||
skip=0,
|
||||
limit=count * 3, # Fetch more for filtering
|
||||
)
|
||||
|
||||
# Convert timestamps to nanoseconds for comparison
|
||||
start_ts = start_timestamp * 1_000_000_000 if start_timestamp else None
|
||||
end_ts = end_timestamp * 1_000_000_000 if end_timestamp else None
|
||||
|
||||
notes = []
|
||||
for note in result.items:
|
||||
# Apply date filters (updated_at is in nanoseconds)
|
||||
if start_ts and note.updated_at < start_ts:
|
||||
continue
|
||||
if end_ts and note.updated_at > end_ts:
|
||||
continue
|
||||
|
||||
# Extract a snippet from the markdown content
|
||||
content_snippet = ""
|
||||
if note.data and note.data.get("content", {}).get("md"):
|
||||
md_content = note.data["content"]["md"]
|
||||
lower_content = md_content.lower()
|
||||
lower_query = query.lower()
|
||||
idx = lower_content.find(lower_query)
|
||||
if idx != -1:
|
||||
start = max(0, idx - 50)
|
||||
end = min(len(md_content), idx + len(query) + 100)
|
||||
content_snippet = (
|
||||
("..." if start > 0 else "")
|
||||
+ md_content[start:end]
|
||||
+ ("..." if end < len(md_content) else "")
|
||||
)
|
||||
else:
|
||||
content_snippet = md_content[:150] + (
|
||||
"..." if len(md_content) > 150 else ""
|
||||
)
|
||||
|
||||
notes.append(
|
||||
{
|
||||
"id": note.id,
|
||||
"title": note.title,
|
||||
"snippet": content_snippet,
|
||||
"updated_at": note.updated_at,
|
||||
}
|
||||
)
|
||||
|
||||
if len(notes) >= count:
|
||||
break
|
||||
|
||||
return json.dumps(notes, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
log.exception(f"search_notes error: {e}")
|
||||
return json.dumps({"error": str(e)})
|
||||
|
||||
|
||||
async def view_note(
|
||||
note_id: str,
|
||||
__request__: Request = None,
|
||||
__user__: dict = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the full content of a note by its ID.
|
||||
|
||||
:param note_id: The ID of the note to retrieve
|
||||
:return: JSON with the note's id, title, and full markdown content
|
||||
"""
|
||||
if __request__ is None:
|
||||
return json.dumps({"error": "Request context not available"})
|
||||
|
||||
if not __user__:
|
||||
return json.dumps({"error": "User context not available"})
|
||||
|
||||
try:
|
||||
note = Notes.get_note_by_id(note_id)
|
||||
|
||||
if not note:
|
||||
return json.dumps({"error": "Note not found"})
|
||||
|
||||
# Check access permission
|
||||
user_id = __user__.get("id")
|
||||
user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id)]
|
||||
|
||||
from open_webui.utils.access_control import has_access
|
||||
|
||||
if note.user_id != user_id and not has_access(
|
||||
user_id, "read", note.access_control, user_group_ids
|
||||
):
|
||||
return json.dumps({"error": "Access denied"})
|
||||
|
||||
# Extract markdown content
|
||||
content = ""
|
||||
if note.data and note.data.get("content", {}).get("md"):
|
||||
content = note.data["content"]["md"]
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"id": note.id,
|
||||
"title": note.title,
|
||||
"content": content,
|
||||
"updated_at": note.updated_at,
|
||||
"created_at": note.created_at,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(f"view_note error: {e}")
|
||||
return json.dumps({"error": str(e)})
|
||||
|
||||
|
||||
async def write_note(
|
||||
title: str,
|
||||
content: str,
|
||||
__request__: Request = None,
|
||||
__user__: dict = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create a new note with the given title and content.
|
||||
|
||||
:param title: The title of the new note
|
||||
:param content: The markdown content for the note
|
||||
:return: JSON with success status and new note id
|
||||
"""
|
||||
if __request__ is None:
|
||||
return json.dumps({"error": "Request context not available"})
|
||||
|
||||
if not __user__:
|
||||
return json.dumps({"error": "User context not available"})
|
||||
|
||||
try:
|
||||
from open_webui.models.notes import NoteForm
|
||||
|
||||
user_id = __user__.get("id")
|
||||
|
||||
form = NoteForm(
|
||||
title=title,
|
||||
data={"content": {"md": content}},
|
||||
access_control={}, # Private by default - only owner can access
|
||||
)
|
||||
|
||||
new_note = Notes.insert_new_note(user_id, form)
|
||||
|
||||
if not new_note:
|
||||
return json.dumps({"error": "Failed to create note"})
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"status": "success",
|
||||
"id": new_note.id,
|
||||
"title": new_note.title,
|
||||
"created_at": new_note.created_at,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(f"write_note error: {e}")
|
||||
return json.dumps({"error": str(e)})
|
||||
|
||||
|
||||
async def replace_note_content(
|
||||
note_id: str,
|
||||
content: str,
|
||||
title: Optional[str] = None,
|
||||
__request__: Request = None,
|
||||
__user__: dict = None,
|
||||
) -> str:
|
||||
"""
|
||||
Update the content of a note. Use this to modify task lists, add notes, or update content.
|
||||
|
||||
:param note_id: The ID of the note to update
|
||||
:param content: The new markdown content for the note
|
||||
:param title: Optional new title for the note
|
||||
:return: JSON with success status and updated note info
|
||||
"""
|
||||
if __request__ is None:
|
||||
return json.dumps({"error": "Request context not available"})
|
||||
|
||||
if not __user__:
|
||||
return json.dumps({"error": "User context not available"})
|
||||
|
||||
try:
|
||||
from open_webui.models.notes import NoteUpdateForm
|
||||
|
||||
note = Notes.get_note_by_id(note_id)
|
||||
|
||||
if not note:
|
||||
return json.dumps({"error": "Note not found"})
|
||||
|
||||
# Check write permission
|
||||
user_id = __user__.get("id")
|
||||
user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id)]
|
||||
|
||||
from open_webui.utils.access_control import has_access
|
||||
|
||||
if note.user_id != user_id and not has_access(
|
||||
user_id, "write", note.access_control, user_group_ids
|
||||
):
|
||||
return json.dumps({"error": "Write access denied"})
|
||||
|
||||
# Build update form
|
||||
update_data = {"data": {"content": {"md": content}}}
|
||||
if title:
|
||||
update_data["title"] = title
|
||||
|
||||
form = NoteUpdateForm(**update_data)
|
||||
updated_note = Notes.update_note_by_id(note_id, form)
|
||||
|
||||
if not updated_note:
|
||||
return json.dumps({"error": "Failed to update note"})
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"status": "success",
|
||||
"id": updated_note.id,
|
||||
"title": updated_note.title,
|
||||
"updated_at": updated_note.updated_at,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(f"replace_note_content error: {e}")
|
||||
return json.dumps({"error": str(e)})
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CHATS TOOLS
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def search_chats(
|
||||
query: str,
|
||||
count: int = 5,
|
||||
start_timestamp: Optional[int] = None,
|
||||
end_timestamp: Optional[int] = None,
|
||||
__request__: Request = None,
|
||||
__user__: dict = None,
|
||||
) -> str:
|
||||
"""
|
||||
Search the user's previous chat conversations by title and message content.
|
||||
|
||||
:param query: The search query to find matching chats
|
||||
:param count: Maximum number of results to return (default: 5)
|
||||
:param start_timestamp: Only include chats updated after this Unix timestamp (seconds)
|
||||
:param end_timestamp: Only include chats updated before this Unix timestamp (seconds)
|
||||
:return: JSON with matching chats containing id, title, updated_at, and content snippet
|
||||
"""
|
||||
if __request__ is None:
|
||||
return json.dumps({"error": "Request context not available"})
|
||||
|
||||
if not __user__:
|
||||
return json.dumps({"error": "User context not available"})
|
||||
|
||||
try:
|
||||
user_id = __user__.get("id")
|
||||
|
||||
chats = Chats.get_chats_by_user_id_and_search_text(
|
||||
user_id=user_id,
|
||||
search_text=query,
|
||||
include_archived=False,
|
||||
skip=0,
|
||||
limit=count * 3, # Fetch more for filtering
|
||||
)
|
||||
|
||||
results = []
|
||||
for chat in chats:
|
||||
# Apply date filters (updated_at is in seconds)
|
||||
if start_timestamp and chat.updated_at < start_timestamp:
|
||||
continue
|
||||
if end_timestamp and chat.updated_at > end_timestamp:
|
||||
continue
|
||||
|
||||
# Find a matching message snippet
|
||||
snippet = ""
|
||||
messages = chat.chat.get("history", {}).get("messages", {})
|
||||
lower_query = query.lower()
|
||||
|
||||
for msg_id, msg in messages.items():
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, str) and lower_query in content.lower():
|
||||
idx = content.lower().find(lower_query)
|
||||
start = max(0, idx - 50)
|
||||
end = min(len(content), idx + len(query) + 100)
|
||||
snippet = (
|
||||
("..." if start > 0 else "")
|
||||
+ content[start:end]
|
||||
+ ("..." if end < len(content) else "")
|
||||
)
|
||||
break
|
||||
|
||||
if not snippet and lower_query in chat.title.lower():
|
||||
snippet = f"Title match: {chat.title}"
|
||||
|
||||
results.append(
|
||||
{
|
||||
"id": chat.id,
|
||||
"title": chat.title,
|
||||
"snippet": snippet,
|
||||
"updated_at": chat.updated_at,
|
||||
}
|
||||
)
|
||||
|
||||
if len(results) >= count:
|
||||
break
|
||||
|
||||
return json.dumps(results, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
log.exception(f"search_chats error: {e}")
|
||||
return json.dumps({"error": str(e)})
|
||||
|
||||
|
||||
async def view_chat(
|
||||
chat_id: str,
|
||||
__request__: Request = None,
|
||||
__user__: dict = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the full conversation history of a chat by its ID.
|
||||
|
||||
:param chat_id: The ID of the chat to retrieve
|
||||
:return: JSON with the chat's id, title, and messages
|
||||
"""
|
||||
if __request__ is None:
|
||||
return json.dumps({"error": "Request context not available"})
|
||||
|
||||
if not __user__:
|
||||
return json.dumps({"error": "User context not available"})
|
||||
|
||||
try:
|
||||
user_id = __user__.get("id")
|
||||
|
||||
chat = Chats.get_chat_by_id_and_user_id(chat_id, user_id)
|
||||
|
||||
if not chat:
|
||||
return json.dumps({"error": "Chat not found or access denied"})
|
||||
|
||||
# Extract messages from history
|
||||
messages = []
|
||||
history = chat.chat.get("history", {})
|
||||
msg_dict = history.get("messages", {})
|
||||
|
||||
# Build message chain from currentId
|
||||
current_id = history.get("currentId")
|
||||
visited = set()
|
||||
|
||||
while current_id and current_id not in visited:
|
||||
visited.add(current_id)
|
||||
msg = msg_dict.get(current_id)
|
||||
if msg:
|
||||
messages.append(
|
||||
{
|
||||
"role": msg.get("role", ""),
|
||||
"content": msg.get("content", ""),
|
||||
}
|
||||
)
|
||||
current_id = msg.get("parentId") if msg else None
|
||||
|
||||
# Reverse to get chronological order
|
||||
messages.reverse()
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"id": chat.id,
|
||||
"title": chat.title,
|
||||
"messages": messages,
|
||||
"updated_at": chat.updated_at,
|
||||
"created_at": chat.created_at,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(f"view_chat error: {e}")
|
||||
return json.dumps({"error": str(e)})
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CHANNELS TOOLS
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def search_channels(
|
||||
query: str,
|
||||
count: int = 5,
|
||||
__request__: Request = None,
|
||||
__user__: dict = None,
|
||||
) -> str:
|
||||
"""
|
||||
Search for channels by name and description that the user has access to.
|
||||
|
||||
:param query: The search query to find matching channels
|
||||
:param count: Maximum number of results to return (default: 5)
|
||||
:return: JSON with matching channels containing id, name, description, and type
|
||||
"""
|
||||
if __request__ is None:
|
||||
return json.dumps({"error": "Request context not available"})
|
||||
|
||||
if not __user__:
|
||||
return json.dumps({"error": "User context not available"})
|
||||
|
||||
try:
|
||||
user_id = __user__.get("id")
|
||||
|
||||
# Get all channels the user has access to
|
||||
all_channels = Channels.get_channels_by_user_id(user_id)
|
||||
|
||||
# Filter by query
|
||||
lower_query = query.lower()
|
||||
matching_channels = []
|
||||
|
||||
for channel in all_channels:
|
||||
name_match = lower_query in channel.name.lower() if channel.name else False
|
||||
desc_match = lower_query in (channel.description or "").lower()
|
||||
|
||||
if name_match or desc_match:
|
||||
matching_channels.append(
|
||||
{
|
||||
"id": channel.id,
|
||||
"name": channel.name,
|
||||
"description": channel.description or "",
|
||||
"type": channel.type or "public",
|
||||
}
|
||||
)
|
||||
|
||||
if len(matching_channels) >= count:
|
||||
break
|
||||
|
||||
return json.dumps(matching_channels, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
log.exception(f"search_channels error: {e}")
|
||||
return json.dumps({"error": str(e)})
|
||||
|
||||
|
||||
async def search_channel_messages(
|
||||
query: str,
|
||||
count: int = 10,
|
||||
start_timestamp: Optional[int] = None,
|
||||
end_timestamp: Optional[int] = None,
|
||||
__request__: Request = None,
|
||||
__user__: dict = None,
|
||||
) -> str:
|
||||
"""
|
||||
Search for messages in channels the user is a member of, including thread replies.
|
||||
|
||||
:param query: The search query to find matching messages
|
||||
:param count: Maximum number of results to return (default: 10)
|
||||
:param start_timestamp: Only include messages created after this Unix timestamp (seconds)
|
||||
:param end_timestamp: Only include messages created before this Unix timestamp (seconds)
|
||||
:return: JSON with matching messages containing channel info, message content, and thread context
|
||||
"""
|
||||
if __request__ is None:
|
||||
return json.dumps({"error": "Request context not available"})
|
||||
|
||||
if not __user__:
|
||||
return json.dumps({"error": "User context not available"})
|
||||
|
||||
try:
|
||||
user_id = __user__.get("id")
|
||||
|
||||
# Get all channels the user has access to
|
||||
user_channels = Channels.get_channels_by_user_id(user_id)
|
||||
channel_ids = [c.id for c in user_channels]
|
||||
channel_map = {c.id: c for c in user_channels}
|
||||
|
||||
if not channel_ids:
|
||||
return json.dumps([])
|
||||
|
||||
# Convert timestamps to nanoseconds (Message.created_at is in nanoseconds)
|
||||
start_ts = start_timestamp * 1_000_000_000 if start_timestamp else None
|
||||
end_ts = end_timestamp * 1_000_000_000 if end_timestamp else None
|
||||
|
||||
# Search messages using the model method
|
||||
matching_messages = Messages.search_messages_by_channel_ids(
|
||||
channel_ids=channel_ids,
|
||||
query=query,
|
||||
start_timestamp=start_ts,
|
||||
end_timestamp=end_ts,
|
||||
limit=count,
|
||||
)
|
||||
|
||||
results = []
|
||||
for msg in matching_messages:
|
||||
channel = channel_map.get(msg.channel_id)
|
||||
|
||||
# Extract snippet around the match
|
||||
content = msg.content or ""
|
||||
lower_query = query.lower()
|
||||
idx = content.lower().find(lower_query)
|
||||
if idx != -1:
|
||||
start = max(0, idx - 50)
|
||||
end = min(len(content), idx + len(query) + 100)
|
||||
snippet = (
|
||||
("..." if start > 0 else "")
|
||||
+ content[start:end]
|
||||
+ ("..." if end < len(content) else "")
|
||||
)
|
||||
else:
|
||||
snippet = content[:150] + ("..." if len(content) > 150 else "")
|
||||
|
||||
results.append(
|
||||
{
|
||||
"channel_id": msg.channel_id,
|
||||
"channel_name": channel.name if channel else "Unknown",
|
||||
"message_id": msg.id,
|
||||
"content_snippet": snippet,
|
||||
"is_thread_reply": msg.parent_id is not None,
|
||||
"parent_id": msg.parent_id,
|
||||
"created_at": msg.created_at,
|
||||
}
|
||||
)
|
||||
|
||||
return json.dumps(results, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
log.exception(f"search_channel_messages error: {e}")
|
||||
return json.dumps({"error": str(e)})
|
||||
|
||||
|
||||
async def view_channel_message(
|
||||
message_id: str,
|
||||
__request__: Request = None,
|
||||
__user__: dict = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the full content of a channel message by its ID, including thread replies.
|
||||
|
||||
:param message_id: The ID of the message to retrieve
|
||||
:return: JSON with the message content, channel info, and thread replies if any
|
||||
"""
|
||||
if __request__ is None:
|
||||
return json.dumps({"error": "Request context not available"})
|
||||
|
||||
if not __user__:
|
||||
return json.dumps({"error": "User context not available"})
|
||||
|
||||
try:
|
||||
user_id = __user__.get("id")
|
||||
|
||||
message = Messages.get_message_by_id(message_id)
|
||||
|
||||
if not message:
|
||||
return json.dumps({"error": "Message not found"})
|
||||
|
||||
# Verify user has access to the channel
|
||||
channel = Channels.get_channel_by_id(message.channel_id)
|
||||
if not channel:
|
||||
return json.dumps({"error": "Channel not found"})
|
||||
|
||||
# Check if user has access to the channel
|
||||
user_channels = Channels.get_channels_by_user_id(user_id)
|
||||
channel_ids = [c.id for c in user_channels]
|
||||
|
||||
if message.channel_id not in channel_ids:
|
||||
return json.dumps({"error": "Access denied"})
|
||||
|
||||
# Build response with thread information
|
||||
result = {
|
||||
"id": message.id,
|
||||
"channel_id": message.channel_id,
|
||||
"channel_name": channel.name,
|
||||
"content": message.content,
|
||||
"user_id": message.user_id,
|
||||
"is_thread_reply": message.parent_id is not None,
|
||||
"parent_id": message.parent_id,
|
||||
"reply_count": message.reply_count,
|
||||
"created_at": message.created_at,
|
||||
"updated_at": message.updated_at,
|
||||
}
|
||||
|
||||
# Include user info if available
|
||||
if message.user:
|
||||
result["user_name"] = message.user.name
|
||||
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
log.exception(f"view_channel_message error: {e}")
|
||||
return json.dumps({"error": str(e)})
|
||||
|
||||
|
||||
async def view_channel_thread(
|
||||
parent_message_id: str,
|
||||
__request__: Request = None,
|
||||
__user__: dict = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get all messages in a channel thread, including the parent message and all replies.
|
||||
|
||||
:param parent_message_id: The ID of the parent message that started the thread
|
||||
:return: JSON with the parent message and all thread replies in chronological order
|
||||
"""
|
||||
if __request__ is None:
|
||||
return json.dumps({"error": "Request context not available"})
|
||||
|
||||
if not __user__:
|
||||
return json.dumps({"error": "User context not available"})
|
||||
|
||||
try:
|
||||
user_id = __user__.get("id")
|
||||
|
||||
# Get the parent message
|
||||
parent_message = Messages.get_message_by_id(parent_message_id)
|
||||
|
||||
if not parent_message:
|
||||
return json.dumps({"error": "Message not found"})
|
||||
|
||||
# Verify user has access to the channel
|
||||
channel = Channels.get_channel_by_id(parent_message.channel_id)
|
||||
if not channel:
|
||||
return json.dumps({"error": "Channel not found"})
|
||||
|
||||
user_channels = Channels.get_channels_by_user_id(user_id)
|
||||
channel_ids = [c.id for c in user_channels]
|
||||
|
||||
if parent_message.channel_id not in channel_ids:
|
||||
return json.dumps({"error": "Access denied"})
|
||||
|
||||
# Get all thread replies
|
||||
thread_replies = Messages.get_thread_replies_by_message_id(parent_message_id)
|
||||
|
||||
# Build the response
|
||||
messages = []
|
||||
|
||||
# Add parent message first
|
||||
messages.append(
|
||||
{
|
||||
"id": parent_message.id,
|
||||
"content": parent_message.content,
|
||||
"user_id": parent_message.user_id,
|
||||
"user_name": parent_message.user.name if parent_message.user else None,
|
||||
"is_parent": True,
|
||||
"created_at": parent_message.created_at,
|
||||
}
|
||||
)
|
||||
|
||||
# Add thread replies (reverse to get chronological order)
|
||||
for reply in reversed(thread_replies):
|
||||
messages.append(
|
||||
{
|
||||
"id": reply.id,
|
||||
"content": reply.content,
|
||||
"user_id": reply.user_id,
|
||||
"user_name": reply.user.name if reply.user else None,
|
||||
"is_parent": False,
|
||||
"reply_to_id": reply.reply_to_id,
|
||||
"created_at": reply.created_at,
|
||||
}
|
||||
)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"channel_id": parent_message.channel_id,
|
||||
"channel_name": channel.name,
|
||||
"thread_id": parent_message_id,
|
||||
"message_count": len(messages),
|
||||
"messages": messages,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(f"view_channel_thread error: {e}")
|
||||
return json.dumps({"error": str(e)})
|
||||
|
||||
@@ -44,8 +44,24 @@ from open_webui.env import (
|
||||
AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL,
|
||||
)
|
||||
from open_webui.tools.builtin import (
|
||||
web_search, fetch_url, generate_image, edit_image,
|
||||
memory_query, memory_add
|
||||
web_search,
|
||||
fetch_url,
|
||||
generate_image,
|
||||
edit_image,
|
||||
memory_query,
|
||||
memory_add,
|
||||
get_current_timestamp,
|
||||
calculate_timestamp,
|
||||
search_notes,
|
||||
search_chats,
|
||||
search_channels,
|
||||
search_channel_messages,
|
||||
view_note,
|
||||
view_chat,
|
||||
view_channel_message,
|
||||
view_channel_thread,
|
||||
replace_note_content,
|
||||
write_note,
|
||||
)
|
||||
|
||||
import copy
|
||||
@@ -324,7 +340,9 @@ async def get_tools(
|
||||
return tools_dict
|
||||
|
||||
|
||||
def get_builtin_tools(request: Request, extra_params: dict, features: dict = None) -> dict[str, dict]:
|
||||
def get_builtin_tools(
|
||||
request: Request, extra_params: dict, features: dict = None
|
||||
) -> dict[str, dict]:
|
||||
"""
|
||||
Get built-in tools for native function calling.
|
||||
Only returns tools when BOTH the global config is enabled AND the feature is enabled for this chat.
|
||||
@@ -333,23 +351,47 @@ def get_builtin_tools(request: Request, extra_params: dict, features: dict = Non
|
||||
builtin_functions = []
|
||||
features = features or {}
|
||||
|
||||
# Add web search tools if enabled globally AND for this chat
|
||||
if (getattr(request.app.state.config, "ENABLE_WEB_SEARCH", False)
|
||||
and features.get("web_search")):
|
||||
builtin_functions.extend([web_search, fetch_url])
|
||||
# Time utilities - always available for date calculations
|
||||
builtin_functions.extend([get_current_timestamp, calculate_timestamp])
|
||||
|
||||
# Add image generation/edit tools if enabled globally AND for this chat
|
||||
if (getattr(request.app.state.config, "ENABLE_IMAGE_GENERATION", False)
|
||||
and features.get("image_generation")):
|
||||
builtin_functions.append(generate_image)
|
||||
if (getattr(request.app.state.config, "ENABLE_IMAGE_EDIT", False)
|
||||
and features.get("image_generation")):
|
||||
builtin_functions.append(edit_image)
|
||||
# Chats tools - search and fetch user's chat history (always available)
|
||||
builtin_functions.extend([search_chats, view_chat])
|
||||
|
||||
# Add memory tools if enabled for this chat
|
||||
if features.get("memory"):
|
||||
builtin_functions.extend([memory_query, memory_add])
|
||||
|
||||
# Add web search tools if enabled globally AND for this chat
|
||||
if getattr(request.app.state.config, "ENABLE_WEB_SEARCH", False) and features.get(
|
||||
"web_search"
|
||||
):
|
||||
builtin_functions.extend([web_search, fetch_url])
|
||||
|
||||
# Add image generation/edit tools if enabled globally AND for this chat
|
||||
if getattr(
|
||||
request.app.state.config, "ENABLE_IMAGE_GENERATION", False
|
||||
) and features.get("image_generation"):
|
||||
builtin_functions.append(generate_image)
|
||||
if getattr(request.app.state.config, "ENABLE_IMAGE_EDIT", False) and features.get(
|
||||
"image_generation"
|
||||
):
|
||||
builtin_functions.append(edit_image)
|
||||
|
||||
# Notes tools - search, view, create, and update user's notes (if notes enabled globally)
|
||||
if getattr(request.app.state.config, "ENABLE_NOTES", False):
|
||||
builtin_functions.extend([search_notes, view_note, write_note, replace_note_content])
|
||||
|
||||
# Channels tools - search channels and messages (if channels enabled globally)
|
||||
if getattr(request.app.state.config, "ENABLE_CHANNELS", False):
|
||||
builtin_functions.extend(
|
||||
[
|
||||
search_channels,
|
||||
search_channel_messages,
|
||||
view_channel_thread,
|
||||
view_channel_message,
|
||||
]
|
||||
)
|
||||
|
||||
for func in builtin_functions:
|
||||
callable = get_async_tool_function_and_apply_extra_params(
|
||||
func,
|
||||
|
||||
Reference in New Issue
Block a user