mirror of
https://github.com/better-auth/better-auth.git
synced 2026-05-21 22:06:04 -05:00
fix: code
This commit is contained in:
@@ -9,7 +9,7 @@ import {
|
||||
import { getCurrentAuthContext } from "@better-auth/core/context";
|
||||
import { base64 } from "@better-auth/utils/base64";
|
||||
import { createHash } from "@better-auth/utils/hash";
|
||||
import { jwtVerify, SignJWT } from "jose";
|
||||
import { importJWK, jwtVerify, SignJWT } from "jose";
|
||||
import * as z from "zod";
|
||||
import { APIError, getSessionFromCtx, sessionMiddleware } from "../../api";
|
||||
import { parseSetCookieHeader } from "../../cookies";
|
||||
@@ -21,6 +21,7 @@ import {
|
||||
import { mergeSchema } from "../../db";
|
||||
import type { jwt } from "../jwt";
|
||||
import { getJwtToken } from "../jwt";
|
||||
import { getJwksAdapter } from "../jwt/adapter";
|
||||
import { authorize } from "./authorize";
|
||||
import type { OAuthApplication } from "./schema";
|
||||
import { schema } from "./schema";
|
||||
@@ -40,6 +41,68 @@ const getJwtPlugin = (ctx: GenericEndpointContext) => {
|
||||
) as ReturnType<typeof jwt>;
|
||||
};
|
||||
|
||||
/**
|
||||
* Verify a JWT token using the JWKS public keys
|
||||
* Returns the payload if valid, null otherwise
|
||||
*/
|
||||
async function verifyJwtWithJWKS(
|
||||
ctx: GenericEndpointContext,
|
||||
token: string,
|
||||
jwtPlugin: ReturnType<typeof jwt>,
|
||||
): Promise<{ sub: string; aud: string } | null> {
|
||||
try {
|
||||
const parts = token.split(".");
|
||||
if (parts.length !== 3) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const headerStr = new TextDecoder().decode(base64.decode(parts[0]!));
|
||||
const header = JSON.parse(headerStr);
|
||||
const kid = header.kid;
|
||||
|
||||
if (!kid) {
|
||||
ctx.context.logger.debug("JWT missing kid in header");
|
||||
return null;
|
||||
}
|
||||
|
||||
// Get all JWKS keys
|
||||
const adapter = getJwksAdapter(ctx.context.adapter, jwtPlugin.options);
|
||||
const keys = await adapter.getAllKeys(ctx);
|
||||
|
||||
if (!keys || keys.length === 0) {
|
||||
ctx.context.logger.debug("No JWKS keys available");
|
||||
return null;
|
||||
}
|
||||
|
||||
const key = keys.find((k) => k.id === kid);
|
||||
if (!key) {
|
||||
ctx.context.logger.debug(`No JWKS key found for kid: ${kid}`);
|
||||
return null;
|
||||
}
|
||||
|
||||
const publicKey = JSON.parse(key.publicKey);
|
||||
const alg =
|
||||
key.alg ?? jwtPlugin.options?.jwks?.keyPairConfig?.alg ?? "EdDSA";
|
||||
const cryptoKey = await importJWK(publicKey, alg);
|
||||
|
||||
const { payload } = await jwtVerify(token, cryptoKey, {
|
||||
issuer: jwtPlugin.options?.jwt?.issuer ?? ctx.context.options.baseURL,
|
||||
});
|
||||
|
||||
if (!payload.sub || !payload.aud) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
sub: payload.sub as string,
|
||||
aud: payload.aud as string,
|
||||
};
|
||||
} catch (error) {
|
||||
ctx.context.logger.debug("JWT verification failed", error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a client by ID, checking trusted clients first, then database
|
||||
*/
|
||||
@@ -1613,21 +1676,15 @@ export const oidcProvider = (options: OIDCOptions) => {
|
||||
try {
|
||||
const jwtPlugin = getJwtPlugin(ctx);
|
||||
if (jwtPlugin && jwtPlugin.options && options?.useJWTPlugin) {
|
||||
// For JWT plugin tokens, we need to verify using the JWKS
|
||||
// We'll extract the audience to get the client and verify with appropriate key
|
||||
// For now, we'll decode without verification to get the audience
|
||||
const parts = id_token_hint.split(".");
|
||||
if (parts.length === 3) {
|
||||
try {
|
||||
const payloadStr = new TextDecoder().decode(
|
||||
base64.decode(parts[1]!),
|
||||
);
|
||||
const payload = JSON.parse(payloadStr);
|
||||
validatedUserId = payload.sub as string;
|
||||
validatedClientId = payload.aud as string;
|
||||
} catch {
|
||||
// Invalid token format
|
||||
}
|
||||
// For JWT plugin tokens, verify using JWKS
|
||||
const verified = await verifyJwtWithJWKS(
|
||||
ctx,
|
||||
id_token_hint,
|
||||
jwtPlugin,
|
||||
);
|
||||
if (verified) {
|
||||
validatedUserId = verified.sub;
|
||||
validatedClientId = verified.aud;
|
||||
}
|
||||
} else {
|
||||
// For HS256 tokens, we need the client_id to verify
|
||||
@@ -1711,19 +1768,6 @@ export const oidcProvider = (options: OIDCOptions) => {
|
||||
if (validatedUserId || session) {
|
||||
const userId = validatedUserId || session?.user.id;
|
||||
if (userId) {
|
||||
// Revoke access tokens for this user and client (client-scoped logout by default)
|
||||
const globalLogout = ctx.context.options?.globalLogout === true;
|
||||
const tokens =
|
||||
await ctx.context.adapter.findMany<OAuthAccessToken>({
|
||||
model: modelName.oauthAccessToken,
|
||||
where: [
|
||||
{ field: "userId", value: userId },
|
||||
...(!globalLogout && validatedClientId
|
||||
? [{ field: "clientId", value: validatedClientId }]
|
||||
: []),
|
||||
],
|
||||
});
|
||||
|
||||
await ctx.context.adapter.deleteMany({
|
||||
model: modelName.oauthAccessToken,
|
||||
where: [{ field: "userId", value: userId }],
|
||||
|
||||
Reference in New Issue
Block a user