diff --git a/packages/better-auth/src/auth/auth.test.ts b/packages/better-auth/src/auth/auth.test.ts index 1cb4ebe226..1651082c8c 100644 --- a/packages/better-auth/src/auth/auth.test.ts +++ b/packages/better-auth/src/auth/auth.test.ts @@ -1,6 +1,11 @@ -import { createAuthEndpoint } from "@better-auth/core/api"; +import { + createAuthEndpoint, + createAuthMiddleware, +} from "@better-auth/core/api"; import type { router } from "better-auth/api"; -import { describe, expectTypeOf, test } from "vitest"; +import { describe, expect, expectTypeOf, test } from "vitest"; +import { createAuthClient } from "../client"; +import { getTestInstance } from "../test-utils"; import type { Auth } from "../types"; import { betterAuth } from "./auth"; @@ -62,3 +67,60 @@ describe("auth type", () => { expectTypeOf().toEqualTypeOf<{ data: { message: string } }>(); }); }); + +describe("auth with trusted proxy headers", () => { + test("shouldn't infer base url from proxy headers if trusted", async () => { + let baseURL: string | undefined; + const { auth, customFetchImpl } = await getTestInstance({ + baseURL: undefined, + advanced: { + trustedProxyHeaders: true, + }, + hooks: { + before: createAuthMiddleware(async (ctx) => { + baseURL = ctx.context.baseURL; + }), + }, + }); + const client = createAuthClient({ + fetchOptions: { + customFetchImpl, + }, + baseURL: "http://localhost:3000", + }); + const res = await client.$fetch("/ok", { + headers: { + "x-forwarded-host": "localhost:3001", + "x-forwarded-proto": "http", + }, + }); + expect(baseURL).toBe("http://localhost:3001/api/auth"); + }); + test("shouldn't infer base url from proxy headers if not trusted", async () => { + let baseURL: string | undefined; + const { customFetchImpl } = await getTestInstance({ + baseURL: undefined, + advanced: { + trustedProxyHeaders: false, + }, + hooks: { + before: createAuthMiddleware(async (ctx) => { + baseURL = ctx.context.baseURL; + }), + }, + }); + const client = createAuthClient({ + fetchOptions: { + customFetchImpl, + }, + baseURL: "http://localhost:3000", + }); + const res = await client.$fetch("/ok", { + headers: { + "x-forwarded-host": "localhost:3001", + "x-forwarded-proto": "http", + }, + }); + expect(baseURL).toBe("http://localhost:3000/api/auth"); + }); +}); diff --git a/packages/better-auth/src/auth/base.ts b/packages/better-auth/src/auth/base.ts index 1b8d2f4c9a..1bbef23e46 100644 --- a/packages/better-auth/src/auth/base.ts +++ b/packages/better-auth/src/auth/base.ts @@ -30,7 +30,13 @@ export const createBetterAuth = ( const ctx = await authContext; const basePath = ctx.options.basePath || "/api/auth"; if (!ctx.options.baseURL) { - const baseURL = getBaseURL(undefined, basePath, request); + const baseURL = getBaseURL( + undefined, + basePath, + request, + undefined, + ctx.options.advanced?.trustedProxyHeaders, + ); if (baseURL) { ctx.baseURL = baseURL; ctx.options.baseURL = getOrigin(ctx.baseURL) || undefined; @@ -40,6 +46,7 @@ export const createBetterAuth = ( ); } } + ctx.trustedOrigins = [ ...(options.trustedOrigins ? Array.isArray(options.trustedOrigins) diff --git a/packages/better-auth/src/utils/url.ts b/packages/better-auth/src/utils/url.ts index 7fd69dc01f..4a6ea05abe 100644 --- a/packages/better-auth/src/utils/url.ts +++ b/packages/better-auth/src/utils/url.ts @@ -55,6 +55,7 @@ export function getBaseURL( path?: string, request?: Request, loadEnv?: boolean, + trustedProxyHeaders?: boolean | undefined, ) { if (url) { return withPath(url, path); @@ -76,7 +77,7 @@ export function getBaseURL( const fromRequest = request?.headers.get("x-forwarded-host"); const fromRequestProto = request?.headers.get("x-forwarded-proto"); - if (fromRequest && fromRequestProto) { + if (fromRequest && fromRequestProto && trustedProxyHeaders) { return withPath(`${fromRequestProto}://${fromRequest}`, path); } diff --git a/packages/core/src/types/init-options.ts b/packages/core/src/types/init-options.ts index f486f9a940..030c72924c 100644 --- a/packages/core/src/types/init-options.ts +++ b/packages/core/src/types/init-options.ts @@ -257,6 +257,20 @@ export type BetterAuthAdvancedOptions = { generateId?: GenerateIdFn | false | "serial" | "uuid"; } | undefined; + /** + * Trusted proxy headers + * + + * - `x-forwarded-host` + * - `x-forwarded-proto` + * + * If set to `true` and no `baseURL` option is provided, we will use the headers to infer the + * base URL. + * + * ⚠︎ This may expose your application to security vulnerabilities if not + * used correctly. Please use this with caution. + */ + trustedProxyHeaders?: boolean | undefined; }; export type BetterAuthOptions = {