diff --git a/packages/better-auth/src/plugins/two-factor/backup-codes/index.ts b/packages/better-auth/src/plugins/two-factor/backup-codes/index.ts index 271c83825f..967d77d2ec 100644 --- a/packages/better-auth/src/plugins/two-factor/backup-codes/index.ts +++ b/packages/better-auth/src/plugins/two-factor/backup-codes/index.ts @@ -4,8 +4,13 @@ import { createAuthEndpoint } from "../../../api/call"; import { sessionMiddleware } from "../../../api"; import { symmetricDecrypt, symmetricEncrypt } from "../../../crypto"; import { verifyTwoFactorMiddleware } from "../verify-middleware"; -import type { TwoFactorProvider, UserWithTwoFactor } from "../types"; +import type { + TwoFactorProvider, + TwoFactorTable, + UserWithTwoFactor, +} from "../types"; import { APIError } from "better-call"; +import { setSessionCookie } from "../../../cookies"; export interface BackupCodeOptions { /** @@ -52,21 +57,21 @@ export async function generateBackupCodes( export async function verifyBackupCode( data: { - user: UserWithTwoFactor; + backupCodes: string; code: string; }, key: string, ) { - const codes = await getBackupCodes(data.user, key); + const codes = await getBackupCodes(data.backupCodes, key); if (!codes) { return false; } return codes.includes(data.code); } -export async function getBackupCodes(user: UserWithTwoFactor, key: string) { +export async function getBackupCodes(backupCodes: string, key: string) { const secret = Buffer.from( - await symmetricDecrypt({ key, data: user.twoFactorBackupCodes }), + await symmetricDecrypt({ key, data: backupCodes }), ).toString("utf-8"); const data = JSON.parse(secret); const result = z.array(z.string()).safeParse(data); @@ -76,7 +81,10 @@ export async function getBackupCodes(user: UserWithTwoFactor, key: string) { return null; } -export const backupCode2fa = (options?: BackupCodeOptions) => { +export const backupCode2fa = ( + options: BackupCodeOptions, + twoFactorTable: string, +) => { return { id: "backup_code", endpoints: { @@ -87,13 +95,32 @@ export const backupCode2fa = (options?: BackupCodeOptions) => { method: "POST", body: z.object({ code: z.string(), + /** + * Disable setting the session cookie + */ + disableSession: z.boolean().optional(), }), use: [verifyTwoFactorMiddleware], }, async (ctx) => { + const user = ctx.context.session.user as UserWithTwoFactor; + const twoFactor = await ctx.context.adapter.findOne({ + model: twoFactorTable, + where: [ + { + field: "userId", + value: user.id, + }, + ], + }); + if (!twoFactor) { + throw new APIError("BAD_REQUEST", { + message: "Backup codes aren't enabled", + }); + } const validate = verifyBackupCode( { - user: ctx.context.session.user, + backupCodes: twoFactor.backupCodes, code: ctx.body.code, }, ctx.context.secret, @@ -103,7 +130,13 @@ export const backupCode2fa = (options?: BackupCodeOptions) => { message: "Invalid backup code", }); } - return ctx.json({ status: true }); + if (!ctx.body.disableSession) { + await setSessionCookie(ctx, ctx.context.session.id); + } + return ctx.json({ + user: user, + session: ctx.context.session, + }); }, ), generateBackupCodes: createAuthEndpoint( @@ -113,19 +146,24 @@ export const backupCode2fa = (options?: BackupCodeOptions) => { use: [sessionMiddleware], }, async (ctx) => { + const user = ctx.context.session.user as UserWithTwoFactor; + if (!user.twoFactorEnabled) { + throw new APIError("BAD_REQUEST", { + message: "Two factor isn't enabled", + }); + } const backupCodes = await generateBackupCodes( ctx.context.secret, options, ); await ctx.context.adapter.update({ - model: "user", + model: twoFactorTable, update: { - twoFactorEnabled: true, - twoFactorBackupCodes: backupCodes.encryptedBackupCodes, + backupCodes: backupCodes.encryptedBackupCodes, }, where: [ { - field: "id", + field: "userId", value: ctx.context.session.user.id, }, ], @@ -144,7 +182,29 @@ export const backupCode2fa = (options?: BackupCodeOptions) => { }, async (ctx) => { const user = ctx.context.session.user as UserWithTwoFactor; - const backupCodes = getBackupCodes(user, ctx.context.secret); + const twoFactor = await ctx.context.adapter.findOne({ + model: twoFactorTable, + where: [ + { + field: "userId", + value: user.id, + }, + ], + }); + if (!twoFactor) { + throw new APIError("BAD_REQUEST", { + message: "Backup codes aren't enabled", + }); + } + const backupCodes = getBackupCodes( + twoFactor.backupCodes, + ctx.context.secret, + ); + if (!backupCodes) { + throw new APIError("BAD_REQUEST", { + message: "Backup codes aren't enabled", + }); + } return ctx.json({ status: true, backupCodes: backupCodes, diff --git a/packages/better-auth/src/plugins/two-factor/client.ts b/packages/better-auth/src/plugins/two-factor/client.ts index 1e003c6e7b..0ab5c613d2 100644 --- a/packages/better-auth/src/plugins/two-factor/client.ts +++ b/packages/better-auth/src/plugins/two-factor/client.ts @@ -31,6 +31,7 @@ export const twoFactorClient = ( "/two-factor/disable": "POST", "/two-factor/enable": "POST", "/two-factor/send-otp": "POST", + "/two-factor/generate-backup-codes": "POST", }, fetchPlugins: [ { diff --git a/packages/better-auth/src/plugins/two-factor/index.ts b/packages/better-auth/src/plugins/two-factor/index.ts index 0515bf93c5..6e5a9151d6 100644 --- a/packages/better-auth/src/plugins/two-factor/index.ts +++ b/packages/better-auth/src/plugins/two-factor/index.ts @@ -14,12 +14,28 @@ import { validatePassword } from "../../utils/password"; import { APIError } from "better-call"; export const twoFactor = (options?: TwoFactorOptions) => { - const totp = totp2fa({ - issuer: options?.issuer || "better-auth", - ...options?.totpOptions, - }); - const backupCode = backupCode2fa(options?.backupCodeOptions); - const otp = otp2fa(options?.otpOptions); + const opts = { + twoFactorTable: options?.twoFactorTable || "twoFactor", + }; + const totp = totp2fa( + { + issuer: options?.issuer || "better-auth", + ...options?.totpOptions, + }, + opts.twoFactorTable, + ); + const backupCode = backupCode2fa( + { + ...options?.backupCodeOptions, + }, + opts.twoFactorTable, + ); + const otp = otp2fa( + { + ...options?.otpOptions, + }, + opts.twoFactorTable, + ); return { id: "two-factor", endpoints: { @@ -56,19 +72,17 @@ export const twoFactor = (options?: TwoFactorOptions) => { ctx.context.secret, options?.backupCodeOptions, ); - await ctx.context.adapter.update({ - model: "user", - update: { - twoFactorSecret: encryptedSecret, - twoFactorEnabled: true, - twoFactorBackupCodes: backupCodes.encryptedBackupCodes, + await ctx.context.internalAdapter.updateUser(user.id, { + twoFactorEnabled: true, + }); + + const res = await ctx.context.adapter.create({ + model: opts.twoFactorTable, + data: { + secret: encryptedSecret, + backupCodes: backupCodes.encryptedBackupCodes, + userId: user.id, }, - where: [ - { - field: "id", - value: user.id, - }, - ], }); return ctx.json({ status: true }); }, @@ -94,14 +108,14 @@ export const twoFactor = (options?: TwoFactorOptions) => { message: "Invalid password", }); } - await ctx.context.adapter.update({ - model: "user", - update: { - twoFactorEnabled: false, - }, + await ctx.context.internalAdapter.updateUser(user.id, { + twoFactorEnabled: false, + }); + await ctx.context.adapter.delete({ + model: opts.twoFactorTable, where: [ { - field: "id", + field: "userId", value: user.id, }, ], @@ -220,16 +234,29 @@ export const twoFactor = (options?: TwoFactorOptions) => { required: false, defaultValue: false, }, - twoFactorSecret: { + }, + }, + twoFactor: { + fields: { + secret: { type: "string", - required: false, + required: true, returned: false, }, - twoFactorBackupCodes: { + backupCodes: { type: "string", - required: false, + required: true, returned: false, }, + userId: { + type: "string", + required: true, + returned: false, + references: { + model: "user", + field: "id", + }, + }, }, }, }, diff --git a/packages/better-auth/src/plugins/two-factor/otp/index.ts b/packages/better-auth/src/plugins/two-factor/otp/index.ts index 99b28f4808..d025dbf177 100644 --- a/packages/better-auth/src/plugins/two-factor/otp/index.ts +++ b/packages/better-auth/src/plugins/two-factor/otp/index.ts @@ -3,7 +3,11 @@ import { TOTPController } from "oslo/otp"; import { z } from "zod"; import { createAuthEndpoint } from "../../../api/call"; import { verifyTwoFactorMiddleware } from "../verify-middleware"; -import type { TwoFactorProvider, UserWithTwoFactor } from "../types"; +import type { + TwoFactorProvider, + TwoFactorTable, + UserWithTwoFactor, +} from "../types"; import { TimeSpan } from "oslo"; export interface OTPOptions { @@ -27,8 +31,9 @@ export interface OTPOptions { /** * The otp adapter is created from the totp adapter. */ -export const otp2fa = (options?: OTPOptions) => { +export const otp2fa = (options: OTPOptions, twoFactorTable: string) => { const opts = { + ...options, period: new TimeSpan(options?.period || 3, "m"), }; const totp = new TOTPController({ @@ -54,7 +59,21 @@ export const otp2fa = (options?: OTPOptions) => { }); } const user = ctx.context.session.user as UserWithTwoFactor; - const code = await totp.generate(Buffer.from(user.twoFactorSecret)); + const twoFactor = await ctx.context.adapter.findOne({ + model: twoFactorTable, + where: [ + { + field: "userId", + value: user.id, + }, + ], + }); + if (!twoFactor) { + throw new APIError("BAD_REQUEST", { + message: "totp isn't enabled", + }); + } + const code = await totp.generate(Buffer.from(twoFactor.secret)); await options.sendOTP(user, code); return ctx.json({ status: true }); }, @@ -76,7 +95,21 @@ export const otp2fa = (options?: OTPOptions) => { message: "two factor isn't enabled", }); } - const toCheckOtp = await totp.generate(Buffer.from(user.twoFactorSecret)); + const twoFactor = await ctx.context.adapter.findOne({ + model: twoFactorTable, + where: [ + { + field: "userId", + value: user.id, + }, + ], + }); + if (!twoFactor) { + throw new APIError("BAD_REQUEST", { + message: "totp isn't enabled", + }); + } + const toCheckOtp = await totp.generate(Buffer.from(twoFactor.secret)); if (toCheckOtp === ctx.body.code) { return ctx.context.valid(); } else { diff --git a/packages/better-auth/src/plugins/two-factor/totp/index.ts b/packages/better-auth/src/plugins/two-factor/totp/index.ts index e2fc1bfe2f..a9d9d52337 100644 --- a/packages/better-auth/src/plugins/two-factor/totp/index.ts +++ b/packages/better-auth/src/plugins/two-factor/totp/index.ts @@ -7,7 +7,11 @@ import { sessionMiddleware } from "../../../api"; import { symmetricDecrypt } from "../../../crypto"; import type { BackupCodeOptions } from "../backup-codes"; import { verifyTwoFactorMiddleware } from "../verify-middleware"; -import type { TwoFactorProvider, UserWithTwoFactor } from "../types"; +import type { + TwoFactorProvider, + TwoFactorTable, + UserWithTwoFactor, +} from "../types"; export type TOTPOptions = { /** @@ -31,8 +35,9 @@ export type TOTPOptions = { backupCodes?: BackupCodeOptions; }; -export const totp2fa = (options: TOTPOptions) => { +export const totp2fa = (options: TOTPOptions, twoFactorTable: string) => { const opts = { + ...options, digits: 6, period: new TimeSpan(options?.period || 30, "s"), }; @@ -52,9 +57,23 @@ export const totp2fa = (options: TOTPOptions) => { message: "totp isn't configured", }); } - const session = ctx.context.session.user as UserWithTwoFactor; + const user = ctx.context.session.user as UserWithTwoFactor; + const twoFactor = await ctx.context.adapter.findOne({ + model: twoFactorTable, + where: [ + { + field: "userId", + value: user.id, + }, + ], + }); + if (!twoFactor) { + throw new APIError("BAD_REQUEST", { + message: "totp isn't enabled", + }); + } const totp = new TOTPController(opts); - const code = await totp.generate(Buffer.from(session.twoFactorSecret)); + const code = await totp.generate(Buffer.from(twoFactor.secret)); return { code }; }, ); @@ -75,7 +94,16 @@ export const totp2fa = (options: TOTPOptions) => { }); } const user = ctx.context.session.user as UserWithTwoFactor; - if (!user.twoFactorSecret) { + const twoFactor = await ctx.context.adapter.findOne({ + model: twoFactorTable, + where: [ + { + field: "userId", + value: user.id, + }, + ], + }); + if (!twoFactor || !user.twoFactorEnabled) { throw new APIError("BAD_REQUEST", { message: "totp isn't enabled", }); @@ -84,7 +112,7 @@ export const totp2fa = (options: TOTPOptions) => { totpURI: createTOTPKeyURI( options?.issuer || "BetterAuth", user.email, - Buffer.from(user.twoFactorSecret), + Buffer.from(twoFactor.secret), opts, ), }; @@ -110,10 +138,25 @@ export const totp2fa = (options: TOTPOptions) => { message: "totp isn't configured", }); } + const user = ctx.context.session.user as UserWithTwoFactor; + const twoFactor = await ctx.context.adapter.findOne({ + model: twoFactorTable, + where: [ + { + field: "userId", + value: user.id, + }, + ], + }); + if (!twoFactor || !twoFactor.enabled) { + throw new APIError("BAD_REQUEST", { + message: "totp isn't enabled", + }); + } const totp = new TOTPController(opts); const decrypted = await symmetricDecrypt({ key: ctx.context.secret, - data: ctx.context.session.user.twoFactorSecret, + data: twoFactor.secret, }); const secret = Buffer.from(decrypted); const status = await totp.verify(ctx.body.code, secret); diff --git a/packages/better-auth/src/plugins/two-factor/two-factor.test.ts b/packages/better-auth/src/plugins/two-factor/two-factor.test.ts index d0aadf014d..bda4d8b21e 100644 --- a/packages/better-auth/src/plugins/two-factor/two-factor.test.ts +++ b/packages/better-auth/src/plugins/two-factor/two-factor.test.ts @@ -3,7 +3,7 @@ import { getTestInstance } from "../../test-utils/test-instance"; import { twoFactor, twoFactorClient } from "."; import { createAuthClient } from "../../client"; import { parseSetCookieHeader } from "../../cookies"; -import type { UserWithTwoFactor } from "./types"; +import type { TwoFactorTable, UserWithTwoFactor } from "./types"; describe("two factor", async () => { let OTP = ""; @@ -39,6 +39,7 @@ describe("two factor", async () => { if (!session) { throw new Error("No session"); } + it("should enable two factor", async () => { const res = await client.twoFactor.enable({ password: testUser.password, @@ -57,10 +58,18 @@ describe("two factor", async () => { }, ], }); - + const twoFactor = await db.findOne({ + model: "twoFactor", + where: [ + { + field: "userId", + value: session.data?.user.id as string, + }, + ], + }); expect(dbUser?.twoFactorEnabled).toBe(true); - expect(dbUser?.twoFactorSecret).toBeDefined(); - expect(dbUser?.twoFactorBackupCodes).toBeDefined(); + expect(twoFactor?.secret).toBeDefined(); + expect(twoFactor?.backupCodes).toBeDefined(); }); it("should require two factor", async () => { @@ -72,6 +81,8 @@ describe("two factor", async () => { const parsed = parseSetCookieHeader( context.response.headers.get("Set-Cookie") || "", ); + expect(parsed.get("better-auth.session_token")?.value).toBe(""); + expect(parsed.get("better-auth.two-factor")?.value).toBeDefined(); headers.append( "cookie", `better-auth.two-factor=${ @@ -111,7 +122,54 @@ describe("two factor", async () => { }, }, }); - expect(verifyRes.data?.status).toBe(true); + expect(verifyRes.data?.session).toBeDefined(); + }); + + let backupCodes: string[] = []; + it("should generate backup codes", async () => { + await client.twoFactor.enable({ + password: testUser.password, + fetchOptions: { + headers, + }, + }); + const backupCodesRes = await client.twoFactor.generateBackupCodes({ + fetchOptions: { + headers, + }, + }); + expect(backupCodesRes.data?.backupCodes).toBeDefined(); + backupCodes = backupCodesRes.data?.backupCodes || []; + }); + + it("should allow sign in with backup code", async () => { + await client.signIn.email({ + email: testUser.email, + password: testUser.password, + fetchOptions: { + onSuccess(context) { + const parsed = parseSetCookieHeader( + context.response.headers.get("Set-Cookie") || "", + ); + const token = parsed.get("better-auth.session_token")?.value; + expect(token).toBe(""); + }, + }, + }); + const backupCode = backupCodes[0]; + await client.twoFactor.verifyBackupCode({ + code: backupCode, + fetchOptions: { + headers, + onSuccess(context) { + const parsed = parseSetCookieHeader( + context.response.headers.get("Set-Cookie") || "", + ); + const token = parsed.get("better-auth.session_token")?.value; + expect(token?.length).toBeGreaterThan(0); + }, + }, + }); }); it("should trust device", async () => { @@ -198,5 +256,11 @@ describe("two factor", async () => { ], }); expect(dbUser?.twoFactorEnabled).toBe(false); + + const signInRes = await client.signIn.email({ + email: testUser.email, + password: testUser.password, + }); + expect(signInRes.data?.user).toBeDefined(); }); }); diff --git a/packages/better-auth/src/plugins/two-factor/types.ts b/packages/better-auth/src/plugins/two-factor/types.ts index 20b8502168..626d211296 100644 --- a/packages/better-auth/src/plugins/two-factor/types.ts +++ b/packages/better-auth/src/plugins/two-factor/types.ts @@ -22,6 +22,11 @@ export interface TwoFactorOptions { * Backup code options */ backupCodeOptions?: BackupCodeOptions; + /** + * Table name for two factor authentication. + * @default "userTwoFactor" + */ + twoFactorTable?: string; } export interface UserWithTwoFactor extends User { @@ -29,18 +34,16 @@ export interface UserWithTwoFactor extends User { * If the user has enabled two factor authentication. */ twoFactorEnabled: boolean; - /** - * The secret used to generate the TOTP or OTP. - */ - twoFactorSecret: string; - /** - * List of backup codes separated by a - * comma - */ - twoFactorBackupCodes: string; } export interface TwoFactorProvider { id: LiteralString; endpoints?: Record; } + +export interface TwoFactorTable { + userId: string; + secret: string; + backupCodes: string; + enabled: boolean; +} diff --git a/packages/better-auth/src/plugins/two-factor/verify-middleware.ts b/packages/better-auth/src/plugins/two-factor/verify-middleware.ts index dfa54bd275..febe712ee0 100644 --- a/packages/better-auth/src/plugins/two-factor/verify-middleware.ts +++ b/packages/better-auth/src/plugins/two-factor/verify-middleware.ts @@ -106,9 +106,14 @@ export const verifyTwoFactorMiddleware = createAuthMiddleware( status: true, callbackURL: ctx.body.callbackURL, redirect: true, + session, + user, }); } - return ctx.json({ status: true }); + return ctx.json({ + session, + user, + }); }, invalid: async () => { throw new APIError("UNAUTHORIZED", {