This commit is contained in:
Bereket Engida
2025-11-06 14:39:01 -08:00
parent c8fcbed9b2
commit ea19399abf
6 changed files with 64 additions and 14 deletions

View File

@@ -35,7 +35,7 @@ export const getJwksAdapter = (
},
createJwk: async (ctx: GenericEndpointContext, webKey: Omit<Jwk, "id">) => {
if (options?.adapter?.createJwk) {
return await options.adapter.createJwk(ctx, webKey);
return await options.adapter.createJwk(webKey, ctx);
}
const jwk = await adapter.create<Omit<Jwk, "id">, Jwk>({
model: "jwks",

View File

@@ -131,15 +131,20 @@ export const jwt = (options?: JwtOptions | undefined) => {
throw new APIError("NOT_FOUND");
}
const adapter = getJwksAdapter(ctx.context.adapter, options);
const adapter = getJwksAdapter(ctx.context.adapter);
const keySets = await adapter.getAllKeys(ctx);
let keySets = await adapter.getAllKeys(ctx);
if (keySets.length === 0) {
if (!keySets || keySets?.length === 0) {
const key = await createJwk(ctx, options);
keySets.push(key);
keySets = [key];
}
if (!keySets?.length) {
throw new BetterAuthError(
"No key sets found. Make sure you have a key in your database.",
);
}
const keyPairConfig = options?.jwks?.keyPairConfig;
const defaultCrv = keyPairConfig
? "crv" in keyPairConfig

View File

@@ -4,7 +4,7 @@ import { createAuthClient } from "../../client";
import { getTestInstance } from "../../test-utils/test-instance";
import { jwt } from ".";
import { jwtClient } from "./client";
import type { JWKOptions, JwtOptions } from "./types";
import type { JWKOptions, Jwk, JwtOptions } from "./types";
import { generateExportedKeyPair } from "./utils";
describe("jwt", async () => {
@@ -701,3 +701,41 @@ describe("jwt - remote url", async () => {
expect(jwtHeader).toBeTruthy();
});
});
describe("jwt - custom adapter", async () => {
it("should use custom adapter", async () => {
const storage: Jwk[] = [];
const { auth } = await getTestInstance({
plugins: [
jwt({
adapter: {
getJwks: async () => {
return storage;
},
getLatestKey: async () => {
return storage[0] ?? null;
},
createJwk: async (data) => {
const key = {
...data,
id: crypto.randomUUID(),
createdAt: new Date(),
};
storage.push(key);
return key;
},
},
}),
],
});
const token = await auth.api.signJWT({
body: {
payload: {
sub: "123",
},
},
});
expect(token?.token).toBeDefined();
expect(storage.length).toBe(1);
});
});

View File

@@ -51,8 +51,11 @@ export async function signJWT(
return options.jwt.sign(jwtPayload);
}
const adapter = getJwksAdapter(ctx.context.adapter);
let key = await adapter.getLatestKey();
const adapter = getJwksAdapter(ctx.context.adapter, options);
let key = await adapter.getLatestKey(ctx);
if (!key) {
key = await createJwk(ctx, options);
}
const privateKeyEncryptionEnabled =
!options?.jwks?.disablePrivateKeyEncryption;

View File

@@ -127,7 +127,9 @@ export interface JwtOptions {
* @param ctx - The context of the request
* @returns The JWKS
*/
getJwks?: (ctx: GenericEndpointContext) => Promise<Jwk[]>;
getJwks?: (
ctx: GenericEndpointContext,
) => Promise<Jwk[] | null | undefined>;
/**
* A custom function to get the latest key from the database or
* other source
@@ -137,19 +139,21 @@ export interface JwtOptions {
* @param ctx - The context of the request
* @returns The latest key
*/
getLatestKey?: (ctx: GenericEndpointContext) => Promise<Jwk>;
getLatestKey?: (
ctx: GenericEndpointContext,
) => Promise<Jwk | null | undefined>;
/**
* A custom function to create a new key in the database or
* other source
*
* This will override the default createJwk from the database
*
* @param webKey - The web key to create
* @param data - The key to create
* @returns The created key
*/
createJwk?: (
ctx: GenericEndpointContext,
data: Omit<Jwk, "id">,
ctx: GenericEndpointContext,
) => Promise<Jwk>;
};
}

View File

@@ -162,8 +162,8 @@ export async function createJwk(
createdAt: new Date(),
};
const adapter = getJwksAdapter(ctx.context.adapter);
const key = await adapter.createJwk(jwk as Jwk);
const adapter = getJwksAdapter(ctx.context.adapter, options);
const key = await adapter.createJwk(ctx, jwk as Jwk);
return key;
}