perf: convert APIKeyRestrictionMiddleware to pure ASGI (#22188)

This commit is contained in:
Algorithm5838
2026-03-06 23:54:03 +03:00
committed by GitHub
parent 2153c8ec9f
commit 39deadcab1

View File

@@ -1398,46 +1398,52 @@ app.add_middleware(RedirectMiddleware)
app.add_middleware(SecurityHeadersMiddleware)
class APIKeyRestrictionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
auth_header = request.headers.get("Authorization")
token = None
class APIKeyRestrictionMiddleware:
def __init__(self, app):
self.app = app
if auth_header:
parts = auth_header.split(" ", 1)
if len(parts) == 2:
token = parts[1]
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
request = Request(scope)
auth_header = request.headers.get("Authorization")
token = None
# Only apply restrictions if an sk- API key is used
if token and token.startswith("sk-"):
# Check if restrictions are enabled
if request.app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS:
allowed_paths = [
path.strip()
for path in str(
request.app.state.config.API_KEYS_ALLOWED_ENDPOINTS
).split(",")
if path.strip()
]
if auth_header:
parts = auth_header.split(" ", 1)
if len(parts) == 2:
token = parts[1]
request_path = request.url.path
# Only apply restrictions if an sk- API key is used
if token and token.startswith("sk-"):
# Check if restrictions are enabled
if app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS:
allowed_paths = [
path.strip()
for path in str(
app.state.config.API_KEYS_ALLOWED_ENDPOINTS
).split(",")
if path.strip()
]
# Match exact path or prefix path
is_allowed = any(
request_path == allowed or request_path.startswith(allowed + "/")
for allowed in allowed_paths
)
request_path = request.url.path
if not is_allowed:
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={
"detail": "API key not allowed to access this endpoint."
},
# Match exact path or prefix path
is_allowed = any(
request_path == allowed
or request_path.startswith(allowed + "/")
for allowed in allowed_paths
)
response = await call_next(request)
return response
if not is_allowed:
await JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={
"detail": "API key not allowed to access this endpoint."
},
)(scope, receive, send)
return
await self.app(scope, receive, send)
app.add_middleware(APIKeyRestrictionMiddleware)