diff --git a/docs/content/docs/plugins/sso.mdx b/docs/content/docs/plugins/sso.mdx index 2fa6cc7d5a..79b9366422 100644 --- a/docs/content/docs/plugins/sso.mdx +++ b/docs/content/docs/plugins/sso.mdx @@ -704,6 +704,122 @@ mapping: { } ``` +## Domain verification + +Domain verification allows your application to automatically trust a new SSO provider +by automatically validating ownership via the associated domain: + + + +```ts title="auth.ts" +const authClient = createAuthClient({ + plugins: [ + ssoClient({ // [!code highlight] + domainVerification: { // [!code highlight] + enabled: true // [!code highlight] + } // [!code highlight] + }) // [!code highlight] + ] +}) +``` + + + +```ts title="auth-client.ts" +const auth = betterAuth({ + plugins: [ + sso({ // [!code highlight] + domainVerification: { // [!code highlight] + enabled: true // [!code highlight] + } // [!code highlight] + }) // [!code highlight] + ] +}); +``` + + + +Once enabled, make sure you migrate the database schema (again). + + + + ```bash + npx @better-auth/cli migrate + ``` + + + ```bash + npx @better-auth/cli generate + ``` + + + +See the [Schema](#schema-for-domain-verification) section to add the fields manually. + +### Verify your domain + +When domain verification is enabled, every new SSO provider will be untrusted at first. +This means that new sign-ups or sign-ins will be allowed until the domain ownership has been verified. + +To verify your ownership over a domain, follow these steps: + + + +#### Acquire verification token +When an SSO provider is registered, a **verification token** will be issued to the provider (it will be returned as part of the response). +You can use this token to prove ownership over the domain. + + +#### Create `TXT` DNS record +To do this, you'll need to add a `TXT` record to your domain's DNS settings: + + * **Host:** `better-auth-token-{your-provider-id}` (**Note:** This assumes the default token prefix, which can be customized through the `domainVerification.tokenPrefix` option) + * **Value:** The verification token you were given. + +**Save the record and wait for it to propagate.** This can take up to 48 hours, but it's usually much faster. + + +#### Submit a validation request +**Once the DNS record has propagated**, you can submit a validation request (See below) + + + +### Domain validation request + +Once you have configured your domain, you can use your `auth` instance to submit a validation request. +This request will either result in a rejection (could not prove your ownership over the domain) +or if the verification is successful, your SSO provider domain will be marked as verified. + + +```ts +type verifyDomain = { + /** + * The provider id + */ + providerId: string = "acme-corp" +} +``` + + +### Creating a new verification token + +Every domain verification token will have a default expiry of 1 week since the moment it was issued +or the moment when the SSO provider was registered. + +After that time, the token will expire and cannot longer be used. When that happens, +you can create a new verification token: + + +```ts +type requestDomainVerification = { + /** + * The provider id + */ + providerId: string = "acme-corp" +} +``` + + ### SAML Endpoints The plugin automatically creates the following SAML endpoints: @@ -726,7 +842,17 @@ The plugin requires additional fields in the `ssoProvider` table to store the pr { name: "samlConfig", type: "string", description: "The SAML configuration (JSON string)", isRequired: false }, { name: "userId", type: "string", description: "The user ID", isRequired: true, references: { model: "user", field: "id" } }, { name: "providerId", type: "string", description: "The provider ID. Used to identify a provider and to generate a redirect URL.", isRequired: true, isUnique: true }, - { name: "organizationId", type: "string", description: "The organization Id. If provider is linked to an organization.", isRequired: false }, + { name: "organizationId", type: "string", description: "The organization Id. If provider is linked to an organization.", isRequired: false } + ]} +/> + +### If you have enabled domain verification: + +The `ssoProvider` schema is extended as follows: + + @@ -790,6 +916,23 @@ If you want to allow account linking for specific trusted providers, enable the type: "number | function", default: 10, }, + domainVerification: { + description: "Configure the domain verification feature", + type: "object", + properties: { + enabled: { + description: "Enables or disables the domain verification feature", + type: "boolean", + required: false + }, + tokenPrefix: { + description: "Prefix to append to the verification token identifier", + type: "string", + required: false, + default: "better-auth-token-" + }, + }, + }, defaultSSO: { description: "Configure a default SSO provider for testing and development. This provider will be used when no matching provider is found in the database.", type: "object", diff --git a/packages/sso/src/client.ts b/packages/sso/src/client.ts index 0b6bd310e6..9d2541b4cc 100644 --- a/packages/sso/src/client.ts +++ b/packages/sso/src/client.ts @@ -1,8 +1,25 @@ import type { BetterAuthClientPlugin } from "better-auth"; -import type { sso } from "./index"; -export const ssoClient = () => { +import type { SSOPlugin } from "./index"; + +interface SSOClientOptions { + domainVerification?: + | { + enabled: boolean; + } + | undefined; +} + +export const ssoClient = ( + options?: CO | undefined, +) => { return { id: "sso-client", - $InferServerPlugin: {} as ReturnType, + $InferServerPlugin: {} as SSOPlugin<{ + domainVerification: { + enabled: CO["domainVerification"] extends { enabled: true } + ? true + : false; + }; + }>, } satisfies BetterAuthClientPlugin; }; diff --git a/packages/sso/src/domain-verification.test.ts b/packages/sso/src/domain-verification.test.ts new file mode 100644 index 0000000000..ccf8b92d8e --- /dev/null +++ b/packages/sso/src/domain-verification.test.ts @@ -0,0 +1,550 @@ +import { betterAuth } from "better-auth"; +import { memoryAdapter } from "better-auth/adapters/memory"; +import { createAuthClient } from "better-auth/client"; +import { setCookieToHeader } from "better-auth/cookies"; +import { bearer, organization } from "better-auth/plugins"; +import { afterEach, describe, expect, it, vi } from "vitest"; +import { sso } from "."; +import { ssoClient } from "./client"; +import type { SSOOptions } from "./types"; + +const dnsMock = vi.hoisted(() => { + return { + resolveTxt: vi.fn(), + }; +}); + +vi.mock("node:dns/promises", () => { + return { + ...dnsMock, + default: dnsMock, + }; +}); + +describe("Domain verification", async () => { + type TestUser = { email: string; password: string; name: string }; + const testUser: TestUser = { + email: "test@email.com", + password: "password", + name: "Test User", + }; + + const createTestAuth = (options?: SSOOptions) => { + const data = { + user: [], + session: [], + verification: [], + account: [], + ssoProvider: [], + member: [], + organization: [], + }; + + const memory = memoryAdapter(data); + + const ssoOptions = { + ...options, + domainVerification: { + ...options?.domainVerification, + enabled: true, + }, + } satisfies SSOOptions; + + const auth = betterAuth({ + database: memory, + baseURL: "http://localhost:3000", + emailAndPassword: { + enabled: true, + }, + plugins: [sso(ssoOptions), organization()], + }); + + const authClient = createAuthClient({ + baseURL: "http://localhost:3000", + plugins: [bearer(), ssoClient({ domainVerification: { enabled: true } })], + fetchOptions: { + customFetchImpl: async (url, init) => { + return auth.handler(new Request(url, init)); + }, + }, + }); + + async function createOrganization(name: string, headers: Headers) { + return await auth.api.createOrganization({ + body: { + name, + slug: name, + }, + headers, + }); + } + + async function getAuthHeaders(user: TestUser, organizationId?: string) { + const headers = new Headers(); + const response = await authClient.signUp.email({ + email: user.email, + password: user.password, + name: user.name, + }); + + if (response.data && organizationId) { + await auth.api.addMember({ + body: { + userId: response.data.user.id, + role: "member", + }, + headers, + }); + } + + await authClient.signIn.email(user, { + throw: true, + onSuccess: setCookieToHeader(headers), + }); + + return headers; + } + + async function registerSSOProvider( + headers: Headers, + organizationId?: string, + ) { + return auth.api.registerSSOProvider({ + body: { + providerId: "saml-provider-1", + issuer: "http://hello.com:8081", + domain: "http://hello.com:8081", + samlConfig: { + entryPoint: "http://idp.com:", + cert: "the-cert", + callbackUrl: "http://hello.com:8081/api/sso/saml2/callback", + spMetadata: {}, + }, + organizationId, + }, + headers, + }); + } + + return { + auth, + authClient, + registerSSOProvider, + getAuthHeaders, + createOrganization, + }; + }; + + afterEach(() => { + vi.clearAllMocks(); + vi.useRealTimers(); + }); + + describe("POST /sso/request-domain-verification", () => { + it("should return unauthorized when session is missing", async () => { + const { auth } = createTestAuth(); + const response = await auth.api.requestDomainVerification({ + body: { + providerId: "the-provider", + }, + asResponse: true, + }); + + expect(response.status).toBe(401); + }); + + it("should return not found when no provider is found", async () => { + const { auth, getAuthHeaders } = createTestAuth(); + const headers = await getAuthHeaders(testUser); + const response = await auth.api.requestDomainVerification({ + body: { + providerId: "unknown", + }, + headers, + asResponse: true, + }); + + expect(response.status).toBe(404); + expect(await response.json()).toEqual({ + message: "Provider not found", + code: "PROVIDER_NOT_FOUND", + }); + }); + + it("should return the existing active verification token", async () => { + const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth(); + const headers = await getAuthHeaders(testUser); + const provider = await registerSSOProvider(headers); + + vi.useFakeTimers({ toFake: ["Date"] }); + + const newAuthHeaders = await getAuthHeaders(testUser); + + const response = await auth.api.requestDomainVerification({ + body: { + providerId: provider.providerId, + }, + headers: newAuthHeaders, + asResponse: true, + }); + + expect(response.status).toBe(201); + expect(await response.json()).toEqual({ + domainVerificationToken: provider.domainVerificationToken, + }); + }); + + it("should return forbidden if user does not own the provider", async () => { + const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth(); + const headers = await getAuthHeaders(testUser); + const provider = await registerSSOProvider(headers); + + const notOwnerHeaders = await getAuthHeaders({ + name: "other", + email: "other@test.com", + password: "password", + }); + const response = await auth.api.requestDomainVerification({ + body: { + providerId: provider.providerId, + }, + headers: notOwnerHeaders, + asResponse: true, + }); + + expect(response.status).toBe(403); + expect(await response.json()).toEqual({ + message: + "User must be owner of or belong to the SSO provider organization", + code: "INSUFICCIENT_ACCESS", + }); + }); + + it("should return forbidden if user does not belong to the provider organization", async () => { + const { auth, getAuthHeaders, registerSSOProvider, createOrganization } = + createTestAuth(); + const headers = await getAuthHeaders(testUser); + + const orgA = await createOrganization("org-a", headers); + const orgB = await createOrganization("org-b", headers); + + const provider = await registerSSOProvider(headers, orgA?.id); + + const notOrgHeaders = await getAuthHeaders( + { + name: "other", + email: "other@test.com", + password: "password", + }, + orgB?.id, + ); + + const response = await auth.api.requestDomainVerification({ + body: { + providerId: provider.providerId, + }, + headers: notOrgHeaders, + asResponse: true, + }); + + expect(response.status).toBe(403); + expect(await response.json()).toEqual({ + message: + "User must be owner of or belong to the SSO provider organization", + code: "INSUFICCIENT_ACCESS", + }); + }); + + it("should return a new domain verification token", async () => { + const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth(); + const headers = await getAuthHeaders(testUser); + const provider = await registerSSOProvider(headers); + + vi.useFakeTimers({ toFake: ["Date"] }); + vi.advanceTimersByTime(Date.now() + 3600 * 24 * 7 * 1000 + 10); // advance 1 week + 10 seconds + + const newHeaders = await getAuthHeaders(testUser); + const response = await auth.api.requestDomainVerification({ + body: { + providerId: provider.providerId, + }, + headers: newHeaders, + asResponse: true, + }); + + expect(response.status).toBe(201); + expect(await response.json()).toMatchObject({ + domainVerificationToken: expect.any(String), + }); + }); + + it("should fail to create a new token on an already verified domain", async () => { + const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth(); + const headers = await getAuthHeaders(testUser); + const provider = await registerSSOProvider(headers); + + dnsMock.resolveTxt.mockResolvedValue([ + [ + `better-auth-token-saml-provider-1=${provider.domainVerificationToken}`, + ], + ]); + + const domainVerificationResponse = await auth.api.verifyDomain({ + body: { + providerId: provider.providerId, + }, + headers, + asResponse: true, + }); + + expect(domainVerificationResponse.status).toBe(204); + + const domainVerificationSubmissionResponse = + await auth.api.requestDomainVerification({ + body: { + providerId: provider.providerId, + }, + headers, + asResponse: true, + }); + + expect(domainVerificationSubmissionResponse.status).toBe(409); + expect(await domainVerificationSubmissionResponse.json()).toEqual({ + message: "Domain has already been verified", + code: "DOMAIN_VERIFIED", + }); + }); + }); + + describe("POST /sso/verify-domain", () => { + it("should return unauthorized when session is missing", async () => { + const { auth } = createTestAuth(); + const response = await auth.api.verifyDomain({ + body: { + providerId: "the-provider", + }, + asResponse: true, + }); + + expect(response.status).toBe(401); + }); + + it("should return not found when no provider is found", async () => { + const { auth, getAuthHeaders } = createTestAuth(); + const headers = await getAuthHeaders(testUser); + const response = await auth.api.verifyDomain({ + body: { + providerId: "unknown", + }, + headers, + asResponse: true, + }); + + expect(response.status).toBe(404); + expect(await response.json()).toEqual({ + message: "Provider not found", + code: "PROVIDER_NOT_FOUND", + }); + }); + + it("should return not found when no pending verification is found", async () => { + const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth(); + const headers = await getAuthHeaders(testUser); + const provider = await registerSSOProvider(headers); + + vi.useFakeTimers({ toFake: ["Date"] }); + vi.advanceTimersByTime(Date.now() + 3600 * 24 * 7 * 1000 + 10); // advance 1 week + 10 seconds + + const newAuthHeaders = await getAuthHeaders(testUser); + + const response = await auth.api.verifyDomain({ + body: { + providerId: provider.providerId, + }, + headers: newAuthHeaders, + asResponse: true, + }); + + expect(response.status).toBe(404); + expect(await response.json()).toEqual({ + message: "No pending domain verification exists", + code: "NO_PENDING_VERIFICATION", + }); + }); + + it("should return bad gateway when unable to verify domain", async () => { + const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth(); + const headers = await getAuthHeaders(testUser); + const provider = await registerSSOProvider(headers); + + dnsMock.resolveTxt.mockResolvedValue([ + ["google-site-verification=the-token"], + ]); + + const response = await auth.api.verifyDomain({ + body: { + providerId: provider.providerId, + }, + headers, + asResponse: true, + }); + + expect(response.status).toBe(502); + expect(await response.json()).toEqual({ + message: "Unable to verify domain ownership. Try again later", + code: "DOMAIN_VERIFICATION_FAILED", + }); + }); + + it("should return forbidden if user does not own the provider", async () => { + const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth(); + const headers = await getAuthHeaders(testUser); + const provider = await registerSSOProvider(headers); + + const notOwnerHeaders = await getAuthHeaders({ + name: "other", + email: "other@test.com", + password: "password", + }); + const response = await auth.api.verifyDomain({ + body: { + providerId: provider.providerId, + }, + headers: notOwnerHeaders, + asResponse: true, + }); + + expect(response.status).toBe(403); + expect(await response.json()).toEqual({ + message: + "User must be owner of or belong to the SSO provider organization", + code: "INSUFICCIENT_ACCESS", + }); + }); + + it("should return forbidden if user does not belong to the provider organization", async () => { + const { auth, getAuthHeaders, registerSSOProvider, createOrganization } = + createTestAuth(); + const headers = await getAuthHeaders(testUser); + const orgA = await createOrganization("org-a", headers); + const orgB = await createOrganization("org-b", headers); + + const provider = await registerSSOProvider(headers, orgA?.id); + + const notOrgHeaders = await getAuthHeaders( + { + name: "other", + email: "other@test.com", + password: "password", + }, + orgB?.id, + ); + const response = await auth.api.verifyDomain({ + body: { + providerId: provider.providerId, + }, + headers: notOrgHeaders, + asResponse: true, + }); + + expect(response.status).toBe(403); + expect(await response.json()).toEqual({ + message: + "User must be owner of or belong to the SSO provider organization", + code: "INSUFICCIENT_ACCESS", + }); + }); + + it("should verify a provider domain ownership", async () => { + const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth(); + const headers = await getAuthHeaders(testUser); + const provider = await registerSSOProvider(headers); + + expect(provider.domain).toBe("http://hello.com:8081"); + expect(provider.domainVerified).toBe(false); + expect(provider.domainVerificationToken).toBeTypeOf("string"); + + dnsMock.resolveTxt.mockResolvedValue([ + ["google-site-verification=the-token"], + [ + "v=spf1 ip4:50.242.118.232/29 include:_spf.google.com include:mail.zendesk.com ~all", + ], + [ + `better-auth-token-saml-provider-1=${provider.domainVerificationToken}`, + ], + ]); + + const response = await auth.api.verifyDomain({ + body: { + providerId: provider.providerId, + }, + headers, + asResponse: true, + }); + + expect(response.status).toBe(204); + }); + + it("should verify a provider domain ownership (custom token verification prefix)", async () => { + const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth({ + domainVerification: { tokenPrefix: "auth-prefix" }, + }); + const headers = await getAuthHeaders(testUser); + const provider = await registerSSOProvider(headers); + + dnsMock.resolveTxt.mockResolvedValue([ + ["google-site-verification=the-token"], + [ + "v=spf1 ip4:50.242.118.232/29 include:_spf.google.com include:mail.zendesk.com ~all", + ], + [`auth-prefix-saml-provider-1=${provider.domainVerificationToken}`], + ]); + + const response = await auth.api.verifyDomain({ + body: { + providerId: provider.providerId, + }, + headers, + asResponse: true, + }); + + expect(response.status).toBe(204); + }); + + it("should fail to verify an already verified domain", async () => { + const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth(); + const headers = await getAuthHeaders(testUser); + const provider = await registerSSOProvider(headers); + + dnsMock.resolveTxt.mockResolvedValue([ + [ + `better-auth-token-saml-provider-1=${provider.domainVerificationToken}`, + ], + ]); + + const firstResponse = await auth.api.verifyDomain({ + body: { + providerId: provider.providerId, + }, + headers, + asResponse: true, + }); + + expect(firstResponse.status).toBe(204); + + const secondResponse = await auth.api.verifyDomain({ + body: { + providerId: provider.providerId, + }, + headers, + asResponse: true, + }); + + expect(secondResponse.status).toBe(409); + expect(await secondResponse.json()).toEqual({ + message: "Domain has already been verified", + code: "DOMAIN_VERIFIED", + }); + }); + }); +}); diff --git a/packages/sso/src/index.ts b/packages/sso/src/index.ts index 9a49f1dc47..5e45274d93 100644 --- a/packages/sso/src/index.ts +++ b/packages/sso/src/index.ts @@ -1,6 +1,10 @@ import type { BetterAuthPlugin } from "better-auth"; import { XMLValidator } from "fast-xml-parser"; import * as saml from "samlify"; +import { + requestDomainVerification, + verifyDomain, +} from "./routes/domain-verification"; import { acsEndpoint, callbackSSO, @@ -25,33 +29,72 @@ const fastValidator = { saml.setSchemaValidator(fastValidator); -type SSOEndpoints = { +type DomainVerificationEndpoints = { + requestDomainVerification: ReturnType; + verifyDomain: ReturnType; +}; + +type SSOEndpoints = { spMetadata: ReturnType; - registerSSOProvider: ReturnType; + registerSSOProvider: ReturnType>; signInSSO: ReturnType; callbackSSO: ReturnType; callbackSSOSAML: ReturnType; acsEndpoint: ReturnType; }; +export type SSOPlugin = { + id: "sso"; + endpoints: SSOEndpoints & + (O extends { domainVerification: { enabled: true } } + ? DomainVerificationEndpoints + : {}); +}; + +export function sso< + O extends SSOOptions & { + domainVerification?: { enabled: true }; + }, +>( + options?: O | undefined, +): { + id: "sso"; + endpoints: SSOEndpoints & DomainVerificationEndpoints; + schema: any; + options: O; +}; export function sso( options?: O | undefined, ): { id: "sso"; - endpoints: SSOEndpoints; + endpoints: SSOEndpoints; }; export function sso(options?: O | undefined): any { + let endpoints = { + spMetadata: spMetadata(), + registerSSOProvider: registerSSOProvider(options as O), + signInSSO: signInSSO(options as O), + callbackSSO: callbackSSO(options as O), + callbackSSOSAML: callbackSSOSAML(options as O), + acsEndpoint: acsEndpoint(options as O), + }; + + if (options?.domainVerification?.enabled) { + const domainVerificationEndpoints = { + requestDomainVerification: requestDomainVerification(options as O), + verifyDomain: verifyDomain(options as O), + }; + + endpoints = { + ...endpoints, + ...domainVerificationEndpoints, + }; + } + return { id: "sso", - endpoints: { - spMetadata: spMetadata(), - registerSSOProvider: registerSSOProvider(options), - signInSSO: signInSSO(options), - callbackSSO: callbackSSO(options), - callbackSSOSAML: callbackSSOSAML(options), - acsEndpoint: acsEndpoint(options), - }, + endpoints, schema: { ssoProvider: { modelName: options?.modelName ?? "ssoProvider", @@ -95,6 +138,9 @@ export function sso(options?: O | undefined): any { required: true, fieldName: options?.fields?.domain ?? "domain", }, + ...(options?.domainVerification?.enabled + ? { domainVerified: { type: "boolean", required: false } } + : {}), }, }, }, diff --git a/packages/sso/src/routes/domain-verification.ts b/packages/sso/src/routes/domain-verification.ts new file mode 100644 index 0000000000..843c934393 --- /dev/null +++ b/packages/sso/src/routes/domain-verification.ts @@ -0,0 +1,275 @@ +import type { Verification } from "better-auth"; +import { + APIError, + createAuthEndpoint, + sessionMiddleware, +} from "better-auth/api"; +import { generateRandomString } from "better-auth/crypto"; +import * as z from "zod/v4"; +import type { SSOOptions, SSOProvider } from "../types"; + +export const requestDomainVerification = (options: SSOOptions) => { + return createAuthEndpoint( + "/sso/request-domain-verification", + { + method: "POST", + body: z.object({ + providerId: z.string(), + }), + metadata: { + openapi: { + summary: "Request a domain verification", + description: + "Request a domain verification for the given SSO provider", + responses: { + "404": { + description: "Provider not found", + }, + "409": { + description: "Domain has already been verified", + }, + "201": { + description: "Domain submitted for verification", + }, + }, + }, + }, + use: [sessionMiddleware], + }, + async (ctx) => { + const body = ctx.body; + const provider = await ctx.context.adapter.findOne< + SSOProvider + >({ + model: "ssoProvider", + where: [{ field: "providerId", value: body.providerId }], + }); + + if (!provider) { + throw new APIError("NOT_FOUND", { + message: "Provider not found", + code: "PROVIDER_NOT_FOUND", + }); + } + + const userId = ctx.context.session.user.id; + let isOrgMember = true; + if (provider.organizationId) { + const membershipsCount = await ctx.context.adapter.count({ + model: "member", + where: [ + { field: "userId", value: userId }, + { field: "organizationId", value: provider.organizationId }, + ], + }); + + isOrgMember = membershipsCount > 0; + } + + if (provider.userId !== userId || !isOrgMember) { + throw new APIError("FORBIDDEN", { + message: + "User must be owner of or belong to the SSO provider organization", + code: "INSUFICCIENT_ACCESS", + }); + } + + if ("domainVerified" in provider && provider.domainVerified) { + throw new APIError("CONFLICT", { + message: "Domain has already been verified", + code: "DOMAIN_VERIFIED", + }); + } + + const activeVerification = + await ctx.context.adapter.findOne({ + model: "verification", + where: [ + { + field: "identifier", + value: options.domainVerification?.tokenPrefix + ? `${options.domainVerification?.tokenPrefix}-${provider.providerId}` + : `better-auth-token-${provider.providerId}`, + }, + { field: "expiresAt", value: new Date(), operator: "gt" }, + ], + }); + + if (activeVerification) { + ctx.setStatus(201); + return ctx.json({ domainVerificationToken: activeVerification.value }); + } + + const domainVerificationToken = generateRandomString(24); + await ctx.context.adapter.create({ + model: "verification", + data: { + identifier: options.domainVerification?.tokenPrefix + ? `${options.domainVerification?.tokenPrefix}-${provider.providerId}` + : `better-auth-token-${provider.providerId}`, + createdAt: new Date(), + updatedAt: new Date(), + value: domainVerificationToken, + expiresAt: new Date(Date.now() + 3600 * 24 * 7 * 1000), // 1 week + }, + }); + + ctx.setStatus(201); + return ctx.json({ + domainVerificationToken, + }); + }, + ); +}; + +export const verifyDomain = (options: SSOOptions) => { + return createAuthEndpoint( + "/sso/verify-domain", + { + method: "POST", + body: z.object({ + providerId: z.string(), + }), + metadata: { + openapi: { + summary: "Verify the provider domain ownership", + description: "Verify the provider domain ownership via DNS records", + responses: { + "404": { + description: "Provider not found", + }, + "409": { + description: + "Domain has already been verified or no pending verification exists", + }, + "502": { + description: + "Unable to verify domain ownership due to upstream validator error", + }, + "204": { + description: "Domain ownership was verified", + }, + }, + }, + }, + use: [sessionMiddleware], + }, + async (ctx) => { + const body = ctx.body; + const provider = await ctx.context.adapter.findOne< + SSOProvider + >({ + model: "ssoProvider", + where: [{ field: "providerId", value: body.providerId }], + }); + + if (!provider) { + throw new APIError("NOT_FOUND", { + message: "Provider not found", + code: "PROVIDER_NOT_FOUND", + }); + } + + const userId = ctx.context.session.user.id; + let isOrgMember = true; + if (provider.organizationId) { + const membershipsCount = await ctx.context.adapter.count({ + model: "member", + where: [ + { field: "userId", value: userId }, + { field: "organizationId", value: provider.organizationId }, + ], + }); + + isOrgMember = membershipsCount > 0; + } + + if (provider.userId !== userId || !isOrgMember) { + throw new APIError("FORBIDDEN", { + message: + "User must be owner of or belong to the SSO provider organization", + code: "INSUFICCIENT_ACCESS", + }); + } + + if ("domainVerified" in provider && provider.domainVerified) { + throw new APIError("CONFLICT", { + message: "Domain has already been verified", + code: "DOMAIN_VERIFIED", + }); + } + + const activeVerification = + await ctx.context.adapter.findOne({ + model: "verification", + where: [ + { + field: "identifier", + value: options.domainVerification?.tokenPrefix + ? `${options.domainVerification?.tokenPrefix}-${provider.providerId}` + : `better-auth-token-${provider.providerId}`, + }, + { field: "expiresAt", value: new Date(), operator: "gt" }, + ], + }); + + if (!activeVerification) { + throw new APIError("NOT_FOUND", { + message: "No pending domain verification exists", + code: "NO_PENDING_VERIFICATION", + }); + } + + let records: string[] = []; + let dns: typeof import("node:dns/promises"); + + try { + dns = await import("node:dns/promises"); + } catch (error) { + ctx.context.logger.error( + "The core node:dns module is required for the domain verification feature", + error, + ); + throw new APIError("INTERNAL_SERVER_ERROR", { + message: "Unable to verify domain ownership due to server error", + code: "DOMAIN_VERIFICATION_FAILED", + }); + } + + try { + const dnsRecords = await dns.resolveTxt( + new URL(provider.domain).hostname, + ); + records = dnsRecords.flat(); + } catch (error) { + ctx.context.logger.warn( + "DNS resolution failure while validating domain ownership", + error, + ); + } + + const record = records.find((record) => + record.includes( + `${activeVerification.identifier}=${activeVerification.value}`, + ), + ); + if (!record) { + throw new APIError("BAD_GATEWAY", { + message: "Unable to verify domain ownership. Try again later", + code: "DOMAIN_VERIFICATION_FAILED", + }); + } + + await ctx.context.adapter.update>({ + model: "ssoProvider", + where: [{ field: "providerId", value: provider.providerId }], + update: { + domainVerified: true, + }, + }); + + ctx.setStatus(204); + return; + }, + ); +}; diff --git a/packages/sso/src/routes/sso.ts b/packages/sso/src/routes/sso.ts index 84d86e7ba6..1945d57226 100644 --- a/packages/sso/src/routes/sso.ts +++ b/packages/sso/src/routes/sso.ts @@ -1,5 +1,5 @@ import { BetterFetchError, betterFetch } from "@better-fetch/fetch"; -import type { Account, Session, User } from "better-auth"; +import type { Account, Session, User, Verification } from "better-auth"; import { createAuthorizationURL, generateState, @@ -13,6 +13,7 @@ import { sessionMiddleware, } from "better-auth/api"; import { setSessionCookie } from "better-auth/cookies"; +import { generateRandomString } from "better-auth/crypto"; import { handleOAuthUserInfo } from "better-auth/oauth2"; import { decodeJwt } from "jose"; import * as saml from "samlify"; @@ -126,7 +127,7 @@ export const spMetadata = () => { ); }; -export const registerSSOProvider = (options?: SSOOptions) => { +export const registerSSOProvider = (options: O) => { return createAuthEndpoint( "/sso/register", { @@ -358,6 +359,16 @@ export const registerSSOProvider = (options?: SSOOptions) => { description: "The domain of the provider, used for email matching", }, + domainVerified: { + type: "boolean", + description: + "A boolean indicating whether the domain has been verified or not", + }, + domainVerificationToken: { + type: "string", + description: + "Domain verification token. It can be used to prove ownership over the SSO domain", + }, oidcConfig: { type: "object", properties: { @@ -586,12 +597,13 @@ export const registerSSOProvider = (options?: SSOOptions) => { const provider = await ctx.context.adapter.create< Record, - SSOProvider + SSOProvider >({ model: "ssoProvider", data: { issuer: body.issuer, domain: body.domain, + domainVerified: false, oidcConfig: body.oidcConfig ? JSON.stringify({ issuer: body.issuer, @@ -640,6 +652,34 @@ export const registerSSOProvider = (options?: SSOOptions) => { }, }); + let domainVerificationToken: string | undefined; + let domainVerified: boolean | undefined; + + if (options?.domainVerification?.enabled) { + domainVerified = false; + domainVerificationToken = generateRandomString(24); + + await ctx.context.adapter.create({ + model: "verification", + data: { + identifier: options.domainVerification?.tokenPrefix + ? `${options.domainVerification?.tokenPrefix}-${provider.providerId}` + : `better-auth-token-${provider.providerId}`, + createdAt: new Date(), + updatedAt: new Date(), + value: domainVerificationToken, + expiresAt: new Date(Date.now() + 3600 * 24 * 7 * 1000), // 1 week + }, + }); + } + + type SSOProviderReturn = O["domainVerification"] extends { enabled: true } + ? { + domainVerified: boolean; + domainVerificationToken: string; + } & SSOProvider + : SSOProvider; + return ctx.json({ ...provider, oidcConfig: JSON.parse( @@ -649,7 +689,11 @@ export const registerSSOProvider = (options?: SSOOptions) => { provider.samlConfig as unknown as string, ) as SAMLConfig, redirectURI: `${ctx.context.baseURL}/sso/callback/${provider.providerId}`, - }); + ...(options?.domainVerification?.enabled ? { domainVerified } : {}), + ...(options?.domainVerification?.enabled + ? { domainVerificationToken } + : {}), + } as unknown as SSOProviderReturn); }, ); }; @@ -840,7 +884,7 @@ export const signInSSO = (options?: SSOOptions) => { return res.id; }); } - let provider: SSOProvider | null = null; + let provider: SSOProvider | null = null; if (options?.defaultSSO?.length) { // Find matching default SSO provider by providerId const matchingDefault = providerId @@ -862,7 +906,10 @@ export const signInSSO = (options?: SSOOptions) => { oidcConfig: matchingDefault.oidcConfig, samlConfig: matchingDefault.samlConfig, domain: matchingDefault.domain, - }; + ...(options.domainVerification?.enabled + ? { domainVerified: true } + : {}), + } as SSOProvider; } } if (!providerId && !orgId && !domain) { @@ -873,7 +920,7 @@ export const signInSSO = (options?: SSOOptions) => { // Try to find provider in database if (!provider) { provider = await ctx.context.adapter - .findOne({ + .findOne>({ model: "ssoProvider", where: [ { @@ -925,6 +972,15 @@ export const signInSSO = (options?: SSOOptions) => { } } + if ( + options?.domainVerification?.enabled && + !("domainVerified" in provider && provider.domainVerified) + ) { + throw new APIError("UNAUTHORIZED", { + message: "Provider domain has not been verified", + }); + } + if (provider.oidcConfig && body.providerType !== "saml") { let finalAuthUrl = provider.oidcConfig.authorizationEndpoint; if (!finalAuthUrl && provider.oidcConfig.discoveryEndpoint) { @@ -1064,7 +1120,7 @@ export const callbackSSO = (options?: SSOOptions) => { }?error=${error}&error_description=${error_description}`, ); } - let provider: SSOProvider | null = null; + let provider: SSOProvider | null = null; if (options?.defaultSSO?.length) { const matchingDefault = options.defaultSSO.find( (defaultProvider) => @@ -1075,7 +1131,10 @@ export const callbackSSO = (options?: SSOOptions) => { ...matchingDefault, issuer: matchingDefault.oidcConfig?.issuer || "", userId: "default", - }; + ...(options.domainVerification?.enabled + ? { domainVerified: true } + : {}), + } as SSOProvider; } } if (!provider) { @@ -1099,7 +1158,7 @@ export const callbackSSO = (options?: SSOOptions) => { ...res, oidcConfig: safeJsonParse(res.oidcConfig) || undefined, - } as SSOProvider; + } as SSOProvider; }); } if (!provider) { @@ -1110,6 +1169,15 @@ export const callbackSSO = (options?: SSOOptions) => { ); } + if ( + options?.domainVerification?.enabled && + !("domainVerified" in provider && provider.domainVerified) + ) { + throw new APIError("UNAUTHORIZED", { + message: "Provider domain has not been verified", + }); + } + let config = provider.oidcConfig; if (!config) { @@ -1405,7 +1473,7 @@ export const callbackSSOSAML = (options?: SSOOptions) => { async (ctx) => { const { SAMLResponse, RelayState } = ctx.body; const { providerId } = ctx.params; - let provider: SSOProvider | null = null; + let provider: SSOProvider | null = null; if (options?.defaultSSO?.length) { const matchingDefault = options.defaultSSO.find( (defaultProvider) => defaultProvider.providerId === providerId, @@ -1415,12 +1483,15 @@ export const callbackSSOSAML = (options?: SSOOptions) => { ...matchingDefault, userId: "default", issuer: matchingDefault.samlConfig?.issuer || "", - }; + ...(options.domainVerification?.enabled + ? { domainVerified: true } + : {}), + } as SSOProvider; } } if (!provider) { provider = await ctx.context.adapter - .findOne({ + .findOne>({ model: "ssoProvider", where: [{ field: "providerId", value: providerId }], }) @@ -1443,6 +1514,15 @@ export const callbackSSOSAML = (options?: SSOOptions) => { }); } + if ( + options?.domainVerification?.enabled && + !("domainVerified" in provider && provider.domainVerified) + ) { + throw new APIError("UNAUTHORIZED", { + message: "Provider domain has not been verified", + }); + } + const parsedSamlConfig = safeJsonParse( provider.samlConfig as unknown as string, ); @@ -1736,7 +1816,7 @@ export const acsEndpoint = (options?: SSOOptions) => { const { providerId } = ctx.params; // If defaultSSO is configured, use it as the provider - let provider: SSOProvider | null = null; + let provider: SSOProvider | null = null; if (options?.defaultSSO?.length) { // For ACS endpoint, we can use the first default provider or try to match by providerId @@ -1753,11 +1833,14 @@ export const acsEndpoint = (options?: SSOOptions) => { userId: "default", samlConfig: matchingDefault.samlConfig, domain: matchingDefault.domain, + ...(options.domainVerification?.enabled + ? { domainVerified: true } + : {}), }; } } else { provider = await ctx.context.adapter - .findOne({ + .findOne>({ model: "ssoProvider", where: [ { @@ -1785,6 +1868,15 @@ export const acsEndpoint = (options?: SSOOptions) => { }); } + if ( + options?.domainVerification?.enabled && + !("domainVerified" in provider && provider.domainVerified) + ) { + throw new APIError("UNAUTHORIZED", { + message: "Provider domain has not been verified", + }); + } + const parsedSamlConfig = provider.samlConfig; // Configure SP and IdP const sp = saml.ServiceProvider({ diff --git a/packages/sso/src/types.ts b/packages/sso/src/types.ts index ae18ddbf9b..0d7e637272 100644 --- a/packages/sso/src/types.ts +++ b/packages/sso/src/types.ts @@ -81,7 +81,7 @@ export interface SAMLConfig { mapping?: SAMLMapping | undefined; } -export type SSOProvider = { +type BaseSSOProvider = { issuer: string; oidcConfig?: OIDCConfig | undefined; samlConfig?: SAMLConfig | undefined; @@ -91,6 +91,13 @@ export type SSOProvider = { domain: string; }; +export type SSOProvider = + O["domainVerification"] extends { enabled: true } + ? { + domainVerified: boolean; + } & BaseSSOProvider + : BaseSSOProvider; + export interface SSOOptions { /** * custom function to provision a user when they sign in with an SSO provider. @@ -112,7 +119,7 @@ export interface SSOOptions { /** * The SSO provider */ - provider: SSOProvider; + provider: SSOProvider; }) => Promise) | undefined; /** @@ -138,7 +145,7 @@ export interface SSOOptions { /** * The SSO provider */ - provider: SSOProvider; + provider: SSOProvider; }) => Promise<"member" | "admin">; } | undefined; @@ -228,4 +235,22 @@ export interface SSOOptions { * @default false */ trustEmailVerified?: boolean | undefined; + /** + * Enable domain verification on SSO providers + * + * When this option is enabled, new SSO providers will require the associated domain to be verified by the owner + * prior to allowing sign-ins. + */ + domainVerification?: { + /** + * Enables or disables the domain verification feature + */ + enabled?: boolean; + /** + * Prefix used to generate the domain verification token + * + * @default "better-auth-token-" + */ + tokenPrefix?: string; + }; }