diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py index abc8920884..f5177eb787 100644 --- a/backend/open_webui/utils/task.py +++ b/backend/open_webui/utils/task.py @@ -142,41 +142,125 @@ def replace_prompt_variable(template: str, prompt: str) -> str: return template +def truncate_content(content: str, max_chars: int, mode: str = "middletruncate") -> str: + """Truncate a string to max_chars using the specified mode. + + Modes: + - middletruncate: keep beginning and end, join with '...' + - start: keep first max_chars characters + - end: keep last max_chars characters + """ + if not content or len(content) <= max_chars: + return content + + if mode == "start": + return content[:max_chars] + elif mode == "end": + return content[-max_chars:] + else: # middletruncate + half = max_chars // 2 + return f"{content[:half]}...{content[-(max_chars - half):]}" + + +def apply_content_filter( + messages: list[dict], filter_str: str +) -> list[dict]: + """Apply a content filter to each message's content. + + filter_str is like 'middletruncate:500', 'start:200', or 'end:200'. + Returns a new list with truncated content (original messages are not mutated). + """ + parts = filter_str.split(":") + if len(parts) != 2: + return messages + + mode = parts[0].lower() + try: + max_chars = int(parts[1]) + except ValueError: + return messages + + if mode not in ("middletruncate", "start", "end"): + return messages + + result = [] + for msg in messages: + new_msg = dict(msg) + if isinstance(new_msg.get("content"), str): + new_msg["content"] = truncate_content(new_msg["content"], max_chars, mode) + elif isinstance(new_msg.get("content"), list): + new_content = [] + for item in new_msg["content"]: + if isinstance(item, dict) and item.get("type") == "text": + new_item = dict(item) + new_item["text"] = truncate_content( + item.get("text", ""), max_chars, mode + ) + new_content.append(new_item) + else: + new_content.append(item) + new_msg["content"] = new_content + result.append(new_msg) + return result + + def replace_messages_variable( template: str, messages: Optional[list[dict]] = None ) -> str: def replacement_function(match): - full_match = match.group(0) - start_length = match.group(1) - end_length = match.group(2) - middle_length = match.group(3) + # Groups: (1) filter for bare MESSAGES + # (2) START count, (3) filter for START + # (4) END count, (5) filter for END + # (6) MIDDLE count,(7) filter for MIDDLE + bare_filter = match.group(1) + start_length = match.group(2) + start_filter = match.group(3) + end_length = match.group(4) + end_filter = match.group(5) + middle_length = match.group(6) + middle_filter = match.group(7) + # If messages is None, handle it as an empty list if messages is None: return "" - # Process messages based on the number of messages required - if full_match == "{{MESSAGES}}": - return get_messages_content(messages) - elif start_length is not None: - return get_messages_content(messages[: int(start_length)]) + # Select messages based on the variant + if start_length is not None: + selected = messages[: int(start_length)] + content_filter = start_filter elif end_length is not None: - return get_messages_content(messages[-int(end_length) :]) + selected = messages[-int(end_length) :] + content_filter = end_filter elif middle_length is not None: mid = int(middle_length) - if len(messages) <= mid: - return get_messages_content(messages) - # Handle middle truncation: split to get start and end portions of the messages list - half = mid // 2 - start_msgs = messages[:half] - end_msgs = messages[-half:] if mid % 2 == 0 else messages[-(half + 1) :] - formatted_start = get_messages_content(start_msgs) - formatted_end = get_messages_content(end_msgs) - return f"{formatted_start}\n{formatted_end}" - return "" + selected = messages + else: + half = mid // 2 + start_msgs = messages[:half] + end_msgs = ( + messages[-half:] if mid % 2 == 0 else messages[-(half + 1) :] + ) + selected = start_msgs + end_msgs + content_filter = middle_filter + else: + # Bare {{MESSAGES}} or {{MESSAGES|filter}} + selected = messages + content_filter = bare_filter + + # Apply content filter if present + if content_filter: + selected = apply_content_filter(selected, content_filter) + + return get_messages_content(selected) template = re.sub( - r"{{MESSAGES}}|{{MESSAGES:START:(\d+)}}|{{MESSAGES:END:(\d+)}}|{{MESSAGES:MIDDLETRUNCATE:(\d+)}}", + r"(?:" + r"\{\{MESSAGES(?:\|(\w+:\d+))?\}\}" + r"|\{\{MESSAGES:START:(\d+)(?:\|(\w+:\d+))?\}\}" + r"|\{\{MESSAGES:END:(\d+)(?:\|(\w+:\d+))?\}\}" + r"|\{\{MESSAGES:MIDDLETRUNCATE:(\d+)(?:\|(\w+:\d+))?\}\}" + r")", replacement_function, template, )