From a247bae3454c349a38b107b9b481a9345fbc8a4e Mon Sep 17 00:00:00 2001 From: Bereket Engida Date: Sun, 20 Oct 2024 17:27:31 +0300 Subject: [PATCH] refactor: redirect middleware and origins --- .../better-auth/src/api/middlewares/csrf.ts | 9 ++- .../src/api/middlewares/redirect.test.ts | 7 +-- .../src/api/middlewares/redirect.ts | 60 ++++++++----------- packages/better-auth/src/auth.ts | 2 +- packages/better-auth/src/client/config.ts | 2 +- packages/better-auth/src/init.ts | 4 +- packages/better-auth/src/oauth2/utils.ts | 2 +- .../src/plugins/anonymous/index.ts | 2 +- .../src/social-providers/social.test.ts | 2 +- .../src/test-utils/test-instance.ts | 2 +- .../src/utils/{base-url.ts => url.ts} | 0 11 files changed, 42 insertions(+), 50 deletions(-) rename packages/better-auth/src/utils/{base-url.ts => url.ts} (100%) diff --git a/packages/better-auth/src/api/middlewares/csrf.ts b/packages/better-auth/src/api/middlewares/csrf.ts index 1dcd9e54ac..959ccfc677 100644 --- a/packages/better-auth/src/api/middlewares/csrf.ts +++ b/packages/better-auth/src/api/middlewares/csrf.ts @@ -18,14 +18,17 @@ export const csrfMiddleware = createAuthMiddleware( ) { return; } - const origin = ctx.headers?.get("origin") || ""; + const originHeader = ctx.headers?.get("origin") || ""; /** * If origin is the same as baseURL or if the * origin is in the trustedOrigins then we * don't need to check the CSRF token. */ - if (ctx.context.trustedOrigins.includes(origin)) { - return; + if (originHeader) { + const origin = new URL(originHeader).origin; + if (ctx.context.trustedOrigins.includes(origin)) { + return; + } } const csrfToken = ctx.body?.csrfToken; diff --git a/packages/better-auth/src/api/middlewares/redirect.test.ts b/packages/better-auth/src/api/middlewares/redirect.test.ts index 7d3593d910..c001a87589 100644 --- a/packages/better-auth/src/api/middlewares/redirect.test.ts +++ b/packages/better-auth/src/api/middlewares/redirect.test.ts @@ -23,7 +23,7 @@ describe("redirectURLMiddleware", async (it) => { callbackURL: "http://malicious.com", }); expect(res.error?.status).toBe(403); - expect(res.error?.message).toBe("Invalid callback URL"); + expect(res.error?.message).toBe("Invalid callbackURL"); }); it("should allow trusted origins", async (ctx) => { @@ -58,7 +58,6 @@ describe("redirectURLMiddleware", async (it) => { }, }); expect(res.error?.status).toBe(403); - expect(res.error?.message).toBe("Invalid callback URL"); const res2 = await client.signIn.email({ email: testUser.email, @@ -71,7 +70,7 @@ describe("redirectURLMiddleware", async (it) => { }, }); expect(res2.error?.status).toBe(403); - expect(res2.error?.message).toBe("Invalid callback URL"); + expect(res2.error?.message).toBe("Invalid currentURL"); }); it("shouldn't allow untrusted redirectTo", async (ctx) => { @@ -86,6 +85,6 @@ describe("redirectURLMiddleware", async (it) => { redirectTo: "http://malicious.com", }); expect(res.error?.status).toBe(403); - expect(res.error?.message).toBe("Invalid callback URL"); + expect(res.error?.message).toBe("Invalid callbackURL"); }); }); diff --git a/packages/better-auth/src/api/middlewares/redirect.ts b/packages/better-auth/src/api/middlewares/redirect.ts index af2813ef97..461ae4fa3d 100644 --- a/packages/better-auth/src/api/middlewares/redirect.ts +++ b/packages/better-auth/src/api/middlewares/redirect.ts @@ -3,43 +3,33 @@ import { createAuthMiddleware } from "../call"; import { logger } from "../../utils/logger"; /** - * This middleware is used to validate the callbackURL and currentURL. - * It checks if the callbackURL is a valid URL and if it's in the trustedOrigins - * to avoid open redirect attacks. + * Middleware to validate callbackURL and currentURL against trustedOrigins, + * preventing open redirect attacks. */ export const redirectURLMiddleware = createAuthMiddleware(async (ctx) => { - const callbackURL = - ctx.body?.callbackURL || - ctx.query?.callbackURL || - ctx.query?.redirectTo || - ctx.body?.redirectTo; - const clientCurrentURL = ctx.headers?.get("referer"); - const currentURL = - ctx.query?.currentURL || clientCurrentURL || ctx.context.baseURL; - const trustedOrigins = ctx.context.trustedOrigins; + const { body, query, headers, context } = ctx; - if (callbackURL?.includes("http")) { - const callbackOrigin = new URL(callbackURL).origin; - if (!trustedOrigins.includes(callbackOrigin)) { - logger.error("Invalid callback URL", { - callbackURL, - trustedOrigins, - }); - throw new APIError("FORBIDDEN", { - message: "Invalid callback URL", - }); + const callbackURL = + body?.callbackURL || + query?.callbackURL || + query?.redirectTo || + body?.redirectTo; + const currentURL = + query?.currentURL || headers?.get("referer") || context.baseURL; + const trustedOrigins = context.trustedOrigins; + + const validateURL = (url: string | undefined, label: string) => { + if (url?.startsWith("http")) { + const isTrustedOrigin = trustedOrigins.some((origin) => + url.startsWith(origin), + ); + if (!isTrustedOrigin) { + logger.error(`Invalid ${label}`, { [label]: url, trustedOrigins }); + throw new APIError("FORBIDDEN", { message: `Invalid ${label}` }); + } } - } - if (currentURL !== ctx.context.baseURL) { - const currentURLOrigin = new URL(currentURL).origin; - if (!trustedOrigins.includes(currentURLOrigin)) { - logger.error("Invalid current URL", { - currentURL, - trustedOrigins, - }); - throw new APIError("FORBIDDEN", { - message: "Invalid callback URL", - }); - } - } + }; + + validateURL(callbackURL, "callbackURL"); + validateURL(currentURL, "currentURL"); }); diff --git a/packages/better-auth/src/auth.ts b/packages/better-auth/src/auth.ts index 842c5f305b..0f760985b6 100644 --- a/packages/better-auth/src/auth.ts +++ b/packages/better-auth/src/auth.ts @@ -3,7 +3,7 @@ import { getEndpoints, router } from "./api"; import { init } from "./init"; import type { BetterAuthOptions } from "./types/options"; import type { InferPluginTypes, InferSession, InferUser } from "./types"; -import { getBaseURL } from "./utils/base-url"; +import { getBaseURL } from "./utils/url"; type InferAPI = Omit< API, diff --git a/packages/better-auth/src/client/config.ts b/packages/better-auth/src/client/config.ts index 8e42868cef..59bfb0acdd 100644 --- a/packages/better-auth/src/client/config.ts +++ b/packages/better-auth/src/client/config.ts @@ -1,5 +1,5 @@ import { createFetch } from "@better-fetch/fetch"; -import { getBaseURL } from "../utils/base-url"; +import { getBaseURL } from "../utils/url"; import { type Atom } from "nanostores"; import type { AtomListener, ClientOptions } from "./types"; import { addCurrentURL, csrfPlugin, redirectPlugin } from "./fetch-plugins"; diff --git a/packages/better-auth/src/init.ts b/packages/better-auth/src/init.ts index 7b9077115f..89121bab24 100644 --- a/packages/better-auth/src/init.ts +++ b/packages/better-auth/src/init.ts @@ -12,7 +12,7 @@ import type { SecondaryStorage, } from "./types"; import { defu } from "defu"; -import { getBaseURL } from "./utils/base-url"; +import { getBaseURL } from "./utils/url"; import { DEFAULT_SECRET } from "./utils/constants"; import { type BetterAuthCookies, @@ -216,7 +216,7 @@ function getTrustedOrigins(options: BetterAuthOptions) { "Base URL can not be empty. Please add `BETTER_AUTH_URL` in your environment variables or pass it in your auth config.", ); } - const trustedOrigins = [baseURL]; + const trustedOrigins = [new URL(baseURL).origin]; if (options.trustedOrigins) { trustedOrigins.push(...options.trustedOrigins); } diff --git a/packages/better-auth/src/oauth2/utils.ts b/packages/better-auth/src/oauth2/utils.ts index 6dc36cb547..7387577582 100644 --- a/packages/better-auth/src/oauth2/utils.ts +++ b/packages/better-auth/src/oauth2/utils.ts @@ -1,5 +1,5 @@ import { sha256 } from "oslo/crypto"; -import { getBaseURL } from "../utils/base-url"; +import { getBaseURL } from "../utils/url"; import { base64url } from "oslo/encoding"; import type { OAuth2Tokens } from "./types"; diff --git a/packages/better-auth/src/plugins/anonymous/index.ts b/packages/better-auth/src/plugins/anonymous/index.ts index ba2692370a..516f291f87 100644 --- a/packages/better-auth/src/plugins/anonymous/index.ts +++ b/packages/better-auth/src/plugins/anonymous/index.ts @@ -3,7 +3,7 @@ import type { BetterAuthPlugin } from "../../types"; import { setSessionCookie } from "../../cookies"; import { z } from "zod"; import { generateId } from "../../utils/id"; -import { getOrigin } from "../../utils/base-url"; +import { getOrigin } from "../../utils/url"; export interface AnonymousOptions { /** diff --git a/packages/better-auth/src/social-providers/social.test.ts b/packages/better-auth/src/social-providers/social.test.ts index b1abab924a..ecd6522acd 100644 --- a/packages/better-auth/src/social-providers/social.test.ts +++ b/packages/better-auth/src/social-providers/social.test.ts @@ -133,7 +133,7 @@ describe("Social Providers", async () => { }, ); expect(signInRes.error?.status).toBe(403); - expect(signInRes.error?.message).toBe("Invalid callback URL"); + expect(signInRes.error?.message).toBe("Invalid callbackURL"); }); }); diff --git a/packages/better-auth/src/test-utils/test-instance.ts b/packages/better-auth/src/test-utils/test-instance.ts index 76f0a57891..44ff937f51 100644 --- a/packages/better-auth/src/test-utils/test-instance.ts +++ b/packages/better-auth/src/test-utils/test-instance.ts @@ -9,7 +9,7 @@ import { parseSetCookieHeader } from "../cookies"; import type { SuccessContext } from "@better-fetch/fetch"; import { getAdapter } from "../db/utils"; import Database from "better-sqlite3"; -import { getBaseURL } from "../utils/base-url"; +import { getBaseURL } from "../utils/url"; export async function getTestInstance< O extends Partial, diff --git a/packages/better-auth/src/utils/base-url.ts b/packages/better-auth/src/utils/url.ts similarity index 100% rename from packages/better-auth/src/utils/base-url.ts rename to packages/better-auth/src/utils/url.ts