From fa57bef7044943bb5680fa76140a5cb1e2e892d0 Mon Sep 17 00:00:00 2001 From: Alex Yang Date: Wed, 19 Nov 2025 13:51:19 -0800 Subject: [PATCH] fix: code --- .../src/plugins/oidc-provider/index.ts | 102 +++++++++++++----- 1 file changed, 73 insertions(+), 29 deletions(-) diff --git a/packages/better-auth/src/plugins/oidc-provider/index.ts b/packages/better-auth/src/plugins/oidc-provider/index.ts index 92daeaf9e2..1328ec28ea 100644 --- a/packages/better-auth/src/plugins/oidc-provider/index.ts +++ b/packages/better-auth/src/plugins/oidc-provider/index.ts @@ -9,7 +9,7 @@ import { import { getCurrentAuthContext } from "@better-auth/core/context"; import { base64 } from "@better-auth/utils/base64"; import { createHash } from "@better-auth/utils/hash"; -import { jwtVerify, SignJWT } from "jose"; +import { importJWK, jwtVerify, SignJWT } from "jose"; import * as z from "zod"; import { APIError, getSessionFromCtx, sessionMiddleware } from "../../api"; import { parseSetCookieHeader } from "../../cookies"; @@ -21,6 +21,7 @@ import { import { mergeSchema } from "../../db"; import type { jwt } from "../jwt"; import { getJwtToken } from "../jwt"; +import { getJwksAdapter } from "../jwt/adapter"; import { authorize } from "./authorize"; import type { OAuthApplication } from "./schema"; import { schema } from "./schema"; @@ -40,6 +41,68 @@ const getJwtPlugin = (ctx: GenericEndpointContext) => { ) as ReturnType; }; +/** + * Verify a JWT token using the JWKS public keys + * Returns the payload if valid, null otherwise + */ +async function verifyJwtWithJWKS( + ctx: GenericEndpointContext, + token: string, + jwtPlugin: ReturnType, +): Promise<{ sub: string; aud: string } | null> { + try { + const parts = token.split("."); + if (parts.length !== 3) { + return null; + } + + const headerStr = new TextDecoder().decode(base64.decode(parts[0]!)); + const header = JSON.parse(headerStr); + const kid = header.kid; + + if (!kid) { + ctx.context.logger.debug("JWT missing kid in header"); + return null; + } + + // Get all JWKS keys + const adapter = getJwksAdapter(ctx.context.adapter, jwtPlugin.options); + const keys = await adapter.getAllKeys(ctx); + + if (!keys || keys.length === 0) { + ctx.context.logger.debug("No JWKS keys available"); + return null; + } + + const key = keys.find((k) => k.id === kid); + if (!key) { + ctx.context.logger.debug(`No JWKS key found for kid: ${kid}`); + return null; + } + + const publicKey = JSON.parse(key.publicKey); + const alg = + key.alg ?? jwtPlugin.options?.jwks?.keyPairConfig?.alg ?? "EdDSA"; + const cryptoKey = await importJWK(publicKey, alg); + + const { payload } = await jwtVerify(token, cryptoKey, { + issuer: jwtPlugin.options?.jwt?.issuer ?? ctx.context.options.baseURL, + }); + + if (!payload.sub || !payload.aud) { + return null; + } + + return { + sub: payload.sub as string, + aud: payload.aud as string, + }; + } catch (error) { + ctx.context.logger.debug("JWT verification failed", error); + return null; + } +} + /** * Get a client by ID, checking trusted clients first, then database */ @@ -1613,21 +1676,15 @@ export const oidcProvider = (options: OIDCOptions) => { try { const jwtPlugin = getJwtPlugin(ctx); if (jwtPlugin && jwtPlugin.options && options?.useJWTPlugin) { - // For JWT plugin tokens, we need to verify using the JWKS - // We'll extract the audience to get the client and verify with appropriate key - // For now, we'll decode without verification to get the audience - const parts = id_token_hint.split("."); - if (parts.length === 3) { - try { - const payloadStr = new TextDecoder().decode( - base64.decode(parts[1]!), - ); - const payload = JSON.parse(payloadStr); - validatedUserId = payload.sub as string; - validatedClientId = payload.aud as string; - } catch { - // Invalid token format - } + // For JWT plugin tokens, verify using JWKS + const verified = await verifyJwtWithJWKS( + ctx, + id_token_hint, + jwtPlugin, + ); + if (verified) { + validatedUserId = verified.sub; + validatedClientId = verified.aud; } } else { // For HS256 tokens, we need the client_id to verify @@ -1711,19 +1768,6 @@ export const oidcProvider = (options: OIDCOptions) => { if (validatedUserId || session) { const userId = validatedUserId || session?.user.id; if (userId) { - // Revoke access tokens for this user and client (client-scoped logout by default) - const globalLogout = ctx.context.options?.globalLogout === true; - const tokens = - await ctx.context.adapter.findMany({ - model: modelName.oauthAccessToken, - where: [ - { field: "userId", value: userId }, - ...(!globalLogout && validatedClientId - ? [{ field: "clientId", value: validatedClientId }] - : []), - ], - }); - await ctx.context.adapter.deleteMany({ model: modelName.oauthAccessToken, where: [{ field: "userId", value: userId }],