fix(saml): IdP-Initiated Callback Routing (#6675)

Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: Alex Yang <himself65@users.noreply.github.com>
This commit is contained in:
Paola Estefanía de Campos
2026-01-14 02:54:37 +01:00
committed by GitHub
parent eb2d831b16
commit 34c8a4bd2a
17 changed files with 1747 additions and 307 deletions

View File

@@ -22,12 +22,16 @@ In this setup:
4. Select "SAML 2.0" as the Sign-in method
5. Configure the following settings:
- **Single Sign-on URL**: Your Better Auth ACS endpoint (e.g., `http://localhost:3000/api/auth/sso/saml2/sp/acs/sso`). while `sso` being your providerId
- **Single Sign-on URL**: Your Better Auth callback endpoint (e.g., `http://localhost:3000/api/auth/sso/saml2/callback/sso`). Note: `sso` is your `providerId`
- **Audience URI (SP Entity ID)**: Your Better Auth metadata URL (e.g., `http://localhost:3000/api/auth/sso/saml2/sp/metadata`)
- **Name ID format**: Email Address or any of your choice.
6. Download the IdP metadata XML file and certificate
<Callout type="info">
**IdP-Initiated SSO**: If you want users to access your app from the Okta dashboard, make sure the **Single Sign-on URL** points to the callback endpoint (`/api/auth/sso/saml2/callback/{providerId}`). Better Auth automatically handles both SP-initiated and IdP-initiated flows.
</Callout>
### Step 2: Configure Better Auth
Heres an example configuration for Okta in a dev environment:
@@ -149,6 +153,7 @@ await authClient.signIn.sso({
- Never use these certificates in production
- The example uses `localhost:3000` - adjust URLs for your environment
- For production, always use proper IdP providers like Okta, Azure AD, or OneLogin
- The `callbackUrl` in your SAML config should point to your app's destination (e.g., `/dashboard`), not the callback route itself
### Step 5: Dynamically Registering SAML Providers

View File

@@ -1242,7 +1242,25 @@ type requestDomainVerification = {
The plugin automatically creates the following SAML endpoints:
- **SP Metadata**: `/api/auth/sso/saml2/sp/metadata?providerId={providerId}`
- **SAML Callback**: `/api/auth/sso/saml2/callback/{providerId}`
- **SAML Callback**: `/api/auth/sso/saml2/callback/{providerId}` (supports both GET and POST)
### SAML Callback URL Configuration
The SAML callback endpoint (`/api/auth/sso/saml2/callback/{providerId}`) handles both **SP-initiated** and **IdP-initiated** SSO flows:
- **SP-initiated**: User clicks "Sign in with SSO" in your app → redirects to IdP → IdP POSTs SAMLResponse to callback
- **IdP-initiated**: User clicks app icon in IdP dashboard (Okta, Azure AD, etc.) → IdP POSTs SAMLResponse to callback
**Important**: The `callbackUrl` in your SAML configuration should point to your application's destination URL (e.g., `/dashboard`), **not** the callback route itself. Better Auth automatically handles the callback route and redirects users to your specified `callbackUrl` after successful authentication.
```ts
samlConfig: {
callbackUrl: "/dashboard", // Correct - points to your app destination
// callbackUrl: "/api/auth/sso/saml2/callback/my-provider" // Incorrect - don't point to callback route
}
```
The callback route supports both GET and POST methods automatically, so you don't need to create any additional route handlers in your framework.
## Schema
@@ -1273,6 +1291,27 @@ The `ssoProvider` schema is extended as follows:
]}
/>
### IdP-Initiated SAML SSO
Better Auth supports **IdP-initiated SSO flows**, where users access your application directly from their Identity Provider dashboard (e.g., Okta, Azure AD, OneLogin). This is common in enterprise environments where IT admins prefer centralized app access.
**How it works:**
1. User clicks your app icon in the IdP dashboard
2. IdP POSTs SAMLResponse to `/api/auth/sso/saml2/callback/{providerId}`
3. Better Auth processes the assertion, creates a session, and redirects to your `callbackUrl`
4. Browser follows the redirect with a GET request (handled automatically)
**No additional configuration required** - the callback route automatically handles both GET and POST requests.
<Callout type="info">
If you previously created a manual GET handler for the SAML callback route as a workaround, you can remove it after upgrading. Better Auth now handles GET requests automatically.
</Callout>
<Callout type="info">
**Security:** Better Auth validates all redirect URLs to prevent open redirect attacks. Only relative paths (e.g., `/dashboard`) and URLs matching your configured `trustedOrigins` are allowed. Malicious URLs like `https://evil.com` or protocol-relative URLs (`//evil.com`) are automatically blocked.
</Callout>
For a detailed guide on setting up SAML SSO with examples for Okta and testing with DummyIDP, see our [SAML SSO with Okta](/docs/guides/saml-sso-with-okta).
## Options

View File

@@ -6,6 +6,7 @@ import type {
} from "@better-auth/core";
import type { InternalLogger } from "@better-auth/core/env";
import { logger } from "@better-auth/core/env";
import { normalizePathname } from "@better-auth/core/utils/url";
import type { Endpoint, Middleware } from "better-call";
import { createRouter } from "better-call";
import type { UnionToIntersection } from "../types/helper";
@@ -275,14 +276,7 @@ export const router = <Option extends BetterAuthOptions>(
async onRequest(req) {
//handle disabled paths
const disabledPaths = ctx.options.disabledPaths || [];
const pathname = new URL(req.url).pathname.replace(/\/+$/, "") || "/";
const normalizedPath =
basePath === "/"
? pathname
: pathname.startsWith(basePath)
? pathname.slice(basePath.length).replace(/\/+$/, "") || "/"
: pathname;
const normalizedPath = normalizePathname(req.url, basePath);
if (disabledPaths.includes(normalizedPath)) {
return new Response("Not Found", { status: 404 });
}

View File

@@ -2,16 +2,18 @@ import type { GenericEndpointContext } from "@better-auth/core";
import { createAuthMiddleware } from "@better-auth/core/api";
import { APIError, BASE_ERROR_CODES } from "@better-auth/core/error";
import { deprecate } from "@better-auth/core/utils/deprecate";
import { normalizePathname } from "@better-auth/core/utils/url";
import { matchesOriginPattern } from "../../auth/trusted-origins";
/**
* Checks if CSRF should be skipped for backward compatibility.
* Previously, disableOriginCheck also disabled CSRF checks.
* This maintains that behavior when disableCSRFCheck isn't explicitly set.
* Only triggers for skipOriginCheck === true, not for path arrays.
*/
function shouldSkipCSRFForBackwardCompat(ctx: GenericEndpointContext): boolean {
return (
ctx.context.skipOriginCheck &&
ctx.context.skipOriginCheck === true &&
ctx.context.options.advanced?.disableCSRFCheck === undefined
);
}
@@ -197,6 +199,22 @@ async function validateOrigin(
return;
}
const skipOriginCheck = ctx.context.skipOriginCheck;
if (Array.isArray(skipOriginCheck)) {
try {
const basePath = new URL(ctx.context.baseURL).pathname;
const currentPath = normalizePathname(ctx.request.url, basePath);
const shouldSkipPath = skipOriginCheck.some((skipPath) =>
currentPath.startsWith(skipPath),
);
if (shouldSkipPath) {
return;
}
} catch {
// If parsing fails, don't skip - continue with validation
}
}
const shouldValidate = forceValidate || useCookies;
if (!shouldValidate) {

View File

@@ -3,6 +3,7 @@ import type {
BetterAuthRateLimitStorage,
} from "@better-auth/core";
import { safeJSONParse } from "@better-auth/core/utils/json";
import { normalizePathname } from "@better-auth/core/utils/url";
import type { RateLimit } from "../../types";
import { getIp } from "../../utils/get-request-ip";
import { wildcardMatch } from "../../utils/wildcard";
@@ -159,9 +160,8 @@ export async function onRequestRateLimit(req: Request, ctx: AuthContext) {
if (!ctx.rateLimit.enabled) {
return;
}
const path = new URL(req.url).pathname
.replace(ctx.options.basePath || "/api/auth", "")
.replace(/\/+$/, "");
const basePath = new URL(ctx.baseURL).pathname;
const path = normalizePathname(req.url, basePath);
let currentWindow = ctx.rateLimit.window;
let currentMax = ctx.rateLimit.max;
const ip = getIp(req, ctx.options);

View File

@@ -179,6 +179,8 @@ export async function createAuthContext(
const hasPluginFn = (id: string) => pluginIds.has(id);
const trustedOrigins = await getTrustedOrigins(options);
const ctx: AuthContext = {
appName: options.appName || "Better Auth",
socialProviders: providers,
@@ -190,7 +192,7 @@ export async function createAuthContext(
skipStateCookieCheck: !!options.account?.skipStateCookieCheck,
},
tables,
trustedOrigins: await getTrustedOrigins(options),
trustedOrigins,
isTrustedOrigin(
url: string,
settings?: {

View File

@@ -64,7 +64,7 @@ export async function getTrustedOrigins(
options: BetterAuthOptions,
request?: Request,
): Promise<string[]> {
const baseURL = getBaseURL(options.baseURL, options.basePath);
const baseURL = getBaseURL(options.baseURL, options.basePath, request);
const trustedOrigins: (string | undefined | null)[] = baseURL
? [new URL(baseURL).origin]
: [];

View File

@@ -1,12 +1,9 @@
import type { GenericEndpointContext } from "@better-auth/core";
import { APIError, BASE_ERROR_CODES } from "@better-auth/core/error";
import * as z from "zod";
import { setOAuthState } from "../api/middlewares/oauth";
import {
generateRandomString,
symmetricDecrypt,
symmetricEncrypt,
} from "../crypto";
import { generateRandomString } from "../crypto";
import type { StateData } from "../state";
import { generateGenericState, parseGenericState, StateError } from "../state";
export async function generateState(
c: GenericEndpointContext,
@@ -24,200 +21,55 @@ export async function generateState(
}
const codeVerifier = generateRandomString(128);
const state = generateRandomString(32);
const storeStateStrategy = c.context.oauthConfig.storeStateStrategy;
const stateData = {
const stateData: StateData = {
...(additionalData ? additionalData : {}),
callbackURL,
codeVerifier,
errorURL: c.body?.errorCallbackURL,
newUserURL: c.body?.newUserCallbackURL,
link,
/**
* This is the actual expiry time of the state
*/
expiresAt: Date.now() + 10 * 60 * 1000,
requestSignUp: c.body?.requestSignUp,
state,
};
await setOAuthState(stateData);
if (storeStateStrategy === "cookie") {
// Store state data in an encrypted cookie
const encryptedData = await symmetricEncrypt({
key: c.context.secret,
data: JSON.stringify(stateData),
try {
return generateGenericState(c, stateData);
} catch (error) {
c.context.logger.error("Failed to create verification", error);
throw new APIError("INTERNAL_SERVER_ERROR", {
message: "Unable to create verification",
cause: error,
});
const stateCookie = c.context.createAuthCookie("oauth_state", {
maxAge: 10 * 60 * 1000, // 10 minutes
});
c.setCookie(stateCookie.name, encryptedData, stateCookie.attributes);
return {
state,
codeVerifier,
};
}
// Default: database strategy
const stateCookie = c.context.createAuthCookie("state", {
maxAge: 5 * 60 * 1000, // 5 minutes
});
await c.setSignedCookie(
stateCookie.name,
state,
c.context.secret,
stateCookie.attributes,
);
const expiresAt = new Date();
expiresAt.setMinutes(expiresAt.getMinutes() + 10);
const verification = await c.context.internalAdapter.createVerificationValue({
value: JSON.stringify(stateData),
identifier: state,
expiresAt,
});
if (!verification) {
c.context.logger.error(
"Unable to create verification. Make sure the database adapter is properly working and there is a verification table in the database",
);
throw APIError.from(
"INTERNAL_SERVER_ERROR",
BASE_ERROR_CODES.FAILED_TO_CREATE_VERIFICATION,
);
}
return {
state: verification.identifier,
codeVerifier,
};
}
export async function parseState(c: GenericEndpointContext) {
const state = c.query.state || c.body.state;
const storeStateStrategy = c.context.oauthConfig.storeStateStrategy;
const errorURL =
c.context.options.onAPIError?.errorURL || `${c.context.baseURL}/error`;
const stateDataSchema = z.looseObject({
callbackURL: z.string(),
codeVerifier: z.string(),
errorURL: z.string().optional(),
newUserURL: z.string().optional(),
expiresAt: z.number(),
link: z
.object({
email: z.string(),
userId: z.coerce.string(),
})
.optional(),
requestSignUp: z.boolean().optional(),
state: z.string().optional(),
});
let parsedData: StateData;
let parsedData: z.infer<typeof stateDataSchema>;
/**
* This is generally cause security issue and should only be used in
* dev or staging environments. It's currently used by the oauth-proxy
* plugin
*/
const skipStateCookieCheck = c.context.oauthConfig?.skipStateCookieCheck;
if (storeStateStrategy === "cookie") {
// Retrieve state data from encrypted cookie
const stateCookie = c.context.createAuthCookie("oauth_state");
const encryptedData = c.getCookie(stateCookie.name);
try {
parsedData = await parseGenericState(c, state);
} catch (error) {
c.context.logger.error("Failed to parse state", error);
if (!encryptedData) {
c.context.logger.error("State Mismatch. OAuth state cookie not found", {
state,
});
const errorURL =
c.context.options.onAPIError?.errorURL || `${c.context.baseURL}/error`;
throw c.redirect(`${errorURL}?error=please_restart_the_process`);
}
try {
const decryptedData = await symmetricDecrypt({
key: c.context.secret,
data: encryptedData,
});
parsedData = stateDataSchema.parse(JSON.parse(decryptedData));
} catch (error) {
c.context.logger.error("Failed to decrypt or parse OAuth state cookie", {
error,
});
const errorURL =
c.context.options.onAPIError?.errorURL || `${c.context.baseURL}/error`;
throw c.redirect(`${errorURL}?error=please_restart_the_process`);
}
const skipStateCookieCheck = c.context.oauthConfig?.skipStateCookieCheck;
if (
!skipStateCookieCheck &&
parsedData.state &&
parsedData.state !== state
error instanceof StateError &&
error.code === "state_security_mismatch"
) {
c.context.logger.error("State Mismatch. State parameter does not match", {
expected: parsedData.state,
received: state,
});
const errorURL =
c.context.options.onAPIError?.errorURL || `${c.context.baseURL}/error`;
throw c.redirect(`${errorURL}?error=state_mismatch`);
}
// Clear the cookie after successful parsing
c.setCookie(stateCookie.name, "", {
maxAge: 0,
});
} else {
// Default: database strategy
const data = await c.context.internalAdapter.findVerificationValue(state);
if (!data) {
c.context.logger.error("State Mismatch. Verification not found", {
state,
});
const errorURL =
c.context.options.onAPIError?.errorURL || `${c.context.baseURL}/error`;
throw c.redirect(`${errorURL}?error=please_restart_the_process`);
}
parsedData = stateDataSchema.parse(JSON.parse(data.value));
const stateCookie = c.context.createAuthCookie("state");
const stateCookieValue = await c.getSignedCookie(
stateCookie.name,
c.context.secret,
);
if (
!skipStateCookieCheck &&
(!stateCookieValue || stateCookieValue !== state)
) {
const errorURL =
c.context.options.onAPIError?.errorURL || `${c.context.baseURL}/error`;
throw c.redirect(`${errorURL}?error=state_mismatch`);
}
c.setCookie(stateCookie.name, "", {
maxAge: 0,
});
// Delete verification value after retrieval
await c.context.internalAdapter.deleteVerificationValue(data.id);
throw c.redirect(`${errorURL}?error=please_restart_the_process`);
}
if (!parsedData.errorURL) {
parsedData.errorURL =
c.context.options.onAPIError?.errorURL || `${c.context.baseURL}/error`;
}
// Check expiration
if (parsedData.expiresAt < Date.now()) {
const errorURL =
c.context.options.onAPIError?.errorURL || `${c.context.baseURL}/error`;
throw c.redirect(`${errorURL}?error=please_restart_the_process`);
parsedData.errorURL = errorURL;
}
if (parsedData) {

View File

@@ -785,7 +785,6 @@ describe("signin", async () => {
expiresAt: expect.any(Number),
invitedBy: "user-123",
errorURL: "http://localhost:3000/api/auth/error",
state: expect.any(String),
});
});

View File

@@ -0,0 +1,221 @@
import type { GenericEndpointContext } from "@better-auth/core";
import { BetterAuthError } from "@better-auth/core/error";
import * as z from "zod";
import {
generateRandomString,
symmetricDecrypt,
symmetricEncrypt,
} from "./crypto";
const stateDataSchema = z.looseObject({
callbackURL: z.string(),
codeVerifier: z.string(),
errorURL: z.string().optional(),
newUserURL: z.string().optional(),
expiresAt: z.number(),
link: z
.object({
email: z.string(),
userId: z.coerce.string(),
})
.optional(),
requestSignUp: z.boolean().optional(),
});
export type StateData = z.infer<typeof stateDataSchema>;
export type StateErrorCode =
| "state_generation_error"
| "state_invalid"
| "state_mismatch"
| "state_security_mismatch";
export class StateError extends BetterAuthError {
code: string;
details?: Record<string, any>;
constructor(
message: string,
options: ErrorOptions & {
code: StateErrorCode;
details?: Record<string, any>;
},
) {
super(message, options);
this.code = options.code;
this.details = options.details;
}
}
export async function generateGenericState(
c: GenericEndpointContext,
stateData: StateData,
settings?: { cookieName: string },
) {
const state = generateRandomString(32);
const storeStateStrategy = c.context.oauthConfig.storeStateStrategy;
if (storeStateStrategy === "cookie") {
// Store state data in an encrypted cookie
const encryptedData = await symmetricEncrypt({
key: c.context.secret,
data: JSON.stringify(stateData),
});
const stateCookie = c.context.createAuthCookie(
settings?.cookieName ?? "oauth_state",
{
maxAge: 10 * 60 * 1000, // 10 minutes
},
);
c.setCookie(stateCookie.name, encryptedData, stateCookie.attributes);
return {
state,
codeVerifier: stateData.codeVerifier,
};
}
// Default: database strategy
const stateCookie = c.context.createAuthCookie(
settings?.cookieName ?? "state",
{
maxAge: 5 * 60 * 1000, // 5 minutes
},
);
await c.setSignedCookie(
stateCookie.name,
state,
c.context.secret,
stateCookie.attributes,
);
const expiresAt = new Date();
expiresAt.setMinutes(expiresAt.getMinutes() + 10);
const verification = await c.context.internalAdapter.createVerificationValue({
value: JSON.stringify(stateData),
identifier: state,
expiresAt,
});
if (!verification) {
throw new StateError(
"Unable to create verification. Make sure the database adapter is properly working and there is a verification table in the database",
{
code: "state_generation_error",
},
);
}
return {
state: verification.identifier,
codeVerifier: stateData.codeVerifier,
};
}
export async function parseGenericState(
c: GenericEndpointContext,
state: string,
settings?: { cookieName: string },
) {
const storeStateStrategy = c.context.oauthConfig.storeStateStrategy;
let parsedData: StateData;
if (storeStateStrategy === "cookie") {
// Retrieve state data from encrypted cookie
const stateCookie = c.context.createAuthCookie(
settings?.cookieName ?? "oauth_state",
);
const encryptedData = c.getCookie(stateCookie.name);
if (!encryptedData) {
throw new StateError("State mismatch: auth state cookie not found", {
code: "state_mismatch",
details: { state },
});
}
try {
const decryptedData = await symmetricDecrypt({
key: c.context.secret,
data: encryptedData,
});
parsedData = stateDataSchema.parse(JSON.parse(decryptedData));
} catch (error) {
throw new StateError(
"State invalid: Failed to decrypt or parse auth state",
{
code: "state_invalid",
details: { state },
cause: error,
},
);
}
// Clear the cookie after successful parsing
c.setCookie(stateCookie.name, "", {
maxAge: 0,
});
} else {
// Default: database strategy
const data = await c.context.internalAdapter.findVerificationValue(state);
if (!data) {
throw new StateError("State mismatch: verification not found", {
code: "state_mismatch",
details: { state },
});
}
parsedData = stateDataSchema.parse(JSON.parse(data.value));
const stateCookie = c.context.createAuthCookie(
settings?.cookieName ?? "state",
);
const stateCookieValue = await c.getSignedCookie(
stateCookie.name,
c.context.secret,
);
/**
* This is generally cause security issue and should only be used in
* dev or staging environments. It's currently used by the oauth-proxy
* plugin
*/
const skipStateCookieCheck = c.context.oauthConfig.skipStateCookieCheck;
if (
!skipStateCookieCheck &&
(!stateCookieValue || stateCookieValue !== state)
) {
throw new StateError("State mismatch: State not persisted correctly", {
code: "state_security_mismatch",
details: { state },
});
}
c.setCookie(stateCookie.name, "", {
maxAge: 0,
});
// Delete verification value after retrieval
await c.context.internalAdapter.deleteVerificationValue(data.id);
}
// Check expiration
if (parsedData.expiresAt < Date.now()) {
throw new StateError("Invalid state: request expired", {
code: "state_mismatch",
details: {
expiresAt: parsedData.expiresAt,
},
});
}
return parsedData;
}

View File

@@ -1,2 +1,4 @@
export * from "../oauth2/state";
export type { StateData } from "../state";
export { generateGenericState, parseGenericState } from "../state";
export * from "./hide-metadata";

View File

@@ -308,17 +308,18 @@ export type AuthContext<Options extends BetterAuthOptions = BetterAuthOptions> =
payload: Record<string, any>;
}) => Promise<void>;
/**
* This skips the origin check for all requests.
* Skip origin check for requests.
*
* set to true by default for `test` environments and `false`
* for other environments.
* - `true`: Skip for ALL requests (DANGEROUS - disables CSRF protection)
* - `string[]`: Skip only for specific paths (e.g., SAML callbacks)
* - `false`: Enable origin check (default)
*
* It's inferred from the `options.advanced?.disableCSRFCheck`
* option or `options.advanced?.disableOriginCheck` option.
* Paths support prefix matching (e.g., "/sso/saml2/callback" matches
* "/sso/saml2/callback/provider-name").
*
* @default false
* @default false (true in test environments)
*/
skipOriginCheck: boolean;
skipOriginCheck: boolean | string[];
/**
* This skips the CSRF check for all requests.
*

View File

@@ -0,0 +1,43 @@
/**
* Normalizes a request pathname by removing the basePath prefix and trailing slashes.
* This is useful for matching paths against configured path lists.
*
* @param requestUrl - The full request URL
* @param basePath - The base path of the auth API (e.g., "/api/auth")
* @returns The normalized path without basePath prefix or trailing slashes,
* or "/" if URL parsing fails
*
* @example
* normalizePathname("http://localhost:3000/api/auth/sso/saml2/callback/provider1", "/api/auth")
* // Returns: "/sso/saml2/callback/provider1"
*
* normalizePathname("http://localhost:3000/sso/saml2/callback/provider1/", "/")
* // Returns: "/sso/saml2/callback/provider1"
*/
export function normalizePathname(
requestUrl: string,
basePath: string,
): string {
let pathname: string;
try {
pathname = new URL(requestUrl).pathname.replace(/\/+$/, "") || "/";
} catch {
return "/";
}
if (basePath === "/" || basePath === "") {
return pathname;
}
// Check for exact match or proper path boundary (basePath followed by "/" or end)
// This prevents "/api/auth" from matching "/api/authevil/..."
if (pathname === basePath) {
return "/";
}
if (pathname.startsWith(basePath + "/")) {
return pathname.slice(basePath.length).replace(/\/+$/, "") || "/";
}
return pathname;
}

View File

@@ -103,6 +103,16 @@ export type SSOPlugin<O extends SSOOptions> = {
: {});
};
/**
* SAML endpoint paths that should skip origin check validation.
* These endpoints receive POST requests from external Identity Providers,
* which won't have a matching Origin header.
*/
const SAML_SKIP_ORIGIN_CHECK_PATHS = [
"/sso/saml2/callback", // SP-initiated SSO callback (prefix matches /callback/:providerId)
"/sso/saml2/sp/acs", // IdP-initiated SSO ACS (prefix matches /sp/acs/:providerId)
];
export function sso<
O extends SSOOptions & {
domainVerification?: { enabled: true };
@@ -148,6 +158,18 @@ export function sso<O extends SSOOptions>(options?: O | undefined): any {
return {
id: "sso",
init(ctx) {
const existing = ctx.skipOriginCheck;
if (existing === true) {
return {};
}
const existingPaths = Array.isArray(existing) ? existing : [];
return {
context: {
skipOriginCheck: [...existingPaths, ...SAML_SKIP_ORIGIN_CHECK_PATHS],
},
};
},
endpoints,
hooks: {
after: [

View File

@@ -11,6 +11,7 @@ import {
import {
APIError,
createAuthEndpoint,
getSessionFromCtx,
sessionMiddleware,
} from "better-auth/api";
import { setSessionCookie } from "better-auth/cookies";
@@ -52,6 +53,7 @@ import {
validateSAMLAlgorithms,
validateSingleAssertion,
} from "../saml";
import { generateRelayState, parseRelayState } from "../saml-state";
import type { OIDCConfig, SAMLConfig, SSOOptions, SSOProvider } from "../types";
import { safeJsonParse, validateEmailDomain } from "../utils";
@@ -165,6 +167,8 @@ const spMetadataQuerySchema = z.object({
format: z.enum(["xml", "json"]).default("xml"),
});
type RelayState = Awaited<ReturnType<typeof parseRelayState>>;
export const spMetadata = () => {
return createAuthEndpoint(
"/sso/saml2/sp/metadata",
@@ -1293,6 +1297,12 @@ export const signInSSO = (options?: SSOOptions) => {
});
}
const { state: relayState } = await generateRelayState(
ctx,
undefined,
false,
);
const shouldSaveRequest =
loginRequest.id && options?.saml?.enableInResponseToValidation;
if (shouldSaveRequest) {
@@ -1311,9 +1321,7 @@ export const signInSSO = (options?: SSOOptions) => {
}
return ctx.json({
url: `${loginRequest.context}&RelayState=${encodeURIComponent(
body.callbackURL,
)}`,
url: `${loginRequest.context}&RelayState=${encodeURIComponent(relayState)}`,
redirect: true,
});
}
@@ -1683,12 +1691,71 @@ const callbackSSOSAMLBodySchema = z.object({
RelayState: z.string().optional(),
});
/**
* Validates and returns a safe redirect URL.
* - Prevents open redirect attacks by validating against trusted origins
* - Prevents redirect loops by checking if URL points to callback route
* - Falls back to appOrigin if URL is invalid or unsafe
*/
const getSafeRedirectUrl = (
url: string | undefined,
callbackPath: string,
appOrigin: string,
isTrustedOrigin: (
url: string,
settings?: { allowRelativePaths: boolean },
) => boolean,
): string => {
if (!url) {
return appOrigin;
}
if (url.startsWith("/") && !url.startsWith("//")) {
try {
const absoluteUrl = new URL(url, appOrigin);
if (absoluteUrl.origin !== appOrigin) {
return appOrigin;
}
const callbackPathname = new URL(callbackPath).pathname;
if (absoluteUrl.pathname === callbackPathname) {
return appOrigin;
}
} catch {
return appOrigin;
}
return url;
}
if (!isTrustedOrigin(url, { allowRelativePaths: false })) {
return appOrigin;
}
try {
const callbackPathname = new URL(callbackPath).pathname;
const urlPathname = new URL(url).pathname;
if (urlPathname === callbackPathname) {
return appOrigin;
}
} catch {
if (url === callbackPath || url.startsWith(`${callbackPath}?`)) {
return appOrigin;
}
}
return url;
};
export const callbackSSOSAML = (options?: SSOOptions) => {
return createAuthEndpoint(
"/sso/saml2/callback/:providerId",
{
method: "POST",
body: callbackSSOSAMLBodySchema,
method: ["GET", "POST"],
body: callbackSSOSAMLBodySchema.optional(),
query: z
.object({
RelayState: z.string().optional(),
})
.optional(),
metadata: {
...HIDE_METADATA,
allowedMediaTypes: [
@@ -1699,7 +1766,7 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
operationId: "handleSAMLCallback",
summary: "Callback URL for SAML provider",
description:
"This endpoint is used as the callback URL for SAML providers.",
"This endpoint is used as the callback URL for SAML providers. Supports both GET and POST methods for IdP-initiated and SP-initiated flows.",
responses: {
"302": {
description: "Redirects to the callback URL",
@@ -1715,8 +1782,41 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
},
},
async (ctx) => {
const { SAMLResponse, RelayState } = ctx.body;
const { providerId } = ctx.params;
const appOrigin = new URL(ctx.context.baseURL).origin;
const errorURL =
ctx.context.options.onAPIError?.errorURL || `${appOrigin}/error`;
const currentCallbackPath = `${ctx.context.baseURL}/sso/saml2/callback/${providerId}`;
// Determine if this is a GET request by checking both method AND body presence
// When called via auth.api.*, ctx.method may not be reliable, so we also check for body
const isGetRequest = ctx.method === "GET" && !ctx.body?.SAMLResponse;
if (isGetRequest) {
const session = await getSessionFromCtx(ctx);
if (!session?.session) {
throw ctx.redirect(`${errorURL}?error=invalid_request`);
}
const relayState = ctx.query?.RelayState as string | undefined;
const safeRedirectUrl = getSafeRedirectUrl(
relayState,
currentCallbackPath,
appOrigin,
(url, settings) => ctx.context.isTrustedOrigin(url, settings),
);
throw ctx.redirect(safeRedirectUrl);
}
if (!ctx.body?.SAMLResponse) {
throw new APIError("BAD_REQUEST", {
message: "SAMLResponse is required for POST requests",
});
}
const { SAMLResponse } = ctx.body;
const maxResponseSize =
options?.saml?.maxResponseSize ?? DEFAULT_MAX_SAML_RESPONSE_SIZE;
@@ -1726,6 +1826,14 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
});
}
let relayState: RelayState | null = null;
if (ctx.body.RelayState) {
try {
relayState = await parseRelayState(ctx);
} catch {
relayState = null;
}
}
let provider: SSOProvider<SSOOptions> | null = null;
if (options?.defaultSSO?.length) {
const matchingDefault = options.defaultSSO.find(
@@ -1846,7 +1954,7 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
parsedResponse = await sp.parseLoginResponse(idp, "post", {
body: {
SAMLResponse,
RelayState: RelayState || undefined,
RelayState: ctx.body.RelayState || undefined,
},
});
@@ -1909,7 +2017,9 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
{ inResponseTo, providerId: provider.providerId },
);
const redirectUrl =
RelayState || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
relayState?.callbackURL ||
parsedSamlConfig.callbackUrl ||
ctx.context.baseURL;
throw ctx.redirect(
`${redirectUrl}?error=invalid_saml_response&error_description=Unknown+or+expired+request+ID`,
);
@@ -1929,7 +2039,9 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
`${AUTHN_REQUEST_KEY_PREFIX}${inResponseTo}`,
);
const redirectUrl =
RelayState || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
relayState?.callbackURL ||
parsedSamlConfig.callbackUrl ||
ctx.context.baseURL;
throw ctx.redirect(
`${redirectUrl}?error=invalid_saml_response&error_description=Provider+mismatch`,
);
@@ -1944,7 +2056,9 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
{ providerId: provider.providerId },
);
const redirectUrl =
RelayState || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
relayState?.callbackURL ||
parsedSamlConfig.callbackUrl ||
ctx.context.baseURL;
throw ctx.redirect(
`${redirectUrl}?error=unsolicited_response&error_description=IdP-initiated+SSO+not+allowed`,
);
@@ -1997,7 +2111,9 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
},
);
const redirectUrl =
RelayState || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
relayState?.callbackURL ||
parsedSamlConfig.callbackUrl ||
ctx.context.baseURL;
throw ctx.redirect(
`${redirectUrl}?error=replay_detected&error_description=SAML+assertion+has+already+been+used`,
);
@@ -2071,7 +2187,9 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
validateEmailDomain(userInfo.email as string, provider.domain));
const callbackUrl =
RelayState || parsedSamlConfig.callbackUrl || ctx.context.baseURL;
relayState?.callbackURL ||
parsedSamlConfig.callbackUrl ||
ctx.context.baseURL;
const result = await handleOAuthUserInfo(ctx, {
userInfo: {
@@ -2122,7 +2240,14 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
});
await setSessionCookie(ctx, { session, user });
throw ctx.redirect(callbackUrl);
const safeRedirectUrl = getSafeRedirectUrl(
relayState?.callbackURL || parsedSamlConfig.callbackUrl,
currentCallbackPath,
appOrigin,
(url, settings) => ctx.context.isTrustedOrigin(url, settings),
);
throw ctx.redirect(safeRedirectUrl);
},
);
};

View File

@@ -0,0 +1,78 @@
import type { GenericEndpointContext, StateData } from "better-auth";
import { generateGenericState, parseGenericState } from "better-auth";
import { generateRandomString } from "better-auth/crypto";
import { APIError } from "better-call";
export async function generateRelayState(
c: GenericEndpointContext,
link:
| {
email: string;
userId: string;
}
| undefined,
additionalData: Record<string, any> | false | undefined,
) {
const callbackURL = c.body.callbackURL;
if (!callbackURL) {
throw new APIError("BAD_REQUEST", {
message: "callbackURL is required",
});
}
const codeVerifier = generateRandomString(128);
const stateData: StateData = {
...(additionalData ? additionalData : {}),
callbackURL,
codeVerifier,
errorURL: c.body.errorCallbackURL,
newUserURL: c.body.newUserCallbackURL,
link,
/**
* This is the actual expiry time of the state
*/
expiresAt: Date.now() + 10 * 60 * 1000,
requestSignUp: c.body.requestSignUp,
};
try {
return generateGenericState(c, stateData, {
cookieName: "relay_state",
});
} catch (error) {
c.context.logger.error(
"Failed to create verification for relay state",
error,
);
throw new APIError("INTERNAL_SERVER_ERROR", {
message: "State error: Unable to create verification for relay state",
cause: error,
});
}
}
export async function parseRelayState(c: GenericEndpointContext) {
const state = c.body.RelayState;
const errorURL =
c.context.options.onAPIError?.errorURL || `${c.context.baseURL}/error`;
let parsedData: StateData;
try {
parsedData = await parseGenericState(c, state, {
cookieName: "relay_state",
});
} catch (error) {
c.context.logger.error("Failed to parse relay state", error);
throw new APIError("BAD_REQUEST", {
message: "State error: failed to validate relay state",
cause: error,
});
}
if (!parsedData.errorURL) {
parsedData.errorURL = errorURL;
}
return parsedData;
}

File diff suppressed because it is too large Load Diff