From efcb6e73ccd9fdf2fdd5cf4d6881dc416ce7173e Mon Sep 17 00:00:00 2001 From: Bereket Engida <86073083+Bekacru@users.noreply.github.com> Date: Sat, 16 Aug 2025 14:30:23 -0700 Subject: [PATCH] feat(jwt): sign with jwt for artbitrary payload (#4041) --- packages/better-auth/src/plugins/jwt/index.ts | 34 ++++++++++- .../better-auth/src/plugins/jwt/jwt.test.ts | 57 +++++++++++++++++++ packages/better-auth/src/plugins/jwt/sign.ts | 56 +++++++++++------- 3 files changed, 124 insertions(+), 23 deletions(-) diff --git a/packages/better-auth/src/plugins/jwt/index.ts b/packages/better-auth/src/plugins/jwt/index.ts index 05066ed73b..d1fec2343a 100644 --- a/packages/better-auth/src/plugins/jwt/index.ts +++ b/packages/better-auth/src/plugins/jwt/index.ts @@ -6,8 +6,8 @@ import type { } from "../../types"; import { type Jwk, schema } from "./schema"; import { getJwksAdapter } from "./adapter"; -import { getJwtToken } from "./sign"; -import { exportJWK, generateKeyPair, type JWK } from "jose"; +import { getJwtToken, signJWT } from "./sign"; +import { exportJWK, generateKeyPair, type JWK, type JWTPayload } from "jose"; import { createAuthEndpoint, createAuthMiddleware, @@ -15,6 +15,7 @@ import { } from "../../api"; import { symmetricEncrypt } from "../../crypto"; import { mergeSchema } from "../../db/schema"; +import z from "zod"; type JWKOptions = | { @@ -327,6 +328,35 @@ export const jwt = (options?: JwtOptions) => { }); }, ), + signJWT: createAuthEndpoint( + "/sign-jwt", + { + method: "POST", + metadata: { + SERVER_ONLY: true, + $Infer: { + body: {} as { + payload: JWTPayload; + overrideOptions?: JwtOptions; + }, + }, + }, + body: z.object({ + payload: z.record(z.string(), z.any()), + overrideOptions: z.record(z.string(), z.any()).optional(), + }), + }, + async (c) => { + const jwt = await signJWT(c, { + options: { + ...options, + ...c.body.overrideOptions, + }, + payload: c.body.payload, + }); + return c.json({ token: jwt }); + }, + ), }, hooks: { after: [ diff --git a/packages/better-auth/src/plugins/jwt/jwt.test.ts b/packages/better-auth/src/plugins/jwt/jwt.test.ts index 5e0254d5bb..43d62c5021 100644 --- a/packages/better-auth/src/plugins/jwt/jwt.test.ts +++ b/packages/better-auth/src/plugins/jwt/jwt.test.ts @@ -368,3 +368,60 @@ describe("jwt", async (it) => { } } }); + +describe("signJWT", async (it) => { + const { auth } = await getTestInstance({ + plugins: [jwt()], + logger: { + level: "error", + }, + }); + + it("should sign a JWT", async () => { + const jwt = await auth.api.signJWT({ + body: { + payload: { + sub: "123", + exp: 1000, + iat: 1000, + iss: "https://example.com", + aud: "https://example.com", + custom: "custom", + }, + }, + }); + expect(jwt?.token).toBeDefined(); + }); + + it("should be a valid JWT", async () => { + const jwt = await auth.api.signJWT({ + body: { + payload: { + sub: "123", + exp: 1000, + iat: 1000, + iss: "https://example.com", + aud: "https://example.com", + custom: "custom", + }, + }, + }); + const jwks = await auth.api.getJwks(); + const publicWebKey = await importJWK({ + ...jwks.keys[0], + alg: "EdDSA", + }); + const decoded = await jwtVerify(jwt?.token!, publicWebKey); + expect(decoded).toMatchObject({ + payload: { + iss: "https://example.com", + aud: "https://example.com", + sub: "123", + exp: expect.any(Number), + iat: expect.any(Number), + custom: "custom", + }, + protectedHeader: { alg: "EdDSA", kid: jwks.keys[0].kid }, + }); + }); +}); diff --git a/packages/better-auth/src/plugins/jwt/sign.ts b/packages/better-auth/src/plugins/jwt/sign.ts index 329d0f8a00..c8146c7dac 100644 --- a/packages/better-auth/src/plugins/jwt/sign.ts +++ b/packages/better-auth/src/plugins/jwt/sign.ts @@ -1,4 +1,4 @@ -import { importJWK, SignJWT } from "jose"; +import { importJWK, SignJWT, type JWTPayload } from "jose"; import type { GenericEndpointContext } from "../../types"; import { BetterAuthError } from "../../error"; import { symmetricDecrypt, symmetricEncrypt } from "../../crypto"; @@ -6,10 +6,14 @@ import { generateExportedKeyPair, type JwtOptions } from "."; import type { Jwk } from "./schema"; import { getJwksAdapter } from "./adapter"; -export async function getJwtToken( +export async function signJWT( ctx: GenericEndpointContext, - options?: JwtOptions, + config: { + options?: JwtOptions; + payload: JWTPayload; + }, ) { + const { options, payload } = config; const adapter = getJwksAdapter(ctx.context.adapter); let key = await adapter.getLatestKey(); @@ -49,29 +53,39 @@ export async function getJwtToken( ); }) : key.privateKey; + const alg = options?.jwks?.keyPairConfig?.alg ?? "EdDSA"; + const privateKey = await importJWK(JSON.parse(privateWebKey), alg); - const privateKey = await importJWK( - JSON.parse(privateWebKey), - options?.jwks?.keyPairConfig?.alg ?? "EdDSA", - ); + const jwt = await new SignJWT({ + iss: options?.jwt?.issuer ?? ctx.context.options.baseURL!, + aud: options?.jwt?.audience ?? ctx.context.options.baseURL!, + ...payload, + }) + .setIssuedAt() + .setExpirationTime(options?.jwt?.expirationTime ?? "15m") + .setProtectedHeader({ + alg, + kid: key.id, + }) + .sign(privateKey); + return jwt; +} +export async function getJwtToken( + ctx: GenericEndpointContext, + options?: JwtOptions, +) { const payload = !options?.jwt?.definePayload ? ctx.context.session!.user : await options?.jwt.definePayload(ctx.context.session!); - const jwt = await new SignJWT(payload) - .setProtectedHeader({ - alg: options?.jwks?.keyPairConfig?.alg ?? "EdDSA", - kid: key.id, - }) - .setIssuedAt() - .setIssuer(options?.jwt?.issuer ?? ctx.context.options.baseURL!) - .setAudience(options?.jwt?.audience ?? ctx.context.options.baseURL!) - .setExpirationTime(options?.jwt?.expirationTime ?? "15m") - .setSubject( - (await options?.jwt?.getSubject?.(ctx.context.session!)) ?? + return await signJWT(ctx, { + options, + payload: { + ...payload, + sub: + (await options?.jwt?.getSubject?.(ctx.context.session!)) ?? ctx.context.session!.user.id, - ) - .sign(privateKey); - return jwt; + }, + }); }