diff --git a/docs/content/docs/plugins/sso.mdx b/docs/content/docs/plugins/sso.mdx
index 2fa6cc7d5a..79b9366422 100644
--- a/docs/content/docs/plugins/sso.mdx
+++ b/docs/content/docs/plugins/sso.mdx
@@ -704,6 +704,122 @@ mapping: {
}
```
+## Domain verification
+
+Domain verification allows your application to automatically trust a new SSO provider
+by automatically validating ownership via the associated domain:
+
+
+
+```ts title="auth.ts"
+const authClient = createAuthClient({
+ plugins: [
+ ssoClient({ // [!code highlight]
+ domainVerification: { // [!code highlight]
+ enabled: true // [!code highlight]
+ } // [!code highlight]
+ }) // [!code highlight]
+ ]
+})
+```
+
+
+
+```ts title="auth-client.ts"
+const auth = betterAuth({
+ plugins: [
+ sso({ // [!code highlight]
+ domainVerification: { // [!code highlight]
+ enabled: true // [!code highlight]
+ } // [!code highlight]
+ }) // [!code highlight]
+ ]
+});
+```
+
+
+
+Once enabled, make sure you migrate the database schema (again).
+
+
+
+ ```bash
+ npx @better-auth/cli migrate
+ ```
+
+
+ ```bash
+ npx @better-auth/cli generate
+ ```
+
+
+
+See the [Schema](#schema-for-domain-verification) section to add the fields manually.
+
+### Verify your domain
+
+When domain verification is enabled, every new SSO provider will be untrusted at first.
+This means that new sign-ups or sign-ins will be allowed until the domain ownership has been verified.
+
+To verify your ownership over a domain, follow these steps:
+
+
+
+#### Acquire verification token
+When an SSO provider is registered, a **verification token** will be issued to the provider (it will be returned as part of the response).
+You can use this token to prove ownership over the domain.
+
+
+#### Create `TXT` DNS record
+To do this, you'll need to add a `TXT` record to your domain's DNS settings:
+
+ * **Host:** `better-auth-token-{your-provider-id}` (**Note:** This assumes the default token prefix, which can be customized through the `domainVerification.tokenPrefix` option)
+ * **Value:** The verification token you were given.
+
+**Save the record and wait for it to propagate.** This can take up to 48 hours, but it's usually much faster.
+
+
+#### Submit a validation request
+**Once the DNS record has propagated**, you can submit a validation request (See below)
+
+
+
+### Domain validation request
+
+Once you have configured your domain, you can use your `auth` instance to submit a validation request.
+This request will either result in a rejection (could not prove your ownership over the domain)
+or if the verification is successful, your SSO provider domain will be marked as verified.
+
+
+```ts
+type verifyDomain = {
+ /**
+ * The provider id
+ */
+ providerId: string = "acme-corp"
+}
+```
+
+
+### Creating a new verification token
+
+Every domain verification token will have a default expiry of 1 week since the moment it was issued
+or the moment when the SSO provider was registered.
+
+After that time, the token will expire and cannot longer be used. When that happens,
+you can create a new verification token:
+
+
+```ts
+type requestDomainVerification = {
+ /**
+ * The provider id
+ */
+ providerId: string = "acme-corp"
+}
+```
+
+
### SAML Endpoints
The plugin automatically creates the following SAML endpoints:
@@ -726,7 +842,17 @@ The plugin requires additional fields in the `ssoProvider` table to store the pr
{ name: "samlConfig", type: "string", description: "The SAML configuration (JSON string)", isRequired: false },
{ name: "userId", type: "string", description: "The user ID", isRequired: true, references: { model: "user", field: "id" } },
{ name: "providerId", type: "string", description: "The provider ID. Used to identify a provider and to generate a redirect URL.", isRequired: true, isUnique: true },
- { name: "organizationId", type: "string", description: "The organization Id. If provider is linked to an organization.", isRequired: false },
+ { name: "organizationId", type: "string", description: "The organization Id. If provider is linked to an organization.", isRequired: false }
+ ]}
+/>
+
+### If you have enabled domain verification:
+
+The `ssoProvider` schema is extended as follows:
+
+
@@ -790,6 +916,23 @@ If you want to allow account linking for specific trusted providers, enable the
type: "number | function",
default: 10,
},
+ domainVerification: {
+ description: "Configure the domain verification feature",
+ type: "object",
+ properties: {
+ enabled: {
+ description: "Enables or disables the domain verification feature",
+ type: "boolean",
+ required: false
+ },
+ tokenPrefix: {
+ description: "Prefix to append to the verification token identifier",
+ type: "string",
+ required: false,
+ default: "better-auth-token-"
+ },
+ },
+ },
defaultSSO: {
description: "Configure a default SSO provider for testing and development. This provider will be used when no matching provider is found in the database.",
type: "object",
diff --git a/packages/sso/src/client.ts b/packages/sso/src/client.ts
index 0b6bd310e6..9d2541b4cc 100644
--- a/packages/sso/src/client.ts
+++ b/packages/sso/src/client.ts
@@ -1,8 +1,25 @@
import type { BetterAuthClientPlugin } from "better-auth";
-import type { sso } from "./index";
-export const ssoClient = () => {
+import type { SSOPlugin } from "./index";
+
+interface SSOClientOptions {
+ domainVerification?:
+ | {
+ enabled: boolean;
+ }
+ | undefined;
+}
+
+export const ssoClient = (
+ options?: CO | undefined,
+) => {
return {
id: "sso-client",
- $InferServerPlugin: {} as ReturnType,
+ $InferServerPlugin: {} as SSOPlugin<{
+ domainVerification: {
+ enabled: CO["domainVerification"] extends { enabled: true }
+ ? true
+ : false;
+ };
+ }>,
} satisfies BetterAuthClientPlugin;
};
diff --git a/packages/sso/src/domain-verification.test.ts b/packages/sso/src/domain-verification.test.ts
new file mode 100644
index 0000000000..ccf8b92d8e
--- /dev/null
+++ b/packages/sso/src/domain-verification.test.ts
@@ -0,0 +1,550 @@
+import { betterAuth } from "better-auth";
+import { memoryAdapter } from "better-auth/adapters/memory";
+import { createAuthClient } from "better-auth/client";
+import { setCookieToHeader } from "better-auth/cookies";
+import { bearer, organization } from "better-auth/plugins";
+import { afterEach, describe, expect, it, vi } from "vitest";
+import { sso } from ".";
+import { ssoClient } from "./client";
+import type { SSOOptions } from "./types";
+
+const dnsMock = vi.hoisted(() => {
+ return {
+ resolveTxt: vi.fn(),
+ };
+});
+
+vi.mock("node:dns/promises", () => {
+ return {
+ ...dnsMock,
+ default: dnsMock,
+ };
+});
+
+describe("Domain verification", async () => {
+ type TestUser = { email: string; password: string; name: string };
+ const testUser: TestUser = {
+ email: "test@email.com",
+ password: "password",
+ name: "Test User",
+ };
+
+ const createTestAuth = (options?: SSOOptions) => {
+ const data = {
+ user: [],
+ session: [],
+ verification: [],
+ account: [],
+ ssoProvider: [],
+ member: [],
+ organization: [],
+ };
+
+ const memory = memoryAdapter(data);
+
+ const ssoOptions = {
+ ...options,
+ domainVerification: {
+ ...options?.domainVerification,
+ enabled: true,
+ },
+ } satisfies SSOOptions;
+
+ const auth = betterAuth({
+ database: memory,
+ baseURL: "http://localhost:3000",
+ emailAndPassword: {
+ enabled: true,
+ },
+ plugins: [sso(ssoOptions), organization()],
+ });
+
+ const authClient = createAuthClient({
+ baseURL: "http://localhost:3000",
+ plugins: [bearer(), ssoClient({ domainVerification: { enabled: true } })],
+ fetchOptions: {
+ customFetchImpl: async (url, init) => {
+ return auth.handler(new Request(url, init));
+ },
+ },
+ });
+
+ async function createOrganization(name: string, headers: Headers) {
+ return await auth.api.createOrganization({
+ body: {
+ name,
+ slug: name,
+ },
+ headers,
+ });
+ }
+
+ async function getAuthHeaders(user: TestUser, organizationId?: string) {
+ const headers = new Headers();
+ const response = await authClient.signUp.email({
+ email: user.email,
+ password: user.password,
+ name: user.name,
+ });
+
+ if (response.data && organizationId) {
+ await auth.api.addMember({
+ body: {
+ userId: response.data.user.id,
+ role: "member",
+ },
+ headers,
+ });
+ }
+
+ await authClient.signIn.email(user, {
+ throw: true,
+ onSuccess: setCookieToHeader(headers),
+ });
+
+ return headers;
+ }
+
+ async function registerSSOProvider(
+ headers: Headers,
+ organizationId?: string,
+ ) {
+ return auth.api.registerSSOProvider({
+ body: {
+ providerId: "saml-provider-1",
+ issuer: "http://hello.com:8081",
+ domain: "http://hello.com:8081",
+ samlConfig: {
+ entryPoint: "http://idp.com:",
+ cert: "the-cert",
+ callbackUrl: "http://hello.com:8081/api/sso/saml2/callback",
+ spMetadata: {},
+ },
+ organizationId,
+ },
+ headers,
+ });
+ }
+
+ return {
+ auth,
+ authClient,
+ registerSSOProvider,
+ getAuthHeaders,
+ createOrganization,
+ };
+ };
+
+ afterEach(() => {
+ vi.clearAllMocks();
+ vi.useRealTimers();
+ });
+
+ describe("POST /sso/request-domain-verification", () => {
+ it("should return unauthorized when session is missing", async () => {
+ const { auth } = createTestAuth();
+ const response = await auth.api.requestDomainVerification({
+ body: {
+ providerId: "the-provider",
+ },
+ asResponse: true,
+ });
+
+ expect(response.status).toBe(401);
+ });
+
+ it("should return not found when no provider is found", async () => {
+ const { auth, getAuthHeaders } = createTestAuth();
+ const headers = await getAuthHeaders(testUser);
+ const response = await auth.api.requestDomainVerification({
+ body: {
+ providerId: "unknown",
+ },
+ headers,
+ asResponse: true,
+ });
+
+ expect(response.status).toBe(404);
+ expect(await response.json()).toEqual({
+ message: "Provider not found",
+ code: "PROVIDER_NOT_FOUND",
+ });
+ });
+
+ it("should return the existing active verification token", async () => {
+ const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
+ const headers = await getAuthHeaders(testUser);
+ const provider = await registerSSOProvider(headers);
+
+ vi.useFakeTimers({ toFake: ["Date"] });
+
+ const newAuthHeaders = await getAuthHeaders(testUser);
+
+ const response = await auth.api.requestDomainVerification({
+ body: {
+ providerId: provider.providerId,
+ },
+ headers: newAuthHeaders,
+ asResponse: true,
+ });
+
+ expect(response.status).toBe(201);
+ expect(await response.json()).toEqual({
+ domainVerificationToken: provider.domainVerificationToken,
+ });
+ });
+
+ it("should return forbidden if user does not own the provider", async () => {
+ const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
+ const headers = await getAuthHeaders(testUser);
+ const provider = await registerSSOProvider(headers);
+
+ const notOwnerHeaders = await getAuthHeaders({
+ name: "other",
+ email: "other@test.com",
+ password: "password",
+ });
+ const response = await auth.api.requestDomainVerification({
+ body: {
+ providerId: provider.providerId,
+ },
+ headers: notOwnerHeaders,
+ asResponse: true,
+ });
+
+ expect(response.status).toBe(403);
+ expect(await response.json()).toEqual({
+ message:
+ "User must be owner of or belong to the SSO provider organization",
+ code: "INSUFICCIENT_ACCESS",
+ });
+ });
+
+ it("should return forbidden if user does not belong to the provider organization", async () => {
+ const { auth, getAuthHeaders, registerSSOProvider, createOrganization } =
+ createTestAuth();
+ const headers = await getAuthHeaders(testUser);
+
+ const orgA = await createOrganization("org-a", headers);
+ const orgB = await createOrganization("org-b", headers);
+
+ const provider = await registerSSOProvider(headers, orgA?.id);
+
+ const notOrgHeaders = await getAuthHeaders(
+ {
+ name: "other",
+ email: "other@test.com",
+ password: "password",
+ },
+ orgB?.id,
+ );
+
+ const response = await auth.api.requestDomainVerification({
+ body: {
+ providerId: provider.providerId,
+ },
+ headers: notOrgHeaders,
+ asResponse: true,
+ });
+
+ expect(response.status).toBe(403);
+ expect(await response.json()).toEqual({
+ message:
+ "User must be owner of or belong to the SSO provider organization",
+ code: "INSUFICCIENT_ACCESS",
+ });
+ });
+
+ it("should return a new domain verification token", async () => {
+ const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
+ const headers = await getAuthHeaders(testUser);
+ const provider = await registerSSOProvider(headers);
+
+ vi.useFakeTimers({ toFake: ["Date"] });
+ vi.advanceTimersByTime(Date.now() + 3600 * 24 * 7 * 1000 + 10); // advance 1 week + 10 seconds
+
+ const newHeaders = await getAuthHeaders(testUser);
+ const response = await auth.api.requestDomainVerification({
+ body: {
+ providerId: provider.providerId,
+ },
+ headers: newHeaders,
+ asResponse: true,
+ });
+
+ expect(response.status).toBe(201);
+ expect(await response.json()).toMatchObject({
+ domainVerificationToken: expect.any(String),
+ });
+ });
+
+ it("should fail to create a new token on an already verified domain", async () => {
+ const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
+ const headers = await getAuthHeaders(testUser);
+ const provider = await registerSSOProvider(headers);
+
+ dnsMock.resolveTxt.mockResolvedValue([
+ [
+ `better-auth-token-saml-provider-1=${provider.domainVerificationToken}`,
+ ],
+ ]);
+
+ const domainVerificationResponse = await auth.api.verifyDomain({
+ body: {
+ providerId: provider.providerId,
+ },
+ headers,
+ asResponse: true,
+ });
+
+ expect(domainVerificationResponse.status).toBe(204);
+
+ const domainVerificationSubmissionResponse =
+ await auth.api.requestDomainVerification({
+ body: {
+ providerId: provider.providerId,
+ },
+ headers,
+ asResponse: true,
+ });
+
+ expect(domainVerificationSubmissionResponse.status).toBe(409);
+ expect(await domainVerificationSubmissionResponse.json()).toEqual({
+ message: "Domain has already been verified",
+ code: "DOMAIN_VERIFIED",
+ });
+ });
+ });
+
+ describe("POST /sso/verify-domain", () => {
+ it("should return unauthorized when session is missing", async () => {
+ const { auth } = createTestAuth();
+ const response = await auth.api.verifyDomain({
+ body: {
+ providerId: "the-provider",
+ },
+ asResponse: true,
+ });
+
+ expect(response.status).toBe(401);
+ });
+
+ it("should return not found when no provider is found", async () => {
+ const { auth, getAuthHeaders } = createTestAuth();
+ const headers = await getAuthHeaders(testUser);
+ const response = await auth.api.verifyDomain({
+ body: {
+ providerId: "unknown",
+ },
+ headers,
+ asResponse: true,
+ });
+
+ expect(response.status).toBe(404);
+ expect(await response.json()).toEqual({
+ message: "Provider not found",
+ code: "PROVIDER_NOT_FOUND",
+ });
+ });
+
+ it("should return not found when no pending verification is found", async () => {
+ const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
+ const headers = await getAuthHeaders(testUser);
+ const provider = await registerSSOProvider(headers);
+
+ vi.useFakeTimers({ toFake: ["Date"] });
+ vi.advanceTimersByTime(Date.now() + 3600 * 24 * 7 * 1000 + 10); // advance 1 week + 10 seconds
+
+ const newAuthHeaders = await getAuthHeaders(testUser);
+
+ const response = await auth.api.verifyDomain({
+ body: {
+ providerId: provider.providerId,
+ },
+ headers: newAuthHeaders,
+ asResponse: true,
+ });
+
+ expect(response.status).toBe(404);
+ expect(await response.json()).toEqual({
+ message: "No pending domain verification exists",
+ code: "NO_PENDING_VERIFICATION",
+ });
+ });
+
+ it("should return bad gateway when unable to verify domain", async () => {
+ const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
+ const headers = await getAuthHeaders(testUser);
+ const provider = await registerSSOProvider(headers);
+
+ dnsMock.resolveTxt.mockResolvedValue([
+ ["google-site-verification=the-token"],
+ ]);
+
+ const response = await auth.api.verifyDomain({
+ body: {
+ providerId: provider.providerId,
+ },
+ headers,
+ asResponse: true,
+ });
+
+ expect(response.status).toBe(502);
+ expect(await response.json()).toEqual({
+ message: "Unable to verify domain ownership. Try again later",
+ code: "DOMAIN_VERIFICATION_FAILED",
+ });
+ });
+
+ it("should return forbidden if user does not own the provider", async () => {
+ const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
+ const headers = await getAuthHeaders(testUser);
+ const provider = await registerSSOProvider(headers);
+
+ const notOwnerHeaders = await getAuthHeaders({
+ name: "other",
+ email: "other@test.com",
+ password: "password",
+ });
+ const response = await auth.api.verifyDomain({
+ body: {
+ providerId: provider.providerId,
+ },
+ headers: notOwnerHeaders,
+ asResponse: true,
+ });
+
+ expect(response.status).toBe(403);
+ expect(await response.json()).toEqual({
+ message:
+ "User must be owner of or belong to the SSO provider organization",
+ code: "INSUFICCIENT_ACCESS",
+ });
+ });
+
+ it("should return forbidden if user does not belong to the provider organization", async () => {
+ const { auth, getAuthHeaders, registerSSOProvider, createOrganization } =
+ createTestAuth();
+ const headers = await getAuthHeaders(testUser);
+ const orgA = await createOrganization("org-a", headers);
+ const orgB = await createOrganization("org-b", headers);
+
+ const provider = await registerSSOProvider(headers, orgA?.id);
+
+ const notOrgHeaders = await getAuthHeaders(
+ {
+ name: "other",
+ email: "other@test.com",
+ password: "password",
+ },
+ orgB?.id,
+ );
+ const response = await auth.api.verifyDomain({
+ body: {
+ providerId: provider.providerId,
+ },
+ headers: notOrgHeaders,
+ asResponse: true,
+ });
+
+ expect(response.status).toBe(403);
+ expect(await response.json()).toEqual({
+ message:
+ "User must be owner of or belong to the SSO provider organization",
+ code: "INSUFICCIENT_ACCESS",
+ });
+ });
+
+ it("should verify a provider domain ownership", async () => {
+ const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
+ const headers = await getAuthHeaders(testUser);
+ const provider = await registerSSOProvider(headers);
+
+ expect(provider.domain).toBe("http://hello.com:8081");
+ expect(provider.domainVerified).toBe(false);
+ expect(provider.domainVerificationToken).toBeTypeOf("string");
+
+ dnsMock.resolveTxt.mockResolvedValue([
+ ["google-site-verification=the-token"],
+ [
+ "v=spf1 ip4:50.242.118.232/29 include:_spf.google.com include:mail.zendesk.com ~all",
+ ],
+ [
+ `better-auth-token-saml-provider-1=${provider.domainVerificationToken}`,
+ ],
+ ]);
+
+ const response = await auth.api.verifyDomain({
+ body: {
+ providerId: provider.providerId,
+ },
+ headers,
+ asResponse: true,
+ });
+
+ expect(response.status).toBe(204);
+ });
+
+ it("should verify a provider domain ownership (custom token verification prefix)", async () => {
+ const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth({
+ domainVerification: { tokenPrefix: "auth-prefix" },
+ });
+ const headers = await getAuthHeaders(testUser);
+ const provider = await registerSSOProvider(headers);
+
+ dnsMock.resolveTxt.mockResolvedValue([
+ ["google-site-verification=the-token"],
+ [
+ "v=spf1 ip4:50.242.118.232/29 include:_spf.google.com include:mail.zendesk.com ~all",
+ ],
+ [`auth-prefix-saml-provider-1=${provider.domainVerificationToken}`],
+ ]);
+
+ const response = await auth.api.verifyDomain({
+ body: {
+ providerId: provider.providerId,
+ },
+ headers,
+ asResponse: true,
+ });
+
+ expect(response.status).toBe(204);
+ });
+
+ it("should fail to verify an already verified domain", async () => {
+ const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
+ const headers = await getAuthHeaders(testUser);
+ const provider = await registerSSOProvider(headers);
+
+ dnsMock.resolveTxt.mockResolvedValue([
+ [
+ `better-auth-token-saml-provider-1=${provider.domainVerificationToken}`,
+ ],
+ ]);
+
+ const firstResponse = await auth.api.verifyDomain({
+ body: {
+ providerId: provider.providerId,
+ },
+ headers,
+ asResponse: true,
+ });
+
+ expect(firstResponse.status).toBe(204);
+
+ const secondResponse = await auth.api.verifyDomain({
+ body: {
+ providerId: provider.providerId,
+ },
+ headers,
+ asResponse: true,
+ });
+
+ expect(secondResponse.status).toBe(409);
+ expect(await secondResponse.json()).toEqual({
+ message: "Domain has already been verified",
+ code: "DOMAIN_VERIFIED",
+ });
+ });
+ });
+});
diff --git a/packages/sso/src/index.ts b/packages/sso/src/index.ts
index 9a49f1dc47..5e45274d93 100644
--- a/packages/sso/src/index.ts
+++ b/packages/sso/src/index.ts
@@ -1,6 +1,10 @@
import type { BetterAuthPlugin } from "better-auth";
import { XMLValidator } from "fast-xml-parser";
import * as saml from "samlify";
+import {
+ requestDomainVerification,
+ verifyDomain,
+} from "./routes/domain-verification";
import {
acsEndpoint,
callbackSSO,
@@ -25,33 +29,72 @@ const fastValidator = {
saml.setSchemaValidator(fastValidator);
-type SSOEndpoints = {
+type DomainVerificationEndpoints = {
+ requestDomainVerification: ReturnType;
+ verifyDomain: ReturnType;
+};
+
+type SSOEndpoints = {
spMetadata: ReturnType;
- registerSSOProvider: ReturnType;
+ registerSSOProvider: ReturnType>;
signInSSO: ReturnType;
callbackSSO: ReturnType;
callbackSSOSAML: ReturnType;
acsEndpoint: ReturnType;
};
+export type SSOPlugin = {
+ id: "sso";
+ endpoints: SSOEndpoints &
+ (O extends { domainVerification: { enabled: true } }
+ ? DomainVerificationEndpoints
+ : {});
+};
+
+export function sso<
+ O extends SSOOptions & {
+ domainVerification?: { enabled: true };
+ },
+>(
+ options?: O | undefined,
+): {
+ id: "sso";
+ endpoints: SSOEndpoints & DomainVerificationEndpoints;
+ schema: any;
+ options: O;
+};
export function sso(
options?: O | undefined,
): {
id: "sso";
- endpoints: SSOEndpoints;
+ endpoints: SSOEndpoints;
};
export function sso(options?: O | undefined): any {
+ let endpoints = {
+ spMetadata: spMetadata(),
+ registerSSOProvider: registerSSOProvider(options as O),
+ signInSSO: signInSSO(options as O),
+ callbackSSO: callbackSSO(options as O),
+ callbackSSOSAML: callbackSSOSAML(options as O),
+ acsEndpoint: acsEndpoint(options as O),
+ };
+
+ if (options?.domainVerification?.enabled) {
+ const domainVerificationEndpoints = {
+ requestDomainVerification: requestDomainVerification(options as O),
+ verifyDomain: verifyDomain(options as O),
+ };
+
+ endpoints = {
+ ...endpoints,
+ ...domainVerificationEndpoints,
+ };
+ }
+
return {
id: "sso",
- endpoints: {
- spMetadata: spMetadata(),
- registerSSOProvider: registerSSOProvider(options),
- signInSSO: signInSSO(options),
- callbackSSO: callbackSSO(options),
- callbackSSOSAML: callbackSSOSAML(options),
- acsEndpoint: acsEndpoint(options),
- },
+ endpoints,
schema: {
ssoProvider: {
modelName: options?.modelName ?? "ssoProvider",
@@ -95,6 +138,9 @@ export function sso(options?: O | undefined): any {
required: true,
fieldName: options?.fields?.domain ?? "domain",
},
+ ...(options?.domainVerification?.enabled
+ ? { domainVerified: { type: "boolean", required: false } }
+ : {}),
},
},
},
diff --git a/packages/sso/src/routes/domain-verification.ts b/packages/sso/src/routes/domain-verification.ts
new file mode 100644
index 0000000000..843c934393
--- /dev/null
+++ b/packages/sso/src/routes/domain-verification.ts
@@ -0,0 +1,275 @@
+import type { Verification } from "better-auth";
+import {
+ APIError,
+ createAuthEndpoint,
+ sessionMiddleware,
+} from "better-auth/api";
+import { generateRandomString } from "better-auth/crypto";
+import * as z from "zod/v4";
+import type { SSOOptions, SSOProvider } from "../types";
+
+export const requestDomainVerification = (options: SSOOptions) => {
+ return createAuthEndpoint(
+ "/sso/request-domain-verification",
+ {
+ method: "POST",
+ body: z.object({
+ providerId: z.string(),
+ }),
+ metadata: {
+ openapi: {
+ summary: "Request a domain verification",
+ description:
+ "Request a domain verification for the given SSO provider",
+ responses: {
+ "404": {
+ description: "Provider not found",
+ },
+ "409": {
+ description: "Domain has already been verified",
+ },
+ "201": {
+ description: "Domain submitted for verification",
+ },
+ },
+ },
+ },
+ use: [sessionMiddleware],
+ },
+ async (ctx) => {
+ const body = ctx.body;
+ const provider = await ctx.context.adapter.findOne<
+ SSOProvider
+ >({
+ model: "ssoProvider",
+ where: [{ field: "providerId", value: body.providerId }],
+ });
+
+ if (!provider) {
+ throw new APIError("NOT_FOUND", {
+ message: "Provider not found",
+ code: "PROVIDER_NOT_FOUND",
+ });
+ }
+
+ const userId = ctx.context.session.user.id;
+ let isOrgMember = true;
+ if (provider.organizationId) {
+ const membershipsCount = await ctx.context.adapter.count({
+ model: "member",
+ where: [
+ { field: "userId", value: userId },
+ { field: "organizationId", value: provider.organizationId },
+ ],
+ });
+
+ isOrgMember = membershipsCount > 0;
+ }
+
+ if (provider.userId !== userId || !isOrgMember) {
+ throw new APIError("FORBIDDEN", {
+ message:
+ "User must be owner of or belong to the SSO provider organization",
+ code: "INSUFICCIENT_ACCESS",
+ });
+ }
+
+ if ("domainVerified" in provider && provider.domainVerified) {
+ throw new APIError("CONFLICT", {
+ message: "Domain has already been verified",
+ code: "DOMAIN_VERIFIED",
+ });
+ }
+
+ const activeVerification =
+ await ctx.context.adapter.findOne({
+ model: "verification",
+ where: [
+ {
+ field: "identifier",
+ value: options.domainVerification?.tokenPrefix
+ ? `${options.domainVerification?.tokenPrefix}-${provider.providerId}`
+ : `better-auth-token-${provider.providerId}`,
+ },
+ { field: "expiresAt", value: new Date(), operator: "gt" },
+ ],
+ });
+
+ if (activeVerification) {
+ ctx.setStatus(201);
+ return ctx.json({ domainVerificationToken: activeVerification.value });
+ }
+
+ const domainVerificationToken = generateRandomString(24);
+ await ctx.context.adapter.create({
+ model: "verification",
+ data: {
+ identifier: options.domainVerification?.tokenPrefix
+ ? `${options.domainVerification?.tokenPrefix}-${provider.providerId}`
+ : `better-auth-token-${provider.providerId}`,
+ createdAt: new Date(),
+ updatedAt: new Date(),
+ value: domainVerificationToken,
+ expiresAt: new Date(Date.now() + 3600 * 24 * 7 * 1000), // 1 week
+ },
+ });
+
+ ctx.setStatus(201);
+ return ctx.json({
+ domainVerificationToken,
+ });
+ },
+ );
+};
+
+export const verifyDomain = (options: SSOOptions) => {
+ return createAuthEndpoint(
+ "/sso/verify-domain",
+ {
+ method: "POST",
+ body: z.object({
+ providerId: z.string(),
+ }),
+ metadata: {
+ openapi: {
+ summary: "Verify the provider domain ownership",
+ description: "Verify the provider domain ownership via DNS records",
+ responses: {
+ "404": {
+ description: "Provider not found",
+ },
+ "409": {
+ description:
+ "Domain has already been verified or no pending verification exists",
+ },
+ "502": {
+ description:
+ "Unable to verify domain ownership due to upstream validator error",
+ },
+ "204": {
+ description: "Domain ownership was verified",
+ },
+ },
+ },
+ },
+ use: [sessionMiddleware],
+ },
+ async (ctx) => {
+ const body = ctx.body;
+ const provider = await ctx.context.adapter.findOne<
+ SSOProvider
+ >({
+ model: "ssoProvider",
+ where: [{ field: "providerId", value: body.providerId }],
+ });
+
+ if (!provider) {
+ throw new APIError("NOT_FOUND", {
+ message: "Provider not found",
+ code: "PROVIDER_NOT_FOUND",
+ });
+ }
+
+ const userId = ctx.context.session.user.id;
+ let isOrgMember = true;
+ if (provider.organizationId) {
+ const membershipsCount = await ctx.context.adapter.count({
+ model: "member",
+ where: [
+ { field: "userId", value: userId },
+ { field: "organizationId", value: provider.organizationId },
+ ],
+ });
+
+ isOrgMember = membershipsCount > 0;
+ }
+
+ if (provider.userId !== userId || !isOrgMember) {
+ throw new APIError("FORBIDDEN", {
+ message:
+ "User must be owner of or belong to the SSO provider organization",
+ code: "INSUFICCIENT_ACCESS",
+ });
+ }
+
+ if ("domainVerified" in provider && provider.domainVerified) {
+ throw new APIError("CONFLICT", {
+ message: "Domain has already been verified",
+ code: "DOMAIN_VERIFIED",
+ });
+ }
+
+ const activeVerification =
+ await ctx.context.adapter.findOne({
+ model: "verification",
+ where: [
+ {
+ field: "identifier",
+ value: options.domainVerification?.tokenPrefix
+ ? `${options.domainVerification?.tokenPrefix}-${provider.providerId}`
+ : `better-auth-token-${provider.providerId}`,
+ },
+ { field: "expiresAt", value: new Date(), operator: "gt" },
+ ],
+ });
+
+ if (!activeVerification) {
+ throw new APIError("NOT_FOUND", {
+ message: "No pending domain verification exists",
+ code: "NO_PENDING_VERIFICATION",
+ });
+ }
+
+ let records: string[] = [];
+ let dns: typeof import("node:dns/promises");
+
+ try {
+ dns = await import("node:dns/promises");
+ } catch (error) {
+ ctx.context.logger.error(
+ "The core node:dns module is required for the domain verification feature",
+ error,
+ );
+ throw new APIError("INTERNAL_SERVER_ERROR", {
+ message: "Unable to verify domain ownership due to server error",
+ code: "DOMAIN_VERIFICATION_FAILED",
+ });
+ }
+
+ try {
+ const dnsRecords = await dns.resolveTxt(
+ new URL(provider.domain).hostname,
+ );
+ records = dnsRecords.flat();
+ } catch (error) {
+ ctx.context.logger.warn(
+ "DNS resolution failure while validating domain ownership",
+ error,
+ );
+ }
+
+ const record = records.find((record) =>
+ record.includes(
+ `${activeVerification.identifier}=${activeVerification.value}`,
+ ),
+ );
+ if (!record) {
+ throw new APIError("BAD_GATEWAY", {
+ message: "Unable to verify domain ownership. Try again later",
+ code: "DOMAIN_VERIFICATION_FAILED",
+ });
+ }
+
+ await ctx.context.adapter.update>({
+ model: "ssoProvider",
+ where: [{ field: "providerId", value: provider.providerId }],
+ update: {
+ domainVerified: true,
+ },
+ });
+
+ ctx.setStatus(204);
+ return;
+ },
+ );
+};
diff --git a/packages/sso/src/routes/sso.ts b/packages/sso/src/routes/sso.ts
index 84d86e7ba6..1945d57226 100644
--- a/packages/sso/src/routes/sso.ts
+++ b/packages/sso/src/routes/sso.ts
@@ -1,5 +1,5 @@
import { BetterFetchError, betterFetch } from "@better-fetch/fetch";
-import type { Account, Session, User } from "better-auth";
+import type { Account, Session, User, Verification } from "better-auth";
import {
createAuthorizationURL,
generateState,
@@ -13,6 +13,7 @@ import {
sessionMiddleware,
} from "better-auth/api";
import { setSessionCookie } from "better-auth/cookies";
+import { generateRandomString } from "better-auth/crypto";
import { handleOAuthUserInfo } from "better-auth/oauth2";
import { decodeJwt } from "jose";
import * as saml from "samlify";
@@ -126,7 +127,7 @@ export const spMetadata = () => {
);
};
-export const registerSSOProvider = (options?: SSOOptions) => {
+export const registerSSOProvider = (options: O) => {
return createAuthEndpoint(
"/sso/register",
{
@@ -358,6 +359,16 @@ export const registerSSOProvider = (options?: SSOOptions) => {
description:
"The domain of the provider, used for email matching",
},
+ domainVerified: {
+ type: "boolean",
+ description:
+ "A boolean indicating whether the domain has been verified or not",
+ },
+ domainVerificationToken: {
+ type: "string",
+ description:
+ "Domain verification token. It can be used to prove ownership over the SSO domain",
+ },
oidcConfig: {
type: "object",
properties: {
@@ -586,12 +597,13 @@ export const registerSSOProvider = (options?: SSOOptions) => {
const provider = await ctx.context.adapter.create<
Record,
- SSOProvider
+ SSOProvider
>({
model: "ssoProvider",
data: {
issuer: body.issuer,
domain: body.domain,
+ domainVerified: false,
oidcConfig: body.oidcConfig
? JSON.stringify({
issuer: body.issuer,
@@ -640,6 +652,34 @@ export const registerSSOProvider = (options?: SSOOptions) => {
},
});
+ let domainVerificationToken: string | undefined;
+ let domainVerified: boolean | undefined;
+
+ if (options?.domainVerification?.enabled) {
+ domainVerified = false;
+ domainVerificationToken = generateRandomString(24);
+
+ await ctx.context.adapter.create({
+ model: "verification",
+ data: {
+ identifier: options.domainVerification?.tokenPrefix
+ ? `${options.domainVerification?.tokenPrefix}-${provider.providerId}`
+ : `better-auth-token-${provider.providerId}`,
+ createdAt: new Date(),
+ updatedAt: new Date(),
+ value: domainVerificationToken,
+ expiresAt: new Date(Date.now() + 3600 * 24 * 7 * 1000), // 1 week
+ },
+ });
+ }
+
+ type SSOProviderReturn = O["domainVerification"] extends { enabled: true }
+ ? {
+ domainVerified: boolean;
+ domainVerificationToken: string;
+ } & SSOProvider
+ : SSOProvider;
+
return ctx.json({
...provider,
oidcConfig: JSON.parse(
@@ -649,7 +689,11 @@ export const registerSSOProvider = (options?: SSOOptions) => {
provider.samlConfig as unknown as string,
) as SAMLConfig,
redirectURI: `${ctx.context.baseURL}/sso/callback/${provider.providerId}`,
- });
+ ...(options?.domainVerification?.enabled ? { domainVerified } : {}),
+ ...(options?.domainVerification?.enabled
+ ? { domainVerificationToken }
+ : {}),
+ } as unknown as SSOProviderReturn);
},
);
};
@@ -840,7 +884,7 @@ export const signInSSO = (options?: SSOOptions) => {
return res.id;
});
}
- let provider: SSOProvider | null = null;
+ let provider: SSOProvider | null = null;
if (options?.defaultSSO?.length) {
// Find matching default SSO provider by providerId
const matchingDefault = providerId
@@ -862,7 +906,10 @@ export const signInSSO = (options?: SSOOptions) => {
oidcConfig: matchingDefault.oidcConfig,
samlConfig: matchingDefault.samlConfig,
domain: matchingDefault.domain,
- };
+ ...(options.domainVerification?.enabled
+ ? { domainVerified: true }
+ : {}),
+ } as SSOProvider;
}
}
if (!providerId && !orgId && !domain) {
@@ -873,7 +920,7 @@ export const signInSSO = (options?: SSOOptions) => {
// Try to find provider in database
if (!provider) {
provider = await ctx.context.adapter
- .findOne({
+ .findOne>({
model: "ssoProvider",
where: [
{
@@ -925,6 +972,15 @@ export const signInSSO = (options?: SSOOptions) => {
}
}
+ if (
+ options?.domainVerification?.enabled &&
+ !("domainVerified" in provider && provider.domainVerified)
+ ) {
+ throw new APIError("UNAUTHORIZED", {
+ message: "Provider domain has not been verified",
+ });
+ }
+
if (provider.oidcConfig && body.providerType !== "saml") {
let finalAuthUrl = provider.oidcConfig.authorizationEndpoint;
if (!finalAuthUrl && provider.oidcConfig.discoveryEndpoint) {
@@ -1064,7 +1120,7 @@ export const callbackSSO = (options?: SSOOptions) => {
}?error=${error}&error_description=${error_description}`,
);
}
- let provider: SSOProvider | null = null;
+ let provider: SSOProvider | null = null;
if (options?.defaultSSO?.length) {
const matchingDefault = options.defaultSSO.find(
(defaultProvider) =>
@@ -1075,7 +1131,10 @@ export const callbackSSO = (options?: SSOOptions) => {
...matchingDefault,
issuer: matchingDefault.oidcConfig?.issuer || "",
userId: "default",
- };
+ ...(options.domainVerification?.enabled
+ ? { domainVerified: true }
+ : {}),
+ } as SSOProvider;
}
}
if (!provider) {
@@ -1099,7 +1158,7 @@ export const callbackSSO = (options?: SSOOptions) => {
...res,
oidcConfig:
safeJsonParse(res.oidcConfig) || undefined,
- } as SSOProvider;
+ } as SSOProvider;
});
}
if (!provider) {
@@ -1110,6 +1169,15 @@ export const callbackSSO = (options?: SSOOptions) => {
);
}
+ if (
+ options?.domainVerification?.enabled &&
+ !("domainVerified" in provider && provider.domainVerified)
+ ) {
+ throw new APIError("UNAUTHORIZED", {
+ message: "Provider domain has not been verified",
+ });
+ }
+
let config = provider.oidcConfig;
if (!config) {
@@ -1405,7 +1473,7 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
async (ctx) => {
const { SAMLResponse, RelayState } = ctx.body;
const { providerId } = ctx.params;
- let provider: SSOProvider | null = null;
+ let provider: SSOProvider | null = null;
if (options?.defaultSSO?.length) {
const matchingDefault = options.defaultSSO.find(
(defaultProvider) => defaultProvider.providerId === providerId,
@@ -1415,12 +1483,15 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
...matchingDefault,
userId: "default",
issuer: matchingDefault.samlConfig?.issuer || "",
- };
+ ...(options.domainVerification?.enabled
+ ? { domainVerified: true }
+ : {}),
+ } as SSOProvider;
}
}
if (!provider) {
provider = await ctx.context.adapter
- .findOne({
+ .findOne>({
model: "ssoProvider",
where: [{ field: "providerId", value: providerId }],
})
@@ -1443,6 +1514,15 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
});
}
+ if (
+ options?.domainVerification?.enabled &&
+ !("domainVerified" in provider && provider.domainVerified)
+ ) {
+ throw new APIError("UNAUTHORIZED", {
+ message: "Provider domain has not been verified",
+ });
+ }
+
const parsedSamlConfig = safeJsonParse(
provider.samlConfig as unknown as string,
);
@@ -1736,7 +1816,7 @@ export const acsEndpoint = (options?: SSOOptions) => {
const { providerId } = ctx.params;
// If defaultSSO is configured, use it as the provider
- let provider: SSOProvider | null = null;
+ let provider: SSOProvider | null = null;
if (options?.defaultSSO?.length) {
// For ACS endpoint, we can use the first default provider or try to match by providerId
@@ -1753,11 +1833,14 @@ export const acsEndpoint = (options?: SSOOptions) => {
userId: "default",
samlConfig: matchingDefault.samlConfig,
domain: matchingDefault.domain,
+ ...(options.domainVerification?.enabled
+ ? { domainVerified: true }
+ : {}),
};
}
} else {
provider = await ctx.context.adapter
- .findOne({
+ .findOne>({
model: "ssoProvider",
where: [
{
@@ -1785,6 +1868,15 @@ export const acsEndpoint = (options?: SSOOptions) => {
});
}
+ if (
+ options?.domainVerification?.enabled &&
+ !("domainVerified" in provider && provider.domainVerified)
+ ) {
+ throw new APIError("UNAUTHORIZED", {
+ message: "Provider domain has not been verified",
+ });
+ }
+
const parsedSamlConfig = provider.samlConfig;
// Configure SP and IdP
const sp = saml.ServiceProvider({
diff --git a/packages/sso/src/types.ts b/packages/sso/src/types.ts
index ae18ddbf9b..0d7e637272 100644
--- a/packages/sso/src/types.ts
+++ b/packages/sso/src/types.ts
@@ -81,7 +81,7 @@ export interface SAMLConfig {
mapping?: SAMLMapping | undefined;
}
-export type SSOProvider = {
+type BaseSSOProvider = {
issuer: string;
oidcConfig?: OIDCConfig | undefined;
samlConfig?: SAMLConfig | undefined;
@@ -91,6 +91,13 @@ export type SSOProvider = {
domain: string;
};
+export type SSOProvider =
+ O["domainVerification"] extends { enabled: true }
+ ? {
+ domainVerified: boolean;
+ } & BaseSSOProvider
+ : BaseSSOProvider;
+
export interface SSOOptions {
/**
* custom function to provision a user when they sign in with an SSO provider.
@@ -112,7 +119,7 @@ export interface SSOOptions {
/**
* The SSO provider
*/
- provider: SSOProvider;
+ provider: SSOProvider;
}) => Promise)
| undefined;
/**
@@ -138,7 +145,7 @@ export interface SSOOptions {
/**
* The SSO provider
*/
- provider: SSOProvider;
+ provider: SSOProvider;
}) => Promise<"member" | "admin">;
}
| undefined;
@@ -228,4 +235,22 @@ export interface SSOOptions {
* @default false
*/
trustEmailVerified?: boolean | undefined;
+ /**
+ * Enable domain verification on SSO providers
+ *
+ * When this option is enabled, new SSO providers will require the associated domain to be verified by the owner
+ * prior to allowing sign-ins.
+ */
+ domainVerification?: {
+ /**
+ * Enables or disables the domain verification feature
+ */
+ enabled?: boolean;
+ /**
+ * Prefix used to generate the domain verification token
+ *
+ * @default "better-auth-token-"
+ */
+ tokenPrefix?: string;
+ };
}