diff --git a/docs/content/docs/guides/saml-sso-with-okta.mdx b/docs/content/docs/guides/saml-sso-with-okta.mdx
index 460edcea7e..478144e7fb 100644
--- a/docs/content/docs/guides/saml-sso-with-okta.mdx
+++ b/docs/content/docs/guides/saml-sso-with-okta.mdx
@@ -22,12 +22,16 @@ In this setup:
4. Select "SAML 2.0" as the Sign-in method
5. Configure the following settings:
- - **Single Sign-on URL**: Your Better Auth ACS endpoint (e.g., `http://localhost:3000/api/auth/sso/saml2/sp/acs/sso`). while `sso` being your providerId
+ - **Single Sign-on URL**: Your Better Auth callback endpoint (e.g., `http://localhost:3000/api/auth/sso/saml2/callback/sso`). Note: `sso` is your `providerId`
- **Audience URI (SP Entity ID)**: Your Better Auth metadata URL (e.g., `http://localhost:3000/api/auth/sso/saml2/sp/metadata`)
- **Name ID format**: Email Address or any of your choice.
6. Download the IdP metadata XML file and certificate
+
+**IdP-Initiated SSO**: If you want users to access your app from the Okta dashboard, make sure the **Single Sign-on URL** points to the callback endpoint (`/api/auth/sso/saml2/callback/{providerId}`). Better Auth automatically handles both SP-initiated and IdP-initiated flows.
+
+
### Step 2: Configure Better Auth
Here’s an example configuration for Okta in a dev environment:
@@ -149,6 +153,7 @@ await authClient.signIn.sso({
- Never use these certificates in production
- The example uses `localhost:3000` - adjust URLs for your environment
- For production, always use proper IdP providers like Okta, Azure AD, or OneLogin
+ - The `callbackUrl` in your SAML config should point to your app's destination (e.g., `/dashboard`), not the callback route itself
### Step 5: Dynamically Registering SAML Providers
diff --git a/docs/content/docs/plugins/sso.mdx b/docs/content/docs/plugins/sso.mdx
index 070806a923..c88de966b1 100644
--- a/docs/content/docs/plugins/sso.mdx
+++ b/docs/content/docs/plugins/sso.mdx
@@ -1242,7 +1242,25 @@ type requestDomainVerification = {
The plugin automatically creates the following SAML endpoints:
- **SP Metadata**: `/api/auth/sso/saml2/sp/metadata?providerId={providerId}`
-- **SAML Callback**: `/api/auth/sso/saml2/callback/{providerId}`
+- **SAML Callback**: `/api/auth/sso/saml2/callback/{providerId}` (supports both GET and POST)
+
+### SAML Callback URL Configuration
+
+The SAML callback endpoint (`/api/auth/sso/saml2/callback/{providerId}`) handles both **SP-initiated** and **IdP-initiated** SSO flows:
+
+- **SP-initiated**: User clicks "Sign in with SSO" in your app → redirects to IdP → IdP POSTs SAMLResponse to callback
+- **IdP-initiated**: User clicks app icon in IdP dashboard (Okta, Azure AD, etc.) → IdP POSTs SAMLResponse to callback
+
+**Important**: The `callbackUrl` in your SAML configuration should point to your application's destination URL (e.g., `/dashboard`), **not** the callback route itself. Better Auth automatically handles the callback route and redirects users to your specified `callbackUrl` after successful authentication.
+
+```ts
+samlConfig: {
+ callbackUrl: "/dashboard", // Correct - points to your app destination
+ // callbackUrl: "/api/auth/sso/saml2/callback/my-provider" // Incorrect - don't point to callback route
+}
+```
+
+The callback route supports both GET and POST methods automatically, so you don't need to create any additional route handlers in your framework.
## Schema
@@ -1273,6 +1291,27 @@ The `ssoProvider` schema is extended as follows:
]}
/>
+### IdP-Initiated SAML SSO
+
+Better Auth supports **IdP-initiated SSO flows**, where users access your application directly from their Identity Provider dashboard (e.g., Okta, Azure AD, OneLogin). This is common in enterprise environments where IT admins prefer centralized app access.
+
+**How it works:**
+
+1. User clicks your app icon in the IdP dashboard
+2. IdP POSTs SAMLResponse to `/api/auth/sso/saml2/callback/{providerId}`
+3. Better Auth processes the assertion, creates a session, and redirects to your `callbackUrl`
+4. Browser follows the redirect with a GET request (handled automatically)
+
+**No additional configuration required** - the callback route automatically handles both GET and POST requests.
+
+
+If you previously created a manual GET handler for the SAML callback route as a workaround, you can remove it after upgrading. Better Auth now handles GET requests automatically.
+
+
+
+**Security:** Better Auth validates all redirect URLs to prevent open redirect attacks. Only relative paths (e.g., `/dashboard`) and URLs matching your configured `trustedOrigins` are allowed. Malicious URLs like `https://evil.com` or protocol-relative URLs (`//evil.com`) are automatically blocked.
+
+
For a detailed guide on setting up SAML SSO with examples for Okta and testing with DummyIDP, see our [SAML SSO with Okta](/docs/guides/saml-sso-with-okta).
## Options
diff --git a/packages/better-auth/src/api/index.ts b/packages/better-auth/src/api/index.ts
index cebe6e77bb..d7142f50b7 100644
--- a/packages/better-auth/src/api/index.ts
+++ b/packages/better-auth/src/api/index.ts
@@ -6,6 +6,7 @@ import type {
} from "@better-auth/core";
import type { InternalLogger } from "@better-auth/core/env";
import { logger } from "@better-auth/core/env";
+import { normalizePathname } from "@better-auth/core/utils/url";
import type { Endpoint, Middleware } from "better-call";
import { createRouter } from "better-call";
import type { UnionToIntersection } from "../types/helper";
@@ -275,14 +276,7 @@ export const router = (
async onRequest(req) {
//handle disabled paths
const disabledPaths = ctx.options.disabledPaths || [];
- const pathname = new URL(req.url).pathname.replace(/\/+$/, "") || "/";
-
- const normalizedPath =
- basePath === "/"
- ? pathname
- : pathname.startsWith(basePath)
- ? pathname.slice(basePath.length).replace(/\/+$/, "") || "/"
- : pathname;
+ const normalizedPath = normalizePathname(req.url, basePath);
if (disabledPaths.includes(normalizedPath)) {
return new Response("Not Found", { status: 404 });
}
diff --git a/packages/better-auth/src/api/middlewares/origin-check.ts b/packages/better-auth/src/api/middlewares/origin-check.ts
index c9f54e66cd..0cae6f0071 100644
--- a/packages/better-auth/src/api/middlewares/origin-check.ts
+++ b/packages/better-auth/src/api/middlewares/origin-check.ts
@@ -2,16 +2,18 @@ import type { GenericEndpointContext } from "@better-auth/core";
import { createAuthMiddleware } from "@better-auth/core/api";
import { APIError, BASE_ERROR_CODES } from "@better-auth/core/error";
import { deprecate } from "@better-auth/core/utils/deprecate";
+import { normalizePathname } from "@better-auth/core/utils/url";
import { matchesOriginPattern } from "../../auth/trusted-origins";
/**
* Checks if CSRF should be skipped for backward compatibility.
* Previously, disableOriginCheck also disabled CSRF checks.
* This maintains that behavior when disableCSRFCheck isn't explicitly set.
+ * Only triggers for skipOriginCheck === true, not for path arrays.
*/
function shouldSkipCSRFForBackwardCompat(ctx: GenericEndpointContext): boolean {
return (
- ctx.context.skipOriginCheck &&
+ ctx.context.skipOriginCheck === true &&
ctx.context.options.advanced?.disableCSRFCheck === undefined
);
}
@@ -197,6 +199,22 @@ async function validateOrigin(
return;
}
+ const skipOriginCheck = ctx.context.skipOriginCheck;
+ if (Array.isArray(skipOriginCheck)) {
+ try {
+ const basePath = new URL(ctx.context.baseURL).pathname;
+ const currentPath = normalizePathname(ctx.request.url, basePath);
+ const shouldSkipPath = skipOriginCheck.some((skipPath) =>
+ currentPath.startsWith(skipPath),
+ );
+ if (shouldSkipPath) {
+ return;
+ }
+ } catch {
+ // If parsing fails, don't skip - continue with validation
+ }
+ }
+
const shouldValidate = forceValidate || useCookies;
if (!shouldValidate) {
diff --git a/packages/better-auth/src/api/rate-limiter/index.ts b/packages/better-auth/src/api/rate-limiter/index.ts
index fe3a27cd4a..f67a86f28d 100644
--- a/packages/better-auth/src/api/rate-limiter/index.ts
+++ b/packages/better-auth/src/api/rate-limiter/index.ts
@@ -3,6 +3,7 @@ import type {
BetterAuthRateLimitStorage,
} from "@better-auth/core";
import { safeJSONParse } from "@better-auth/core/utils/json";
+import { normalizePathname } from "@better-auth/core/utils/url";
import type { RateLimit } from "../../types";
import { getIp } from "../../utils/get-request-ip";
import { wildcardMatch } from "../../utils/wildcard";
@@ -159,9 +160,8 @@ export async function onRequestRateLimit(req: Request, ctx: AuthContext) {
if (!ctx.rateLimit.enabled) {
return;
}
- const path = new URL(req.url).pathname
- .replace(ctx.options.basePath || "/api/auth", "")
- .replace(/\/+$/, "");
+ const basePath = new URL(ctx.baseURL).pathname;
+ const path = normalizePathname(req.url, basePath);
let currentWindow = ctx.rateLimit.window;
let currentMax = ctx.rateLimit.max;
const ip = getIp(req, ctx.options);
diff --git a/packages/better-auth/src/context/create-context.ts b/packages/better-auth/src/context/create-context.ts
index 13eb6027e9..e494c9892f 100644
--- a/packages/better-auth/src/context/create-context.ts
+++ b/packages/better-auth/src/context/create-context.ts
@@ -179,6 +179,8 @@ export async function createAuthContext(
const hasPluginFn = (id: string) => pluginIds.has(id);
+ const trustedOrigins = await getTrustedOrigins(options);
+
const ctx: AuthContext = {
appName: options.appName || "Better Auth",
socialProviders: providers,
@@ -190,7 +192,7 @@ export async function createAuthContext(
skipStateCookieCheck: !!options.account?.skipStateCookieCheck,
},
tables,
- trustedOrigins: await getTrustedOrigins(options),
+ trustedOrigins,
isTrustedOrigin(
url: string,
settings?: {
diff --git a/packages/better-auth/src/context/helpers.ts b/packages/better-auth/src/context/helpers.ts
index 2a6fa07484..bb8f17a6f6 100644
--- a/packages/better-auth/src/context/helpers.ts
+++ b/packages/better-auth/src/context/helpers.ts
@@ -64,7 +64,7 @@ export async function getTrustedOrigins(
options: BetterAuthOptions,
request?: Request,
): Promise {
- const baseURL = getBaseURL(options.baseURL, options.basePath);
+ const baseURL = getBaseURL(options.baseURL, options.basePath, request);
const trustedOrigins: (string | undefined | null)[] = baseURL
? [new URL(baseURL).origin]
: [];
diff --git a/packages/better-auth/src/oauth2/state.ts b/packages/better-auth/src/oauth2/state.ts
index 9f47f0db3d..decac154a9 100644
--- a/packages/better-auth/src/oauth2/state.ts
+++ b/packages/better-auth/src/oauth2/state.ts
@@ -1,12 +1,9 @@
import type { GenericEndpointContext } from "@better-auth/core";
import { APIError, BASE_ERROR_CODES } from "@better-auth/core/error";
-import * as z from "zod";
import { setOAuthState } from "../api/middlewares/oauth";
-import {
- generateRandomString,
- symmetricDecrypt,
- symmetricEncrypt,
-} from "../crypto";
+import { generateRandomString } from "../crypto";
+import type { StateData } from "../state";
+import { generateGenericState, parseGenericState, StateError } from "../state";
export async function generateState(
c: GenericEndpointContext,
@@ -24,200 +21,55 @@ export async function generateState(
}
const codeVerifier = generateRandomString(128);
- const state = generateRandomString(32);
- const storeStateStrategy = c.context.oauthConfig.storeStateStrategy;
- const stateData = {
+ const stateData: StateData = {
...(additionalData ? additionalData : {}),
callbackURL,
codeVerifier,
errorURL: c.body?.errorCallbackURL,
newUserURL: c.body?.newUserCallbackURL,
link,
- /**
- * This is the actual expiry time of the state
- */
expiresAt: Date.now() + 10 * 60 * 1000,
requestSignUp: c.body?.requestSignUp,
- state,
};
await setOAuthState(stateData);
- if (storeStateStrategy === "cookie") {
- // Store state data in an encrypted cookie
- const encryptedData = await symmetricEncrypt({
- key: c.context.secret,
- data: JSON.stringify(stateData),
+ try {
+ return generateGenericState(c, stateData);
+ } catch (error) {
+ c.context.logger.error("Failed to create verification", error);
+ throw new APIError("INTERNAL_SERVER_ERROR", {
+ message: "Unable to create verification",
+ cause: error,
});
-
- const stateCookie = c.context.createAuthCookie("oauth_state", {
- maxAge: 10 * 60 * 1000, // 10 minutes
- });
-
- c.setCookie(stateCookie.name, encryptedData, stateCookie.attributes);
-
- return {
- state,
- codeVerifier,
- };
}
-
- // Default: database strategy
- const stateCookie = c.context.createAuthCookie("state", {
- maxAge: 5 * 60 * 1000, // 5 minutes
- });
- await c.setSignedCookie(
- stateCookie.name,
- state,
- c.context.secret,
- stateCookie.attributes,
- );
-
- const expiresAt = new Date();
- expiresAt.setMinutes(expiresAt.getMinutes() + 10);
- const verification = await c.context.internalAdapter.createVerificationValue({
- value: JSON.stringify(stateData),
- identifier: state,
- expiresAt,
- });
- if (!verification) {
- c.context.logger.error(
- "Unable to create verification. Make sure the database adapter is properly working and there is a verification table in the database",
- );
- throw APIError.from(
- "INTERNAL_SERVER_ERROR",
- BASE_ERROR_CODES.FAILED_TO_CREATE_VERIFICATION,
- );
- }
- return {
- state: verification.identifier,
- codeVerifier,
- };
}
export async function parseState(c: GenericEndpointContext) {
const state = c.query.state || c.body.state;
- const storeStateStrategy = c.context.oauthConfig.storeStateStrategy;
+ const errorURL =
+ c.context.options.onAPIError?.errorURL || `${c.context.baseURL}/error`;
- const stateDataSchema = z.looseObject({
- callbackURL: z.string(),
- codeVerifier: z.string(),
- errorURL: z.string().optional(),
- newUserURL: z.string().optional(),
- expiresAt: z.number(),
- link: z
- .object({
- email: z.string(),
- userId: z.coerce.string(),
- })
- .optional(),
- requestSignUp: z.boolean().optional(),
- state: z.string().optional(),
- });
+ let parsedData: StateData;
- let parsedData: z.infer;
- /**
- * This is generally cause security issue and should only be used in
- * dev or staging environments. It's currently used by the oauth-proxy
- * plugin
- */
- const skipStateCookieCheck = c.context.oauthConfig?.skipStateCookieCheck;
- if (storeStateStrategy === "cookie") {
- // Retrieve state data from encrypted cookie
- const stateCookie = c.context.createAuthCookie("oauth_state");
- const encryptedData = c.getCookie(stateCookie.name);
+ try {
+ parsedData = await parseGenericState(c, state);
+ } catch (error) {
+ c.context.logger.error("Failed to parse state", error);
- if (!encryptedData) {
- c.context.logger.error("State Mismatch. OAuth state cookie not found", {
- state,
- });
- const errorURL =
- c.context.options.onAPIError?.errorURL || `${c.context.baseURL}/error`;
- throw c.redirect(`${errorURL}?error=please_restart_the_process`);
- }
-
- try {
- const decryptedData = await symmetricDecrypt({
- key: c.context.secret,
- data: encryptedData,
- });
-
- parsedData = stateDataSchema.parse(JSON.parse(decryptedData));
- } catch (error) {
- c.context.logger.error("Failed to decrypt or parse OAuth state cookie", {
- error,
- });
- const errorURL =
- c.context.options.onAPIError?.errorURL || `${c.context.baseURL}/error`;
- throw c.redirect(`${errorURL}?error=please_restart_the_process`);
- }
-
- const skipStateCookieCheck = c.context.oauthConfig?.skipStateCookieCheck;
if (
- !skipStateCookieCheck &&
- parsedData.state &&
- parsedData.state !== state
+ error instanceof StateError &&
+ error.code === "state_security_mismatch"
) {
- c.context.logger.error("State Mismatch. State parameter does not match", {
- expected: parsedData.state,
- received: state,
- });
- const errorURL =
- c.context.options.onAPIError?.errorURL || `${c.context.baseURL}/error`;
throw c.redirect(`${errorURL}?error=state_mismatch`);
}
- // Clear the cookie after successful parsing
- c.setCookie(stateCookie.name, "", {
- maxAge: 0,
- });
- } else {
- // Default: database strategy
- const data = await c.context.internalAdapter.findVerificationValue(state);
- if (!data) {
- c.context.logger.error("State Mismatch. Verification not found", {
- state,
- });
- const errorURL =
- c.context.options.onAPIError?.errorURL || `${c.context.baseURL}/error`;
- throw c.redirect(`${errorURL}?error=please_restart_the_process`);
- }
-
- parsedData = stateDataSchema.parse(JSON.parse(data.value));
-
- const stateCookie = c.context.createAuthCookie("state");
- const stateCookieValue = await c.getSignedCookie(
- stateCookie.name,
- c.context.secret,
- );
-
- if (
- !skipStateCookieCheck &&
- (!stateCookieValue || stateCookieValue !== state)
- ) {
- const errorURL =
- c.context.options.onAPIError?.errorURL || `${c.context.baseURL}/error`;
- throw c.redirect(`${errorURL}?error=state_mismatch`);
- }
- c.setCookie(stateCookie.name, "", {
- maxAge: 0,
- });
-
- // Delete verification value after retrieval
- await c.context.internalAdapter.deleteVerificationValue(data.id);
+ throw c.redirect(`${errorURL}?error=please_restart_the_process`);
}
if (!parsedData.errorURL) {
- parsedData.errorURL =
- c.context.options.onAPIError?.errorURL || `${c.context.baseURL}/error`;
- }
-
- // Check expiration
- if (parsedData.expiresAt < Date.now()) {
- const errorURL =
- c.context.options.onAPIError?.errorURL || `${c.context.baseURL}/error`;
- throw c.redirect(`${errorURL}?error=please_restart_the_process`);
+ parsedData.errorURL = errorURL;
}
if (parsedData) {
diff --git a/packages/better-auth/src/social.test.ts b/packages/better-auth/src/social.test.ts
index 335dfac380..e82a008945 100644
--- a/packages/better-auth/src/social.test.ts
+++ b/packages/better-auth/src/social.test.ts
@@ -785,7 +785,6 @@ describe("signin", async () => {
expiresAt: expect.any(Number),
invitedBy: "user-123",
errorURL: "http://localhost:3000/api/auth/error",
- state: expect.any(String),
});
});
diff --git a/packages/better-auth/src/state.ts b/packages/better-auth/src/state.ts
new file mode 100644
index 0000000000..5da6430466
--- /dev/null
+++ b/packages/better-auth/src/state.ts
@@ -0,0 +1,221 @@
+import type { GenericEndpointContext } from "@better-auth/core";
+import { BetterAuthError } from "@better-auth/core/error";
+import * as z from "zod";
+import {
+ generateRandomString,
+ symmetricDecrypt,
+ symmetricEncrypt,
+} from "./crypto";
+
+const stateDataSchema = z.looseObject({
+ callbackURL: z.string(),
+ codeVerifier: z.string(),
+ errorURL: z.string().optional(),
+ newUserURL: z.string().optional(),
+ expiresAt: z.number(),
+ link: z
+ .object({
+ email: z.string(),
+ userId: z.coerce.string(),
+ })
+ .optional(),
+ requestSignUp: z.boolean().optional(),
+});
+
+export type StateData = z.infer;
+
+export type StateErrorCode =
+ | "state_generation_error"
+ | "state_invalid"
+ | "state_mismatch"
+ | "state_security_mismatch";
+
+export class StateError extends BetterAuthError {
+ code: string;
+ details?: Record;
+
+ constructor(
+ message: string,
+ options: ErrorOptions & {
+ code: StateErrorCode;
+ details?: Record;
+ },
+ ) {
+ super(message, options);
+ this.code = options.code;
+ this.details = options.details;
+ }
+}
+
+export async function generateGenericState(
+ c: GenericEndpointContext,
+ stateData: StateData,
+ settings?: { cookieName: string },
+) {
+ const state = generateRandomString(32);
+ const storeStateStrategy = c.context.oauthConfig.storeStateStrategy;
+
+ if (storeStateStrategy === "cookie") {
+ // Store state data in an encrypted cookie
+
+ const encryptedData = await symmetricEncrypt({
+ key: c.context.secret,
+ data: JSON.stringify(stateData),
+ });
+
+ const stateCookie = c.context.createAuthCookie(
+ settings?.cookieName ?? "oauth_state",
+ {
+ maxAge: 10 * 60 * 1000, // 10 minutes
+ },
+ );
+
+ c.setCookie(stateCookie.name, encryptedData, stateCookie.attributes);
+
+ return {
+ state,
+ codeVerifier: stateData.codeVerifier,
+ };
+ }
+
+ // Default: database strategy
+
+ const stateCookie = c.context.createAuthCookie(
+ settings?.cookieName ?? "state",
+ {
+ maxAge: 5 * 60 * 1000, // 5 minutes
+ },
+ );
+
+ await c.setSignedCookie(
+ stateCookie.name,
+ state,
+ c.context.secret,
+ stateCookie.attributes,
+ );
+
+ const expiresAt = new Date();
+ expiresAt.setMinutes(expiresAt.getMinutes() + 10);
+
+ const verification = await c.context.internalAdapter.createVerificationValue({
+ value: JSON.stringify(stateData),
+ identifier: state,
+ expiresAt,
+ });
+
+ if (!verification) {
+ throw new StateError(
+ "Unable to create verification. Make sure the database adapter is properly working and there is a verification table in the database",
+ {
+ code: "state_generation_error",
+ },
+ );
+ }
+
+ return {
+ state: verification.identifier,
+ codeVerifier: stateData.codeVerifier,
+ };
+}
+
+export async function parseGenericState(
+ c: GenericEndpointContext,
+ state: string,
+ settings?: { cookieName: string },
+) {
+ const storeStateStrategy = c.context.oauthConfig.storeStateStrategy;
+ let parsedData: StateData;
+
+ if (storeStateStrategy === "cookie") {
+ // Retrieve state data from encrypted cookie
+ const stateCookie = c.context.createAuthCookie(
+ settings?.cookieName ?? "oauth_state",
+ );
+ const encryptedData = c.getCookie(stateCookie.name);
+
+ if (!encryptedData) {
+ throw new StateError("State mismatch: auth state cookie not found", {
+ code: "state_mismatch",
+ details: { state },
+ });
+ }
+
+ try {
+ const decryptedData = await symmetricDecrypt({
+ key: c.context.secret,
+ data: encryptedData,
+ });
+
+ parsedData = stateDataSchema.parse(JSON.parse(decryptedData));
+ } catch (error) {
+ throw new StateError(
+ "State invalid: Failed to decrypt or parse auth state",
+ {
+ code: "state_invalid",
+ details: { state },
+ cause: error,
+ },
+ );
+ }
+
+ // Clear the cookie after successful parsing
+ c.setCookie(stateCookie.name, "", {
+ maxAge: 0,
+ });
+ } else {
+ // Default: database strategy
+ const data = await c.context.internalAdapter.findVerificationValue(state);
+ if (!data) {
+ throw new StateError("State mismatch: verification not found", {
+ code: "state_mismatch",
+ details: { state },
+ });
+ }
+
+ parsedData = stateDataSchema.parse(JSON.parse(data.value));
+
+ const stateCookie = c.context.createAuthCookie(
+ settings?.cookieName ?? "state",
+ );
+
+ const stateCookieValue = await c.getSignedCookie(
+ stateCookie.name,
+ c.context.secret,
+ );
+
+ /**
+ * This is generally cause security issue and should only be used in
+ * dev or staging environments. It's currently used by the oauth-proxy
+ * plugin
+ */
+ const skipStateCookieCheck = c.context.oauthConfig.skipStateCookieCheck;
+ if (
+ !skipStateCookieCheck &&
+ (!stateCookieValue || stateCookieValue !== state)
+ ) {
+ throw new StateError("State mismatch: State not persisted correctly", {
+ code: "state_security_mismatch",
+ details: { state },
+ });
+ }
+
+ c.setCookie(stateCookie.name, "", {
+ maxAge: 0,
+ });
+
+ // Delete verification value after retrieval
+ await c.context.internalAdapter.deleteVerificationValue(data.id);
+ }
+
+ // Check expiration
+ if (parsedData.expiresAt < Date.now()) {
+ throw new StateError("Invalid state: request expired", {
+ code: "state_mismatch",
+ details: {
+ expiresAt: parsedData.expiresAt,
+ },
+ });
+ }
+
+ return parsedData;
+}
diff --git a/packages/better-auth/src/utils/index.ts b/packages/better-auth/src/utils/index.ts
index 38232d0c40..ac4f86af76 100644
--- a/packages/better-auth/src/utils/index.ts
+++ b/packages/better-auth/src/utils/index.ts
@@ -1,2 +1,4 @@
export * from "../oauth2/state";
+export type { StateData } from "../state";
+export { generateGenericState, parseGenericState } from "../state";
export * from "./hide-metadata";
diff --git a/packages/core/src/types/context.ts b/packages/core/src/types/context.ts
index b3b28cee06..afa9f754fc 100644
--- a/packages/core/src/types/context.ts
+++ b/packages/core/src/types/context.ts
@@ -308,17 +308,18 @@ export type AuthContext =
payload: Record;
}) => Promise;
/**
- * This skips the origin check for all requests.
+ * Skip origin check for requests.
*
- * set to true by default for `test` environments and `false`
- * for other environments.
+ * - `true`: Skip for ALL requests (DANGEROUS - disables CSRF protection)
+ * - `string[]`: Skip only for specific paths (e.g., SAML callbacks)
+ * - `false`: Enable origin check (default)
*
- * It's inferred from the `options.advanced?.disableCSRFCheck`
- * option or `options.advanced?.disableOriginCheck` option.
+ * Paths support prefix matching (e.g., "/sso/saml2/callback" matches
+ * "/sso/saml2/callback/provider-name").
*
- * @default false
+ * @default false (true in test environments)
*/
- skipOriginCheck: boolean;
+ skipOriginCheck: boolean | string[];
/**
* This skips the CSRF check for all requests.
*
diff --git a/packages/core/src/utils/url.ts b/packages/core/src/utils/url.ts
new file mode 100644
index 0000000000..a895ef4ccf
--- /dev/null
+++ b/packages/core/src/utils/url.ts
@@ -0,0 +1,43 @@
+/**
+ * Normalizes a request pathname by removing the basePath prefix and trailing slashes.
+ * This is useful for matching paths against configured path lists.
+ *
+ * @param requestUrl - The full request URL
+ * @param basePath - The base path of the auth API (e.g., "/api/auth")
+ * @returns The normalized path without basePath prefix or trailing slashes,
+ * or "/" if URL parsing fails
+ *
+ * @example
+ * normalizePathname("http://localhost:3000/api/auth/sso/saml2/callback/provider1", "/api/auth")
+ * // Returns: "/sso/saml2/callback/provider1"
+ *
+ * normalizePathname("http://localhost:3000/sso/saml2/callback/provider1/", "/")
+ * // Returns: "/sso/saml2/callback/provider1"
+ */
+export function normalizePathname(
+ requestUrl: string,
+ basePath: string,
+): string {
+ let pathname: string;
+ try {
+ pathname = new URL(requestUrl).pathname.replace(/\/+$/, "") || "/";
+ } catch {
+ return "/";
+ }
+
+ if (basePath === "/" || basePath === "") {
+ return pathname;
+ }
+
+ // Check for exact match or proper path boundary (basePath followed by "/" or end)
+ // This prevents "/api/auth" from matching "/api/authevil/..."
+ if (pathname === basePath) {
+ return "/";
+ }
+
+ if (pathname.startsWith(basePath + "/")) {
+ return pathname.slice(basePath.length).replace(/\/+$/, "") || "/";
+ }
+
+ return pathname;
+}
diff --git a/packages/sso/src/index.ts b/packages/sso/src/index.ts
index 137dd17a46..290d8e04a4 100644
--- a/packages/sso/src/index.ts
+++ b/packages/sso/src/index.ts
@@ -103,6 +103,16 @@ export type SSOPlugin = {
: {});
};
+/**
+ * SAML endpoint paths that should skip origin check validation.
+ * These endpoints receive POST requests from external Identity Providers,
+ * which won't have a matching Origin header.
+ */
+const SAML_SKIP_ORIGIN_CHECK_PATHS = [
+ "/sso/saml2/callback", // SP-initiated SSO callback (prefix matches /callback/:providerId)
+ "/sso/saml2/sp/acs", // IdP-initiated SSO ACS (prefix matches /sp/acs/:providerId)
+];
+
export function sso<
O extends SSOOptions & {
domainVerification?: { enabled: true };
@@ -148,6 +158,18 @@ export function sso(options?: O | undefined): any {
return {
id: "sso",
+ init(ctx) {
+ const existing = ctx.skipOriginCheck;
+ if (existing === true) {
+ return {};
+ }
+ const existingPaths = Array.isArray(existing) ? existing : [];
+ return {
+ context: {
+ skipOriginCheck: [...existingPaths, ...SAML_SKIP_ORIGIN_CHECK_PATHS],
+ },
+ };
+ },
endpoints,
hooks: {
after: [
diff --git a/packages/sso/src/routes/sso.ts b/packages/sso/src/routes/sso.ts
index 4197ca76d7..753bc3a5e5 100644
--- a/packages/sso/src/routes/sso.ts
+++ b/packages/sso/src/routes/sso.ts
@@ -11,6 +11,7 @@ import {
import {
APIError,
createAuthEndpoint,
+ getSessionFromCtx,
sessionMiddleware,
} from "better-auth/api";
import { setSessionCookie } from "better-auth/cookies";
@@ -52,6 +53,7 @@ import {
validateSAMLAlgorithms,
validateSingleAssertion,
} from "../saml";
+import { generateRelayState, parseRelayState } from "../saml-state";
import type { OIDCConfig, SAMLConfig, SSOOptions, SSOProvider } from "../types";
import { safeJsonParse, validateEmailDomain } from "../utils";
@@ -165,6 +167,8 @@ const spMetadataQuerySchema = z.object({
format: z.enum(["xml", "json"]).default("xml"),
});
+type RelayState = Awaited>;
+
export const spMetadata = () => {
return createAuthEndpoint(
"/sso/saml2/sp/metadata",
@@ -1293,6 +1297,12 @@ export const signInSSO = (options?: SSOOptions) => {
});
}
+ const { state: relayState } = await generateRelayState(
+ ctx,
+ undefined,
+ false,
+ );
+
const shouldSaveRequest =
loginRequest.id && options?.saml?.enableInResponseToValidation;
if (shouldSaveRequest) {
@@ -1311,9 +1321,7 @@ export const signInSSO = (options?: SSOOptions) => {
}
return ctx.json({
- url: `${loginRequest.context}&RelayState=${encodeURIComponent(
- body.callbackURL,
- )}`,
+ url: `${loginRequest.context}&RelayState=${encodeURIComponent(relayState)}`,
redirect: true,
});
}
@@ -1683,12 +1691,71 @@ const callbackSSOSAMLBodySchema = z.object({
RelayState: z.string().optional(),
});
+/**
+ * Validates and returns a safe redirect URL.
+ * - Prevents open redirect attacks by validating against trusted origins
+ * - Prevents redirect loops by checking if URL points to callback route
+ * - Falls back to appOrigin if URL is invalid or unsafe
+ */
+const getSafeRedirectUrl = (
+ url: string | undefined,
+ callbackPath: string,
+ appOrigin: string,
+ isTrustedOrigin: (
+ url: string,
+ settings?: { allowRelativePaths: boolean },
+ ) => boolean,
+): string => {
+ if (!url) {
+ return appOrigin;
+ }
+
+ if (url.startsWith("/") && !url.startsWith("//")) {
+ try {
+ const absoluteUrl = new URL(url, appOrigin);
+ if (absoluteUrl.origin !== appOrigin) {
+ return appOrigin;
+ }
+ const callbackPathname = new URL(callbackPath).pathname;
+ if (absoluteUrl.pathname === callbackPathname) {
+ return appOrigin;
+ }
+ } catch {
+ return appOrigin;
+ }
+ return url;
+ }
+
+ if (!isTrustedOrigin(url, { allowRelativePaths: false })) {
+ return appOrigin;
+ }
+
+ try {
+ const callbackPathname = new URL(callbackPath).pathname;
+ const urlPathname = new URL(url).pathname;
+ if (urlPathname === callbackPathname) {
+ return appOrigin;
+ }
+ } catch {
+ if (url === callbackPath || url.startsWith(`${callbackPath}?`)) {
+ return appOrigin;
+ }
+ }
+
+ return url;
+};
+
export const callbackSSOSAML = (options?: SSOOptions) => {
return createAuthEndpoint(
"/sso/saml2/callback/:providerId",
{
- method: "POST",
- body: callbackSSOSAMLBodySchema,
+ method: ["GET", "POST"],
+ body: callbackSSOSAMLBodySchema.optional(),
+ query: z
+ .object({
+ RelayState: z.string().optional(),
+ })
+ .optional(),
metadata: {
...HIDE_METADATA,
allowedMediaTypes: [
@@ -1699,7 +1766,7 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
operationId: "handleSAMLCallback",
summary: "Callback URL for SAML provider",
description:
- "This endpoint is used as the callback URL for SAML providers.",
+ "This endpoint is used as the callback URL for SAML providers. Supports both GET and POST methods for IdP-initiated and SP-initiated flows.",
responses: {
"302": {
description: "Redirects to the callback URL",
@@ -1715,8 +1782,41 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
},
},
async (ctx) => {
- const { SAMLResponse, RelayState } = ctx.body;
const { providerId } = ctx.params;
+ const appOrigin = new URL(ctx.context.baseURL).origin;
+ const errorURL =
+ ctx.context.options.onAPIError?.errorURL || `${appOrigin}/error`;
+ const currentCallbackPath = `${ctx.context.baseURL}/sso/saml2/callback/${providerId}`;
+
+ // Determine if this is a GET request by checking both method AND body presence
+ // When called via auth.api.*, ctx.method may not be reliable, so we also check for body
+ const isGetRequest = ctx.method === "GET" && !ctx.body?.SAMLResponse;
+
+ if (isGetRequest) {
+ const session = await getSessionFromCtx(ctx);
+
+ if (!session?.session) {
+ throw ctx.redirect(`${errorURL}?error=invalid_request`);
+ }
+
+ const relayState = ctx.query?.RelayState as string | undefined;
+ const safeRedirectUrl = getSafeRedirectUrl(
+ relayState,
+ currentCallbackPath,
+ appOrigin,
+ (url, settings) => ctx.context.isTrustedOrigin(url, settings),
+ );
+
+ throw ctx.redirect(safeRedirectUrl);
+ }
+
+ if (!ctx.body?.SAMLResponse) {
+ throw new APIError("BAD_REQUEST", {
+ message: "SAMLResponse is required for POST requests",
+ });
+ }
+
+ const { SAMLResponse } = ctx.body;
const maxResponseSize =
options?.saml?.maxResponseSize ?? DEFAULT_MAX_SAML_RESPONSE_SIZE;
@@ -1726,6 +1826,14 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
});
}
+ let relayState: RelayState | null = null;
+ if (ctx.body.RelayState) {
+ try {
+ relayState = await parseRelayState(ctx);
+ } catch {
+ relayState = null;
+ }
+ }
let provider: SSOProvider | null = null;
if (options?.defaultSSO?.length) {
const matchingDefault = options.defaultSSO.find(
@@ -1846,7 +1954,7 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
parsedResponse = await sp.parseLoginResponse(idp, "post", {
body: {
SAMLResponse,
- RelayState: RelayState || undefined,
+ RelayState: ctx.body.RelayState || undefined,
},
});
@@ -1909,7 +2017,9 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
{ inResponseTo, providerId: provider.providerId },
);
const redirectUrl =
- RelayState || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
+ relayState?.callbackURL ||
+ parsedSamlConfig.callbackUrl ||
+ ctx.context.baseURL;
throw ctx.redirect(
`${redirectUrl}?error=invalid_saml_response&error_description=Unknown+or+expired+request+ID`,
);
@@ -1929,7 +2039,9 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
`${AUTHN_REQUEST_KEY_PREFIX}${inResponseTo}`,
);
const redirectUrl =
- RelayState || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
+ relayState?.callbackURL ||
+ parsedSamlConfig.callbackUrl ||
+ ctx.context.baseURL;
throw ctx.redirect(
`${redirectUrl}?error=invalid_saml_response&error_description=Provider+mismatch`,
);
@@ -1944,7 +2056,9 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
{ providerId: provider.providerId },
);
const redirectUrl =
- RelayState || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
+ relayState?.callbackURL ||
+ parsedSamlConfig.callbackUrl ||
+ ctx.context.baseURL;
throw ctx.redirect(
`${redirectUrl}?error=unsolicited_response&error_description=IdP-initiated+SSO+not+allowed`,
);
@@ -1997,7 +2111,9 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
},
);
const redirectUrl =
- RelayState || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
+ relayState?.callbackURL ||
+ parsedSamlConfig.callbackUrl ||
+ ctx.context.baseURL;
throw ctx.redirect(
`${redirectUrl}?error=replay_detected&error_description=SAML+assertion+has+already+been+used`,
);
@@ -2071,7 +2187,9 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
validateEmailDomain(userInfo.email as string, provider.domain));
const callbackUrl =
- RelayState || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
+ relayState?.callbackURL ||
+ parsedSamlConfig.callbackUrl ||
+ ctx.context.baseURL;
const result = await handleOAuthUserInfo(ctx, {
userInfo: {
@@ -2122,7 +2240,14 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
});
await setSessionCookie(ctx, { session, user });
- throw ctx.redirect(callbackUrl);
+
+ const safeRedirectUrl = getSafeRedirectUrl(
+ relayState?.callbackURL || parsedSamlConfig.callbackUrl,
+ currentCallbackPath,
+ appOrigin,
+ (url, settings) => ctx.context.isTrustedOrigin(url, settings),
+ );
+ throw ctx.redirect(safeRedirectUrl);
},
);
};
diff --git a/packages/sso/src/saml-state.ts b/packages/sso/src/saml-state.ts
new file mode 100644
index 0000000000..098d78e162
--- /dev/null
+++ b/packages/sso/src/saml-state.ts
@@ -0,0 +1,78 @@
+import type { GenericEndpointContext, StateData } from "better-auth";
+import { generateGenericState, parseGenericState } from "better-auth";
+import { generateRandomString } from "better-auth/crypto";
+import { APIError } from "better-call";
+
+export async function generateRelayState(
+ c: GenericEndpointContext,
+ link:
+ | {
+ email: string;
+ userId: string;
+ }
+ | undefined,
+ additionalData: Record | false | undefined,
+) {
+ const callbackURL = c.body.callbackURL;
+ if (!callbackURL) {
+ throw new APIError("BAD_REQUEST", {
+ message: "callbackURL is required",
+ });
+ }
+
+ const codeVerifier = generateRandomString(128);
+ const stateData: StateData = {
+ ...(additionalData ? additionalData : {}),
+ callbackURL,
+ codeVerifier,
+ errorURL: c.body.errorCallbackURL,
+ newUserURL: c.body.newUserCallbackURL,
+ link,
+ /**
+ * This is the actual expiry time of the state
+ */
+ expiresAt: Date.now() + 10 * 60 * 1000,
+ requestSignUp: c.body.requestSignUp,
+ };
+
+ try {
+ return generateGenericState(c, stateData, {
+ cookieName: "relay_state",
+ });
+ } catch (error) {
+ c.context.logger.error(
+ "Failed to create verification for relay state",
+ error,
+ );
+ throw new APIError("INTERNAL_SERVER_ERROR", {
+ message: "State error: Unable to create verification for relay state",
+ cause: error,
+ });
+ }
+}
+
+export async function parseRelayState(c: GenericEndpointContext) {
+ const state = c.body.RelayState;
+ const errorURL =
+ c.context.options.onAPIError?.errorURL || `${c.context.baseURL}/error`;
+
+ let parsedData: StateData;
+
+ try {
+ parsedData = await parseGenericState(c, state, {
+ cookieName: "relay_state",
+ });
+ } catch (error) {
+ c.context.logger.error("Failed to parse relay state", error);
+ throw new APIError("BAD_REQUEST", {
+ message: "State error: failed to validate relay state",
+ cause: error,
+ });
+ }
+
+ if (!parsedData.errorURL) {
+ parsedData.errorURL = errorURL;
+ }
+
+ return parsedData;
+}
diff --git a/packages/sso/src/saml.test.ts b/packages/sso/src/saml.test.ts
index 9b5139b747..e195e1088e 100644
--- a/packages/sso/src/saml.test.ts
+++ b/packages/sso/src/saml.test.ts
@@ -1,9 +1,9 @@
import { randomUUID } from "node:crypto";
import type { createServer } from "node:http";
-import { base64 } from "@better-auth/utils/base64";
import { betterFetch } from "@better-fetch/fetch";
import { betterAuth } from "better-auth";
import { memoryAdapter } from "better-auth/adapters/memory";
+import { APIError } from "better-auth/api";
import { createAuthClient } from "better-auth/client";
import { setCookieToHeader } from "better-auth/cookies";
import { bearer } from "better-auth/plugins";
@@ -1100,6 +1100,222 @@ describe("SAML SSO", async () => {
});
});
+ it("should initiate SAML login and validate RelayState", async () => {
+ const { auth, signInWithTestUser } = await getTestInstance({
+ plugins: [sso()],
+ });
+
+ const { headers } = await signInWithTestUser();
+ await auth.api.registerSSOProvider({
+ body: {
+ providerId: "saml-provider-1",
+ issuer: "http://localhost:8081",
+ domain: "http://localhost:8081",
+ samlConfig: {
+ entryPoint: "http://localhost:8081/api/sso/saml2/idp/post",
+ cert: certificate,
+ callbackUrl: "http://localhost:3000/dashboard",
+ wantAssertionsSigned: false,
+ signatureAlgorithm: "sha256",
+ digestAlgorithm: "sha256",
+ idpMetadata: {
+ metadata: idpMetadata,
+ },
+ spMetadata: {
+ metadata: spMetadata,
+ },
+ identifierFormat:
+ "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress",
+ },
+ },
+ headers,
+ });
+
+ const response = await auth.api.signInSSO({
+ body: {
+ providerId: "saml-provider-1",
+ callbackURL: "http://localhost:3000/dashboard",
+ },
+ returnHeaders: true,
+ });
+
+ const signInResponse = response.response;
+ expect(signInResponse).toEqual({
+ url: expect.stringContaining("http://localhost:8081"),
+ redirect: true,
+ });
+
+ let samlResponse: any;
+ await betterFetch(signInResponse?.url, {
+ onSuccess: async (context) => {
+ samlResponse = await context.data;
+ },
+ });
+
+ let samlRedirectUrl = new URL(signInResponse?.url);
+ const callbackResponse = await auth.api.callbackSSOSAML({
+ method: "POST",
+ body: {
+ SAMLResponse: samlResponse.samlResponse,
+ RelayState: samlRedirectUrl.searchParams.get("RelayState") ?? "",
+ },
+ headers: {
+ Cookie: response.headers.get("set-cookie") ?? "",
+ },
+ params: {
+ providerId: "saml-provider-1",
+ },
+ asResponse: true,
+ });
+
+ expect(callbackResponse.headers.get("location")).toContain("dashboard");
+ });
+
+ it("should initiate SAML login and fallback to callbackUrl on invalid RelayState", async () => {
+ const { auth, signInWithTestUser } = await getTestInstance({
+ plugins: [sso()],
+ });
+
+ const { headers } = await signInWithTestUser();
+ await auth.api.registerSSOProvider({
+ body: {
+ providerId: "saml-provider-1",
+ issuer: "http://localhost:8081",
+ domain: "http://localhost:8081",
+ samlConfig: {
+ entryPoint: "http://localhost:8081/api/sso/saml2/idp/post",
+ cert: certificate,
+ callbackUrl: "http://localhost:3000/dashboard",
+ wantAssertionsSigned: false,
+ signatureAlgorithm: "sha256",
+ digestAlgorithm: "sha256",
+ idpMetadata: {
+ metadata: idpMetadata,
+ },
+ spMetadata: {
+ metadata: spMetadata,
+ },
+ identifierFormat:
+ "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress",
+ },
+ },
+ headers,
+ });
+
+ const response = await auth.api.signInSSO({
+ body: {
+ providerId: "saml-provider-1",
+ callbackURL: "http://localhost:3000/dashboard",
+ },
+ returnHeaders: true,
+ });
+
+ const signInResponse = response.response;
+ expect(signInResponse).toEqual({
+ url: expect.stringContaining("http://localhost:8081"),
+ redirect: true,
+ });
+
+ let samlResponse: any;
+ await betterFetch(signInResponse?.url, {
+ onSuccess: async (context) => {
+ samlResponse = await context.data;
+ },
+ });
+
+ const callbackResponse = await auth.api.callbackSSOSAML({
+ method: "POST",
+ body: {
+ SAMLResponse: samlResponse.samlResponse,
+ RelayState: "not-the-right-relay-state",
+ },
+ headers: {
+ Cookie: response.headers.get("set-cookie") ?? "",
+ },
+ params: {
+ providerId: "saml-provider-1",
+ },
+ asResponse: true,
+ });
+
+ expect(callbackResponse.status).toBe(302);
+ expect(callbackResponse.headers.get("location")).toBe(
+ "http://localhost:3000/dashboard",
+ );
+ });
+
+ it("should initiate SAML login and signup user when disableImplicitSignUp is true but requestSignup is explicitly enabled", async () => {
+ const { auth, signInWithTestUser } = await getTestInstance({
+ plugins: [sso({ disableImplicitSignUp: true })],
+ });
+
+ const { headers } = await signInWithTestUser();
+ await auth.api.registerSSOProvider({
+ body: {
+ providerId: "saml-provider-1",
+ issuer: "http://localhost:8081",
+ domain: "http://localhost:8081",
+ samlConfig: {
+ entryPoint: "http://localhost:8081/api/sso/saml2/idp/post",
+ cert: certificate,
+ callbackUrl: "http://localhost:3000/dashboard",
+ wantAssertionsSigned: false,
+ signatureAlgorithm: "sha256",
+ digestAlgorithm: "sha256",
+ idpMetadata: {
+ metadata: idpMetadata,
+ },
+ spMetadata: {
+ metadata: spMetadata,
+ },
+ identifierFormat:
+ "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress",
+ },
+ },
+ headers,
+ });
+
+ const response = await auth.api.signInSSO({
+ body: {
+ providerId: "saml-provider-1",
+ callbackURL: "http://localhost:3000/dashboard",
+ requestSignUp: true,
+ },
+ returnHeaders: true,
+ });
+
+ const signInResponse = response.response;
+ expect(signInResponse).toEqual({
+ url: expect.stringContaining("http://localhost:8081"),
+ redirect: true,
+ });
+
+ let samlResponse: any;
+ await betterFetch(signInResponse?.url, {
+ onSuccess: async (context) => {
+ samlResponse = await context.data;
+ },
+ });
+
+ let samlRedirectUrl = new URL(signInResponse?.url);
+ const callbackResponse = await auth.api.callbackSSOSAML({
+ method: "POST",
+ body: {
+ SAMLResponse: samlResponse.samlResponse,
+ RelayState: samlRedirectUrl.searchParams.get("RelayState") ?? "",
+ },
+ headers: {
+ Cookie: response.headers.get("set-cookie") ?? "",
+ },
+ params: {
+ providerId: "saml-provider-1",
+ },
+ asResponse: true,
+ });
+
+ expect(callbackResponse.headers.get("location")).toContain("dashboard");
+ });
+
it("should reject SAML sign-in when disableImplicitSignUp is true and user doesn't exist", async () => {
const { auth: authWithDisabledSignUp, signInWithTestUser } =
await getTestInstance({
@@ -1293,7 +1509,6 @@ describe("SAML SSO", async () => {
},
body: new URLSearchParams({
SAMLResponse: samlResponse.samlResponse,
- RelayState: "http://localhost:3000/dashboard",
}),
},
),
@@ -1374,7 +1589,6 @@ describe("SAML SSO", async () => {
},
body: new URLSearchParams({
SAMLResponse: samlResponse.samlResponse,
- RelayState: "http://localhost:3000/dashboard",
}),
},
),
@@ -1442,7 +1656,6 @@ describe("SAML SSO", async () => {
},
body: new URLSearchParams({
SAMLResponse: samlResponse.samlResponse,
- RelayState: "http://localhost:3000/dashboard",
}),
},
),
@@ -1509,7 +1722,6 @@ describe("SAML SSO", async () => {
},
body: new URLSearchParams({
SAMLResponse: samlResponse.samlResponse,
- RelayState: "http://localhost:3000/dashboard",
}),
},
),
@@ -1569,7 +1781,6 @@ describe("SAML SSO", async () => {
},
body: new URLSearchParams({
SAMLResponse: samlResponse.samlResponse,
- RelayState: "http://localhost:3000/dashboard",
}),
},
),
@@ -1638,7 +1849,6 @@ describe("SAML SSO", async () => {
},
body: new URLSearchParams({
SAMLResponse: samlResponse.samlResponse,
- RelayState: "http://localhost:3000/dashboard",
}),
},
),
@@ -1978,8 +2188,8 @@ describe("SSO Provider Config Parsing", () => {
});
});
-describe("SAML SSO - Signature Validation Security", () => {
- it("should reject unsigned SAML response with forged NameID", async () => {
+describe("SAML SSO - IdP Initiated Flow", () => {
+ it("should handle IdP-initiated flow with GET after POST redirect", async () => {
const { auth, signInWithTestUser } = await getTestInstance({
plugins: [sso()],
});
@@ -1988,11 +2198,14 @@ describe("SAML SSO - Signature Validation Security", () => {
await auth.api.registerSSOProvider({
body: {
- providerId: "security-test-provider",
+ providerId: "idp-initiated-provider",
issuer: "http://localhost:8081",
domain: "http://localhost:8081",
samlConfig: {
- entryPoint: "http://localhost:8081/api/sso/saml2/idp/post",
+ entryPoint: sharedMockIdP.metadataUrl.replace(
+ "/idp/metadata",
+ "/idp/post",
+ ),
cert: certificate,
callbackUrl: "http://localhost:3000/dashboard",
wantAssertionsSigned: false,
@@ -2011,49 +2224,161 @@ describe("SAML SSO - Signature Validation Security", () => {
headers,
});
- const forgedSamlResponse = `
-
- http://localhost:8081
-
-
-
-
- http://localhost:8081
-
- attacker-forged@evil.com
-
-
-
- http://localhost:3001
-
-
-
-
- urn:oasis:names:tc:SAML:2.0:ac:classes:Password
-
-
-
-
- `;
-
- const encodedForgedResponse = base64.encode(forgedSamlResponse);
-
- await expect(
- auth.api.callbackSSOSAML({
- body: {
- SAMLResponse: encodedForgedResponse,
- RelayState: "http://localhost:3000/dashboard",
- },
- params: {
- providerId: "security-test-provider",
- },
- }),
- ).rejects.toMatchObject({
- status: "BAD_REQUEST",
+ let samlResponse:
+ | { samlResponse: string; entityEndpoint?: string }
+ | undefined;
+ await betterFetch("http://localhost:8081/api/sso/saml2/idp/post", {
+ onSuccess: async (context) => {
+ samlResponse = context.data as {
+ samlResponse: string;
+ entityEndpoint?: string;
+ };
+ },
});
+
+ if (!samlResponse?.samlResponse) {
+ throw new Error("Failed to get SAML response from mock IdP");
+ }
+
+ const postResponse = await auth.api.callbackSSOSAML({
+ method: "POST",
+ body: {
+ SAMLResponse: samlResponse.samlResponse,
+ RelayState: "http://localhost:3000/dashboard",
+ },
+ params: {
+ providerId: "idp-initiated-provider",
+ },
+ asResponse: true,
+ });
+
+ expect(postResponse).toBeInstanceOf(Response);
+ expect(postResponse.status).toBe(302);
+ const redirectLocation = postResponse.headers.get("location");
+ expect(redirectLocation).toBe("http://localhost:3000/dashboard");
+
+ const cookieHeader = postResponse.headers.get("set-cookie");
+ const getResponse = await auth.api.callbackSSOSAML({
+ method: "GET",
+ query: {
+ RelayState: "http://localhost:3000/dashboard",
+ },
+ params: {
+ providerId: "idp-initiated-provider",
+ },
+ headers: cookieHeader ? { cookie: cookieHeader } : undefined,
+ asResponse: true,
+ });
+
+ expect(getResponse).toBeInstanceOf(Response);
+ expect(getResponse.status).toBe(302);
+ const getRedirectLocation = getResponse.headers.get("location");
+ expect(getRedirectLocation).toBe("http://localhost:3000/dashboard");
});
- it("should reject SAML response with invalid signature", async () => {
+ it("should reject direct GET request without session", async () => {
+ const { auth } = await getTestInstance({
+ plugins: [sso()],
+ });
+
+ const getResponse = await auth.api
+ .callbackSSOSAML({
+ method: "GET",
+ params: {
+ providerId: "test-provider",
+ },
+ asResponse: true,
+ })
+ .catch((e) => {
+ if (e instanceof APIError && e.status === "FOUND") {
+ return new Response(null, {
+ status: e.statusCode,
+ headers: e.headers || new Headers(),
+ });
+ }
+ throw e;
+ });
+
+ expect(getResponse).toBeInstanceOf(Response);
+ expect(getResponse.status).toBe(302);
+ const redirectLocation = getResponse.headers.get("location");
+ expect(redirectLocation).toContain("/error");
+ expect(redirectLocation).toContain("error=invalid_request");
+ });
+
+ it("should prevent redirect loop when callbackUrl points to callback route", async () => {
+ const { auth, signInWithTestUser } = await getTestInstance({
+ plugins: [sso()],
+ });
+
+ const { headers } = await signInWithTestUser();
+
+ const callbackRouteUrl =
+ "http://localhost:3000/api/auth/sso/saml2/callback/loop-test-provider";
+
+ await auth.api.registerSSOProvider({
+ body: {
+ providerId: "loop-test-provider",
+ issuer: "http://localhost:8081",
+ domain: "http://localhost:8081",
+ samlConfig: {
+ entryPoint: sharedMockIdP.metadataUrl.replace(
+ "/idp/metadata",
+ "/idp/post",
+ ),
+ cert: certificate,
+ callbackUrl: callbackRouteUrl,
+ wantAssertionsSigned: false,
+ signatureAlgorithm: "sha256",
+ digestAlgorithm: "sha256",
+ idpMetadata: {
+ metadata: idpMetadata,
+ },
+ spMetadata: {
+ metadata: spMetadata,
+ },
+ identifierFormat:
+ "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress",
+ },
+ },
+ headers,
+ });
+
+ let samlResponse:
+ | { samlResponse: string; entityEndpoint?: string }
+ | undefined;
+ await betterFetch("http://localhost:8081/api/sso/saml2/idp/post", {
+ onSuccess: async (context) => {
+ samlResponse = context.data as {
+ samlResponse: string;
+ entityEndpoint?: string;
+ };
+ },
+ });
+
+ if (!samlResponse?.samlResponse) {
+ throw new Error("Failed to get SAML response from mock IdP");
+ }
+
+ const postResponse = await auth.api.callbackSSOSAML({
+ method: "POST",
+ body: {
+ SAMLResponse: samlResponse.samlResponse,
+ },
+ params: {
+ providerId: "loop-test-provider",
+ },
+ asResponse: true,
+ });
+
+ expect(postResponse).toBeInstanceOf(Response);
+ expect(postResponse.status).toBe(302);
+ const redirectLocation = postResponse.headers.get("location");
+ expect(redirectLocation).not.toBe(callbackRouteUrl);
+ expect(redirectLocation).toBe("http://localhost:3000");
+ });
+
+ it("should handle GET request with RelayState in query", async () => {
const { auth, signInWithTestUser } = await getTestInstance({
plugins: [sso()],
});
@@ -2062,11 +2387,14 @@ describe("SAML SSO - Signature Validation Security", () => {
await auth.api.registerSSOProvider({
body: {
- providerId: "invalid-sig-provider",
+ providerId: "relaystate-provider",
issuer: "http://localhost:8081",
domain: "http://localhost:8081",
samlConfig: {
- entryPoint: "http://localhost:8081/api/sso/saml2/idp/post",
+ entryPoint: sharedMockIdP.metadataUrl.replace(
+ "/idp/metadata",
+ "/idp/post",
+ ),
cert: certificate,
callbackUrl: "http://localhost:3000/dashboard",
wantAssertionsSigned: false,
@@ -2085,47 +2413,442 @@ describe("SAML SSO - Signature Validation Security", () => {
headers,
});
- const responseWithBadSignature = `
-
- http://localhost:8081
-
-
-
-
-
-
- FAKE_DIGEST_VALUE
-
-
- INVALID_SIGNATURE_VALUE_THAT_SHOULD_FAIL_VERIFICATION
-
-
-
-
-
- http://localhost:8081
-
- forged-admin@company.com
-
-
-
- `;
-
- const encodedBadSigResponse = base64.encode(responseWithBadSignature);
-
- await expect(
- auth.api.callbackSSOSAML({
- body: {
- SAMLResponse: encodedBadSigResponse,
- RelayState: "http://localhost:3000/dashboard",
- },
- params: {
- providerId: "invalid-sig-provider",
- },
- }),
- ).rejects.toMatchObject({
- status: "BAD_REQUEST",
+ let samlResponse:
+ | { samlResponse: string; entityEndpoint?: string }
+ | undefined;
+ await betterFetch("http://localhost:8081/api/sso/saml2/idp/post", {
+ onSuccess: async (context) => {
+ samlResponse = context.data as {
+ samlResponse: string;
+ entityEndpoint?: string;
+ };
+ },
});
+
+ if (!samlResponse?.samlResponse) {
+ throw new Error("Failed to get SAML response from mock IdP");
+ }
+
+ const postResponse = await auth.api.callbackSSOSAML({
+ method: "POST",
+ body: {
+ SAMLResponse: samlResponse.samlResponse,
+ RelayState: "http://localhost:3000/custom-path",
+ },
+ params: {
+ providerId: "relaystate-provider",
+ },
+ asResponse: true,
+ });
+
+ const cookieHeader = postResponse.headers.get("set-cookie");
+ const getResponse = await auth.api.callbackSSOSAML({
+ method: "GET",
+ query: {
+ RelayState: "http://localhost:3000/custom-path",
+ },
+ params: {
+ providerId: "relaystate-provider",
+ },
+ headers: cookieHeader ? { cookie: cookieHeader } : undefined,
+ asResponse: true,
+ });
+
+ expect(getResponse).toBeInstanceOf(Response);
+ expect(getResponse.status).toBe(302);
+ const redirectLocation = getResponse.headers.get("location");
+ expect(redirectLocation).toBe("http://localhost:3000/custom-path");
+ });
+
+ it("should handle GET request when POST redirects to callback URL (original issue scenario)", async () => {
+ const { auth, signInWithTestUser } = await getTestInstance({
+ plugins: [sso()],
+ });
+
+ const { headers } = await signInWithTestUser();
+
+ const callbackRouteUrl =
+ "http://localhost:3000/api/auth/sso/saml2/callback/issue-6615-provider";
+
+ await auth.api.registerSSOProvider({
+ body: {
+ providerId: "issue-6615-provider",
+ issuer: "http://localhost:8081",
+ domain: "http://localhost:8081",
+ samlConfig: {
+ entryPoint: sharedMockIdP.metadataUrl.replace(
+ "/idp/metadata",
+ "/idp/post",
+ ),
+ cert: certificate,
+ callbackUrl: "http://localhost:3000/dashboard",
+ wantAssertionsSigned: false,
+ signatureAlgorithm: "sha256",
+ digestAlgorithm: "sha256",
+ idpMetadata: {
+ metadata: idpMetadata,
+ },
+ spMetadata: {
+ metadata: spMetadata,
+ },
+ identifierFormat:
+ "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress",
+ },
+ },
+ headers,
+ });
+
+ let samlResponse:
+ | { samlResponse: string; entityEndpoint?: string }
+ | undefined;
+ await betterFetch("http://localhost:8081/api/sso/saml2/idp/post", {
+ onSuccess: async (context) => {
+ samlResponse = context.data as {
+ samlResponse: string;
+ entityEndpoint?: string;
+ };
+ },
+ });
+
+ if (!samlResponse?.samlResponse) {
+ throw new Error("Failed to get SAML response from mock IdP");
+ }
+
+ const postResponse = await auth.api.callbackSSOSAML({
+ method: "POST",
+ body: {
+ SAMLResponse: samlResponse.samlResponse,
+ RelayState: callbackRouteUrl,
+ },
+ params: {
+ providerId: "issue-6615-provider",
+ },
+ asResponse: true,
+ });
+
+ expect(postResponse).toBeInstanceOf(Response);
+ expect(postResponse.status).toBe(302);
+ const postRedirectLocation = postResponse.headers.get("location");
+ expect(postRedirectLocation).not.toBe(callbackRouteUrl);
+ expect(postRedirectLocation).toBe("http://localhost:3000/dashboard");
+
+ const cookieHeader = postResponse.headers.get("set-cookie");
+ const getResponse = await auth.api.callbackSSOSAML({
+ method: "GET",
+ params: {
+ providerId: "issue-6615-provider",
+ },
+ headers: cookieHeader ? { cookie: cookieHeader } : undefined,
+ asResponse: true,
+ });
+
+ expect(getResponse).toBeInstanceOf(Response);
+ expect(getResponse.status).toBe(302);
+ const getRedirectLocation = getResponse.headers.get("location");
+ expect(getRedirectLocation).toBe("http://localhost:3000");
+ });
+
+ it("should prevent open redirect with malicious RelayState URL", async () => {
+ const { auth, signInWithTestUser } = await getTestInstance({
+ plugins: [sso()],
+ });
+
+ const { headers } = await signInWithTestUser();
+
+ await auth.api.registerSSOProvider({
+ body: {
+ providerId: "open-redirect-test-provider",
+ issuer: "http://localhost:8081",
+ domain: "http://localhost:8081",
+ samlConfig: {
+ entryPoint: sharedMockIdP.metadataUrl.replace(
+ "/idp/metadata",
+ "/idp/post",
+ ),
+ cert: certificate,
+ callbackUrl: "http://localhost:3000/dashboard",
+ wantAssertionsSigned: false,
+ signatureAlgorithm: "sha256",
+ digestAlgorithm: "sha256",
+ idpMetadata: {
+ metadata: idpMetadata,
+ },
+ spMetadata: {
+ metadata: spMetadata,
+ },
+ identifierFormat:
+ "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress",
+ },
+ },
+ headers,
+ });
+
+ let samlResponse:
+ | { samlResponse: string; entityEndpoint?: string }
+ | undefined;
+ await betterFetch("http://localhost:8081/api/sso/saml2/idp/post", {
+ onSuccess: async (context) => {
+ samlResponse = context.data as {
+ samlResponse: string;
+ entityEndpoint?: string;
+ };
+ },
+ });
+
+ if (!samlResponse?.samlResponse) {
+ throw new Error("Failed to get SAML response from mock IdP");
+ }
+
+ // Test POST with malicious RelayState - raw RelayState is not trusted
+ // Falls back to parsedSamlConfig.callbackUrl
+ const postResponse = await auth.api.callbackSSOSAML({
+ method: "POST",
+ body: {
+ SAMLResponse: samlResponse.samlResponse,
+ RelayState: "https://evil.com/phishing",
+ },
+ params: {
+ providerId: "open-redirect-test-provider",
+ },
+ asResponse: true,
+ });
+
+ expect(postResponse).toBeInstanceOf(Response);
+ expect(postResponse.status).toBe(302);
+ const postRedirectLocation = postResponse.headers.get("location");
+ // Should NOT redirect to evil.com - raw RelayState is ignored
+ expect(postRedirectLocation).not.toContain("evil.com");
+ // Falls back to samlConfig.callbackUrl
+ expect(postRedirectLocation).toBe("http://localhost:3000/dashboard");
+ });
+
+ it("should prevent open redirect via GET with malicious RelayState", async () => {
+ const { auth, signInWithTestUser } = await getTestInstance({
+ plugins: [sso()],
+ });
+
+ const { headers } = await signInWithTestUser();
+
+ await auth.api.registerSSOProvider({
+ body: {
+ providerId: "open-redirect-get-provider",
+ issuer: "http://localhost:8081",
+ domain: "http://localhost:8081",
+ samlConfig: {
+ entryPoint: sharedMockIdP.metadataUrl.replace(
+ "/idp/metadata",
+ "/idp/post",
+ ),
+ cert: certificate,
+ callbackUrl: "http://localhost:3000/dashboard",
+ wantAssertionsSigned: false,
+ signatureAlgorithm: "sha256",
+ digestAlgorithm: "sha256",
+ idpMetadata: {
+ metadata: idpMetadata,
+ },
+ spMetadata: {
+ metadata: spMetadata,
+ },
+ identifierFormat:
+ "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress",
+ },
+ },
+ headers,
+ });
+
+ let samlResponse:
+ | { samlResponse: string; entityEndpoint?: string }
+ | undefined;
+ await betterFetch("http://localhost:8081/api/sso/saml2/idp/post", {
+ onSuccess: async (context) => {
+ samlResponse = context.data as {
+ samlResponse: string;
+ entityEndpoint?: string;
+ };
+ },
+ });
+
+ if (!samlResponse?.samlResponse) {
+ throw new Error("Failed to get SAML response from mock IdP");
+ }
+
+ // First do POST to establish session
+ const postResponse = await auth.api.callbackSSOSAML({
+ method: "POST",
+ body: {
+ SAMLResponse: samlResponse.samlResponse,
+ },
+ params: {
+ providerId: "open-redirect-get-provider",
+ },
+ asResponse: true,
+ });
+
+ const cookieHeader = postResponse.headers.get("set-cookie");
+
+ // Test GET with malicious RelayState in query params
+ const getResponse = await auth.api.callbackSSOSAML({
+ method: "GET",
+ query: {
+ RelayState: "https://evil.com/steal-cookies",
+ },
+ params: {
+ providerId: "open-redirect-get-provider",
+ },
+ headers: cookieHeader ? { cookie: cookieHeader } : undefined,
+ asResponse: true,
+ });
+
+ expect(getResponse).toBeInstanceOf(Response);
+ expect(getResponse.status).toBe(302);
+ const getRedirectLocation = getResponse.headers.get("location");
+ // Should NOT redirect to evil.com
+ expect(getRedirectLocation).not.toContain("evil.com");
+ expect(getRedirectLocation).toBe("http://localhost:3000");
+ });
+
+ it("should allow relative path redirects", async () => {
+ const { auth, signInWithTestUser } = await getTestInstance({
+ plugins: [sso()],
+ });
+
+ const { headers } = await signInWithTestUser();
+
+ await auth.api.registerSSOProvider({
+ body: {
+ providerId: "relative-path-provider",
+ issuer: "http://localhost:8081",
+ domain: "http://localhost:8081",
+ samlConfig: {
+ entryPoint: sharedMockIdP.metadataUrl.replace(
+ "/idp/metadata",
+ "/idp/post",
+ ),
+ cert: certificate,
+ callbackUrl: "http://localhost:3000/dashboard",
+ wantAssertionsSigned: false,
+ signatureAlgorithm: "sha256",
+ digestAlgorithm: "sha256",
+ idpMetadata: {
+ metadata: idpMetadata,
+ },
+ spMetadata: {
+ metadata: spMetadata,
+ },
+ identifierFormat:
+ "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress",
+ },
+ },
+ headers,
+ });
+
+ let samlResponse:
+ | { samlResponse: string; entityEndpoint?: string }
+ | undefined;
+ await betterFetch("http://localhost:8081/api/sso/saml2/idp/post", {
+ onSuccess: async (context) => {
+ samlResponse = context.data as {
+ samlResponse: string;
+ entityEndpoint?: string;
+ };
+ },
+ });
+
+ if (!samlResponse?.samlResponse) {
+ throw new Error("Failed to get SAML response from mock IdP");
+ }
+
+ const postResponse = await auth.api.callbackSSOSAML({
+ method: "POST",
+ body: {
+ SAMLResponse: samlResponse.samlResponse,
+ RelayState: "/dashboard/settings",
+ },
+ params: {
+ providerId: "relative-path-provider",
+ },
+ asResponse: true,
+ });
+
+ expect(postResponse).toBeInstanceOf(Response);
+ expect(postResponse.status).toBe(302);
+ const redirectLocation = postResponse.headers.get("location");
+ expect(redirectLocation).toBe("http://localhost:3000/dashboard");
+ });
+
+ it("should block protocol-relative URL attacks (//evil.com)", async () => {
+ const { auth, signInWithTestUser } = await getTestInstance({
+ plugins: [sso()],
+ });
+
+ const { headers } = await signInWithTestUser();
+
+ await auth.api.registerSSOProvider({
+ body: {
+ providerId: "protocol-relative-provider",
+ issuer: "http://localhost:8081",
+ domain: "http://localhost:8081",
+ samlConfig: {
+ entryPoint: sharedMockIdP.metadataUrl.replace(
+ "/idp/metadata",
+ "/idp/post",
+ ),
+ cert: certificate,
+ callbackUrl: "http://localhost:3000/dashboard",
+ wantAssertionsSigned: false,
+ signatureAlgorithm: "sha256",
+ digestAlgorithm: "sha256",
+ idpMetadata: {
+ metadata: idpMetadata,
+ },
+ spMetadata: {
+ metadata: spMetadata,
+ },
+ identifierFormat:
+ "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress",
+ },
+ },
+ headers,
+ });
+
+ let samlResponse:
+ | { samlResponse: string; entityEndpoint?: string }
+ | undefined;
+ await betterFetch("http://localhost:8081/api/sso/saml2/idp/post", {
+ onSuccess: async (context) => {
+ samlResponse = context.data as {
+ samlResponse: string;
+ entityEndpoint?: string;
+ };
+ },
+ });
+
+ if (!samlResponse?.samlResponse) {
+ throw new Error("Failed to get SAML response from mock IdP");
+ }
+
+ // Test POST with protocol-relative URL - raw RelayState is not trusted
+ // Falls back to parsedSamlConfig.callbackUrl
+ const postResponse = await auth.api.callbackSSOSAML({
+ method: "POST",
+ body: {
+ SAMLResponse: samlResponse.samlResponse,
+ RelayState: "//evil.com/phishing",
+ },
+ params: {
+ providerId: "protocol-relative-provider",
+ },
+ asResponse: true,
+ });
+
+ expect(postResponse).toBeInstanceOf(Response);
+ expect(postResponse.status).toBe(302);
+ const redirectLocation = postResponse.headers.get("location");
+ // Should NOT redirect to evil.com - raw RelayState is ignored
+ expect(redirectLocation).not.toContain("evil.com");
+ // Falls back to samlConfig.callbackUrl
+ expect(redirectLocation).toBe("http://localhost:3000/dashboard");
});
});
@@ -2358,6 +3081,328 @@ describe("SAML SSO - Timestamp Validation", () => {
});
});
+describe("SAML ACS Origin Check Bypass", () => {
+ describe("Positive: SAML endpoints allow external IdP origins", () => {
+ it("should allow SAML callback POST from external IdP origin", async () => {
+ const { auth, signInWithTestUser } = await getTestInstance({
+ plugins: [sso()],
+ });
+ const { headers } = await signInWithTestUser();
+
+ // Register SAML provider with full config
+ await auth.api.registerSSOProvider({
+ body: {
+ providerId: "origin-bypass-callback",
+ issuer: "http://localhost:8081",
+ domain: "origin-bypass.com",
+ samlConfig: {
+ entryPoint: sharedMockIdP.metadataUrl,
+ cert: certificate,
+ callbackUrl: "http://localhost:8081/api/auth/sso/saml2/callback",
+ wantAssertionsSigned: false,
+ signatureAlgorithm: "sha256",
+ digestAlgorithm: "sha256",
+ spMetadata: {
+ metadata: spMetadata,
+ },
+ },
+ },
+ headers,
+ });
+
+ // POST to callback with external Origin header (simulating IdP POST)
+ // Origin check should be bypassed for SAML callback endpoints
+ const callbackRes = await auth.handler(
+ new Request(
+ "http://localhost:8081/api/auth/sso/saml2/callback/origin-bypass-callback",
+ {
+ method: "POST",
+ headers: {
+ "Content-Type": "application/x-www-form-urlencoded",
+ Origin: "http://external-idp.example.com", // External IdP origin - would normally be blocked
+ Cookie: headers.get("cookie") || "",
+ },
+ body: new URLSearchParams({
+ SAMLResponse: Buffer.from(" ").toString(
+ "base64",
+ ),
+ RelayState: "",
+ }).toString(),
+ },
+ ),
+ );
+
+ // Should NOT return 403 Forbidden (origin check bypassed)
+ // May return other errors (400, 500) due to invalid SAML response, but NOT origin rejection
+ expect(callbackRes.status).not.toBe(403);
+ });
+
+ it("should allow ACS endpoint POST from external IdP origin", async () => {
+ const { auth, signInWithTestUser } = await getTestInstance({
+ plugins: [sso()],
+ });
+ const { headers } = await signInWithTestUser();
+
+ // Register SAML provider with full config
+ await auth.api.registerSSOProvider({
+ body: {
+ providerId: "origin-bypass-acs",
+ issuer: "http://localhost:8081",
+ domain: "origin-bypass-acs.com",
+ samlConfig: {
+ entryPoint: sharedMockIdP.metadataUrl,
+ cert: certificate,
+ callbackUrl: "http://localhost:8081/api/auth/sso/saml2/sp/acs",
+ wantAssertionsSigned: false,
+ signatureAlgorithm: "sha256",
+ digestAlgorithm: "sha256",
+ spMetadata: {
+ metadata: spMetadata,
+ },
+ },
+ },
+ headers,
+ });
+
+ // POST to ACS with external Origin header
+ const acsRes = await auth.handler(
+ new Request(
+ "http://localhost:8081/api/auth/sso/saml2/sp/acs/origin-bypass-acs",
+ {
+ method: "POST",
+ headers: {
+ "Content-Type": "application/x-www-form-urlencoded",
+ Origin: "http://idp.external.com", // External IdP origin
+ Cookie: headers.get("cookie") || "",
+ },
+ body: new URLSearchParams({
+ SAMLResponse: Buffer.from(" ").toString(
+ "base64",
+ ),
+ }).toString(),
+ },
+ ),
+ );
+
+ // Should NOT return 403 Forbidden
+ expect(acsRes.status).not.toBe(403);
+ });
+ });
+
+ describe("Negative: Non-SAML endpoints remain protected", () => {
+ it("should block POST to sign-up with untrusted origin when origin check is enabled", async () => {
+ const { auth } = await getTestInstance({
+ plugins: [sso()],
+ advanced: {
+ disableCSRFCheck: false,
+ disableOriginCheck: false,
+ },
+ });
+
+ // Origin check applies when cookies are present and check is enabled
+ const signUpRes = await auth.handler(
+ new Request("http://localhost:8081/api/auth/sign-up/email", {
+ method: "POST",
+ headers: {
+ "Content-Type": "application/json",
+ Origin: "http://attacker.com",
+ Cookie: "better-auth.session_token=fake-session",
+ },
+ body: JSON.stringify({
+ email: "victim@example.com",
+ password: "password123",
+ name: "Victim",
+ }),
+ }),
+ );
+
+ expect(signUpRes.status).toBe(403);
+ });
+ });
+
+ describe("Edge cases", () => {
+ it("should allow GET requests to SAML metadata regardless of origin", async () => {
+ const { auth } = await getTestInstance({
+ plugins: [sso()],
+ });
+
+ // GET requests always bypass origin check
+ const metadataRes = await auth.handler(
+ new Request("http://localhost:8081/api/auth/sso/saml2/sp/metadata", {
+ method: "GET",
+ headers: {
+ Origin: "http://any-origin.com",
+ },
+ }),
+ );
+
+ expect(metadataRes.status).not.toBe(403);
+ });
+
+ it("should not redirect to malicious RelayState URLs", async () => {
+ const { auth, signInWithTestUser } = await getTestInstance({
+ plugins: [sso()],
+ });
+ const { headers } = await signInWithTestUser();
+
+ await auth.api.registerSSOProvider({
+ body: {
+ providerId: "relay-security-test",
+ issuer: "http://localhost:8081",
+ domain: "relay-security.com",
+ samlConfig: {
+ entryPoint: sharedMockIdP.metadataUrl,
+ cert: certificate,
+ callbackUrl: "http://localhost:8081/api/auth/sso/saml2/callback",
+ wantAssertionsSigned: false,
+ signatureAlgorithm: "sha256",
+ digestAlgorithm: "sha256",
+ spMetadata: {
+ metadata: spMetadata,
+ },
+ },
+ },
+ headers,
+ });
+
+ // Even with origin bypass, malicious RelayState should be rejected
+ const callbackRes = await auth.handler(
+ new Request(
+ "http://localhost:8081/api/auth/sso/saml2/callback/relay-security-test",
+ {
+ method: "POST",
+ headers: {
+ "Content-Type": "application/x-www-form-urlencoded",
+ Origin: "http://idp.example.com",
+ },
+ body: new URLSearchParams({
+ SAMLResponse: Buffer.from(" ").toString(
+ "base64",
+ ),
+ RelayState: "http://malicious-site.com/steal-token",
+ }).toString(),
+ },
+ ),
+ );
+
+ // Should NOT redirect to malicious URL
+ if (callbackRes.status === 302) {
+ const location = callbackRes.headers.get("Location");
+ expect(location).not.toContain("malicious-site.com");
+ }
+ });
+ });
+});
+
+describe("SAML Response Security", () => {
+ it("should reject forged/unsigned SAML responses", async () => {
+ const { auth, signInWithTestUser } = await getTestInstance({
+ plugins: [sso()],
+ });
+ const { headers } = await signInWithTestUser();
+
+ await auth.api.registerSSOProvider({
+ body: {
+ providerId: "security-test-provider",
+ issuer: "http://localhost:8081",
+ domain: "security-test.com",
+ samlConfig: {
+ entryPoint: sharedMockIdP.metadataUrl,
+ cert: certificate,
+ callbackUrl: "http://localhost:8081/api/auth/sso/saml2/callback",
+ wantAssertionsSigned: false,
+ signatureAlgorithm: "sha256",
+ digestAlgorithm: "sha256",
+ spMetadata: {
+ metadata: spMetadata,
+ },
+ },
+ },
+ headers,
+ });
+
+ const forgedSAMLResponse = `
+
+
+
+ attacker@evil.com
+
+
+
+ `;
+
+ const callbackRes = await auth.handler(
+ new Request(
+ "http://localhost:8081/api/auth/sso/saml2/callback/security-test-provider",
+ {
+ method: "POST",
+ headers: {
+ "Content-Type": "application/x-www-form-urlencoded",
+ },
+ body: new URLSearchParams({
+ SAMLResponse: Buffer.from(forgedSAMLResponse).toString("base64"),
+ RelayState: "",
+ }).toString(),
+ },
+ ),
+ );
+
+ expect(callbackRes.status).toBe(400);
+ const body = await callbackRes.json();
+ expect(body.message).toBe("Invalid SAML response");
+ });
+
+ it("should reject SAML response with tampered nameID", async () => {
+ const { auth, signInWithTestUser } = await getTestInstance({
+ plugins: [sso()],
+ });
+ const { headers } = await signInWithTestUser();
+
+ await auth.api.registerSSOProvider({
+ body: {
+ providerId: "tamper-test-provider",
+ issuer: "http://localhost:8081",
+ domain: "tamper-test.com",
+ samlConfig: {
+ entryPoint: sharedMockIdP.metadataUrl,
+ cert: certificate,
+ callbackUrl: "http://localhost:8081/api/auth/sso/saml2/callback",
+ wantAssertionsSigned: false,
+ signatureAlgorithm: "sha256",
+ digestAlgorithm: "sha256",
+ spMetadata: {
+ metadata: spMetadata,
+ },
+ },
+ },
+ headers,
+ });
+
+ const tamperedResponse = `
+
+ admin@victim.com
+ `;
+
+ const callbackRes = await auth.handler(
+ new Request(
+ "http://localhost:8081/api/auth/sso/saml2/callback/tamper-test-provider",
+ {
+ method: "POST",
+ headers: {
+ "Content-Type": "application/x-www-form-urlencoded",
+ },
+ body: new URLSearchParams({
+ SAMLResponse: Buffer.from(tamperedResponse).toString("base64"),
+ RelayState: "",
+ }).toString(),
+ },
+ ),
+ );
+
+ expect(callbackRes.status).toBe(400);
+ });
+});
+
describe("SAML SSO - Size Limit Validation", () => {
it("should export default size limit constants", async () => {
const { DEFAULT_MAX_SAML_RESPONSE_SIZE, DEFAULT_MAX_SAML_METADATA_SIZE } =
@@ -2408,7 +3453,6 @@ describe("SAML SSO - Assertion Replay Protection", () => {
},
});
- // First submission should succeed
const firstResponse = await auth.handler(
new Request(
"http://localhost:3000/api/auth/sso/saml2/callback/replay-test-provider",
@@ -2429,7 +3473,6 @@ describe("SAML SSO - Assertion Replay Protection", () => {
const firstLocation = firstResponse.headers.get("location") || "";
expect(firstLocation).not.toContain("error");
- // Second submission (replay) should be rejected
const replayResponse = await auth.handler(
new Request(
"http://localhost:3000/api/auth/sso/saml2/callback/replay-test-provider",
@@ -2490,7 +3533,6 @@ describe("SAML SSO - Assertion Replay Protection", () => {
},
});
- // First submission to ACS endpoint should succeed
const firstResponse = await auth.handler(
new Request(
"http://localhost:3000/api/auth/sso/saml2/sp/acs/acs-replay-test-provider",
@@ -2511,7 +3553,6 @@ describe("SAML SSO - Assertion Replay Protection", () => {
const firstLocation = firstResponse.headers.get("location") || "";
expect(firstLocation).not.toContain("error");
- // Second submission (replay) to ACS endpoint should be rejected
const replayResponse = await auth.handler(
new Request(
"http://localhost:3000/api/auth/sso/saml2/sp/acs/acs-replay-test-provider",
@@ -2572,7 +3613,6 @@ describe("SAML SSO - Assertion Replay Protection", () => {
},
});
- // First: Submit to callback endpoint (should succeed)
const callbackResponse = await auth.handler(
new Request(
"http://localhost:3000/api/auth/sso/saml2/callback/cross-endpoint-provider",
@@ -2592,7 +3632,6 @@ describe("SAML SSO - Assertion Replay Protection", () => {
expect(callbackResponse.status).toBe(302);
expect(callbackResponse.headers.get("location")).not.toContain("error");
- // Second: Replay same assertion to ACS endpoint (should be rejected)
const acsReplayResponse = await auth.handler(
new Request(
"http://localhost:3000/api/auth/sso/saml2/sp/acs/cross-endpoint-provider",