mirror of
https://github.com/open-webui/open-webui.git
synced 2026-05-01 09:49:03 -05:00
269 lines
10 KiB
Python
269 lines
10 KiB
Python
"""
|
|
Pure-ASGI replacements for the project's previous
|
|
`@app.middleware('http')` / `BaseHTTPMiddleware` middlewares.
|
|
|
|
Why this matters
|
|
----------------
|
|
Starlette's `BaseHTTPMiddleware` (which `@app.middleware('http')` is
|
|
sugar for) runs the downstream app inside an `anyio` task group. When
|
|
the wrapper exits — for any reason: response complete, client
|
|
disconnect, an outer middleware bailing out — the task group cancels
|
|
the inner task. That `CancelledError` then propagates into whatever
|
|
the inner task was doing, including in-flight DB queries, embedding
|
|
calls and disk I/O.
|
|
|
|
In Open WebUI this surfaces as:
|
|
|
|
* SQLAlchemy logging multi-page `NotImplementedError:
|
|
terminate_force_close()` tracebacks at ERROR every time a request is
|
|
cancelled mid-DB-call (the aiosqlite connector cleanup path).
|
|
* Spurious cancellations cascading through the four stacked
|
|
`@app.middleware('http')` wrappers.
|
|
|
|
Pure ASGI middleware does not introduce a cancel scope around the
|
|
downstream app, so client disconnects propagate the way ASGI was
|
|
designed to (via `receive()` returning `http.disconnect`) instead of
|
|
being injected as `CancelledError` into arbitrary `await` points.
|
|
|
|
Reference: https://www.starlette.io/middleware/#limitations
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import re
|
|
import time
|
|
from urllib.parse import parse_qs, urlencode
|
|
|
|
from fastapi.responses import JSONResponse, RedirectResponse
|
|
from fastapi.security import HTTPAuthorizationCredentials
|
|
from starlette.datastructures import MutableHeaders
|
|
from starlette.requests import Request
|
|
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
|
|
|
from open_webui.internal.db import ScopedSession
|
|
from open_webui.utils.auth import get_http_authorization_cred
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class CommitSessionMiddleware:
|
|
"""Commit and release the thread-local sync `ScopedSession` after each
|
|
HTTP request.
|
|
|
|
Most requests now use the async session; the sync ScopedSession is
|
|
only touched by startup, healthchecks, and a handful of legacy
|
|
helpers (notably the pgvector / opengauss vector-DB clients). The
|
|
middleware exists so that PostgreSQL connections do not accumulate
|
|
as "idle in transaction" and so that any pending sync work made
|
|
inside the request is durably persisted.
|
|
|
|
Failure semantics
|
|
-----------------
|
|
* Downstream raised → roll back any pending sync work, release the
|
|
connection, and re-raise so the outer exception middleware can
|
|
turn it into an error response. We never commit work on a
|
|
request that did not complete successfully.
|
|
* Downstream returned → commit pending sync work; on commit
|
|
failure, log loudly, roll back, and re-raise. Note that in pure
|
|
ASGI the response messages have already been emitted by the
|
|
time `await self.app(...)` returns, so a commit failure cannot
|
|
retroactively change what the client sees on the wire — but
|
|
re-raising still surfaces the error in logs and to ASGI servers
|
|
that expose it. We deliberately do not buffer the response to
|
|
gate it on commit success, because that would defeat streaming
|
|
responses (chat completions, SSE) which are core to the app.
|
|
|
|
For request paths where commit-before-send is required, manage the
|
|
sync session explicitly inside the handler instead of relying on
|
|
this middleware.
|
|
"""
|
|
|
|
def __init__(self, app: ASGIApp) -> None:
|
|
self.app = app
|
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
if scope['type'] != 'http':
|
|
await self.app(scope, receive, send)
|
|
return
|
|
|
|
try:
|
|
await self.app(scope, receive, send)
|
|
except BaseException:
|
|
# Downstream did not complete successfully. Roll back any
|
|
# pending sync writes, release the connection, and let the
|
|
# exception propagate.
|
|
try:
|
|
ScopedSession.rollback()
|
|
except Exception:
|
|
log.exception('CommitSessionMiddleware: rollback failed after downstream error')
|
|
finally:
|
|
ScopedSession.remove()
|
|
raise
|
|
|
|
# Downstream completed. Commit pending sync work.
|
|
try:
|
|
ScopedSession.commit()
|
|
except Exception:
|
|
log.exception('CommitSessionMiddleware: post-request commit failed; response was already sent to client')
|
|
try:
|
|
ScopedSession.rollback()
|
|
except Exception:
|
|
log.exception('CommitSessionMiddleware: rollback failed after commit failure')
|
|
raise
|
|
finally:
|
|
# CRITICAL: remove() returns the connection to the pool.
|
|
# Without this, connections remain "checked out" and
|
|
# accumulate as "idle in transaction" in PostgreSQL.
|
|
ScopedSession.remove()
|
|
|
|
|
|
class AuthTokenMiddleware:
|
|
"""Extract the bearer/cookie/x-api-key credential and stash it on
|
|
`request.state.token`.
|
|
|
|
Routes that depend on `get_verified_user` etc. read this state.
|
|
Also exposes `request.state.enable_api_keys` (snapshotted at request
|
|
entry from runtime config) and stamps an `X-Process-Time` response
|
|
header.
|
|
"""
|
|
|
|
def __init__(self, app: ASGIApp, *, fastapi_app) -> None:
|
|
self.app = app
|
|
self._fastapi_app = fastapi_app
|
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
if scope['type'] != 'http':
|
|
await self.app(scope, receive, send)
|
|
return
|
|
|
|
start_time = time.monotonic()
|
|
request = Request(scope)
|
|
|
|
token = get_http_authorization_cred(request.headers.get('Authorization'))
|
|
if token is None:
|
|
cookie_token = request.cookies.get('token')
|
|
if cookie_token:
|
|
token = HTTPAuthorizationCredentials(scheme='Bearer', credentials=cookie_token)
|
|
if token is None:
|
|
api_key = request.headers.get('x-api-key')
|
|
if api_key:
|
|
token = HTTPAuthorizationCredentials(scheme='Bearer', credentials=api_key)
|
|
|
|
request.state.token = token
|
|
request.state.enable_api_keys = self._fastapi_app.state.config.ENABLE_API_KEYS
|
|
|
|
async def send_with_timing(message: Message) -> None:
|
|
if message['type'] == 'http.response.start':
|
|
process_time = int(time.monotonic() - start_time)
|
|
headers = MutableHeaders(scope=message)
|
|
headers['X-Process-Time'] = str(process_time)
|
|
await send(message)
|
|
|
|
await self.app(scope, receive, send_with_timing)
|
|
|
|
|
|
class WebsocketUpgradeGuardMiddleware:
|
|
"""Reject HTTP requests to `/ws/socket.io` that claim
|
|
`transport=websocket` but lack the proper `Upgrade`/`Connection`
|
|
headers.
|
|
|
|
Works around https://github.com/miguelgrinberg/python-engineio/issues/367
|
|
where engineio mishandles such requests.
|
|
"""
|
|
|
|
def __init__(self, app: ASGIApp) -> None:
|
|
self.app = app
|
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
if scope['type'] != 'http':
|
|
await self.app(scope, receive, send)
|
|
return
|
|
|
|
path = scope.get('path', '')
|
|
if '/ws/socket.io' in path:
|
|
query_string = scope.get('query_string', b'').decode('latin-1', errors='replace')
|
|
query_params = parse_qs(query_string)
|
|
if query_params.get('transport', [''])[0] == 'websocket':
|
|
headers = _scope_headers(scope)
|
|
upgrade = headers.get('upgrade', '').lower()
|
|
connection_tokens = [token.strip() for token in headers.get('connection', '').lower().split(',')]
|
|
if upgrade != 'websocket' or 'upgrade' not in connection_tokens:
|
|
response = JSONResponse(
|
|
status_code=400,
|
|
content={'detail': 'Invalid WebSocket upgrade request'},
|
|
)
|
|
await response(scope, receive, send)
|
|
return
|
|
|
|
await self.app(scope, receive, send)
|
|
|
|
|
|
class RedirectMiddleware:
|
|
"""Rewrites a couple of legacy entry-points to the SPA's own routes:
|
|
|
|
* ``GET /watch?v=ID`` (YouTube) → ``/?youtube=ID``
|
|
* ``GET /?shared=…`` (PWA share-target) → ``/?youtube=…`` /
|
|
``/?load-url=…`` / ``/?q=…``
|
|
"""
|
|
|
|
def __init__(self, app: ASGIApp) -> None:
|
|
self.app = app
|
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
if scope['type'] != 'http' or scope.get('method', '').upper() != 'GET':
|
|
await self.app(scope, receive, send)
|
|
return
|
|
|
|
path = scope.get('path', '')
|
|
query_string = scope.get('query_string', b'').decode('latin-1', errors='replace')
|
|
query_params = parse_qs(query_string)
|
|
|
|
redirect_params: dict[str, str] = {}
|
|
if path.endswith('/watch') and 'v' in query_params and query_params['v']:
|
|
redirect_params['youtube'] = query_params['v'][0]
|
|
|
|
if 'shared' in query_params and query_params['shared']:
|
|
text = query_params['shared'][0]
|
|
if text:
|
|
url_match = re.match(r'https://\S+', text)
|
|
if url_match:
|
|
# Local import: youtube loader pulls heavy deps and is
|
|
# only needed when a share-target actually contains a
|
|
# YouTube URL.
|
|
from open_webui.retrieval.loaders.youtube import _parse_video_id
|
|
|
|
youtube_video_id = _parse_video_id(url_match[0])
|
|
if youtube_video_id:
|
|
redirect_params['youtube'] = youtube_video_id
|
|
else:
|
|
redirect_params['load-url'] = url_match[0]
|
|
else:
|
|
redirect_params['q'] = text
|
|
|
|
if redirect_params:
|
|
redirect_url = f'/?{urlencode(redirect_params)}'
|
|
response = RedirectResponse(url=redirect_url)
|
|
await response(scope, receive, send)
|
|
return
|
|
|
|
await self.app(scope, receive, send)
|
|
|
|
|
|
def _scope_headers(scope: Scope) -> dict[str, str]:
|
|
"""Return ASGI scope headers as a lower-cased str→str dict.
|
|
|
|
ASGI delivers headers as a list of (bytes, bytes) pairs. For
|
|
convenience, fold duplicate keys with comma-joining (matching
|
|
HTTP/1.1 semantics).
|
|
"""
|
|
decoded: dict[str, str] = {}
|
|
for raw_key, raw_value in scope.get('headers', []):
|
|
key = raw_key.decode('latin-1').lower()
|
|
value = raw_value.decode('latin-1')
|
|
if key in decoded:
|
|
decoded[key] = f'{decoded[key]}, {value}'
|
|
else:
|
|
decoded[key] = value
|
|
return decoded
|