From 78f12ba6b73d9fb2e6e77a6bbccdf8a806ac86aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paola=20Estefan=C3=ADa=20de=20Campos?= <84341268+Paola3stefania@users.noreply.github.com> Date: Tue, 20 Jan 2026 00:36:13 -0300 Subject: [PATCH] fix(sso): fix validateToken JWK handling for all key types (#7479) --- .../src/oauth2/validate-authorization-code.ts | 19 +- .../core/src/oauth2/validate-token.test.ts | 174 ++++++++++++++++++ 2 files changed, 181 insertions(+), 12 deletions(-) create mode 100644 packages/core/src/oauth2/validate-token.test.ts diff --git a/packages/core/src/oauth2/validate-authorization-code.ts b/packages/core/src/oauth2/validate-authorization-code.ts index 9e8e22bca5..74631d6200 100644 --- a/packages/core/src/oauth2/validate-authorization-code.ts +++ b/packages/core/src/oauth2/validate-authorization-code.ts @@ -1,6 +1,7 @@ import { base64 } from "@better-auth/utils/base64"; import { betterFetch } from "@better-fetch/fetch"; -import { jwtVerify } from "jose"; +import type { JWK } from "jose"; +import { decodeProtectedHeader, importJWK, jwtVerify } from "jose"; import type { ProviderOptions } from "./index"; import { getOAuth2Tokens } from "./index"; @@ -126,14 +127,7 @@ export async function validateAuthorizationCode({ export async function validateToken(token: string, jwksEndpoint: string) { const { data, error } = await betterFetch<{ - keys: { - kid: string; - kty: string; - use: string; - n: string; - e: string; - x5c: string[]; - }[]; + keys: JWK[]; }>(jwksEndpoint, { method: "GET", headers: { @@ -144,11 +138,12 @@ export async function validateToken(token: string, jwksEndpoint: string) { throw error; } const keys = data["keys"]; - const header = JSON.parse(atob(token.split(".")[0]!)); - const key = keys.find((key) => key.kid === header.kid); + const header = decodeProtectedHeader(token); + const key = keys.find((k) => k.kid === header.kid); if (!key) { throw new Error("Key not found"); } - const verified = await jwtVerify(token, key); + const cryptoKey = await importJWK(key, header.alg); + const verified = await jwtVerify(token, cryptoKey); return verified; } diff --git a/packages/core/src/oauth2/validate-token.test.ts b/packages/core/src/oauth2/validate-token.test.ts new file mode 100644 index 0000000000..313db39dbc --- /dev/null +++ b/packages/core/src/oauth2/validate-token.test.ts @@ -0,0 +1,174 @@ +import { exportJWK, generateKeyPair, SignJWT } from "jose"; +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { validateToken } from "./validate-authorization-code"; + +vi.mock("@better-fetch/fetch", () => ({ + betterFetch: vi.fn(), +})); + +import { betterFetch } from "@better-fetch/fetch"; + +const mockedBetterFetch = vi.mocked(betterFetch); + +describe("validateToken", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + async function createTestJWKS(alg: string, crv?: string) { + const { publicKey, privateKey } = await generateKeyPair(alg, { + crv, + extractable: true, + }); + const publicJWK = await exportJWK(publicKey); + const privateJWK = await exportJWK(privateKey); + const kid = `test-key-${Date.now()}`; + publicJWK.kid = kid; + privateJWK.kid = kid; + return { publicJWK, privateJWK, kid, publicKey, privateKey }; + } + + async function createSignedToken( + privateKey: CryptoKey, + alg: string, + kid: string, + payload: Record = {}, + ) { + return await new SignJWT({ + sub: "user-123", + email: "test@example.com", + iss: "https://example.com", + aud: "test-client", + ...payload, + }) + .setProtectedHeader({ alg, kid }) + .setIssuedAt() + .setExpirationTime("1h") + .sign(privateKey); + } + + it("should verify RS256 signed token", async () => { + const { publicJWK, privateKey, kid } = await createTestJWKS("RS256"); + const token = await createSignedToken(privateKey, "RS256", kid); + + mockedBetterFetch.mockResolvedValueOnce({ + data: { keys: [publicJWK] }, + error: null, + }); + + const result = await validateToken( + token, + "https://example.com/.well-known/jwks", + ); + + expect(result).toBeDefined(); + expect(result.payload.sub).toBe("user-123"); + expect(result.payload.email).toBe("test@example.com"); + expect(mockedBetterFetch).toHaveBeenCalledWith( + "https://example.com/.well-known/jwks", + expect.objectContaining({ method: "GET" }), + ); + }); + + it("should verify ES256 signed token", async () => { + const { publicJWK, privateKey, kid } = await createTestJWKS("ES256"); + const token = await createSignedToken(privateKey, "ES256", kid); + + mockedBetterFetch.mockResolvedValueOnce({ + data: { keys: [publicJWK] }, + error: null, + }); + + const result = await validateToken( + token, + "https://example.com/.well-known/jwks", + ); + + expect(result).toBeDefined(); + expect(result.payload.sub).toBe("user-123"); + }); + + it("should verify EdDSA (Ed25519) signed token", async () => { + const { publicJWK, privateKey, kid } = await createTestJWKS( + "EdDSA", + "Ed25519", + ); + const token = await createSignedToken(privateKey, "EdDSA", kid); + + mockedBetterFetch.mockResolvedValueOnce({ + data: { keys: [publicJWK] }, + error: null, + }); + + const result = await validateToken( + token, + "https://example.com/.well-known/jwks", + ); + + expect(result).toBeDefined(); + expect(result.payload.sub).toBe("user-123"); + }); + + it("should throw 'Key not found' when kid doesn't match", async () => { + const { publicJWK, privateKey } = await createTestJWKS("RS256"); + publicJWK.kid = "different-kid"; + const token = await createSignedToken(privateKey, "RS256", "original-kid"); + + mockedBetterFetch.mockResolvedValueOnce({ + data: { keys: [publicJWK] }, + error: null, + }); + + await expect( + validateToken(token, "https://example.com/.well-known/jwks"), + ).rejects.toThrow("Key not found"); + }); + + it("should find correct key when multiple keys exist", async () => { + const key1 = await createTestJWKS("RS256"); + const key2 = await createTestJWKS("RS256"); + const key3 = await createTestJWKS("ES256"); + const token = await createSignedToken(key2.privateKey, "RS256", key2.kid); + + mockedBetterFetch.mockResolvedValueOnce({ + data: { keys: [key1.publicJWK, key2.publicJWK, key3.publicJWK] }, + error: null, + }); + + const result = await validateToken( + token, + "https://example.com/.well-known/jwks", + ); + + expect(result).toBeDefined(); + expect(result.payload.sub).toBe("user-123"); + }); + + it("should throw when JWKS returns empty keys array", async () => { + const { privateKey, kid } = await createTestJWKS("RS256"); + const token = await createSignedToken(privateKey, "RS256", kid); + + mockedBetterFetch.mockResolvedValueOnce({ + data: { keys: [] }, + error: null, + }); + + await expect( + validateToken(token, "https://example.com/.well-known/jwks"), + ).rejects.toThrow("Key not found"); + }); + + it("should throw when JWKS fetch fails", async () => { + const { privateKey, kid } = await createTestJWKS("RS256"); + const token = await createSignedToken(privateKey, "RS256", kid); + + mockedBetterFetch.mockResolvedValueOnce({ + data: null, + error: { status: 500, statusText: "Internal Server Error" }, + }); + + await expect( + validateToken(token, "https://example.com/.well-known/jwks"), + ).rejects.toBeDefined(); + }); +});