mirror of
https://github.com/open-webui/open-webui.git
synced 2026-05-03 02:39:11 -05:00
290 lines
9.8 KiB
Python
290 lines
9.8 KiB
Python
from contextlib import asynccontextmanager
|
||
from dataclasses import asdict, dataclass
|
||
from enum import Enum
|
||
import re
|
||
from typing import (
|
||
TYPE_CHECKING,
|
||
Any,
|
||
AsyncGenerator,
|
||
Dict,
|
||
MutableMapping,
|
||
Optional,
|
||
cast,
|
||
)
|
||
import uuid
|
||
|
||
from asgiref.typing import (
|
||
ASGI3Application,
|
||
ASGIReceiveCallable,
|
||
ASGIReceiveEvent,
|
||
ASGISendCallable,
|
||
ASGISendEvent,
|
||
Scope as ASGIScope,
|
||
)
|
||
from loguru import logger
|
||
from starlette.requests import Request
|
||
|
||
from open_webui.env import AUDIT_LOG_LEVEL, ENABLE_AUDIT_GET_REQUESTS, AUDIT_INCLUDED_PATHS, MAX_BODY_LOG_SIZE
|
||
from open_webui.utils.auth import get_current_user, get_http_authorization_cred
|
||
from open_webui.models.users import UserModel
|
||
|
||
if TYPE_CHECKING:
|
||
from loguru import Logger
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class AuditLogEntry:
|
||
# `Metadata` audit level properties
|
||
id: str
|
||
user: Optional[dict[str, Any]]
|
||
audit_level: str
|
||
verb: str
|
||
request_uri: str
|
||
user_agent: Optional[str] = None
|
||
source_ip: Optional[str] = None
|
||
# `Request` audit level properties
|
||
request_object: Any = None
|
||
# `Request Response` level
|
||
response_object: Any = None
|
||
response_status_code: Optional[int] = None
|
||
|
||
|
||
class AuditLevel(str, Enum):
|
||
NONE = 'NONE'
|
||
METADATA = 'METADATA'
|
||
REQUEST = 'REQUEST'
|
||
REQUEST_RESPONSE = 'REQUEST_RESPONSE'
|
||
|
||
|
||
class AuditLogger:
|
||
"""
|
||
A helper class that encapsulates audit logging functionality. It uses Loguru’s logger with an auditable binding to ensure that audit log entries are filtered correctly.
|
||
|
||
Parameters:
|
||
logger (Logger): An instance of Loguru’s logger.
|
||
"""
|
||
|
||
def __init__(self, logger: 'Logger'):
|
||
self.logger = logger.bind(auditable=True)
|
||
|
||
def write(
|
||
self,
|
||
audit_entry: AuditLogEntry,
|
||
*,
|
||
log_level: str = 'INFO',
|
||
extra: Optional[dict] = None,
|
||
):
|
||
entry = asdict(audit_entry)
|
||
|
||
if extra:
|
||
entry['extra'] = extra
|
||
|
||
self.logger.log(
|
||
log_level,
|
||
'',
|
||
**entry,
|
||
)
|
||
|
||
|
||
class AuditContext:
|
||
"""
|
||
Captures and aggregates the HTTP request and response bodies during the processing of a request. It ensures that only a configurable maximum amount of data is stored to prevent excessive memory usage.
|
||
|
||
Attributes:
|
||
request_body (bytearray): Accumulated request payload.
|
||
response_body (bytearray): Accumulated response payload.
|
||
max_body_size (int): Maximum number of bytes to capture.
|
||
metadata (Dict[str, Any]): A dictionary to store additional audit metadata (user, http verb, user agent, etc.).
|
||
"""
|
||
|
||
def __init__(self, max_body_size: int = MAX_BODY_LOG_SIZE):
|
||
self.request_body = bytearray()
|
||
self.response_body = bytearray()
|
||
self.max_body_size = max_body_size
|
||
self.metadata: Dict[str, Any] = {}
|
||
|
||
def add_request_chunk(self, chunk: bytes):
|
||
if len(self.request_body) < self.max_body_size:
|
||
self.request_body.extend(chunk[: self.max_body_size - len(self.request_body)])
|
||
|
||
def add_response_chunk(self, chunk: bytes):
|
||
if len(self.response_body) < self.max_body_size:
|
||
self.response_body.extend(chunk[: self.max_body_size - len(self.response_body)])
|
||
|
||
|
||
class AuditLoggingMiddleware:
|
||
"""
|
||
ASGI middleware that intercepts HTTP requests and responses to perform audit logging. It captures request/response bodies (depending on audit level), headers, HTTP methods, and user information, then logs a structured audit entry at the end of the request cycle.
|
||
"""
|
||
|
||
DEFAULT_AUDITED_METHODS = {'PUT', 'PATCH', 'DELETE', 'POST'}
|
||
|
||
def __init__(
|
||
self,
|
||
app: ASGI3Application,
|
||
*,
|
||
excluded_paths: Optional[list[str]] = None,
|
||
included_paths: Optional[list[str]] = None,
|
||
max_body_size: int = MAX_BODY_LOG_SIZE,
|
||
audit_level: AuditLevel = AuditLevel.NONE,
|
||
audit_get_requests: bool = False,
|
||
) -> None:
|
||
self.app = app
|
||
self.audit_logger = AuditLogger(logger)
|
||
self.excluded_paths = excluded_paths or []
|
||
self.included_paths = included_paths or []
|
||
self.max_body_size = max_body_size
|
||
self.audited_methods = set(self.DEFAULT_AUDITED_METHODS)
|
||
if audit_get_requests:
|
||
self.audited_methods.add('GET')
|
||
self.audit_level = audit_level
|
||
|
||
if self.included_paths and self.excluded_paths:
|
||
logger.warning(
|
||
'Both AUDIT_INCLUDED_PATHS and AUDIT_EXCLUDED_PATHS are set. '
|
||
'AUDIT_INCLUDED_PATHS (whitelist) takes precedence.'
|
||
)
|
||
|
||
async def __call__(
|
||
self,
|
||
scope: ASGIScope,
|
||
receive: ASGIReceiveCallable,
|
||
send: ASGISendCallable,
|
||
) -> None:
|
||
if scope['type'] != 'http':
|
||
return await self.app(scope, receive, send)
|
||
|
||
request = Request(scope=cast(MutableMapping, scope))
|
||
|
||
if self._should_skip_auditing(request):
|
||
return await self.app(scope, receive, send)
|
||
|
||
async with self._audit_context(request) as context:
|
||
|
||
async def send_wrapper(message: ASGISendEvent) -> None:
|
||
if self.audit_level == AuditLevel.REQUEST_RESPONSE:
|
||
await self._capture_response(message, context)
|
||
|
||
await send(message)
|
||
|
||
original_receive = receive
|
||
|
||
async def receive_wrapper() -> ASGIReceiveEvent:
|
||
nonlocal original_receive
|
||
message = await original_receive()
|
||
|
||
if self.audit_level in (
|
||
AuditLevel.REQUEST,
|
||
AuditLevel.REQUEST_RESPONSE,
|
||
):
|
||
await self._capture_request(message, context)
|
||
|
||
return message
|
||
|
||
await self.app(scope, receive_wrapper, send_wrapper)
|
||
|
||
@asynccontextmanager
|
||
async def _audit_context(self, request: Request) -> AsyncGenerator[AuditContext, None]:
|
||
"""
|
||
async context manager that ensures that an audit log entry is recorded after the request is processed.
|
||
"""
|
||
context = AuditContext()
|
||
try:
|
||
yield context
|
||
finally:
|
||
await self._log_audit_entry(request, context)
|
||
|
||
async def _get_authenticated_user(self, request: Request) -> Optional[UserModel]:
|
||
auth_header = request.headers.get('Authorization')
|
||
|
||
try:
|
||
user = await get_current_user(request, None, None, get_http_authorization_cred(auth_header))
|
||
return user
|
||
except Exception as e:
|
||
logger.debug(f'Failed to get authenticated user: {str(e)}')
|
||
|
||
return None
|
||
|
||
def _should_skip_auditing(self, request: Request) -> bool:
|
||
if AUDIT_LOG_LEVEL == 'NONE':
|
||
return True
|
||
|
||
if request.method not in self.audited_methods:
|
||
return True
|
||
|
||
ALWAYS_LOG_ENDPOINTS = {
|
||
'/api/v1/auths/signin',
|
||
'/api/v1/auths/signout',
|
||
'/api/v1/auths/signup',
|
||
}
|
||
path = request.url.path.lower()
|
||
for endpoint in ALWAYS_LOG_ENDPOINTS:
|
||
if path.startswith(endpoint):
|
||
return False # Do NOT skip logging for auth endpoints
|
||
|
||
# Skip logging if the request is not authenticated
|
||
# Check both Authorization header (API keys) and token cookie (browser sessions)
|
||
if not request.headers.get('authorization') and not request.cookies.get('token'):
|
||
return True
|
||
|
||
# Whitelist mode: only log paths that match included_paths
|
||
if self.included_paths:
|
||
pattern = re.compile(r'^/api(?:/v1)?/(' + '|'.join(self.included_paths) + r')\b')
|
||
if not pattern.match(request.url.path):
|
||
return True # Skip: path not in whitelist
|
||
return False # Do NOT skip: path is in whitelist
|
||
|
||
# Blacklist mode: skip paths that match excluded_paths
|
||
pattern = re.compile(r'^/api(?:/v1)?/(' + '|'.join(self.excluded_paths) + r')\b')
|
||
if pattern.match(request.url.path):
|
||
return True
|
||
|
||
return False
|
||
|
||
async def _capture_request(self, message: ASGIReceiveEvent, context: AuditContext):
|
||
if message['type'] == 'http.request':
|
||
body = message.get('body', b'')
|
||
context.add_request_chunk(body)
|
||
|
||
async def _capture_response(self, message: ASGISendEvent, context: AuditContext):
|
||
if message['type'] == 'http.response.start':
|
||
context.metadata['response_status_code'] = message['status']
|
||
|
||
elif message['type'] == 'http.response.body':
|
||
body = message.get('body', b'')
|
||
context.add_response_chunk(body)
|
||
|
||
async def _log_audit_entry(self, request: Request, context: AuditContext):
|
||
try:
|
||
user = await self._get_authenticated_user(request)
|
||
|
||
user = user.model_dump(include={'id', 'name', 'email', 'role'}) if user else {}
|
||
|
||
request_body = context.request_body.decode('utf-8', errors='replace')
|
||
response_body = context.response_body.decode('utf-8', errors='replace')
|
||
|
||
# Redact sensitive information
|
||
if 'password' in request_body:
|
||
request_body = re.sub(
|
||
r'"password":\s*"(.*?)"',
|
||
'"password": "********"',
|
||
request_body,
|
||
)
|
||
|
||
entry = AuditLogEntry(
|
||
id=str(uuid.uuid4()),
|
||
user=user,
|
||
audit_level=self.audit_level.value,
|
||
verb=request.method,
|
||
request_uri=str(request.url),
|
||
response_status_code=context.metadata.get('response_status_code', None),
|
||
source_ip=request.client.host if request.client else None,
|
||
user_agent=request.headers.get('user-agent'),
|
||
request_object=request_body,
|
||
response_object=response_body,
|
||
)
|
||
|
||
self.audit_logger.write(entry)
|
||
except Exception as e:
|
||
logger.error(f'Failed to log audit entry: {str(e)}')
|