mirror of
https://github.com/better-auth/better-auth.git
synced 2026-05-25 00:22:43 -05:00
fix(sso): validate aud claim in OpenID Connect ID tokens (#7816)
This commit is contained in:
committed by
GitHub
parent
07162c66b1
commit
01faefc2ce
@@ -1,7 +1,6 @@
|
||||
import { base64 } from "@better-auth/utils/base64";
|
||||
import { betterFetch } from "@better-fetch/fetch";
|
||||
import type { JWK } from "jose";
|
||||
import { decodeProtectedHeader, importJWK, jwtVerify } from "jose";
|
||||
import { createRemoteJWKSet, jwtVerify } from "jose";
|
||||
import type { AwaitableFunction } from "../types";
|
||||
import type { ProviderOptions } from "./index";
|
||||
import { getOAuth2Tokens } from "./index";
|
||||
@@ -164,25 +163,18 @@ export async function validateAuthorizationCode({
|
||||
return tokens;
|
||||
}
|
||||
|
||||
export async function validateToken(token: string, jwksEndpoint: string) {
|
||||
const { data, error } = await betterFetch<{
|
||||
keys: JWK[];
|
||||
}>(jwksEndpoint, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
accept: "application/json",
|
||||
},
|
||||
export async function validateToken(
|
||||
token: string,
|
||||
jwksEndpoint: string,
|
||||
options?: {
|
||||
audience?: string | string[];
|
||||
issuer?: string | string[];
|
||||
},
|
||||
) {
|
||||
const jwks = createRemoteJWKSet(new URL(jwksEndpoint));
|
||||
const verified = await jwtVerify(token, jwks, {
|
||||
audience: options?.audience,
|
||||
issuer: options?.issuer,
|
||||
});
|
||||
if (error) {
|
||||
throw error;
|
||||
}
|
||||
const keys = data["keys"];
|
||||
const header = decodeProtectedHeader(token);
|
||||
const key = keys.find((k) => k.kid === header.kid);
|
||||
if (!key) {
|
||||
throw new Error("Key not found");
|
||||
}
|
||||
const cryptoKey = await importJWK(key, header.alg);
|
||||
const verified = await jwtVerify(token, cryptoKey);
|
||||
return verified;
|
||||
}
|
||||
|
||||
@@ -1,18 +1,31 @@
|
||||
import type { JWK } from "jose";
|
||||
import { exportJWK, generateKeyPair, SignJWT } from "jose";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import {
|
||||
afterAll,
|
||||
beforeAll,
|
||||
beforeEach,
|
||||
describe,
|
||||
expect,
|
||||
it,
|
||||
vi,
|
||||
} from "vitest";
|
||||
import { validateToken } from "./validate-authorization-code";
|
||||
|
||||
vi.mock("@better-fetch/fetch", () => ({
|
||||
betterFetch: vi.fn(),
|
||||
}));
|
||||
|
||||
import { betterFetch } from "@better-fetch/fetch";
|
||||
|
||||
const mockedBetterFetch = vi.mocked(betterFetch);
|
||||
|
||||
describe("validateToken", () => {
|
||||
const originalFetch = globalThis.fetch;
|
||||
const mockedFetch = vi.fn() as unknown as typeof fetch &
|
||||
ReturnType<typeof vi.fn>;
|
||||
|
||||
beforeAll(() => {
|
||||
globalThis.fetch = mockedFetch;
|
||||
});
|
||||
|
||||
afterAll(() => {
|
||||
globalThis.fetch = originalFetch;
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockedFetch.mockReset();
|
||||
});
|
||||
|
||||
async function createTestJWKS(alg: string, crv?: string) {
|
||||
@@ -24,7 +37,9 @@ describe("validateToken", () => {
|
||||
const privateJWK = await exportJWK(privateKey);
|
||||
const kid = `test-key-${Date.now()}`;
|
||||
publicJWK.kid = kid;
|
||||
publicJWK.alg = alg;
|
||||
privateJWK.kid = kid;
|
||||
privateJWK.alg = alg;
|
||||
return { publicJWK, privateJWK, kid, publicKey, privateKey };
|
||||
}
|
||||
|
||||
@@ -47,14 +62,19 @@ describe("validateToken", () => {
|
||||
.sign(privateKey);
|
||||
}
|
||||
|
||||
function mockJWKSResponse(...publicJWKs: JWK[]) {
|
||||
mockedFetch.mockResolvedValueOnce(
|
||||
new Response(JSON.stringify({ keys: publicJWKs }), {
|
||||
status: 200,
|
||||
headers: { "content-type": "application/json" },
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
it("should verify RS256 signed token", async () => {
|
||||
const { publicJWK, privateKey, kid } = await createTestJWKS("RS256");
|
||||
const token = await createSignedToken(privateKey, "RS256", kid);
|
||||
|
||||
mockedBetterFetch.mockResolvedValueOnce({
|
||||
data: { keys: [publicJWK] },
|
||||
error: null,
|
||||
});
|
||||
mockJWKSResponse(publicJWK);
|
||||
|
||||
const result = await validateToken(
|
||||
token,
|
||||
@@ -64,20 +84,12 @@ describe("validateToken", () => {
|
||||
expect(result).toBeDefined();
|
||||
expect(result.payload.sub).toBe("user-123");
|
||||
expect(result.payload.email).toBe("test@example.com");
|
||||
expect(mockedBetterFetch).toHaveBeenCalledWith(
|
||||
"https://example.com/.well-known/jwks",
|
||||
expect.objectContaining({ method: "GET" }),
|
||||
);
|
||||
});
|
||||
|
||||
it("should verify ES256 signed token", async () => {
|
||||
const { publicJWK, privateKey, kid } = await createTestJWKS("ES256");
|
||||
const token = await createSignedToken(privateKey, "ES256", kid);
|
||||
|
||||
mockedBetterFetch.mockResolvedValueOnce({
|
||||
data: { keys: [publicJWK] },
|
||||
error: null,
|
||||
});
|
||||
mockJWKSResponse(publicJWK);
|
||||
|
||||
const result = await validateToken(
|
||||
token,
|
||||
@@ -94,11 +106,7 @@ describe("validateToken", () => {
|
||||
"Ed25519",
|
||||
);
|
||||
const token = await createSignedToken(privateKey, "EdDSA", kid);
|
||||
|
||||
mockedBetterFetch.mockResolvedValueOnce({
|
||||
data: { keys: [publicJWK] },
|
||||
error: null,
|
||||
});
|
||||
mockJWKSResponse(publicJWK);
|
||||
|
||||
const result = await validateToken(
|
||||
token,
|
||||
@@ -109,19 +117,15 @@ describe("validateToken", () => {
|
||||
expect(result.payload.sub).toBe("user-123");
|
||||
});
|
||||
|
||||
it("should throw 'Key not found' when kid doesn't match", async () => {
|
||||
it("should throw when kid doesn't match any key", async () => {
|
||||
const { publicJWK, privateKey } = await createTestJWKS("RS256");
|
||||
publicJWK.kid = "different-kid";
|
||||
const token = await createSignedToken(privateKey, "RS256", "original-kid");
|
||||
|
||||
mockedBetterFetch.mockResolvedValueOnce({
|
||||
data: { keys: [publicJWK] },
|
||||
error: null,
|
||||
});
|
||||
mockJWKSResponse(publicJWK);
|
||||
|
||||
await expect(
|
||||
validateToken(token, "https://example.com/.well-known/jwks"),
|
||||
).rejects.toThrow("Key not found");
|
||||
).rejects.toThrow();
|
||||
});
|
||||
|
||||
it("should find correct key when multiple keys exist", async () => {
|
||||
@@ -129,11 +133,7 @@ describe("validateToken", () => {
|
||||
const key2 = await createTestJWKS("RS256");
|
||||
const key3 = await createTestJWKS("ES256");
|
||||
const token = await createSignedToken(key2.privateKey, "RS256", key2.kid);
|
||||
|
||||
mockedBetterFetch.mockResolvedValueOnce({
|
||||
data: { keys: [key1.publicJWK, key2.publicJWK, key3.publicJWK] },
|
||||
error: null,
|
||||
});
|
||||
mockJWKSResponse(key1.publicJWK, key2.publicJWK, key3.publicJWK);
|
||||
|
||||
const result = await validateToken(
|
||||
token,
|
||||
@@ -147,28 +147,95 @@ describe("validateToken", () => {
|
||||
it("should throw when JWKS returns empty keys array", async () => {
|
||||
const { privateKey, kid } = await createTestJWKS("RS256");
|
||||
const token = await createSignedToken(privateKey, "RS256", kid);
|
||||
|
||||
mockedBetterFetch.mockResolvedValueOnce({
|
||||
data: { keys: [] },
|
||||
error: null,
|
||||
});
|
||||
mockJWKSResponse();
|
||||
|
||||
await expect(
|
||||
validateToken(token, "https://example.com/.well-known/jwks"),
|
||||
).rejects.toThrow("Key not found");
|
||||
).rejects.toThrow();
|
||||
});
|
||||
|
||||
it("should throw when JWKS fetch fails", async () => {
|
||||
const { privateKey, kid } = await createTestJWKS("RS256");
|
||||
const token = await createSignedToken(privateKey, "RS256", kid);
|
||||
|
||||
mockedBetterFetch.mockResolvedValueOnce({
|
||||
data: null,
|
||||
error: { status: 500, statusText: "Internal Server Error" },
|
||||
});
|
||||
mockedFetch.mockResolvedValueOnce(
|
||||
new Response("Internal Server Error", { status: 500 }),
|
||||
);
|
||||
|
||||
await expect(
|
||||
validateToken(token, "https://example.com/.well-known/jwks"),
|
||||
).rejects.toBeDefined();
|
||||
});
|
||||
|
||||
it("should verify token with matching audience", async () => {
|
||||
const { publicJWK, privateKey, kid } = await createTestJWKS("RS256");
|
||||
const token = await createSignedToken(privateKey, "RS256", kid);
|
||||
mockJWKSResponse(publicJWK);
|
||||
|
||||
const result = await validateToken(
|
||||
token,
|
||||
"https://example.com/.well-known/jwks",
|
||||
{ audience: "test-client" },
|
||||
);
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result.payload.aud).toBe("test-client");
|
||||
});
|
||||
|
||||
it("should reject token with mismatched audience", async () => {
|
||||
const { publicJWK, privateKey, kid } = await createTestJWKS("RS256");
|
||||
const token = await createSignedToken(privateKey, "RS256", kid);
|
||||
mockJWKSResponse(publicJWK);
|
||||
|
||||
await expect(
|
||||
validateToken(token, "https://example.com/.well-known/jwks", {
|
||||
audience: "wrong-client",
|
||||
}),
|
||||
).rejects.toThrow();
|
||||
});
|
||||
|
||||
it("should verify token with matching issuer", async () => {
|
||||
const { publicJWK, privateKey, kid } = await createTestJWKS("RS256");
|
||||
const token = await createSignedToken(privateKey, "RS256", kid);
|
||||
mockJWKSResponse(publicJWK);
|
||||
|
||||
const result = await validateToken(
|
||||
token,
|
||||
"https://example.com/.well-known/jwks",
|
||||
{ issuer: "https://example.com" },
|
||||
);
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result.payload.iss).toBe("https://example.com");
|
||||
});
|
||||
|
||||
it("should reject token with mismatched issuer", async () => {
|
||||
const { publicJWK, privateKey, kid } = await createTestJWKS("RS256");
|
||||
const token = await createSignedToken(privateKey, "RS256", kid);
|
||||
mockJWKSResponse(publicJWK);
|
||||
|
||||
await expect(
|
||||
validateToken(token, "https://example.com/.well-known/jwks", {
|
||||
issuer: "https://wrong-issuer.com",
|
||||
}),
|
||||
).rejects.toThrow();
|
||||
});
|
||||
|
||||
it("should verify token with both audience and issuer", async () => {
|
||||
const { publicJWK, privateKey, kid } = await createTestJWKS("RS256");
|
||||
const token = await createSignedToken(privateKey, "RS256", kid);
|
||||
mockJWKSResponse(publicJWK);
|
||||
|
||||
const result = await validateToken(
|
||||
token,
|
||||
"https://example.com/.well-known/jwks",
|
||||
{
|
||||
audience: "test-client",
|
||||
issuer: "https://example.com",
|
||||
},
|
||||
);
|
||||
|
||||
expect(result).toBeDefined();
|
||||
expect(result.payload.aud).toBe("test-client");
|
||||
expect(result.payload.iss).toBe("https://example.com");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1588,6 +1588,10 @@ export const callbackSSO = (options?: SSOOptions) => {
|
||||
const verified = await validateToken(
|
||||
tokenResponse.idToken,
|
||||
config.jwksEndpoint,
|
||||
{
|
||||
audience: config.clientId,
|
||||
issuer: provider.issuer,
|
||||
},
|
||||
).catch((e) => {
|
||||
ctx.context.logger.error(e);
|
||||
return null;
|
||||
@@ -1599,13 +1603,6 @@ export const callbackSSO = (options?: SSOOptions) => {
|
||||
}?error=invalid_provider&error_description=token_not_verified`,
|
||||
);
|
||||
}
|
||||
if (verified.payload.iss !== provider.issuer) {
|
||||
throw ctx.redirect(
|
||||
`${
|
||||
errorURL || callbackURL
|
||||
}?error=invalid_provider&error_description=issuer_mismatch`,
|
||||
);
|
||||
}
|
||||
|
||||
const mapping = config.mapping || {};
|
||||
userInfo = {
|
||||
|
||||
Reference in New Issue
Block a user