fix: code

This commit is contained in:
Alex Yang
2025-11-19 13:51:19 -08:00
parent 336e17ce03
commit fa57bef704

View File

@@ -9,7 +9,7 @@ import {
import { getCurrentAuthContext } from "@better-auth/core/context";
import { base64 } from "@better-auth/utils/base64";
import { createHash } from "@better-auth/utils/hash";
import { jwtVerify, SignJWT } from "jose";
import { importJWK, jwtVerify, SignJWT } from "jose";
import * as z from "zod";
import { APIError, getSessionFromCtx, sessionMiddleware } from "../../api";
import { parseSetCookieHeader } from "../../cookies";
@@ -21,6 +21,7 @@ import {
import { mergeSchema } from "../../db";
import type { jwt } from "../jwt";
import { getJwtToken } from "../jwt";
import { getJwksAdapter } from "../jwt/adapter";
import { authorize } from "./authorize";
import type { OAuthApplication } from "./schema";
import { schema } from "./schema";
@@ -40,6 +41,68 @@ const getJwtPlugin = (ctx: GenericEndpointContext) => {
) as ReturnType<typeof jwt>;
};
/**
* Verify a JWT token using the JWKS public keys
* Returns the payload if valid, null otherwise
*/
async function verifyJwtWithJWKS(
ctx: GenericEndpointContext,
token: string,
jwtPlugin: ReturnType<typeof jwt>,
): Promise<{ sub: string; aud: string } | null> {
try {
const parts = token.split(".");
if (parts.length !== 3) {
return null;
}
const headerStr = new TextDecoder().decode(base64.decode(parts[0]!));
const header = JSON.parse(headerStr);
const kid = header.kid;
if (!kid) {
ctx.context.logger.debug("JWT missing kid in header");
return null;
}
// Get all JWKS keys
const adapter = getJwksAdapter(ctx.context.adapter, jwtPlugin.options);
const keys = await adapter.getAllKeys(ctx);
if (!keys || keys.length === 0) {
ctx.context.logger.debug("No JWKS keys available");
return null;
}
const key = keys.find((k) => k.id === kid);
if (!key) {
ctx.context.logger.debug(`No JWKS key found for kid: ${kid}`);
return null;
}
const publicKey = JSON.parse(key.publicKey);
const alg =
key.alg ?? jwtPlugin.options?.jwks?.keyPairConfig?.alg ?? "EdDSA";
const cryptoKey = await importJWK(publicKey, alg);
const { payload } = await jwtVerify(token, cryptoKey, {
issuer: jwtPlugin.options?.jwt?.issuer ?? ctx.context.options.baseURL,
});
if (!payload.sub || !payload.aud) {
return null;
}
return {
sub: payload.sub as string,
aud: payload.aud as string,
};
} catch (error) {
ctx.context.logger.debug("JWT verification failed", error);
return null;
}
}
/**
* Get a client by ID, checking trusted clients first, then database
*/
@@ -1613,21 +1676,15 @@ export const oidcProvider = (options: OIDCOptions) => {
try {
const jwtPlugin = getJwtPlugin(ctx);
if (jwtPlugin && jwtPlugin.options && options?.useJWTPlugin) {
// For JWT plugin tokens, we need to verify using the JWKS
// We'll extract the audience to get the client and verify with appropriate key
// For now, we'll decode without verification to get the audience
const parts = id_token_hint.split(".");
if (parts.length === 3) {
try {
const payloadStr = new TextDecoder().decode(
base64.decode(parts[1]!),
);
const payload = JSON.parse(payloadStr);
validatedUserId = payload.sub as string;
validatedClientId = payload.aud as string;
} catch {
// Invalid token format
}
// For JWT plugin tokens, verify using JWKS
const verified = await verifyJwtWithJWKS(
ctx,
id_token_hint,
jwtPlugin,
);
if (verified) {
validatedUserId = verified.sub;
validatedClientId = verified.aud;
}
} else {
// For HS256 tokens, we need the client_id to verify
@@ -1711,19 +1768,6 @@ export const oidcProvider = (options: OIDCOptions) => {
if (validatedUserId || session) {
const userId = validatedUserId || session?.user.id;
if (userId) {
// Revoke access tokens for this user and client (client-scoped logout by default)
const globalLogout = ctx.context.options?.globalLogout === true;
const tokens =
await ctx.context.adapter.findMany<OAuthAccessToken>({
model: modelName.oauthAccessToken,
where: [
{ field: "userId", value: userId },
...(!globalLogout && validatedClientId
? [{ field: "clientId", value: validatedClientId }]
: []),
],
});
await ctx.context.adapter.deleteMany({
model: modelName.oauthAccessToken,
where: [{ field: "userId", value: userId }],