feat(sso): add domain verification for SSO providers (#5910)

Co-authored-by: Maxwell <145994855+ping-maxwell@users.noreply.github.com>
This commit is contained in:
Jonathan Samines
2025-11-19 02:17:12 -06:00
committed by GitHub
parent 5b6a91895f
commit da9657e53b
7 changed files with 1181 additions and 33 deletions

View File

@@ -704,6 +704,122 @@ mapping: {
}
```
## Domain verification
Domain verification allows your application to automatically trust a new SSO provider
by automatically validating ownership via the associated domain:
<Tabs items={["client", "server"]}>
<Tab value="client">
```ts title="auth.ts"
const authClient = createAuthClient({
plugins: [
ssoClient({ // [!code highlight]
domainVerification: { // [!code highlight]
enabled: true // [!code highlight]
} // [!code highlight]
}) // [!code highlight]
]
})
```
</Tab>
<Tab value="server">
```ts title="auth-client.ts"
const auth = betterAuth({
plugins: [
sso({ // [!code highlight]
domainVerification: { // [!code highlight]
enabled: true // [!code highlight]
} // [!code highlight]
}) // [!code highlight]
]
});
```
</Tab>
</Tabs>
Once enabled, make sure you migrate the database schema (again).
<Tabs items={["migrate", "generate"]}>
<Tab value="migrate">
```bash
npx @better-auth/cli migrate
```
</Tab>
<Tab value="generate">
```bash
npx @better-auth/cli generate
```
</Tab>
</Tabs>
See the [Schema](#schema-for-domain-verification) section to add the fields manually.
### Verify your domain
When domain verification is enabled, every new SSO provider will be untrusted at first.
This means that new sign-ups or sign-ins will be allowed until the domain ownership has been verified.
To verify your ownership over a domain, follow these steps:
<Steps>
<Step>
#### Acquire verification token
When an SSO provider is registered, a **verification token** will be issued to the provider (it will be returned as part of the response).
You can use this token to prove ownership over the domain.
</Step>
<Step>
#### Create `TXT` DNS record
To do this, you'll need to add a `TXT` record to your domain's DNS settings:
* **Host:** `better-auth-token-{your-provider-id}` (**Note:** This assumes the default token prefix, which can be customized through the `domainVerification.tokenPrefix` option)
* **Value:** The verification token you were given.
**Save the record and wait for it to propagate.** This can take up to 48 hours, but it's usually much faster.
</Step>
<Step>
#### Submit a validation request
**Once the DNS record has propagated**, you can submit a validation request (See below)
</Step>
</Steps>
### Domain validation request
Once you have configured your domain, you can use your `auth` instance to submit a validation request.
This request will either result in a rejection (could not prove your ownership over the domain)
or if the verification is successful, your SSO provider domain will be marked as verified.
<APIMethod path="/sso/verify-domain" method="POST" requireSession>
```ts
type verifyDomain = {
/**
* The provider id
*/
providerId: string = "acme-corp"
}
```
</APIMethod>
### Creating a new verification token
Every domain verification token will have a default expiry of 1 week since the moment it was issued
or the moment when the SSO provider was registered.
After that time, the token will expire and cannot longer be used. When that happens,
you can create a new verification token:
<APIMethod path="/sso/request-domain-verification" method="POST" requireSession>
```ts
type requestDomainVerification = {
/**
* The provider id
*/
providerId: string = "acme-corp"
}
```
</APIMethod>
### SAML Endpoints
The plugin automatically creates the following SAML endpoints:
@@ -726,7 +842,17 @@ The plugin requires additional fields in the `ssoProvider` table to store the pr
{ name: "samlConfig", type: "string", description: "The SAML configuration (JSON string)", isRequired: false },
{ name: "userId", type: "string", description: "The user ID", isRequired: true, references: { model: "user", field: "id" } },
{ name: "providerId", type: "string", description: "The provider ID. Used to identify a provider and to generate a redirect URL.", isRequired: true, isUnique: true },
{ name: "organizationId", type: "string", description: "The organization Id. If provider is linked to an organization.", isRequired: false },
{ name: "organizationId", type: "string", description: "The organization Id. If provider is linked to an organization.", isRequired: false }
]}
/>
### If you have enabled domain verification:
The `ssoProvider` schema is extended as follows:
<DatabaseTable
fields={[
{ name: "domainVerified", type: "boolean", description: "A flag indicating whether the provider domain has been verified.", isRequired: false },
]}
/>
@@ -790,6 +916,23 @@ If you want to allow account linking for specific trusted providers, enable the
type: "number | function",
default: 10,
},
domainVerification: {
description: "Configure the domain verification feature",
type: "object",
properties: {
enabled: {
description: "Enables or disables the domain verification feature",
type: "boolean",
required: false
},
tokenPrefix: {
description: "Prefix to append to the verification token identifier",
type: "string",
required: false,
default: "better-auth-token-"
},
},
},
defaultSSO: {
description: "Configure a default SSO provider for testing and development. This provider will be used when no matching provider is found in the database.",
type: "object",

View File

@@ -1,8 +1,25 @@
import type { BetterAuthClientPlugin } from "better-auth";
import type { sso } from "./index";
export const ssoClient = () => {
import type { SSOPlugin } from "./index";
interface SSOClientOptions {
domainVerification?:
| {
enabled: boolean;
}
| undefined;
}
export const ssoClient = <CO extends SSOClientOptions>(
options?: CO | undefined,
) => {
return {
id: "sso-client",
$InferServerPlugin: {} as ReturnType<typeof sso>,
$InferServerPlugin: {} as SSOPlugin<{
domainVerification: {
enabled: CO["domainVerification"] extends { enabled: true }
? true
: false;
};
}>,
} satisfies BetterAuthClientPlugin;
};

View File

@@ -0,0 +1,550 @@
import { betterAuth } from "better-auth";
import { memoryAdapter } from "better-auth/adapters/memory";
import { createAuthClient } from "better-auth/client";
import { setCookieToHeader } from "better-auth/cookies";
import { bearer, organization } from "better-auth/plugins";
import { afterEach, describe, expect, it, vi } from "vitest";
import { sso } from ".";
import { ssoClient } from "./client";
import type { SSOOptions } from "./types";
const dnsMock = vi.hoisted(() => {
return {
resolveTxt: vi.fn(),
};
});
vi.mock("node:dns/promises", () => {
return {
...dnsMock,
default: dnsMock,
};
});
describe("Domain verification", async () => {
type TestUser = { email: string; password: string; name: string };
const testUser: TestUser = {
email: "test@email.com",
password: "password",
name: "Test User",
};
const createTestAuth = (options?: SSOOptions) => {
const data = {
user: [],
session: [],
verification: [],
account: [],
ssoProvider: [],
member: [],
organization: [],
};
const memory = memoryAdapter(data);
const ssoOptions = {
...options,
domainVerification: {
...options?.domainVerification,
enabled: true,
},
} satisfies SSOOptions;
const auth = betterAuth({
database: memory,
baseURL: "http://localhost:3000",
emailAndPassword: {
enabled: true,
},
plugins: [sso(ssoOptions), organization()],
});
const authClient = createAuthClient({
baseURL: "http://localhost:3000",
plugins: [bearer(), ssoClient({ domainVerification: { enabled: true } })],
fetchOptions: {
customFetchImpl: async (url, init) => {
return auth.handler(new Request(url, init));
},
},
});
async function createOrganization(name: string, headers: Headers) {
return await auth.api.createOrganization({
body: {
name,
slug: name,
},
headers,
});
}
async function getAuthHeaders(user: TestUser, organizationId?: string) {
const headers = new Headers();
const response = await authClient.signUp.email({
email: user.email,
password: user.password,
name: user.name,
});
if (response.data && organizationId) {
await auth.api.addMember({
body: {
userId: response.data.user.id,
role: "member",
},
headers,
});
}
await authClient.signIn.email(user, {
throw: true,
onSuccess: setCookieToHeader(headers),
});
return headers;
}
async function registerSSOProvider(
headers: Headers,
organizationId?: string,
) {
return auth.api.registerSSOProvider({
body: {
providerId: "saml-provider-1",
issuer: "http://hello.com:8081",
domain: "http://hello.com:8081",
samlConfig: {
entryPoint: "http://idp.com:",
cert: "the-cert",
callbackUrl: "http://hello.com:8081/api/sso/saml2/callback",
spMetadata: {},
},
organizationId,
},
headers,
});
}
return {
auth,
authClient,
registerSSOProvider,
getAuthHeaders,
createOrganization,
};
};
afterEach(() => {
vi.clearAllMocks();
vi.useRealTimers();
});
describe("POST /sso/request-domain-verification", () => {
it("should return unauthorized when session is missing", async () => {
const { auth } = createTestAuth();
const response = await auth.api.requestDomainVerification({
body: {
providerId: "the-provider",
},
asResponse: true,
});
expect(response.status).toBe(401);
});
it("should return not found when no provider is found", async () => {
const { auth, getAuthHeaders } = createTestAuth();
const headers = await getAuthHeaders(testUser);
const response = await auth.api.requestDomainVerification({
body: {
providerId: "unknown",
},
headers,
asResponse: true,
});
expect(response.status).toBe(404);
expect(await response.json()).toEqual({
message: "Provider not found",
code: "PROVIDER_NOT_FOUND",
});
});
it("should return the existing active verification token", async () => {
const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
const headers = await getAuthHeaders(testUser);
const provider = await registerSSOProvider(headers);
vi.useFakeTimers({ toFake: ["Date"] });
const newAuthHeaders = await getAuthHeaders(testUser);
const response = await auth.api.requestDomainVerification({
body: {
providerId: provider.providerId,
},
headers: newAuthHeaders,
asResponse: true,
});
expect(response.status).toBe(201);
expect(await response.json()).toEqual({
domainVerificationToken: provider.domainVerificationToken,
});
});
it("should return forbidden if user does not own the provider", async () => {
const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
const headers = await getAuthHeaders(testUser);
const provider = await registerSSOProvider(headers);
const notOwnerHeaders = await getAuthHeaders({
name: "other",
email: "other@test.com",
password: "password",
});
const response = await auth.api.requestDomainVerification({
body: {
providerId: provider.providerId,
},
headers: notOwnerHeaders,
asResponse: true,
});
expect(response.status).toBe(403);
expect(await response.json()).toEqual({
message:
"User must be owner of or belong to the SSO provider organization",
code: "INSUFICCIENT_ACCESS",
});
});
it("should return forbidden if user does not belong to the provider organization", async () => {
const { auth, getAuthHeaders, registerSSOProvider, createOrganization } =
createTestAuth();
const headers = await getAuthHeaders(testUser);
const orgA = await createOrganization("org-a", headers);
const orgB = await createOrganization("org-b", headers);
const provider = await registerSSOProvider(headers, orgA?.id);
const notOrgHeaders = await getAuthHeaders(
{
name: "other",
email: "other@test.com",
password: "password",
},
orgB?.id,
);
const response = await auth.api.requestDomainVerification({
body: {
providerId: provider.providerId,
},
headers: notOrgHeaders,
asResponse: true,
});
expect(response.status).toBe(403);
expect(await response.json()).toEqual({
message:
"User must be owner of or belong to the SSO provider organization",
code: "INSUFICCIENT_ACCESS",
});
});
it("should return a new domain verification token", async () => {
const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
const headers = await getAuthHeaders(testUser);
const provider = await registerSSOProvider(headers);
vi.useFakeTimers({ toFake: ["Date"] });
vi.advanceTimersByTime(Date.now() + 3600 * 24 * 7 * 1000 + 10); // advance 1 week + 10 seconds
const newHeaders = await getAuthHeaders(testUser);
const response = await auth.api.requestDomainVerification({
body: {
providerId: provider.providerId,
},
headers: newHeaders,
asResponse: true,
});
expect(response.status).toBe(201);
expect(await response.json()).toMatchObject({
domainVerificationToken: expect.any(String),
});
});
it("should fail to create a new token on an already verified domain", async () => {
const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
const headers = await getAuthHeaders(testUser);
const provider = await registerSSOProvider(headers);
dnsMock.resolveTxt.mockResolvedValue([
[
`better-auth-token-saml-provider-1=${provider.domainVerificationToken}`,
],
]);
const domainVerificationResponse = await auth.api.verifyDomain({
body: {
providerId: provider.providerId,
},
headers,
asResponse: true,
});
expect(domainVerificationResponse.status).toBe(204);
const domainVerificationSubmissionResponse =
await auth.api.requestDomainVerification({
body: {
providerId: provider.providerId,
},
headers,
asResponse: true,
});
expect(domainVerificationSubmissionResponse.status).toBe(409);
expect(await domainVerificationSubmissionResponse.json()).toEqual({
message: "Domain has already been verified",
code: "DOMAIN_VERIFIED",
});
});
});
describe("POST /sso/verify-domain", () => {
it("should return unauthorized when session is missing", async () => {
const { auth } = createTestAuth();
const response = await auth.api.verifyDomain({
body: {
providerId: "the-provider",
},
asResponse: true,
});
expect(response.status).toBe(401);
});
it("should return not found when no provider is found", async () => {
const { auth, getAuthHeaders } = createTestAuth();
const headers = await getAuthHeaders(testUser);
const response = await auth.api.verifyDomain({
body: {
providerId: "unknown",
},
headers,
asResponse: true,
});
expect(response.status).toBe(404);
expect(await response.json()).toEqual({
message: "Provider not found",
code: "PROVIDER_NOT_FOUND",
});
});
it("should return not found when no pending verification is found", async () => {
const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
const headers = await getAuthHeaders(testUser);
const provider = await registerSSOProvider(headers);
vi.useFakeTimers({ toFake: ["Date"] });
vi.advanceTimersByTime(Date.now() + 3600 * 24 * 7 * 1000 + 10); // advance 1 week + 10 seconds
const newAuthHeaders = await getAuthHeaders(testUser);
const response = await auth.api.verifyDomain({
body: {
providerId: provider.providerId,
},
headers: newAuthHeaders,
asResponse: true,
});
expect(response.status).toBe(404);
expect(await response.json()).toEqual({
message: "No pending domain verification exists",
code: "NO_PENDING_VERIFICATION",
});
});
it("should return bad gateway when unable to verify domain", async () => {
const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
const headers = await getAuthHeaders(testUser);
const provider = await registerSSOProvider(headers);
dnsMock.resolveTxt.mockResolvedValue([
["google-site-verification=the-token"],
]);
const response = await auth.api.verifyDomain({
body: {
providerId: provider.providerId,
},
headers,
asResponse: true,
});
expect(response.status).toBe(502);
expect(await response.json()).toEqual({
message: "Unable to verify domain ownership. Try again later",
code: "DOMAIN_VERIFICATION_FAILED",
});
});
it("should return forbidden if user does not own the provider", async () => {
const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
const headers = await getAuthHeaders(testUser);
const provider = await registerSSOProvider(headers);
const notOwnerHeaders = await getAuthHeaders({
name: "other",
email: "other@test.com",
password: "password",
});
const response = await auth.api.verifyDomain({
body: {
providerId: provider.providerId,
},
headers: notOwnerHeaders,
asResponse: true,
});
expect(response.status).toBe(403);
expect(await response.json()).toEqual({
message:
"User must be owner of or belong to the SSO provider organization",
code: "INSUFICCIENT_ACCESS",
});
});
it("should return forbidden if user does not belong to the provider organization", async () => {
const { auth, getAuthHeaders, registerSSOProvider, createOrganization } =
createTestAuth();
const headers = await getAuthHeaders(testUser);
const orgA = await createOrganization("org-a", headers);
const orgB = await createOrganization("org-b", headers);
const provider = await registerSSOProvider(headers, orgA?.id);
const notOrgHeaders = await getAuthHeaders(
{
name: "other",
email: "other@test.com",
password: "password",
},
orgB?.id,
);
const response = await auth.api.verifyDomain({
body: {
providerId: provider.providerId,
},
headers: notOrgHeaders,
asResponse: true,
});
expect(response.status).toBe(403);
expect(await response.json()).toEqual({
message:
"User must be owner of or belong to the SSO provider organization",
code: "INSUFICCIENT_ACCESS",
});
});
it("should verify a provider domain ownership", async () => {
const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
const headers = await getAuthHeaders(testUser);
const provider = await registerSSOProvider(headers);
expect(provider.domain).toBe("http://hello.com:8081");
expect(provider.domainVerified).toBe(false);
expect(provider.domainVerificationToken).toBeTypeOf("string");
dnsMock.resolveTxt.mockResolvedValue([
["google-site-verification=the-token"],
[
"v=spf1 ip4:50.242.118.232/29 include:_spf.google.com include:mail.zendesk.com ~all",
],
[
`better-auth-token-saml-provider-1=${provider.domainVerificationToken}`,
],
]);
const response = await auth.api.verifyDomain({
body: {
providerId: provider.providerId,
},
headers,
asResponse: true,
});
expect(response.status).toBe(204);
});
it("should verify a provider domain ownership (custom token verification prefix)", async () => {
const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth({
domainVerification: { tokenPrefix: "auth-prefix" },
});
const headers = await getAuthHeaders(testUser);
const provider = await registerSSOProvider(headers);
dnsMock.resolveTxt.mockResolvedValue([
["google-site-verification=the-token"],
[
"v=spf1 ip4:50.242.118.232/29 include:_spf.google.com include:mail.zendesk.com ~all",
],
[`auth-prefix-saml-provider-1=${provider.domainVerificationToken}`],
]);
const response = await auth.api.verifyDomain({
body: {
providerId: provider.providerId,
},
headers,
asResponse: true,
});
expect(response.status).toBe(204);
});
it("should fail to verify an already verified domain", async () => {
const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
const headers = await getAuthHeaders(testUser);
const provider = await registerSSOProvider(headers);
dnsMock.resolveTxt.mockResolvedValue([
[
`better-auth-token-saml-provider-1=${provider.domainVerificationToken}`,
],
]);
const firstResponse = await auth.api.verifyDomain({
body: {
providerId: provider.providerId,
},
headers,
asResponse: true,
});
expect(firstResponse.status).toBe(204);
const secondResponse = await auth.api.verifyDomain({
body: {
providerId: provider.providerId,
},
headers,
asResponse: true,
});
expect(secondResponse.status).toBe(409);
expect(await secondResponse.json()).toEqual({
message: "Domain has already been verified",
code: "DOMAIN_VERIFIED",
});
});
});
});

View File

@@ -1,6 +1,10 @@
import type { BetterAuthPlugin } from "better-auth";
import { XMLValidator } from "fast-xml-parser";
import * as saml from "samlify";
import {
requestDomainVerification,
verifyDomain,
} from "./routes/domain-verification";
import {
acsEndpoint,
callbackSSO,
@@ -25,33 +29,72 @@ const fastValidator = {
saml.setSchemaValidator(fastValidator);
type SSOEndpoints = {
type DomainVerificationEndpoints = {
requestDomainVerification: ReturnType<typeof requestDomainVerification>;
verifyDomain: ReturnType<typeof verifyDomain>;
};
type SSOEndpoints<O extends SSOOptions> = {
spMetadata: ReturnType<typeof spMetadata>;
registerSSOProvider: ReturnType<typeof registerSSOProvider>;
registerSSOProvider: ReturnType<typeof registerSSOProvider<O>>;
signInSSO: ReturnType<typeof signInSSO>;
callbackSSO: ReturnType<typeof callbackSSO>;
callbackSSOSAML: ReturnType<typeof callbackSSOSAML>;
acsEndpoint: ReturnType<typeof acsEndpoint>;
};
export type SSOPlugin<O extends SSOOptions> = {
id: "sso";
endpoints: SSOEndpoints<O> &
(O extends { domainVerification: { enabled: true } }
? DomainVerificationEndpoints
: {});
};
export function sso<
O extends SSOOptions & {
domainVerification?: { enabled: true };
},
>(
options?: O | undefined,
): {
id: "sso";
endpoints: SSOEndpoints<O> & DomainVerificationEndpoints;
schema: any;
options: O;
};
export function sso<O extends SSOOptions>(
options?: O | undefined,
): {
id: "sso";
endpoints: SSOEndpoints;
endpoints: SSOEndpoints<O>;
};
export function sso<O extends SSOOptions>(options?: O | undefined): any {
let endpoints = {
spMetadata: spMetadata(),
registerSSOProvider: registerSSOProvider(options as O),
signInSSO: signInSSO(options as O),
callbackSSO: callbackSSO(options as O),
callbackSSOSAML: callbackSSOSAML(options as O),
acsEndpoint: acsEndpoint(options as O),
};
if (options?.domainVerification?.enabled) {
const domainVerificationEndpoints = {
requestDomainVerification: requestDomainVerification(options as O),
verifyDomain: verifyDomain(options as O),
};
endpoints = {
...endpoints,
...domainVerificationEndpoints,
};
}
return {
id: "sso",
endpoints: {
spMetadata: spMetadata(),
registerSSOProvider: registerSSOProvider(options),
signInSSO: signInSSO(options),
callbackSSO: callbackSSO(options),
callbackSSOSAML: callbackSSOSAML(options),
acsEndpoint: acsEndpoint(options),
},
endpoints,
schema: {
ssoProvider: {
modelName: options?.modelName ?? "ssoProvider",
@@ -95,6 +138,9 @@ export function sso<O extends SSOOptions>(options?: O | undefined): any {
required: true,
fieldName: options?.fields?.domain ?? "domain",
},
...(options?.domainVerification?.enabled
? { domainVerified: { type: "boolean", required: false } }
: {}),
},
},
},

View File

@@ -0,0 +1,275 @@
import type { Verification } from "better-auth";
import {
APIError,
createAuthEndpoint,
sessionMiddleware,
} from "better-auth/api";
import { generateRandomString } from "better-auth/crypto";
import * as z from "zod/v4";
import type { SSOOptions, SSOProvider } from "../types";
export const requestDomainVerification = (options: SSOOptions) => {
return createAuthEndpoint(
"/sso/request-domain-verification",
{
method: "POST",
body: z.object({
providerId: z.string(),
}),
metadata: {
openapi: {
summary: "Request a domain verification",
description:
"Request a domain verification for the given SSO provider",
responses: {
"404": {
description: "Provider not found",
},
"409": {
description: "Domain has already been verified",
},
"201": {
description: "Domain submitted for verification",
},
},
},
},
use: [sessionMiddleware],
},
async (ctx) => {
const body = ctx.body;
const provider = await ctx.context.adapter.findOne<
SSOProvider<SSOOptions>
>({
model: "ssoProvider",
where: [{ field: "providerId", value: body.providerId }],
});
if (!provider) {
throw new APIError("NOT_FOUND", {
message: "Provider not found",
code: "PROVIDER_NOT_FOUND",
});
}
const userId = ctx.context.session.user.id;
let isOrgMember = true;
if (provider.organizationId) {
const membershipsCount = await ctx.context.adapter.count({
model: "member",
where: [
{ field: "userId", value: userId },
{ field: "organizationId", value: provider.organizationId },
],
});
isOrgMember = membershipsCount > 0;
}
if (provider.userId !== userId || !isOrgMember) {
throw new APIError("FORBIDDEN", {
message:
"User must be owner of or belong to the SSO provider organization",
code: "INSUFICCIENT_ACCESS",
});
}
if ("domainVerified" in provider && provider.domainVerified) {
throw new APIError("CONFLICT", {
message: "Domain has already been verified",
code: "DOMAIN_VERIFIED",
});
}
const activeVerification =
await ctx.context.adapter.findOne<Verification>({
model: "verification",
where: [
{
field: "identifier",
value: options.domainVerification?.tokenPrefix
? `${options.domainVerification?.tokenPrefix}-${provider.providerId}`
: `better-auth-token-${provider.providerId}`,
},
{ field: "expiresAt", value: new Date(), operator: "gt" },
],
});
if (activeVerification) {
ctx.setStatus(201);
return ctx.json({ domainVerificationToken: activeVerification.value });
}
const domainVerificationToken = generateRandomString(24);
await ctx.context.adapter.create<Verification>({
model: "verification",
data: {
identifier: options.domainVerification?.tokenPrefix
? `${options.domainVerification?.tokenPrefix}-${provider.providerId}`
: `better-auth-token-${provider.providerId}`,
createdAt: new Date(),
updatedAt: new Date(),
value: domainVerificationToken,
expiresAt: new Date(Date.now() + 3600 * 24 * 7 * 1000), // 1 week
},
});
ctx.setStatus(201);
return ctx.json({
domainVerificationToken,
});
},
);
};
export const verifyDomain = (options: SSOOptions) => {
return createAuthEndpoint(
"/sso/verify-domain",
{
method: "POST",
body: z.object({
providerId: z.string(),
}),
metadata: {
openapi: {
summary: "Verify the provider domain ownership",
description: "Verify the provider domain ownership via DNS records",
responses: {
"404": {
description: "Provider not found",
},
"409": {
description:
"Domain has already been verified or no pending verification exists",
},
"502": {
description:
"Unable to verify domain ownership due to upstream validator error",
},
"204": {
description: "Domain ownership was verified",
},
},
},
},
use: [sessionMiddleware],
},
async (ctx) => {
const body = ctx.body;
const provider = await ctx.context.adapter.findOne<
SSOProvider<SSOOptions>
>({
model: "ssoProvider",
where: [{ field: "providerId", value: body.providerId }],
});
if (!provider) {
throw new APIError("NOT_FOUND", {
message: "Provider not found",
code: "PROVIDER_NOT_FOUND",
});
}
const userId = ctx.context.session.user.id;
let isOrgMember = true;
if (provider.organizationId) {
const membershipsCount = await ctx.context.adapter.count({
model: "member",
where: [
{ field: "userId", value: userId },
{ field: "organizationId", value: provider.organizationId },
],
});
isOrgMember = membershipsCount > 0;
}
if (provider.userId !== userId || !isOrgMember) {
throw new APIError("FORBIDDEN", {
message:
"User must be owner of or belong to the SSO provider organization",
code: "INSUFICCIENT_ACCESS",
});
}
if ("domainVerified" in provider && provider.domainVerified) {
throw new APIError("CONFLICT", {
message: "Domain has already been verified",
code: "DOMAIN_VERIFIED",
});
}
const activeVerification =
await ctx.context.adapter.findOne<Verification>({
model: "verification",
where: [
{
field: "identifier",
value: options.domainVerification?.tokenPrefix
? `${options.domainVerification?.tokenPrefix}-${provider.providerId}`
: `better-auth-token-${provider.providerId}`,
},
{ field: "expiresAt", value: new Date(), operator: "gt" },
],
});
if (!activeVerification) {
throw new APIError("NOT_FOUND", {
message: "No pending domain verification exists",
code: "NO_PENDING_VERIFICATION",
});
}
let records: string[] = [];
let dns: typeof import("node:dns/promises");
try {
dns = await import("node:dns/promises");
} catch (error) {
ctx.context.logger.error(
"The core node:dns module is required for the domain verification feature",
error,
);
throw new APIError("INTERNAL_SERVER_ERROR", {
message: "Unable to verify domain ownership due to server error",
code: "DOMAIN_VERIFICATION_FAILED",
});
}
try {
const dnsRecords = await dns.resolveTxt(
new URL(provider.domain).hostname,
);
records = dnsRecords.flat();
} catch (error) {
ctx.context.logger.warn(
"DNS resolution failure while validating domain ownership",
error,
);
}
const record = records.find((record) =>
record.includes(
`${activeVerification.identifier}=${activeVerification.value}`,
),
);
if (!record) {
throw new APIError("BAD_GATEWAY", {
message: "Unable to verify domain ownership. Try again later",
code: "DOMAIN_VERIFICATION_FAILED",
});
}
await ctx.context.adapter.update<SSOProvider<SSOOptions>>({
model: "ssoProvider",
where: [{ field: "providerId", value: provider.providerId }],
update: {
domainVerified: true,
},
});
ctx.setStatus(204);
return;
},
);
};

View File

@@ -1,5 +1,5 @@
import { BetterFetchError, betterFetch } from "@better-fetch/fetch";
import type { Account, Session, User } from "better-auth";
import type { Account, Session, User, Verification } from "better-auth";
import {
createAuthorizationURL,
generateState,
@@ -13,6 +13,7 @@ import {
sessionMiddleware,
} from "better-auth/api";
import { setSessionCookie } from "better-auth/cookies";
import { generateRandomString } from "better-auth/crypto";
import { handleOAuthUserInfo } from "better-auth/oauth2";
import { decodeJwt } from "jose";
import * as saml from "samlify";
@@ -126,7 +127,7 @@ export const spMetadata = () => {
);
};
export const registerSSOProvider = (options?: SSOOptions) => {
export const registerSSOProvider = <O extends SSOOptions>(options: O) => {
return createAuthEndpoint(
"/sso/register",
{
@@ -358,6 +359,16 @@ export const registerSSOProvider = (options?: SSOOptions) => {
description:
"The domain of the provider, used for email matching",
},
domainVerified: {
type: "boolean",
description:
"A boolean indicating whether the domain has been verified or not",
},
domainVerificationToken: {
type: "string",
description:
"Domain verification token. It can be used to prove ownership over the SSO domain",
},
oidcConfig: {
type: "object",
properties: {
@@ -586,12 +597,13 @@ export const registerSSOProvider = (options?: SSOOptions) => {
const provider = await ctx.context.adapter.create<
Record<string, any>,
SSOProvider
SSOProvider<O>
>({
model: "ssoProvider",
data: {
issuer: body.issuer,
domain: body.domain,
domainVerified: false,
oidcConfig: body.oidcConfig
? JSON.stringify({
issuer: body.issuer,
@@ -640,6 +652,34 @@ export const registerSSOProvider = (options?: SSOOptions) => {
},
});
let domainVerificationToken: string | undefined;
let domainVerified: boolean | undefined;
if (options?.domainVerification?.enabled) {
domainVerified = false;
domainVerificationToken = generateRandomString(24);
await ctx.context.adapter.create<Verification>({
model: "verification",
data: {
identifier: options.domainVerification?.tokenPrefix
? `${options.domainVerification?.tokenPrefix}-${provider.providerId}`
: `better-auth-token-${provider.providerId}`,
createdAt: new Date(),
updatedAt: new Date(),
value: domainVerificationToken,
expiresAt: new Date(Date.now() + 3600 * 24 * 7 * 1000), // 1 week
},
});
}
type SSOProviderReturn = O["domainVerification"] extends { enabled: true }
? {
domainVerified: boolean;
domainVerificationToken: string;
} & SSOProvider<O>
: SSOProvider<O>;
return ctx.json({
...provider,
oidcConfig: JSON.parse(
@@ -649,7 +689,11 @@ export const registerSSOProvider = (options?: SSOOptions) => {
provider.samlConfig as unknown as string,
) as SAMLConfig,
redirectURI: `${ctx.context.baseURL}/sso/callback/${provider.providerId}`,
});
...(options?.domainVerification?.enabled ? { domainVerified } : {}),
...(options?.domainVerification?.enabled
? { domainVerificationToken }
: {}),
} as unknown as SSOProviderReturn);
},
);
};
@@ -840,7 +884,7 @@ export const signInSSO = (options?: SSOOptions) => {
return res.id;
});
}
let provider: SSOProvider | null = null;
let provider: SSOProvider<SSOOptions> | null = null;
if (options?.defaultSSO?.length) {
// Find matching default SSO provider by providerId
const matchingDefault = providerId
@@ -862,7 +906,10 @@ export const signInSSO = (options?: SSOOptions) => {
oidcConfig: matchingDefault.oidcConfig,
samlConfig: matchingDefault.samlConfig,
domain: matchingDefault.domain,
};
...(options.domainVerification?.enabled
? { domainVerified: true }
: {}),
} as SSOProvider<SSOOptions>;
}
}
if (!providerId && !orgId && !domain) {
@@ -873,7 +920,7 @@ export const signInSSO = (options?: SSOOptions) => {
// Try to find provider in database
if (!provider) {
provider = await ctx.context.adapter
.findOne<SSOProvider>({
.findOne<SSOProvider<SSOOptions>>({
model: "ssoProvider",
where: [
{
@@ -925,6 +972,15 @@ export const signInSSO = (options?: SSOOptions) => {
}
}
if (
options?.domainVerification?.enabled &&
!("domainVerified" in provider && provider.domainVerified)
) {
throw new APIError("UNAUTHORIZED", {
message: "Provider domain has not been verified",
});
}
if (provider.oidcConfig && body.providerType !== "saml") {
let finalAuthUrl = provider.oidcConfig.authorizationEndpoint;
if (!finalAuthUrl && provider.oidcConfig.discoveryEndpoint) {
@@ -1064,7 +1120,7 @@ export const callbackSSO = (options?: SSOOptions) => {
}?error=${error}&error_description=${error_description}`,
);
}
let provider: SSOProvider | null = null;
let provider: SSOProvider<SSOOptions> | null = null;
if (options?.defaultSSO?.length) {
const matchingDefault = options.defaultSSO.find(
(defaultProvider) =>
@@ -1075,7 +1131,10 @@ export const callbackSSO = (options?: SSOOptions) => {
...matchingDefault,
issuer: matchingDefault.oidcConfig?.issuer || "",
userId: "default",
};
...(options.domainVerification?.enabled
? { domainVerified: true }
: {}),
} as SSOProvider<SSOOptions>;
}
}
if (!provider) {
@@ -1099,7 +1158,7 @@ export const callbackSSO = (options?: SSOOptions) => {
...res,
oidcConfig:
safeJsonParse<OIDCConfig>(res.oidcConfig) || undefined,
} as SSOProvider;
} as SSOProvider<SSOOptions>;
});
}
if (!provider) {
@@ -1110,6 +1169,15 @@ export const callbackSSO = (options?: SSOOptions) => {
);
}
if (
options?.domainVerification?.enabled &&
!("domainVerified" in provider && provider.domainVerified)
) {
throw new APIError("UNAUTHORIZED", {
message: "Provider domain has not been verified",
});
}
let config = provider.oidcConfig;
if (!config) {
@@ -1405,7 +1473,7 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
async (ctx) => {
const { SAMLResponse, RelayState } = ctx.body;
const { providerId } = ctx.params;
let provider: SSOProvider | null = null;
let provider: SSOProvider<SSOOptions> | null = null;
if (options?.defaultSSO?.length) {
const matchingDefault = options.defaultSSO.find(
(defaultProvider) => defaultProvider.providerId === providerId,
@@ -1415,12 +1483,15 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
...matchingDefault,
userId: "default",
issuer: matchingDefault.samlConfig?.issuer || "",
};
...(options.domainVerification?.enabled
? { domainVerified: true }
: {}),
} as SSOProvider<SSOOptions>;
}
}
if (!provider) {
provider = await ctx.context.adapter
.findOne<SSOProvider>({
.findOne<SSOProvider<SSOOptions>>({
model: "ssoProvider",
where: [{ field: "providerId", value: providerId }],
})
@@ -1443,6 +1514,15 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
});
}
if (
options?.domainVerification?.enabled &&
!("domainVerified" in provider && provider.domainVerified)
) {
throw new APIError("UNAUTHORIZED", {
message: "Provider domain has not been verified",
});
}
const parsedSamlConfig = safeJsonParse<SAMLConfig>(
provider.samlConfig as unknown as string,
);
@@ -1736,7 +1816,7 @@ export const acsEndpoint = (options?: SSOOptions) => {
const { providerId } = ctx.params;
// If defaultSSO is configured, use it as the provider
let provider: SSOProvider | null = null;
let provider: SSOProvider<SSOOptions> | null = null;
if (options?.defaultSSO?.length) {
// For ACS endpoint, we can use the first default provider or try to match by providerId
@@ -1753,11 +1833,14 @@ export const acsEndpoint = (options?: SSOOptions) => {
userId: "default",
samlConfig: matchingDefault.samlConfig,
domain: matchingDefault.domain,
...(options.domainVerification?.enabled
? { domainVerified: true }
: {}),
};
}
} else {
provider = await ctx.context.adapter
.findOne<SSOProvider>({
.findOne<SSOProvider<SSOOptions>>({
model: "ssoProvider",
where: [
{
@@ -1785,6 +1868,15 @@ export const acsEndpoint = (options?: SSOOptions) => {
});
}
if (
options?.domainVerification?.enabled &&
!("domainVerified" in provider && provider.domainVerified)
) {
throw new APIError("UNAUTHORIZED", {
message: "Provider domain has not been verified",
});
}
const parsedSamlConfig = provider.samlConfig;
// Configure SP and IdP
const sp = saml.ServiceProvider({

View File

@@ -81,7 +81,7 @@ export interface SAMLConfig {
mapping?: SAMLMapping | undefined;
}
export type SSOProvider = {
type BaseSSOProvider = {
issuer: string;
oidcConfig?: OIDCConfig | undefined;
samlConfig?: SAMLConfig | undefined;
@@ -91,6 +91,13 @@ export type SSOProvider = {
domain: string;
};
export type SSOProvider<O extends SSOOptions> =
O["domainVerification"] extends { enabled: true }
? {
domainVerified: boolean;
} & BaseSSOProvider
: BaseSSOProvider;
export interface SSOOptions {
/**
* custom function to provision a user when they sign in with an SSO provider.
@@ -112,7 +119,7 @@ export interface SSOOptions {
/**
* The SSO provider
*/
provider: SSOProvider;
provider: SSOProvider<SSOOptions>;
}) => Promise<void>)
| undefined;
/**
@@ -138,7 +145,7 @@ export interface SSOOptions {
/**
* The SSO provider
*/
provider: SSOProvider;
provider: SSOProvider<SSOOptions>;
}) => Promise<"member" | "admin">;
}
| undefined;
@@ -228,4 +235,22 @@ export interface SSOOptions {
* @default false
*/
trustEmailVerified?: boolean | undefined;
/**
* Enable domain verification on SSO providers
*
* When this option is enabled, new SSO providers will require the associated domain to be verified by the owner
* prior to allowing sign-ins.
*/
domainVerification?: {
/**
* Enables or disables the domain verification feature
*/
enabled?: boolean;
/**
* Prefix used to generate the domain verification token
*
* @default "better-auth-token-"
*/
tokenPrefix?: string;
};
}