mirror of
https://github.com/open-webui/open-webui.git
synced 2026-03-11 17:47:44 -05:00
perf: convert APIKeyRestrictionMiddleware to pure ASGI (#22188)
This commit is contained in:
@@ -1398,46 +1398,52 @@ app.add_middleware(RedirectMiddleware)
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
|
||||
|
||||
class APIKeyRestrictionMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
auth_header = request.headers.get("Authorization")
|
||||
token = None
|
||||
class APIKeyRestrictionMiddleware:
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
||||
if auth_header:
|
||||
parts = auth_header.split(" ", 1)
|
||||
if len(parts) == 2:
|
||||
token = parts[1]
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
request = Request(scope)
|
||||
auth_header = request.headers.get("Authorization")
|
||||
token = None
|
||||
|
||||
# Only apply restrictions if an sk- API key is used
|
||||
if token and token.startswith("sk-"):
|
||||
# Check if restrictions are enabled
|
||||
if request.app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS:
|
||||
allowed_paths = [
|
||||
path.strip()
|
||||
for path in str(
|
||||
request.app.state.config.API_KEYS_ALLOWED_ENDPOINTS
|
||||
).split(",")
|
||||
if path.strip()
|
||||
]
|
||||
if auth_header:
|
||||
parts = auth_header.split(" ", 1)
|
||||
if len(parts) == 2:
|
||||
token = parts[1]
|
||||
|
||||
request_path = request.url.path
|
||||
# Only apply restrictions if an sk- API key is used
|
||||
if token and token.startswith("sk-"):
|
||||
# Check if restrictions are enabled
|
||||
if app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS:
|
||||
allowed_paths = [
|
||||
path.strip()
|
||||
for path in str(
|
||||
app.state.config.API_KEYS_ALLOWED_ENDPOINTS
|
||||
).split(",")
|
||||
if path.strip()
|
||||
]
|
||||
|
||||
# Match exact path or prefix path
|
||||
is_allowed = any(
|
||||
request_path == allowed or request_path.startswith(allowed + "/")
|
||||
for allowed in allowed_paths
|
||||
)
|
||||
request_path = request.url.path
|
||||
|
||||
if not is_allowed:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
content={
|
||||
"detail": "API key not allowed to access this endpoint."
|
||||
},
|
||||
# Match exact path or prefix path
|
||||
is_allowed = any(
|
||||
request_path == allowed
|
||||
or request_path.startswith(allowed + "/")
|
||||
for allowed in allowed_paths
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
if not is_allowed:
|
||||
await JSONResponse(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
content={
|
||||
"detail": "API key not allowed to access this endpoint."
|
||||
},
|
||||
)(scope, receive, send)
|
||||
return
|
||||
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
|
||||
app.add_middleware(APIKeyRestrictionMiddleware)
|
||||
|
||||
Reference in New Issue
Block a user