mirror of
https://github.com/better-auth/better-auth.git
synced 2026-05-28 01:46:45 -05:00
refactor: redirect middleware and origins
This commit is contained in:
@@ -18,14 +18,17 @@ export const csrfMiddleware = createAuthMiddleware(
|
||||
) {
|
||||
return;
|
||||
}
|
||||
const origin = ctx.headers?.get("origin") || "";
|
||||
const originHeader = ctx.headers?.get("origin") || "";
|
||||
/**
|
||||
* If origin is the same as baseURL or if the
|
||||
* origin is in the trustedOrigins then we
|
||||
* don't need to check the CSRF token.
|
||||
*/
|
||||
if (ctx.context.trustedOrigins.includes(origin)) {
|
||||
return;
|
||||
if (originHeader) {
|
||||
const origin = new URL(originHeader).origin;
|
||||
if (ctx.context.trustedOrigins.includes(origin)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const csrfToken = ctx.body?.csrfToken;
|
||||
|
||||
@@ -23,7 +23,7 @@ describe("redirectURLMiddleware", async (it) => {
|
||||
callbackURL: "http://malicious.com",
|
||||
});
|
||||
expect(res.error?.status).toBe(403);
|
||||
expect(res.error?.message).toBe("Invalid callback URL");
|
||||
expect(res.error?.message).toBe("Invalid callbackURL");
|
||||
});
|
||||
|
||||
it("should allow trusted origins", async (ctx) => {
|
||||
@@ -58,7 +58,6 @@ describe("redirectURLMiddleware", async (it) => {
|
||||
},
|
||||
});
|
||||
expect(res.error?.status).toBe(403);
|
||||
expect(res.error?.message).toBe("Invalid callback URL");
|
||||
|
||||
const res2 = await client.signIn.email({
|
||||
email: testUser.email,
|
||||
@@ -71,7 +70,7 @@ describe("redirectURLMiddleware", async (it) => {
|
||||
},
|
||||
});
|
||||
expect(res2.error?.status).toBe(403);
|
||||
expect(res2.error?.message).toBe("Invalid callback URL");
|
||||
expect(res2.error?.message).toBe("Invalid currentURL");
|
||||
});
|
||||
|
||||
it("shouldn't allow untrusted redirectTo", async (ctx) => {
|
||||
@@ -86,6 +85,6 @@ describe("redirectURLMiddleware", async (it) => {
|
||||
redirectTo: "http://malicious.com",
|
||||
});
|
||||
expect(res.error?.status).toBe(403);
|
||||
expect(res.error?.message).toBe("Invalid callback URL");
|
||||
expect(res.error?.message).toBe("Invalid callbackURL");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -3,43 +3,33 @@ import { createAuthMiddleware } from "../call";
|
||||
import { logger } from "../../utils/logger";
|
||||
|
||||
/**
|
||||
* This middleware is used to validate the callbackURL and currentURL.
|
||||
* It checks if the callbackURL is a valid URL and if it's in the trustedOrigins
|
||||
* to avoid open redirect attacks.
|
||||
* Middleware to validate callbackURL and currentURL against trustedOrigins,
|
||||
* preventing open redirect attacks.
|
||||
*/
|
||||
export const redirectURLMiddleware = createAuthMiddleware(async (ctx) => {
|
||||
const callbackURL =
|
||||
ctx.body?.callbackURL ||
|
||||
ctx.query?.callbackURL ||
|
||||
ctx.query?.redirectTo ||
|
||||
ctx.body?.redirectTo;
|
||||
const clientCurrentURL = ctx.headers?.get("referer");
|
||||
const currentURL =
|
||||
ctx.query?.currentURL || clientCurrentURL || ctx.context.baseURL;
|
||||
const trustedOrigins = ctx.context.trustedOrigins;
|
||||
const { body, query, headers, context } = ctx;
|
||||
|
||||
if (callbackURL?.includes("http")) {
|
||||
const callbackOrigin = new URL(callbackURL).origin;
|
||||
if (!trustedOrigins.includes(callbackOrigin)) {
|
||||
logger.error("Invalid callback URL", {
|
||||
callbackURL,
|
||||
trustedOrigins,
|
||||
});
|
||||
throw new APIError("FORBIDDEN", {
|
||||
message: "Invalid callback URL",
|
||||
});
|
||||
const callbackURL =
|
||||
body?.callbackURL ||
|
||||
query?.callbackURL ||
|
||||
query?.redirectTo ||
|
||||
body?.redirectTo;
|
||||
const currentURL =
|
||||
query?.currentURL || headers?.get("referer") || context.baseURL;
|
||||
const trustedOrigins = context.trustedOrigins;
|
||||
|
||||
const validateURL = (url: string | undefined, label: string) => {
|
||||
if (url?.startsWith("http")) {
|
||||
const isTrustedOrigin = trustedOrigins.some((origin) =>
|
||||
url.startsWith(origin),
|
||||
);
|
||||
if (!isTrustedOrigin) {
|
||||
logger.error(`Invalid ${label}`, { [label]: url, trustedOrigins });
|
||||
throw new APIError("FORBIDDEN", { message: `Invalid ${label}` });
|
||||
}
|
||||
}
|
||||
}
|
||||
if (currentURL !== ctx.context.baseURL) {
|
||||
const currentURLOrigin = new URL(currentURL).origin;
|
||||
if (!trustedOrigins.includes(currentURLOrigin)) {
|
||||
logger.error("Invalid current URL", {
|
||||
currentURL,
|
||||
trustedOrigins,
|
||||
});
|
||||
throw new APIError("FORBIDDEN", {
|
||||
message: "Invalid callback URL",
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
validateURL(callbackURL, "callbackURL");
|
||||
validateURL(currentURL, "currentURL");
|
||||
});
|
||||
|
||||
@@ -3,7 +3,7 @@ import { getEndpoints, router } from "./api";
|
||||
import { init } from "./init";
|
||||
import type { BetterAuthOptions } from "./types/options";
|
||||
import type { InferPluginTypes, InferSession, InferUser } from "./types";
|
||||
import { getBaseURL } from "./utils/base-url";
|
||||
import { getBaseURL } from "./utils/url";
|
||||
|
||||
type InferAPI<API> = Omit<
|
||||
API,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { createFetch } from "@better-fetch/fetch";
|
||||
import { getBaseURL } from "../utils/base-url";
|
||||
import { getBaseURL } from "../utils/url";
|
||||
import { type Atom } from "nanostores";
|
||||
import type { AtomListener, ClientOptions } from "./types";
|
||||
import { addCurrentURL, csrfPlugin, redirectPlugin } from "./fetch-plugins";
|
||||
|
||||
@@ -12,7 +12,7 @@ import type {
|
||||
SecondaryStorage,
|
||||
} from "./types";
|
||||
import { defu } from "defu";
|
||||
import { getBaseURL } from "./utils/base-url";
|
||||
import { getBaseURL } from "./utils/url";
|
||||
import { DEFAULT_SECRET } from "./utils/constants";
|
||||
import {
|
||||
type BetterAuthCookies,
|
||||
@@ -216,7 +216,7 @@ function getTrustedOrigins(options: BetterAuthOptions) {
|
||||
"Base URL can not be empty. Please add `BETTER_AUTH_URL` in your environment variables or pass it in your auth config.",
|
||||
);
|
||||
}
|
||||
const trustedOrigins = [baseURL];
|
||||
const trustedOrigins = [new URL(baseURL).origin];
|
||||
if (options.trustedOrigins) {
|
||||
trustedOrigins.push(...options.trustedOrigins);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { sha256 } from "oslo/crypto";
|
||||
import { getBaseURL } from "../utils/base-url";
|
||||
import { getBaseURL } from "../utils/url";
|
||||
import { base64url } from "oslo/encoding";
|
||||
import type { OAuth2Tokens } from "./types";
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import type { BetterAuthPlugin } from "../../types";
|
||||
import { setSessionCookie } from "../../cookies";
|
||||
import { z } from "zod";
|
||||
import { generateId } from "../../utils/id";
|
||||
import { getOrigin } from "../../utils/base-url";
|
||||
import { getOrigin } from "../../utils/url";
|
||||
|
||||
export interface AnonymousOptions {
|
||||
/**
|
||||
|
||||
@@ -133,7 +133,7 @@ describe("Social Providers", async () => {
|
||||
},
|
||||
);
|
||||
expect(signInRes.error?.status).toBe(403);
|
||||
expect(signInRes.error?.message).toBe("Invalid callback URL");
|
||||
expect(signInRes.error?.message).toBe("Invalid callbackURL");
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import { parseSetCookieHeader } from "../cookies";
|
||||
import type { SuccessContext } from "@better-fetch/fetch";
|
||||
import { getAdapter } from "../db/utils";
|
||||
import Database from "better-sqlite3";
|
||||
import { getBaseURL } from "../utils/base-url";
|
||||
import { getBaseURL } from "../utils/url";
|
||||
|
||||
export async function getTestInstance<
|
||||
O extends Partial<BetterAuthOptions>,
|
||||
|
||||
Reference in New Issue
Block a user