diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index f0dbf9c114..e0dcaff276 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -544,6 +544,12 @@ OAUTH_MAX_SESSIONS_PER_USER = int(os.environ.get('OAUTH_MAX_SESSIONS_PER_USER', # Allows external apps to exchange OAuth tokens for OpenWebUI tokens ENABLE_OAUTH_TOKEN_EXCHANGE = os.environ.get('ENABLE_OAUTH_TOKEN_EXCHANGE', 'False').lower() == 'true' +# Back-Channel Logout Configuration +# When enabled, exposes POST /oauth/backchannel-logout for IdP-initiated logout +# per OpenID Connect Back-Channel Logout 1.0 spec. +# Requires Redis for JWT revocation. +ENABLE_OAUTH_BACKCHANNEL_LOGOUT = os.environ.get('ENABLE_OAUTH_BACKCHANNEL_LOGOUT', 'False').lower() == 'true' + #################################### # SCIM Configuration #################################### diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 255c9ace5f..494d657fc8 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -511,6 +511,8 @@ from open_webui.env import ( WEBUI_ADMIN_NAME, ENABLE_EASTER_EGGS, LOG_FORMAT, + # OAuth Back-Channel Logout + ENABLE_OAUTH_BACKCHANNEL_LOGOUT, ) @@ -2477,6 +2479,21 @@ async def oauth_login_callback( return await oauth_manager.handle_callback(request, provider, response, db=db) +############################ +# OIDC Back-Channel Logout +############################ + + +@app.post('/oauth/backchannel-logout') +async def oauth_backchannel_logout( + request: Request, + db: Session = Depends(get_session), +): + if not ENABLE_OAUTH_BACKCHANNEL_LOGOUT: + raise HTTPException(status_code=404) + return await oauth_manager.handle_backchannel_logout(request, db=db) + + @app.get('/manifest.json') async def get_manifest_json(): if app.state.EXTERNAL_PWA_MANIFEST_URL: diff --git a/backend/open_webui/utils/auth.py b/backend/open_webui/utils/auth.py index 34412d6041..16bc36500b 100644 --- a/backend/open_webui/utils/auth.py +++ b/backend/open_webui/utils/auth.py @@ -206,7 +206,7 @@ def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> st payload.update({'exp': expire}) jti = str(uuid.uuid4()) - payload.update({'jti': jti}) + payload.update({'jti': jti, 'iat': datetime.now(UTC)}) encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM) return encoded_jwt @@ -221,15 +221,36 @@ def decode_token(token: str) -> Optional[dict]: async def is_valid_token(request, decoded) -> bool: - # Require Redis to check revoked tokens + """ + Check whether a JWT has been revoked. Two mechanisms: + 1. Per-token (jti) — used by user-initiated sign-out (known jti). + 2. Per-user (revoked_at) — used by OIDC back-channel logout when + individual jti values are unknown; rejects tokens with iat <= revoked_at. + """ if request.app.state.redis: + # Per-token revocation jti = decoded.get('jti') - if jti: revoked = await request.app.state.redis.get(f'{REDIS_KEY_PREFIX}:auth:token:{jti}:revoked') if revoked: return False + # Per-user revocation (OIDC back-channel logout) + user_id = decoded.get('id') + if user_id: + revoked_at = await request.app.state.redis.get( + f'{REDIS_KEY_PREFIX}:auth:user:{user_id}:revoked_at' + ) + if revoked_at: + try: + revoked_at_ts = int(revoked_at) + token_iat = decoded.get('iat') + # No iat means legacy token — reject since we can't verify issue time + if token_iat is None or token_iat <= revoked_at_ts: + return False + except (ValueError, TypeError): + pass + return True diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index 885b79afda..bf6be461e6 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -75,6 +75,7 @@ from open_webui.env import ( ENABLE_OAUTH_EMAIL_FALLBACK, OAUTH_CLIENT_INFO_ENCRYPTION_KEY, OAUTH_MAX_SESSIONS_PER_USER, + REDIS_KEY_PREFIX, ) from open_webui.utils.misc import parse_duration from open_webui.utils.auth import get_password_hash, create_token @@ -1693,3 +1694,181 @@ class OAuthManager: log.error(f'Failed to store OAuth session server-side: {e}') return response + + async def handle_backchannel_logout(self, request, db=None): + """ + Handle an OIDC Back-Channel Logout request. + Validates the logout_token, identifies the user, revokes their + sessions via Redis, and deletes their OAuth sessions. + Returns a JSONResponse per the OIDC Back-Channel Logout 1.0 spec. + """ + import jwt as pyjwt + from fastapi.responses import JSONResponse + + # 1. Extract logout_token from form body + try: + form = await request.form() + logout_token = form.get('logout_token') + except Exception: + logout_token = None + + if not logout_token: + return JSONResponse( + status_code=400, + content={'error': 'invalid_request', 'error_description': 'Missing logout_token parameter'}, + ) + + # 2. Peek at unverified issuer to match against configured providers + try: + unverified_claims = pyjwt.decode(logout_token, options={'verify_signature': False}) + token_issuer = unverified_claims.get('iss') + except Exception as e: + log.warning(f'Back-channel logout: cannot decode logout_token: {e}') + return JSONResponse( + status_code=400, + content={'error': 'invalid_request', 'error_description': 'Malformed logout_token'}, + ) + + if not token_issuer: + return JSONResponse( + status_code=400, + content={'error': 'invalid_request', 'error_description': 'logout_token missing iss claim'}, + ) + + # 3. Find the configured provider whose issuer matches the token + matched_provider = None + matched_client_id = None + matched_jwks_uri = None + matched_issuer = None + + for provider_name in OAUTH_PROVIDERS: + server_metadata_url = self.get_server_metadata_url(provider_name) + if not server_metadata_url: + continue + + try: + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.get(server_metadata_url, ssl=AIOHTTP_CLIENT_SESSION_SSL) as r: + if r.status != 200: + continue + oidc_config = await r.json() + + provider_issuer = oidc_config.get('issuer') + if provider_issuer and provider_issuer == token_issuer: + client = self.get_client(provider_name) + matched_provider = provider_name + matched_client_id = client.client_id if client else None + matched_jwks_uri = oidc_config.get('jwks_uri') + matched_issuer = provider_issuer + break + except Exception as e: + log.debug(f'Back-channel logout: error checking provider {provider_name}: {e}') + continue + + if not matched_provider or not matched_client_id or not matched_jwks_uri: + log.warning(f'Back-channel logout: no configured provider matches issuer {token_issuer}') + return JSONResponse( + status_code=400, + content={'error': 'invalid_request', 'error_description': 'No configured provider matches token issuer'}, + ) + + # 4. Validate the logout_token signature and claims + try: + jwks_client = pyjwt.PyJWKClient(matched_jwks_uri) + signing_key = jwks_client.get_signing_key_from_jwt(logout_token) + + claims = pyjwt.decode( + logout_token, + signing_key.key, + algorithms=['RS256', 'RS384', 'RS512', 'ES256', 'ES384', 'ES512'], + audience=matched_client_id, + issuer=matched_issuer, + options={ + 'require': ['iss', 'aud', 'iat', 'events'], + }, + ) + except pyjwt.InvalidTokenError as e: + log.warning(f'Back-channel logout: invalid logout_token: {e}') + return JSONResponse( + status_code=400, + content={'error': 'invalid_request', 'error_description': f'Invalid logout_token: {e}'}, + ) + except Exception as e: + log.error(f'Back-channel logout: error validating logout_token: {e}') + return JSONResponse( + status_code=400, + content={'error': 'invalid_request', 'error_description': 'Failed to validate logout_token'}, + ) + + # 5. Validate events claim per spec + events = claims.get('events', {}) + if 'http://schemas.openid.net/event/backchannel-logout' not in events: + log.warning('Back-channel logout: missing required backchannel-logout event claim') + return JSONResponse( + status_code=400, + content={'error': 'invalid_request', 'error_description': 'Missing backchannel-logout event claim'}, + ) + + # 6. Per spec, back-channel logout tokens MUST NOT contain a nonce + if 'nonce' in claims: + log.warning('Back-channel logout: logout_token contains nonce (rejected per spec)') + return JSONResponse( + status_code=400, + content={'error': 'invalid_request', 'error_description': 'logout_token must not contain nonce'}, + ) + + # 7. Extract sub and/or sid — at least one must be present + sub = claims.get('sub') + sid = claims.get('sid') + + if not sub and not sid: + log.warning('Back-channel logout: logout_token contains neither sub nor sid') + return JSONResponse( + status_code=400, + content={'error': 'invalid_request', 'error_description': 'logout_token must contain sub or sid'}, + ) + + # 8. Identify users to log out + users_to_logout = [] + if sub: + user = Users.get_user_by_oauth_sub(matched_provider, sub, db=db) + if user: + users_to_logout.append(user) + + if not users_to_logout and sid: + log.info(f'Back-channel logout: no user found by sub, sid-based lookup not yet supported (sid={sid})') + + if not users_to_logout: + log.info(f'Back-channel logout: no matching user for provider={matched_provider}, sub={sub}, sid={sid}') + return JSONResponse(status_code=200, content={}) + + # 9. Revoke tokens and delete sessions + redis = request.app.state.redis + if not redis: + log.warning( + 'Back-channel logout: Redis not configured, cannot revoke JWT tokens. ' + 'OAuth sessions will be deleted but existing JWTs will remain valid until expiry.' + ) + + revoked_count = 0 + for user in users_to_logout: + sessions = OAuthSessions.get_sessions_by_user_id(user.id, db=db) + for oauth_session in sessions: + OAuthSessions.delete_session_by_id(oauth_session.id, db=db) + + if redis: + revocation_key = f'{REDIS_KEY_PREFIX}:auth:user:{user.id}:revoked_at' + await redis.set( + revocation_key, + str(int(time.time())), + ex=60 * 60 * 24 * 30, + ) + revoked_count += 1 + + log.info( + f'Back-channel logout: revoked sessions for user {user.id} ' + f'(email={user.email}, provider={matched_provider}, sessions_deleted={len(sessions)})' + ) + + log.info(f'Back-channel logout: completed for {len(users_to_logout)} user(s), {revoked_count} revocation(s) set') + return JSONResponse(status_code=200, content={})