fix(sso): validate aud claim in OpenID Connect ID tokens (#7816)

This commit is contained in:
Paola Estefanía de Campos
2026-02-05 16:17:05 -08:00
committed by GitHub
parent 07162c66b1
commit 01faefc2ce
3 changed files with 136 additions and 80 deletions

View File

@@ -1,7 +1,6 @@
import { base64 } from "@better-auth/utils/base64";
import { betterFetch } from "@better-fetch/fetch";
import type { JWK } from "jose";
import { decodeProtectedHeader, importJWK, jwtVerify } from "jose";
import { createRemoteJWKSet, jwtVerify } from "jose";
import type { AwaitableFunction } from "../types";
import type { ProviderOptions } from "./index";
import { getOAuth2Tokens } from "./index";
@@ -164,25 +163,18 @@ export async function validateAuthorizationCode({
return tokens;
}
export async function validateToken(token: string, jwksEndpoint: string) {
const { data, error } = await betterFetch<{
keys: JWK[];
}>(jwksEndpoint, {
method: "GET",
headers: {
accept: "application/json",
},
export async function validateToken(
token: string,
jwksEndpoint: string,
options?: {
audience?: string | string[];
issuer?: string | string[];
},
) {
const jwks = createRemoteJWKSet(new URL(jwksEndpoint));
const verified = await jwtVerify(token, jwks, {
audience: options?.audience,
issuer: options?.issuer,
});
if (error) {
throw error;
}
const keys = data["keys"];
const header = decodeProtectedHeader(token);
const key = keys.find((k) => k.kid === header.kid);
if (!key) {
throw new Error("Key not found");
}
const cryptoKey = await importJWK(key, header.alg);
const verified = await jwtVerify(token, cryptoKey);
return verified;
}

View File

@@ -1,18 +1,31 @@
import type { JWK } from "jose";
import { exportJWK, generateKeyPair, SignJWT } from "jose";
import { beforeEach, describe, expect, it, vi } from "vitest";
import {
afterAll,
beforeAll,
beforeEach,
describe,
expect,
it,
vi,
} from "vitest";
import { validateToken } from "./validate-authorization-code";
vi.mock("@better-fetch/fetch", () => ({
betterFetch: vi.fn(),
}));
import { betterFetch } from "@better-fetch/fetch";
const mockedBetterFetch = vi.mocked(betterFetch);
describe("validateToken", () => {
const originalFetch = globalThis.fetch;
const mockedFetch = vi.fn() as unknown as typeof fetch &
ReturnType<typeof vi.fn>;
beforeAll(() => {
globalThis.fetch = mockedFetch;
});
afterAll(() => {
globalThis.fetch = originalFetch;
});
beforeEach(() => {
vi.clearAllMocks();
mockedFetch.mockReset();
});
async function createTestJWKS(alg: string, crv?: string) {
@@ -24,7 +37,9 @@ describe("validateToken", () => {
const privateJWK = await exportJWK(privateKey);
const kid = `test-key-${Date.now()}`;
publicJWK.kid = kid;
publicJWK.alg = alg;
privateJWK.kid = kid;
privateJWK.alg = alg;
return { publicJWK, privateJWK, kid, publicKey, privateKey };
}
@@ -47,14 +62,19 @@ describe("validateToken", () => {
.sign(privateKey);
}
function mockJWKSResponse(...publicJWKs: JWK[]) {
mockedFetch.mockResolvedValueOnce(
new Response(JSON.stringify({ keys: publicJWKs }), {
status: 200,
headers: { "content-type": "application/json" },
}),
);
}
it("should verify RS256 signed token", async () => {
const { publicJWK, privateKey, kid } = await createTestJWKS("RS256");
const token = await createSignedToken(privateKey, "RS256", kid);
mockedBetterFetch.mockResolvedValueOnce({
data: { keys: [publicJWK] },
error: null,
});
mockJWKSResponse(publicJWK);
const result = await validateToken(
token,
@@ -64,20 +84,12 @@ describe("validateToken", () => {
expect(result).toBeDefined();
expect(result.payload.sub).toBe("user-123");
expect(result.payload.email).toBe("test@example.com");
expect(mockedBetterFetch).toHaveBeenCalledWith(
"https://example.com/.well-known/jwks",
expect.objectContaining({ method: "GET" }),
);
});
it("should verify ES256 signed token", async () => {
const { publicJWK, privateKey, kid } = await createTestJWKS("ES256");
const token = await createSignedToken(privateKey, "ES256", kid);
mockedBetterFetch.mockResolvedValueOnce({
data: { keys: [publicJWK] },
error: null,
});
mockJWKSResponse(publicJWK);
const result = await validateToken(
token,
@@ -94,11 +106,7 @@ describe("validateToken", () => {
"Ed25519",
);
const token = await createSignedToken(privateKey, "EdDSA", kid);
mockedBetterFetch.mockResolvedValueOnce({
data: { keys: [publicJWK] },
error: null,
});
mockJWKSResponse(publicJWK);
const result = await validateToken(
token,
@@ -109,19 +117,15 @@ describe("validateToken", () => {
expect(result.payload.sub).toBe("user-123");
});
it("should throw 'Key not found' when kid doesn't match", async () => {
it("should throw when kid doesn't match any key", async () => {
const { publicJWK, privateKey } = await createTestJWKS("RS256");
publicJWK.kid = "different-kid";
const token = await createSignedToken(privateKey, "RS256", "original-kid");
mockedBetterFetch.mockResolvedValueOnce({
data: { keys: [publicJWK] },
error: null,
});
mockJWKSResponse(publicJWK);
await expect(
validateToken(token, "https://example.com/.well-known/jwks"),
).rejects.toThrow("Key not found");
).rejects.toThrow();
});
it("should find correct key when multiple keys exist", async () => {
@@ -129,11 +133,7 @@ describe("validateToken", () => {
const key2 = await createTestJWKS("RS256");
const key3 = await createTestJWKS("ES256");
const token = await createSignedToken(key2.privateKey, "RS256", key2.kid);
mockedBetterFetch.mockResolvedValueOnce({
data: { keys: [key1.publicJWK, key2.publicJWK, key3.publicJWK] },
error: null,
});
mockJWKSResponse(key1.publicJWK, key2.publicJWK, key3.publicJWK);
const result = await validateToken(
token,
@@ -147,28 +147,95 @@ describe("validateToken", () => {
it("should throw when JWKS returns empty keys array", async () => {
const { privateKey, kid } = await createTestJWKS("RS256");
const token = await createSignedToken(privateKey, "RS256", kid);
mockedBetterFetch.mockResolvedValueOnce({
data: { keys: [] },
error: null,
});
mockJWKSResponse();
await expect(
validateToken(token, "https://example.com/.well-known/jwks"),
).rejects.toThrow("Key not found");
).rejects.toThrow();
});
it("should throw when JWKS fetch fails", async () => {
const { privateKey, kid } = await createTestJWKS("RS256");
const token = await createSignedToken(privateKey, "RS256", kid);
mockedBetterFetch.mockResolvedValueOnce({
data: null,
error: { status: 500, statusText: "Internal Server Error" },
});
mockedFetch.mockResolvedValueOnce(
new Response("Internal Server Error", { status: 500 }),
);
await expect(
validateToken(token, "https://example.com/.well-known/jwks"),
).rejects.toBeDefined();
});
it("should verify token with matching audience", async () => {
const { publicJWK, privateKey, kid } = await createTestJWKS("RS256");
const token = await createSignedToken(privateKey, "RS256", kid);
mockJWKSResponse(publicJWK);
const result = await validateToken(
token,
"https://example.com/.well-known/jwks",
{ audience: "test-client" },
);
expect(result).toBeDefined();
expect(result.payload.aud).toBe("test-client");
});
it("should reject token with mismatched audience", async () => {
const { publicJWK, privateKey, kid } = await createTestJWKS("RS256");
const token = await createSignedToken(privateKey, "RS256", kid);
mockJWKSResponse(publicJWK);
await expect(
validateToken(token, "https://example.com/.well-known/jwks", {
audience: "wrong-client",
}),
).rejects.toThrow();
});
it("should verify token with matching issuer", async () => {
const { publicJWK, privateKey, kid } = await createTestJWKS("RS256");
const token = await createSignedToken(privateKey, "RS256", kid);
mockJWKSResponse(publicJWK);
const result = await validateToken(
token,
"https://example.com/.well-known/jwks",
{ issuer: "https://example.com" },
);
expect(result).toBeDefined();
expect(result.payload.iss).toBe("https://example.com");
});
it("should reject token with mismatched issuer", async () => {
const { publicJWK, privateKey, kid } = await createTestJWKS("RS256");
const token = await createSignedToken(privateKey, "RS256", kid);
mockJWKSResponse(publicJWK);
await expect(
validateToken(token, "https://example.com/.well-known/jwks", {
issuer: "https://wrong-issuer.com",
}),
).rejects.toThrow();
});
it("should verify token with both audience and issuer", async () => {
const { publicJWK, privateKey, kid } = await createTestJWKS("RS256");
const token = await createSignedToken(privateKey, "RS256", kid);
mockJWKSResponse(publicJWK);
const result = await validateToken(
token,
"https://example.com/.well-known/jwks",
{
audience: "test-client",
issuer: "https://example.com",
},
);
expect(result).toBeDefined();
expect(result.payload.aud).toBe("test-client");
expect(result.payload.iss).toBe("https://example.com");
});
});

View File

@@ -1588,6 +1588,10 @@ export const callbackSSO = (options?: SSOOptions) => {
const verified = await validateToken(
tokenResponse.idToken,
config.jwksEndpoint,
{
audience: config.clientId,
issuer: provider.issuer,
},
).catch((e) => {
ctx.context.logger.error(e);
return null;
@@ -1599,13 +1603,6 @@ export const callbackSSO = (options?: SSOOptions) => {
}?error=invalid_provider&error_description=token_not_verified`,
);
}
if (verified.payload.iss !== provider.issuer) {
throw ctx.redirect(
`${
errorURL || callbackURL
}?error=invalid_provider&error_description=issuer_mismatch`,
);
}
const mapping = config.mapping || {};
userInfo = {