refactor: redirect middleware and origins

This commit is contained in:
Bereket Engida
2024-10-20 17:27:31 +03:00
parent 9fe8eca15f
commit a247bae345
11 changed files with 42 additions and 50 deletions

View File

@@ -18,14 +18,17 @@ export const csrfMiddleware = createAuthMiddleware(
) {
return;
}
const origin = ctx.headers?.get("origin") || "";
const originHeader = ctx.headers?.get("origin") || "";
/**
* If origin is the same as baseURL or if the
* origin is in the trustedOrigins then we
* don't need to check the CSRF token.
*/
if (ctx.context.trustedOrigins.includes(origin)) {
return;
if (originHeader) {
const origin = new URL(originHeader).origin;
if (ctx.context.trustedOrigins.includes(origin)) {
return;
}
}
const csrfToken = ctx.body?.csrfToken;

View File

@@ -23,7 +23,7 @@ describe("redirectURLMiddleware", async (it) => {
callbackURL: "http://malicious.com",
});
expect(res.error?.status).toBe(403);
expect(res.error?.message).toBe("Invalid callback URL");
expect(res.error?.message).toBe("Invalid callbackURL");
});
it("should allow trusted origins", async (ctx) => {
@@ -58,7 +58,6 @@ describe("redirectURLMiddleware", async (it) => {
},
});
expect(res.error?.status).toBe(403);
expect(res.error?.message).toBe("Invalid callback URL");
const res2 = await client.signIn.email({
email: testUser.email,
@@ -71,7 +70,7 @@ describe("redirectURLMiddleware", async (it) => {
},
});
expect(res2.error?.status).toBe(403);
expect(res2.error?.message).toBe("Invalid callback URL");
expect(res2.error?.message).toBe("Invalid currentURL");
});
it("shouldn't allow untrusted redirectTo", async (ctx) => {
@@ -86,6 +85,6 @@ describe("redirectURLMiddleware", async (it) => {
redirectTo: "http://malicious.com",
});
expect(res.error?.status).toBe(403);
expect(res.error?.message).toBe("Invalid callback URL");
expect(res.error?.message).toBe("Invalid callbackURL");
});
});

View File

@@ -3,43 +3,33 @@ import { createAuthMiddleware } from "../call";
import { logger } from "../../utils/logger";
/**
* This middleware is used to validate the callbackURL and currentURL.
* It checks if the callbackURL is a valid URL and if it's in the trustedOrigins
* to avoid open redirect attacks.
* Middleware to validate callbackURL and currentURL against trustedOrigins,
* preventing open redirect attacks.
*/
export const redirectURLMiddleware = createAuthMiddleware(async (ctx) => {
const callbackURL =
ctx.body?.callbackURL ||
ctx.query?.callbackURL ||
ctx.query?.redirectTo ||
ctx.body?.redirectTo;
const clientCurrentURL = ctx.headers?.get("referer");
const currentURL =
ctx.query?.currentURL || clientCurrentURL || ctx.context.baseURL;
const trustedOrigins = ctx.context.trustedOrigins;
const { body, query, headers, context } = ctx;
if (callbackURL?.includes("http")) {
const callbackOrigin = new URL(callbackURL).origin;
if (!trustedOrigins.includes(callbackOrigin)) {
logger.error("Invalid callback URL", {
callbackURL,
trustedOrigins,
});
throw new APIError("FORBIDDEN", {
message: "Invalid callback URL",
});
const callbackURL =
body?.callbackURL ||
query?.callbackURL ||
query?.redirectTo ||
body?.redirectTo;
const currentURL =
query?.currentURL || headers?.get("referer") || context.baseURL;
const trustedOrigins = context.trustedOrigins;
const validateURL = (url: string | undefined, label: string) => {
if (url?.startsWith("http")) {
const isTrustedOrigin = trustedOrigins.some((origin) =>
url.startsWith(origin),
);
if (!isTrustedOrigin) {
logger.error(`Invalid ${label}`, { [label]: url, trustedOrigins });
throw new APIError("FORBIDDEN", { message: `Invalid ${label}` });
}
}
}
if (currentURL !== ctx.context.baseURL) {
const currentURLOrigin = new URL(currentURL).origin;
if (!trustedOrigins.includes(currentURLOrigin)) {
logger.error("Invalid current URL", {
currentURL,
trustedOrigins,
});
throw new APIError("FORBIDDEN", {
message: "Invalid callback URL",
});
}
}
};
validateURL(callbackURL, "callbackURL");
validateURL(currentURL, "currentURL");
});

View File

@@ -3,7 +3,7 @@ import { getEndpoints, router } from "./api";
import { init } from "./init";
import type { BetterAuthOptions } from "./types/options";
import type { InferPluginTypes, InferSession, InferUser } from "./types";
import { getBaseURL } from "./utils/base-url";
import { getBaseURL } from "./utils/url";
type InferAPI<API> = Omit<
API,

View File

@@ -1,5 +1,5 @@
import { createFetch } from "@better-fetch/fetch";
import { getBaseURL } from "../utils/base-url";
import { getBaseURL } from "../utils/url";
import { type Atom } from "nanostores";
import type { AtomListener, ClientOptions } from "./types";
import { addCurrentURL, csrfPlugin, redirectPlugin } from "./fetch-plugins";

View File

@@ -12,7 +12,7 @@ import type {
SecondaryStorage,
} from "./types";
import { defu } from "defu";
import { getBaseURL } from "./utils/base-url";
import { getBaseURL } from "./utils/url";
import { DEFAULT_SECRET } from "./utils/constants";
import {
type BetterAuthCookies,
@@ -216,7 +216,7 @@ function getTrustedOrigins(options: BetterAuthOptions) {
"Base URL can not be empty. Please add `BETTER_AUTH_URL` in your environment variables or pass it in your auth config.",
);
}
const trustedOrigins = [baseURL];
const trustedOrigins = [new URL(baseURL).origin];
if (options.trustedOrigins) {
trustedOrigins.push(...options.trustedOrigins);
}

View File

@@ -1,5 +1,5 @@
import { sha256 } from "oslo/crypto";
import { getBaseURL } from "../utils/base-url";
import { getBaseURL } from "../utils/url";
import { base64url } from "oslo/encoding";
import type { OAuth2Tokens } from "./types";

View File

@@ -3,7 +3,7 @@ import type { BetterAuthPlugin } from "../../types";
import { setSessionCookie } from "../../cookies";
import { z } from "zod";
import { generateId } from "../../utils/id";
import { getOrigin } from "../../utils/base-url";
import { getOrigin } from "../../utils/url";
export interface AnonymousOptions {
/**

View File

@@ -133,7 +133,7 @@ describe("Social Providers", async () => {
},
);
expect(signInRes.error?.status).toBe(403);
expect(signInRes.error?.message).toBe("Invalid callback URL");
expect(signInRes.error?.message).toBe("Invalid callbackURL");
});
});

View File

@@ -9,7 +9,7 @@ import { parseSetCookieHeader } from "../cookies";
import type { SuccessContext } from "@better-fetch/fetch";
import { getAdapter } from "../db/utils";
import Database from "better-sqlite3";
import { getBaseURL } from "../utils/base-url";
import { getBaseURL } from "../utils/url";
export async function getTestInstance<
O extends Partial<BetterAuthOptions>,