mirror of
https://github.com/better-auth/better-auth.git
synced 2026-05-22 22:32:01 -05:00
feat(sso): add domain verification for SSO providers (#5910)
Co-authored-by: Maxwell <145994855+ping-maxwell@users.noreply.github.com>
This commit is contained in:
@@ -704,6 +704,122 @@ mapping: {
|
||||
}
|
||||
```
|
||||
|
||||
## Domain verification
|
||||
|
||||
Domain verification allows your application to automatically trust a new SSO provider
|
||||
by automatically validating ownership via the associated domain:
|
||||
|
||||
<Tabs items={["client", "server"]}>
|
||||
<Tab value="client">
|
||||
```ts title="auth.ts"
|
||||
const authClient = createAuthClient({
|
||||
plugins: [
|
||||
ssoClient({ // [!code highlight]
|
||||
domainVerification: { // [!code highlight]
|
||||
enabled: true // [!code highlight]
|
||||
} // [!code highlight]
|
||||
}) // [!code highlight]
|
||||
]
|
||||
})
|
||||
```
|
||||
</Tab>
|
||||
|
||||
<Tab value="server">
|
||||
```ts title="auth-client.ts"
|
||||
const auth = betterAuth({
|
||||
plugins: [
|
||||
sso({ // [!code highlight]
|
||||
domainVerification: { // [!code highlight]
|
||||
enabled: true // [!code highlight]
|
||||
} // [!code highlight]
|
||||
}) // [!code highlight]
|
||||
]
|
||||
});
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
Once enabled, make sure you migrate the database schema (again).
|
||||
|
||||
<Tabs items={["migrate", "generate"]}>
|
||||
<Tab value="migrate">
|
||||
```bash
|
||||
npx @better-auth/cli migrate
|
||||
```
|
||||
</Tab>
|
||||
<Tab value="generate">
|
||||
```bash
|
||||
npx @better-auth/cli generate
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
See the [Schema](#schema-for-domain-verification) section to add the fields manually.
|
||||
|
||||
### Verify your domain
|
||||
|
||||
When domain verification is enabled, every new SSO provider will be untrusted at first.
|
||||
This means that new sign-ups or sign-ins will be allowed until the domain ownership has been verified.
|
||||
|
||||
To verify your ownership over a domain, follow these steps:
|
||||
|
||||
<Steps>
|
||||
<Step>
|
||||
#### Acquire verification token
|
||||
When an SSO provider is registered, a **verification token** will be issued to the provider (it will be returned as part of the response).
|
||||
You can use this token to prove ownership over the domain.
|
||||
</Step>
|
||||
<Step>
|
||||
#### Create `TXT` DNS record
|
||||
To do this, you'll need to add a `TXT` record to your domain's DNS settings:
|
||||
|
||||
* **Host:** `better-auth-token-{your-provider-id}` (**Note:** This assumes the default token prefix, which can be customized through the `domainVerification.tokenPrefix` option)
|
||||
* **Value:** The verification token you were given.
|
||||
|
||||
**Save the record and wait for it to propagate.** This can take up to 48 hours, but it's usually much faster.
|
||||
</Step>
|
||||
<Step>
|
||||
#### Submit a validation request
|
||||
**Once the DNS record has propagated**, you can submit a validation request (See below)
|
||||
</Step>
|
||||
</Steps>
|
||||
|
||||
### Domain validation request
|
||||
|
||||
Once you have configured your domain, you can use your `auth` instance to submit a validation request.
|
||||
This request will either result in a rejection (could not prove your ownership over the domain)
|
||||
or if the verification is successful, your SSO provider domain will be marked as verified.
|
||||
|
||||
<APIMethod path="/sso/verify-domain" method="POST" requireSession>
|
||||
```ts
|
||||
type verifyDomain = {
|
||||
/**
|
||||
* The provider id
|
||||
*/
|
||||
providerId: string = "acme-corp"
|
||||
}
|
||||
```
|
||||
</APIMethod>
|
||||
|
||||
### Creating a new verification token
|
||||
|
||||
Every domain verification token will have a default expiry of 1 week since the moment it was issued
|
||||
or the moment when the SSO provider was registered.
|
||||
|
||||
After that time, the token will expire and cannot longer be used. When that happens,
|
||||
you can create a new verification token:
|
||||
|
||||
<APIMethod path="/sso/request-domain-verification" method="POST" requireSession>
|
||||
```ts
|
||||
type requestDomainVerification = {
|
||||
/**
|
||||
* The provider id
|
||||
*/
|
||||
providerId: string = "acme-corp"
|
||||
}
|
||||
```
|
||||
</APIMethod>
|
||||
|
||||
### SAML Endpoints
|
||||
|
||||
The plugin automatically creates the following SAML endpoints:
|
||||
@@ -726,7 +842,17 @@ The plugin requires additional fields in the `ssoProvider` table to store the pr
|
||||
{ name: "samlConfig", type: "string", description: "The SAML configuration (JSON string)", isRequired: false },
|
||||
{ name: "userId", type: "string", description: "The user ID", isRequired: true, references: { model: "user", field: "id" } },
|
||||
{ name: "providerId", type: "string", description: "The provider ID. Used to identify a provider and to generate a redirect URL.", isRequired: true, isUnique: true },
|
||||
{ name: "organizationId", type: "string", description: "The organization Id. If provider is linked to an organization.", isRequired: false },
|
||||
{ name: "organizationId", type: "string", description: "The organization Id. If provider is linked to an organization.", isRequired: false }
|
||||
]}
|
||||
/>
|
||||
|
||||
### If you have enabled domain verification:
|
||||
|
||||
The `ssoProvider` schema is extended as follows:
|
||||
|
||||
<DatabaseTable
|
||||
fields={[
|
||||
{ name: "domainVerified", type: "boolean", description: "A flag indicating whether the provider domain has been verified.", isRequired: false },
|
||||
]}
|
||||
/>
|
||||
|
||||
@@ -790,6 +916,23 @@ If you want to allow account linking for specific trusted providers, enable the
|
||||
type: "number | function",
|
||||
default: 10,
|
||||
},
|
||||
domainVerification: {
|
||||
description: "Configure the domain verification feature",
|
||||
type: "object",
|
||||
properties: {
|
||||
enabled: {
|
||||
description: "Enables or disables the domain verification feature",
|
||||
type: "boolean",
|
||||
required: false
|
||||
},
|
||||
tokenPrefix: {
|
||||
description: "Prefix to append to the verification token identifier",
|
||||
type: "string",
|
||||
required: false,
|
||||
default: "better-auth-token-"
|
||||
},
|
||||
},
|
||||
},
|
||||
defaultSSO: {
|
||||
description: "Configure a default SSO provider for testing and development. This provider will be used when no matching provider is found in the database.",
|
||||
type: "object",
|
||||
|
||||
@@ -1,8 +1,25 @@
|
||||
import type { BetterAuthClientPlugin } from "better-auth";
|
||||
import type { sso } from "./index";
|
||||
export const ssoClient = () => {
|
||||
import type { SSOPlugin } from "./index";
|
||||
|
||||
interface SSOClientOptions {
|
||||
domainVerification?:
|
||||
| {
|
||||
enabled: boolean;
|
||||
}
|
||||
| undefined;
|
||||
}
|
||||
|
||||
export const ssoClient = <CO extends SSOClientOptions>(
|
||||
options?: CO | undefined,
|
||||
) => {
|
||||
return {
|
||||
id: "sso-client",
|
||||
$InferServerPlugin: {} as ReturnType<typeof sso>,
|
||||
$InferServerPlugin: {} as SSOPlugin<{
|
||||
domainVerification: {
|
||||
enabled: CO["domainVerification"] extends { enabled: true }
|
||||
? true
|
||||
: false;
|
||||
};
|
||||
}>,
|
||||
} satisfies BetterAuthClientPlugin;
|
||||
};
|
||||
|
||||
550
packages/sso/src/domain-verification.test.ts
Normal file
550
packages/sso/src/domain-verification.test.ts
Normal file
@@ -0,0 +1,550 @@
|
||||
import { betterAuth } from "better-auth";
|
||||
import { memoryAdapter } from "better-auth/adapters/memory";
|
||||
import { createAuthClient } from "better-auth/client";
|
||||
import { setCookieToHeader } from "better-auth/cookies";
|
||||
import { bearer, organization } from "better-auth/plugins";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { sso } from ".";
|
||||
import { ssoClient } from "./client";
|
||||
import type { SSOOptions } from "./types";
|
||||
|
||||
const dnsMock = vi.hoisted(() => {
|
||||
return {
|
||||
resolveTxt: vi.fn(),
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock("node:dns/promises", () => {
|
||||
return {
|
||||
...dnsMock,
|
||||
default: dnsMock,
|
||||
};
|
||||
});
|
||||
|
||||
describe("Domain verification", async () => {
|
||||
type TestUser = { email: string; password: string; name: string };
|
||||
const testUser: TestUser = {
|
||||
email: "test@email.com",
|
||||
password: "password",
|
||||
name: "Test User",
|
||||
};
|
||||
|
||||
const createTestAuth = (options?: SSOOptions) => {
|
||||
const data = {
|
||||
user: [],
|
||||
session: [],
|
||||
verification: [],
|
||||
account: [],
|
||||
ssoProvider: [],
|
||||
member: [],
|
||||
organization: [],
|
||||
};
|
||||
|
||||
const memory = memoryAdapter(data);
|
||||
|
||||
const ssoOptions = {
|
||||
...options,
|
||||
domainVerification: {
|
||||
...options?.domainVerification,
|
||||
enabled: true,
|
||||
},
|
||||
} satisfies SSOOptions;
|
||||
|
||||
const auth = betterAuth({
|
||||
database: memory,
|
||||
baseURL: "http://localhost:3000",
|
||||
emailAndPassword: {
|
||||
enabled: true,
|
||||
},
|
||||
plugins: [sso(ssoOptions), organization()],
|
||||
});
|
||||
|
||||
const authClient = createAuthClient({
|
||||
baseURL: "http://localhost:3000",
|
||||
plugins: [bearer(), ssoClient({ domainVerification: { enabled: true } })],
|
||||
fetchOptions: {
|
||||
customFetchImpl: async (url, init) => {
|
||||
return auth.handler(new Request(url, init));
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
async function createOrganization(name: string, headers: Headers) {
|
||||
return await auth.api.createOrganization({
|
||||
body: {
|
||||
name,
|
||||
slug: name,
|
||||
},
|
||||
headers,
|
||||
});
|
||||
}
|
||||
|
||||
async function getAuthHeaders(user: TestUser, organizationId?: string) {
|
||||
const headers = new Headers();
|
||||
const response = await authClient.signUp.email({
|
||||
email: user.email,
|
||||
password: user.password,
|
||||
name: user.name,
|
||||
});
|
||||
|
||||
if (response.data && organizationId) {
|
||||
await auth.api.addMember({
|
||||
body: {
|
||||
userId: response.data.user.id,
|
||||
role: "member",
|
||||
},
|
||||
headers,
|
||||
});
|
||||
}
|
||||
|
||||
await authClient.signIn.email(user, {
|
||||
throw: true,
|
||||
onSuccess: setCookieToHeader(headers),
|
||||
});
|
||||
|
||||
return headers;
|
||||
}
|
||||
|
||||
async function registerSSOProvider(
|
||||
headers: Headers,
|
||||
organizationId?: string,
|
||||
) {
|
||||
return auth.api.registerSSOProvider({
|
||||
body: {
|
||||
providerId: "saml-provider-1",
|
||||
issuer: "http://hello.com:8081",
|
||||
domain: "http://hello.com:8081",
|
||||
samlConfig: {
|
||||
entryPoint: "http://idp.com:",
|
||||
cert: "the-cert",
|
||||
callbackUrl: "http://hello.com:8081/api/sso/saml2/callback",
|
||||
spMetadata: {},
|
||||
},
|
||||
organizationId,
|
||||
},
|
||||
headers,
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
auth,
|
||||
authClient,
|
||||
registerSSOProvider,
|
||||
getAuthHeaders,
|
||||
createOrganization,
|
||||
};
|
||||
};
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
describe("POST /sso/request-domain-verification", () => {
|
||||
it("should return unauthorized when session is missing", async () => {
|
||||
const { auth } = createTestAuth();
|
||||
const response = await auth.api.requestDomainVerification({
|
||||
body: {
|
||||
providerId: "the-provider",
|
||||
},
|
||||
asResponse: true,
|
||||
});
|
||||
|
||||
expect(response.status).toBe(401);
|
||||
});
|
||||
|
||||
it("should return not found when no provider is found", async () => {
|
||||
const { auth, getAuthHeaders } = createTestAuth();
|
||||
const headers = await getAuthHeaders(testUser);
|
||||
const response = await auth.api.requestDomainVerification({
|
||||
body: {
|
||||
providerId: "unknown",
|
||||
},
|
||||
headers,
|
||||
asResponse: true,
|
||||
});
|
||||
|
||||
expect(response.status).toBe(404);
|
||||
expect(await response.json()).toEqual({
|
||||
message: "Provider not found",
|
||||
code: "PROVIDER_NOT_FOUND",
|
||||
});
|
||||
});
|
||||
|
||||
it("should return the existing active verification token", async () => {
|
||||
const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
|
||||
const headers = await getAuthHeaders(testUser);
|
||||
const provider = await registerSSOProvider(headers);
|
||||
|
||||
vi.useFakeTimers({ toFake: ["Date"] });
|
||||
|
||||
const newAuthHeaders = await getAuthHeaders(testUser);
|
||||
|
||||
const response = await auth.api.requestDomainVerification({
|
||||
body: {
|
||||
providerId: provider.providerId,
|
||||
},
|
||||
headers: newAuthHeaders,
|
||||
asResponse: true,
|
||||
});
|
||||
|
||||
expect(response.status).toBe(201);
|
||||
expect(await response.json()).toEqual({
|
||||
domainVerificationToken: provider.domainVerificationToken,
|
||||
});
|
||||
});
|
||||
|
||||
it("should return forbidden if user does not own the provider", async () => {
|
||||
const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
|
||||
const headers = await getAuthHeaders(testUser);
|
||||
const provider = await registerSSOProvider(headers);
|
||||
|
||||
const notOwnerHeaders = await getAuthHeaders({
|
||||
name: "other",
|
||||
email: "other@test.com",
|
||||
password: "password",
|
||||
});
|
||||
const response = await auth.api.requestDomainVerification({
|
||||
body: {
|
||||
providerId: provider.providerId,
|
||||
},
|
||||
headers: notOwnerHeaders,
|
||||
asResponse: true,
|
||||
});
|
||||
|
||||
expect(response.status).toBe(403);
|
||||
expect(await response.json()).toEqual({
|
||||
message:
|
||||
"User must be owner of or belong to the SSO provider organization",
|
||||
code: "INSUFICCIENT_ACCESS",
|
||||
});
|
||||
});
|
||||
|
||||
it("should return forbidden if user does not belong to the provider organization", async () => {
|
||||
const { auth, getAuthHeaders, registerSSOProvider, createOrganization } =
|
||||
createTestAuth();
|
||||
const headers = await getAuthHeaders(testUser);
|
||||
|
||||
const orgA = await createOrganization("org-a", headers);
|
||||
const orgB = await createOrganization("org-b", headers);
|
||||
|
||||
const provider = await registerSSOProvider(headers, orgA?.id);
|
||||
|
||||
const notOrgHeaders = await getAuthHeaders(
|
||||
{
|
||||
name: "other",
|
||||
email: "other@test.com",
|
||||
password: "password",
|
||||
},
|
||||
orgB?.id,
|
||||
);
|
||||
|
||||
const response = await auth.api.requestDomainVerification({
|
||||
body: {
|
||||
providerId: provider.providerId,
|
||||
},
|
||||
headers: notOrgHeaders,
|
||||
asResponse: true,
|
||||
});
|
||||
|
||||
expect(response.status).toBe(403);
|
||||
expect(await response.json()).toEqual({
|
||||
message:
|
||||
"User must be owner of or belong to the SSO provider organization",
|
||||
code: "INSUFICCIENT_ACCESS",
|
||||
});
|
||||
});
|
||||
|
||||
it("should return a new domain verification token", async () => {
|
||||
const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
|
||||
const headers = await getAuthHeaders(testUser);
|
||||
const provider = await registerSSOProvider(headers);
|
||||
|
||||
vi.useFakeTimers({ toFake: ["Date"] });
|
||||
vi.advanceTimersByTime(Date.now() + 3600 * 24 * 7 * 1000 + 10); // advance 1 week + 10 seconds
|
||||
|
||||
const newHeaders = await getAuthHeaders(testUser);
|
||||
const response = await auth.api.requestDomainVerification({
|
||||
body: {
|
||||
providerId: provider.providerId,
|
||||
},
|
||||
headers: newHeaders,
|
||||
asResponse: true,
|
||||
});
|
||||
|
||||
expect(response.status).toBe(201);
|
||||
expect(await response.json()).toMatchObject({
|
||||
domainVerificationToken: expect.any(String),
|
||||
});
|
||||
});
|
||||
|
||||
it("should fail to create a new token on an already verified domain", async () => {
|
||||
const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
|
||||
const headers = await getAuthHeaders(testUser);
|
||||
const provider = await registerSSOProvider(headers);
|
||||
|
||||
dnsMock.resolveTxt.mockResolvedValue([
|
||||
[
|
||||
`better-auth-token-saml-provider-1=${provider.domainVerificationToken}`,
|
||||
],
|
||||
]);
|
||||
|
||||
const domainVerificationResponse = await auth.api.verifyDomain({
|
||||
body: {
|
||||
providerId: provider.providerId,
|
||||
},
|
||||
headers,
|
||||
asResponse: true,
|
||||
});
|
||||
|
||||
expect(domainVerificationResponse.status).toBe(204);
|
||||
|
||||
const domainVerificationSubmissionResponse =
|
||||
await auth.api.requestDomainVerification({
|
||||
body: {
|
||||
providerId: provider.providerId,
|
||||
},
|
||||
headers,
|
||||
asResponse: true,
|
||||
});
|
||||
|
||||
expect(domainVerificationSubmissionResponse.status).toBe(409);
|
||||
expect(await domainVerificationSubmissionResponse.json()).toEqual({
|
||||
message: "Domain has already been verified",
|
||||
code: "DOMAIN_VERIFIED",
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("POST /sso/verify-domain", () => {
|
||||
it("should return unauthorized when session is missing", async () => {
|
||||
const { auth } = createTestAuth();
|
||||
const response = await auth.api.verifyDomain({
|
||||
body: {
|
||||
providerId: "the-provider",
|
||||
},
|
||||
asResponse: true,
|
||||
});
|
||||
|
||||
expect(response.status).toBe(401);
|
||||
});
|
||||
|
||||
it("should return not found when no provider is found", async () => {
|
||||
const { auth, getAuthHeaders } = createTestAuth();
|
||||
const headers = await getAuthHeaders(testUser);
|
||||
const response = await auth.api.verifyDomain({
|
||||
body: {
|
||||
providerId: "unknown",
|
||||
},
|
||||
headers,
|
||||
asResponse: true,
|
||||
});
|
||||
|
||||
expect(response.status).toBe(404);
|
||||
expect(await response.json()).toEqual({
|
||||
message: "Provider not found",
|
||||
code: "PROVIDER_NOT_FOUND",
|
||||
});
|
||||
});
|
||||
|
||||
it("should return not found when no pending verification is found", async () => {
|
||||
const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
|
||||
const headers = await getAuthHeaders(testUser);
|
||||
const provider = await registerSSOProvider(headers);
|
||||
|
||||
vi.useFakeTimers({ toFake: ["Date"] });
|
||||
vi.advanceTimersByTime(Date.now() + 3600 * 24 * 7 * 1000 + 10); // advance 1 week + 10 seconds
|
||||
|
||||
const newAuthHeaders = await getAuthHeaders(testUser);
|
||||
|
||||
const response = await auth.api.verifyDomain({
|
||||
body: {
|
||||
providerId: provider.providerId,
|
||||
},
|
||||
headers: newAuthHeaders,
|
||||
asResponse: true,
|
||||
});
|
||||
|
||||
expect(response.status).toBe(404);
|
||||
expect(await response.json()).toEqual({
|
||||
message: "No pending domain verification exists",
|
||||
code: "NO_PENDING_VERIFICATION",
|
||||
});
|
||||
});
|
||||
|
||||
it("should return bad gateway when unable to verify domain", async () => {
|
||||
const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
|
||||
const headers = await getAuthHeaders(testUser);
|
||||
const provider = await registerSSOProvider(headers);
|
||||
|
||||
dnsMock.resolveTxt.mockResolvedValue([
|
||||
["google-site-verification=the-token"],
|
||||
]);
|
||||
|
||||
const response = await auth.api.verifyDomain({
|
||||
body: {
|
||||
providerId: provider.providerId,
|
||||
},
|
||||
headers,
|
||||
asResponse: true,
|
||||
});
|
||||
|
||||
expect(response.status).toBe(502);
|
||||
expect(await response.json()).toEqual({
|
||||
message: "Unable to verify domain ownership. Try again later",
|
||||
code: "DOMAIN_VERIFICATION_FAILED",
|
||||
});
|
||||
});
|
||||
|
||||
it("should return forbidden if user does not own the provider", async () => {
|
||||
const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
|
||||
const headers = await getAuthHeaders(testUser);
|
||||
const provider = await registerSSOProvider(headers);
|
||||
|
||||
const notOwnerHeaders = await getAuthHeaders({
|
||||
name: "other",
|
||||
email: "other@test.com",
|
||||
password: "password",
|
||||
});
|
||||
const response = await auth.api.verifyDomain({
|
||||
body: {
|
||||
providerId: provider.providerId,
|
||||
},
|
||||
headers: notOwnerHeaders,
|
||||
asResponse: true,
|
||||
});
|
||||
|
||||
expect(response.status).toBe(403);
|
||||
expect(await response.json()).toEqual({
|
||||
message:
|
||||
"User must be owner of or belong to the SSO provider organization",
|
||||
code: "INSUFICCIENT_ACCESS",
|
||||
});
|
||||
});
|
||||
|
||||
it("should return forbidden if user does not belong to the provider organization", async () => {
|
||||
const { auth, getAuthHeaders, registerSSOProvider, createOrganization } =
|
||||
createTestAuth();
|
||||
const headers = await getAuthHeaders(testUser);
|
||||
const orgA = await createOrganization("org-a", headers);
|
||||
const orgB = await createOrganization("org-b", headers);
|
||||
|
||||
const provider = await registerSSOProvider(headers, orgA?.id);
|
||||
|
||||
const notOrgHeaders = await getAuthHeaders(
|
||||
{
|
||||
name: "other",
|
||||
email: "other@test.com",
|
||||
password: "password",
|
||||
},
|
||||
orgB?.id,
|
||||
);
|
||||
const response = await auth.api.verifyDomain({
|
||||
body: {
|
||||
providerId: provider.providerId,
|
||||
},
|
||||
headers: notOrgHeaders,
|
||||
asResponse: true,
|
||||
});
|
||||
|
||||
expect(response.status).toBe(403);
|
||||
expect(await response.json()).toEqual({
|
||||
message:
|
||||
"User must be owner of or belong to the SSO provider organization",
|
||||
code: "INSUFICCIENT_ACCESS",
|
||||
});
|
||||
});
|
||||
|
||||
it("should verify a provider domain ownership", async () => {
|
||||
const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
|
||||
const headers = await getAuthHeaders(testUser);
|
||||
const provider = await registerSSOProvider(headers);
|
||||
|
||||
expect(provider.domain).toBe("http://hello.com:8081");
|
||||
expect(provider.domainVerified).toBe(false);
|
||||
expect(provider.domainVerificationToken).toBeTypeOf("string");
|
||||
|
||||
dnsMock.resolveTxt.mockResolvedValue([
|
||||
["google-site-verification=the-token"],
|
||||
[
|
||||
"v=spf1 ip4:50.242.118.232/29 include:_spf.google.com include:mail.zendesk.com ~all",
|
||||
],
|
||||
[
|
||||
`better-auth-token-saml-provider-1=${provider.domainVerificationToken}`,
|
||||
],
|
||||
]);
|
||||
|
||||
const response = await auth.api.verifyDomain({
|
||||
body: {
|
||||
providerId: provider.providerId,
|
||||
},
|
||||
headers,
|
||||
asResponse: true,
|
||||
});
|
||||
|
||||
expect(response.status).toBe(204);
|
||||
});
|
||||
|
||||
it("should verify a provider domain ownership (custom token verification prefix)", async () => {
|
||||
const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth({
|
||||
domainVerification: { tokenPrefix: "auth-prefix" },
|
||||
});
|
||||
const headers = await getAuthHeaders(testUser);
|
||||
const provider = await registerSSOProvider(headers);
|
||||
|
||||
dnsMock.resolveTxt.mockResolvedValue([
|
||||
["google-site-verification=the-token"],
|
||||
[
|
||||
"v=spf1 ip4:50.242.118.232/29 include:_spf.google.com include:mail.zendesk.com ~all",
|
||||
],
|
||||
[`auth-prefix-saml-provider-1=${provider.domainVerificationToken}`],
|
||||
]);
|
||||
|
||||
const response = await auth.api.verifyDomain({
|
||||
body: {
|
||||
providerId: provider.providerId,
|
||||
},
|
||||
headers,
|
||||
asResponse: true,
|
||||
});
|
||||
|
||||
expect(response.status).toBe(204);
|
||||
});
|
||||
|
||||
it("should fail to verify an already verified domain", async () => {
|
||||
const { auth, getAuthHeaders, registerSSOProvider } = createTestAuth();
|
||||
const headers = await getAuthHeaders(testUser);
|
||||
const provider = await registerSSOProvider(headers);
|
||||
|
||||
dnsMock.resolveTxt.mockResolvedValue([
|
||||
[
|
||||
`better-auth-token-saml-provider-1=${provider.domainVerificationToken}`,
|
||||
],
|
||||
]);
|
||||
|
||||
const firstResponse = await auth.api.verifyDomain({
|
||||
body: {
|
||||
providerId: provider.providerId,
|
||||
},
|
||||
headers,
|
||||
asResponse: true,
|
||||
});
|
||||
|
||||
expect(firstResponse.status).toBe(204);
|
||||
|
||||
const secondResponse = await auth.api.verifyDomain({
|
||||
body: {
|
||||
providerId: provider.providerId,
|
||||
},
|
||||
headers,
|
||||
asResponse: true,
|
||||
});
|
||||
|
||||
expect(secondResponse.status).toBe(409);
|
||||
expect(await secondResponse.json()).toEqual({
|
||||
message: "Domain has already been verified",
|
||||
code: "DOMAIN_VERIFIED",
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,6 +1,10 @@
|
||||
import type { BetterAuthPlugin } from "better-auth";
|
||||
import { XMLValidator } from "fast-xml-parser";
|
||||
import * as saml from "samlify";
|
||||
import {
|
||||
requestDomainVerification,
|
||||
verifyDomain,
|
||||
} from "./routes/domain-verification";
|
||||
import {
|
||||
acsEndpoint,
|
||||
callbackSSO,
|
||||
@@ -25,33 +29,72 @@ const fastValidator = {
|
||||
|
||||
saml.setSchemaValidator(fastValidator);
|
||||
|
||||
type SSOEndpoints = {
|
||||
type DomainVerificationEndpoints = {
|
||||
requestDomainVerification: ReturnType<typeof requestDomainVerification>;
|
||||
verifyDomain: ReturnType<typeof verifyDomain>;
|
||||
};
|
||||
|
||||
type SSOEndpoints<O extends SSOOptions> = {
|
||||
spMetadata: ReturnType<typeof spMetadata>;
|
||||
registerSSOProvider: ReturnType<typeof registerSSOProvider>;
|
||||
registerSSOProvider: ReturnType<typeof registerSSOProvider<O>>;
|
||||
signInSSO: ReturnType<typeof signInSSO>;
|
||||
callbackSSO: ReturnType<typeof callbackSSO>;
|
||||
callbackSSOSAML: ReturnType<typeof callbackSSOSAML>;
|
||||
acsEndpoint: ReturnType<typeof acsEndpoint>;
|
||||
};
|
||||
|
||||
export type SSOPlugin<O extends SSOOptions> = {
|
||||
id: "sso";
|
||||
endpoints: SSOEndpoints<O> &
|
||||
(O extends { domainVerification: { enabled: true } }
|
||||
? DomainVerificationEndpoints
|
||||
: {});
|
||||
};
|
||||
|
||||
export function sso<
|
||||
O extends SSOOptions & {
|
||||
domainVerification?: { enabled: true };
|
||||
},
|
||||
>(
|
||||
options?: O | undefined,
|
||||
): {
|
||||
id: "sso";
|
||||
endpoints: SSOEndpoints<O> & DomainVerificationEndpoints;
|
||||
schema: any;
|
||||
options: O;
|
||||
};
|
||||
export function sso<O extends SSOOptions>(
|
||||
options?: O | undefined,
|
||||
): {
|
||||
id: "sso";
|
||||
endpoints: SSOEndpoints;
|
||||
endpoints: SSOEndpoints<O>;
|
||||
};
|
||||
|
||||
export function sso<O extends SSOOptions>(options?: O | undefined): any {
|
||||
let endpoints = {
|
||||
spMetadata: spMetadata(),
|
||||
registerSSOProvider: registerSSOProvider(options as O),
|
||||
signInSSO: signInSSO(options as O),
|
||||
callbackSSO: callbackSSO(options as O),
|
||||
callbackSSOSAML: callbackSSOSAML(options as O),
|
||||
acsEndpoint: acsEndpoint(options as O),
|
||||
};
|
||||
|
||||
if (options?.domainVerification?.enabled) {
|
||||
const domainVerificationEndpoints = {
|
||||
requestDomainVerification: requestDomainVerification(options as O),
|
||||
verifyDomain: verifyDomain(options as O),
|
||||
};
|
||||
|
||||
endpoints = {
|
||||
...endpoints,
|
||||
...domainVerificationEndpoints,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
id: "sso",
|
||||
endpoints: {
|
||||
spMetadata: spMetadata(),
|
||||
registerSSOProvider: registerSSOProvider(options),
|
||||
signInSSO: signInSSO(options),
|
||||
callbackSSO: callbackSSO(options),
|
||||
callbackSSOSAML: callbackSSOSAML(options),
|
||||
acsEndpoint: acsEndpoint(options),
|
||||
},
|
||||
endpoints,
|
||||
schema: {
|
||||
ssoProvider: {
|
||||
modelName: options?.modelName ?? "ssoProvider",
|
||||
@@ -95,6 +138,9 @@ export function sso<O extends SSOOptions>(options?: O | undefined): any {
|
||||
required: true,
|
||||
fieldName: options?.fields?.domain ?? "domain",
|
||||
},
|
||||
...(options?.domainVerification?.enabled
|
||||
? { domainVerified: { type: "boolean", required: false } }
|
||||
: {}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
275
packages/sso/src/routes/domain-verification.ts
Normal file
275
packages/sso/src/routes/domain-verification.ts
Normal file
@@ -0,0 +1,275 @@
|
||||
import type { Verification } from "better-auth";
|
||||
import {
|
||||
APIError,
|
||||
createAuthEndpoint,
|
||||
sessionMiddleware,
|
||||
} from "better-auth/api";
|
||||
import { generateRandomString } from "better-auth/crypto";
|
||||
import * as z from "zod/v4";
|
||||
import type { SSOOptions, SSOProvider } from "../types";
|
||||
|
||||
export const requestDomainVerification = (options: SSOOptions) => {
|
||||
return createAuthEndpoint(
|
||||
"/sso/request-domain-verification",
|
||||
{
|
||||
method: "POST",
|
||||
body: z.object({
|
||||
providerId: z.string(),
|
||||
}),
|
||||
metadata: {
|
||||
openapi: {
|
||||
summary: "Request a domain verification",
|
||||
description:
|
||||
"Request a domain verification for the given SSO provider",
|
||||
responses: {
|
||||
"404": {
|
||||
description: "Provider not found",
|
||||
},
|
||||
"409": {
|
||||
description: "Domain has already been verified",
|
||||
},
|
||||
"201": {
|
||||
description: "Domain submitted for verification",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
use: [sessionMiddleware],
|
||||
},
|
||||
async (ctx) => {
|
||||
const body = ctx.body;
|
||||
const provider = await ctx.context.adapter.findOne<
|
||||
SSOProvider<SSOOptions>
|
||||
>({
|
||||
model: "ssoProvider",
|
||||
where: [{ field: "providerId", value: body.providerId }],
|
||||
});
|
||||
|
||||
if (!provider) {
|
||||
throw new APIError("NOT_FOUND", {
|
||||
message: "Provider not found",
|
||||
code: "PROVIDER_NOT_FOUND",
|
||||
});
|
||||
}
|
||||
|
||||
const userId = ctx.context.session.user.id;
|
||||
let isOrgMember = true;
|
||||
if (provider.organizationId) {
|
||||
const membershipsCount = await ctx.context.adapter.count({
|
||||
model: "member",
|
||||
where: [
|
||||
{ field: "userId", value: userId },
|
||||
{ field: "organizationId", value: provider.organizationId },
|
||||
],
|
||||
});
|
||||
|
||||
isOrgMember = membershipsCount > 0;
|
||||
}
|
||||
|
||||
if (provider.userId !== userId || !isOrgMember) {
|
||||
throw new APIError("FORBIDDEN", {
|
||||
message:
|
||||
"User must be owner of or belong to the SSO provider organization",
|
||||
code: "INSUFICCIENT_ACCESS",
|
||||
});
|
||||
}
|
||||
|
||||
if ("domainVerified" in provider && provider.domainVerified) {
|
||||
throw new APIError("CONFLICT", {
|
||||
message: "Domain has already been verified",
|
||||
code: "DOMAIN_VERIFIED",
|
||||
});
|
||||
}
|
||||
|
||||
const activeVerification =
|
||||
await ctx.context.adapter.findOne<Verification>({
|
||||
model: "verification",
|
||||
where: [
|
||||
{
|
||||
field: "identifier",
|
||||
value: options.domainVerification?.tokenPrefix
|
||||
? `${options.domainVerification?.tokenPrefix}-${provider.providerId}`
|
||||
: `better-auth-token-${provider.providerId}`,
|
||||
},
|
||||
{ field: "expiresAt", value: new Date(), operator: "gt" },
|
||||
],
|
||||
});
|
||||
|
||||
if (activeVerification) {
|
||||
ctx.setStatus(201);
|
||||
return ctx.json({ domainVerificationToken: activeVerification.value });
|
||||
}
|
||||
|
||||
const domainVerificationToken = generateRandomString(24);
|
||||
await ctx.context.adapter.create<Verification>({
|
||||
model: "verification",
|
||||
data: {
|
||||
identifier: options.domainVerification?.tokenPrefix
|
||||
? `${options.domainVerification?.tokenPrefix}-${provider.providerId}`
|
||||
: `better-auth-token-${provider.providerId}`,
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
value: domainVerificationToken,
|
||||
expiresAt: new Date(Date.now() + 3600 * 24 * 7 * 1000), // 1 week
|
||||
},
|
||||
});
|
||||
|
||||
ctx.setStatus(201);
|
||||
return ctx.json({
|
||||
domainVerificationToken,
|
||||
});
|
||||
},
|
||||
);
|
||||
};
|
||||
|
||||
export const verifyDomain = (options: SSOOptions) => {
|
||||
return createAuthEndpoint(
|
||||
"/sso/verify-domain",
|
||||
{
|
||||
method: "POST",
|
||||
body: z.object({
|
||||
providerId: z.string(),
|
||||
}),
|
||||
metadata: {
|
||||
openapi: {
|
||||
summary: "Verify the provider domain ownership",
|
||||
description: "Verify the provider domain ownership via DNS records",
|
||||
responses: {
|
||||
"404": {
|
||||
description: "Provider not found",
|
||||
},
|
||||
"409": {
|
||||
description:
|
||||
"Domain has already been verified or no pending verification exists",
|
||||
},
|
||||
"502": {
|
||||
description:
|
||||
"Unable to verify domain ownership due to upstream validator error",
|
||||
},
|
||||
"204": {
|
||||
description: "Domain ownership was verified",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
use: [sessionMiddleware],
|
||||
},
|
||||
async (ctx) => {
|
||||
const body = ctx.body;
|
||||
const provider = await ctx.context.adapter.findOne<
|
||||
SSOProvider<SSOOptions>
|
||||
>({
|
||||
model: "ssoProvider",
|
||||
where: [{ field: "providerId", value: body.providerId }],
|
||||
});
|
||||
|
||||
if (!provider) {
|
||||
throw new APIError("NOT_FOUND", {
|
||||
message: "Provider not found",
|
||||
code: "PROVIDER_NOT_FOUND",
|
||||
});
|
||||
}
|
||||
|
||||
const userId = ctx.context.session.user.id;
|
||||
let isOrgMember = true;
|
||||
if (provider.organizationId) {
|
||||
const membershipsCount = await ctx.context.adapter.count({
|
||||
model: "member",
|
||||
where: [
|
||||
{ field: "userId", value: userId },
|
||||
{ field: "organizationId", value: provider.organizationId },
|
||||
],
|
||||
});
|
||||
|
||||
isOrgMember = membershipsCount > 0;
|
||||
}
|
||||
|
||||
if (provider.userId !== userId || !isOrgMember) {
|
||||
throw new APIError("FORBIDDEN", {
|
||||
message:
|
||||
"User must be owner of or belong to the SSO provider organization",
|
||||
code: "INSUFICCIENT_ACCESS",
|
||||
});
|
||||
}
|
||||
|
||||
if ("domainVerified" in provider && provider.domainVerified) {
|
||||
throw new APIError("CONFLICT", {
|
||||
message: "Domain has already been verified",
|
||||
code: "DOMAIN_VERIFIED",
|
||||
});
|
||||
}
|
||||
|
||||
const activeVerification =
|
||||
await ctx.context.adapter.findOne<Verification>({
|
||||
model: "verification",
|
||||
where: [
|
||||
{
|
||||
field: "identifier",
|
||||
value: options.domainVerification?.tokenPrefix
|
||||
? `${options.domainVerification?.tokenPrefix}-${provider.providerId}`
|
||||
: `better-auth-token-${provider.providerId}`,
|
||||
},
|
||||
{ field: "expiresAt", value: new Date(), operator: "gt" },
|
||||
],
|
||||
});
|
||||
|
||||
if (!activeVerification) {
|
||||
throw new APIError("NOT_FOUND", {
|
||||
message: "No pending domain verification exists",
|
||||
code: "NO_PENDING_VERIFICATION",
|
||||
});
|
||||
}
|
||||
|
||||
let records: string[] = [];
|
||||
let dns: typeof import("node:dns/promises");
|
||||
|
||||
try {
|
||||
dns = await import("node:dns/promises");
|
||||
} catch (error) {
|
||||
ctx.context.logger.error(
|
||||
"The core node:dns module is required for the domain verification feature",
|
||||
error,
|
||||
);
|
||||
throw new APIError("INTERNAL_SERVER_ERROR", {
|
||||
message: "Unable to verify domain ownership due to server error",
|
||||
code: "DOMAIN_VERIFICATION_FAILED",
|
||||
});
|
||||
}
|
||||
|
||||
try {
|
||||
const dnsRecords = await dns.resolveTxt(
|
||||
new URL(provider.domain).hostname,
|
||||
);
|
||||
records = dnsRecords.flat();
|
||||
} catch (error) {
|
||||
ctx.context.logger.warn(
|
||||
"DNS resolution failure while validating domain ownership",
|
||||
error,
|
||||
);
|
||||
}
|
||||
|
||||
const record = records.find((record) =>
|
||||
record.includes(
|
||||
`${activeVerification.identifier}=${activeVerification.value}`,
|
||||
),
|
||||
);
|
||||
if (!record) {
|
||||
throw new APIError("BAD_GATEWAY", {
|
||||
message: "Unable to verify domain ownership. Try again later",
|
||||
code: "DOMAIN_VERIFICATION_FAILED",
|
||||
});
|
||||
}
|
||||
|
||||
await ctx.context.adapter.update<SSOProvider<SSOOptions>>({
|
||||
model: "ssoProvider",
|
||||
where: [{ field: "providerId", value: provider.providerId }],
|
||||
update: {
|
||||
domainVerified: true,
|
||||
},
|
||||
});
|
||||
|
||||
ctx.setStatus(204);
|
||||
return;
|
||||
},
|
||||
);
|
||||
};
|
||||
@@ -1,5 +1,5 @@
|
||||
import { BetterFetchError, betterFetch } from "@better-fetch/fetch";
|
||||
import type { Account, Session, User } from "better-auth";
|
||||
import type { Account, Session, User, Verification } from "better-auth";
|
||||
import {
|
||||
createAuthorizationURL,
|
||||
generateState,
|
||||
@@ -13,6 +13,7 @@ import {
|
||||
sessionMiddleware,
|
||||
} from "better-auth/api";
|
||||
import { setSessionCookie } from "better-auth/cookies";
|
||||
import { generateRandomString } from "better-auth/crypto";
|
||||
import { handleOAuthUserInfo } from "better-auth/oauth2";
|
||||
import { decodeJwt } from "jose";
|
||||
import * as saml from "samlify";
|
||||
@@ -126,7 +127,7 @@ export const spMetadata = () => {
|
||||
);
|
||||
};
|
||||
|
||||
export const registerSSOProvider = (options?: SSOOptions) => {
|
||||
export const registerSSOProvider = <O extends SSOOptions>(options: O) => {
|
||||
return createAuthEndpoint(
|
||||
"/sso/register",
|
||||
{
|
||||
@@ -358,6 +359,16 @@ export const registerSSOProvider = (options?: SSOOptions) => {
|
||||
description:
|
||||
"The domain of the provider, used for email matching",
|
||||
},
|
||||
domainVerified: {
|
||||
type: "boolean",
|
||||
description:
|
||||
"A boolean indicating whether the domain has been verified or not",
|
||||
},
|
||||
domainVerificationToken: {
|
||||
type: "string",
|
||||
description:
|
||||
"Domain verification token. It can be used to prove ownership over the SSO domain",
|
||||
},
|
||||
oidcConfig: {
|
||||
type: "object",
|
||||
properties: {
|
||||
@@ -586,12 +597,13 @@ export const registerSSOProvider = (options?: SSOOptions) => {
|
||||
|
||||
const provider = await ctx.context.adapter.create<
|
||||
Record<string, any>,
|
||||
SSOProvider
|
||||
SSOProvider<O>
|
||||
>({
|
||||
model: "ssoProvider",
|
||||
data: {
|
||||
issuer: body.issuer,
|
||||
domain: body.domain,
|
||||
domainVerified: false,
|
||||
oidcConfig: body.oidcConfig
|
||||
? JSON.stringify({
|
||||
issuer: body.issuer,
|
||||
@@ -640,6 +652,34 @@ export const registerSSOProvider = (options?: SSOOptions) => {
|
||||
},
|
||||
});
|
||||
|
||||
let domainVerificationToken: string | undefined;
|
||||
let domainVerified: boolean | undefined;
|
||||
|
||||
if (options?.domainVerification?.enabled) {
|
||||
domainVerified = false;
|
||||
domainVerificationToken = generateRandomString(24);
|
||||
|
||||
await ctx.context.adapter.create<Verification>({
|
||||
model: "verification",
|
||||
data: {
|
||||
identifier: options.domainVerification?.tokenPrefix
|
||||
? `${options.domainVerification?.tokenPrefix}-${provider.providerId}`
|
||||
: `better-auth-token-${provider.providerId}`,
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
value: domainVerificationToken,
|
||||
expiresAt: new Date(Date.now() + 3600 * 24 * 7 * 1000), // 1 week
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
type SSOProviderReturn = O["domainVerification"] extends { enabled: true }
|
||||
? {
|
||||
domainVerified: boolean;
|
||||
domainVerificationToken: string;
|
||||
} & SSOProvider<O>
|
||||
: SSOProvider<O>;
|
||||
|
||||
return ctx.json({
|
||||
...provider,
|
||||
oidcConfig: JSON.parse(
|
||||
@@ -649,7 +689,11 @@ export const registerSSOProvider = (options?: SSOOptions) => {
|
||||
provider.samlConfig as unknown as string,
|
||||
) as SAMLConfig,
|
||||
redirectURI: `${ctx.context.baseURL}/sso/callback/${provider.providerId}`,
|
||||
});
|
||||
...(options?.domainVerification?.enabled ? { domainVerified } : {}),
|
||||
...(options?.domainVerification?.enabled
|
||||
? { domainVerificationToken }
|
||||
: {}),
|
||||
} as unknown as SSOProviderReturn);
|
||||
},
|
||||
);
|
||||
};
|
||||
@@ -840,7 +884,7 @@ export const signInSSO = (options?: SSOOptions) => {
|
||||
return res.id;
|
||||
});
|
||||
}
|
||||
let provider: SSOProvider | null = null;
|
||||
let provider: SSOProvider<SSOOptions> | null = null;
|
||||
if (options?.defaultSSO?.length) {
|
||||
// Find matching default SSO provider by providerId
|
||||
const matchingDefault = providerId
|
||||
@@ -862,7 +906,10 @@ export const signInSSO = (options?: SSOOptions) => {
|
||||
oidcConfig: matchingDefault.oidcConfig,
|
||||
samlConfig: matchingDefault.samlConfig,
|
||||
domain: matchingDefault.domain,
|
||||
};
|
||||
...(options.domainVerification?.enabled
|
||||
? { domainVerified: true }
|
||||
: {}),
|
||||
} as SSOProvider<SSOOptions>;
|
||||
}
|
||||
}
|
||||
if (!providerId && !orgId && !domain) {
|
||||
@@ -873,7 +920,7 @@ export const signInSSO = (options?: SSOOptions) => {
|
||||
// Try to find provider in database
|
||||
if (!provider) {
|
||||
provider = await ctx.context.adapter
|
||||
.findOne<SSOProvider>({
|
||||
.findOne<SSOProvider<SSOOptions>>({
|
||||
model: "ssoProvider",
|
||||
where: [
|
||||
{
|
||||
@@ -925,6 +972,15 @@ export const signInSSO = (options?: SSOOptions) => {
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
options?.domainVerification?.enabled &&
|
||||
!("domainVerified" in provider && provider.domainVerified)
|
||||
) {
|
||||
throw new APIError("UNAUTHORIZED", {
|
||||
message: "Provider domain has not been verified",
|
||||
});
|
||||
}
|
||||
|
||||
if (provider.oidcConfig && body.providerType !== "saml") {
|
||||
let finalAuthUrl = provider.oidcConfig.authorizationEndpoint;
|
||||
if (!finalAuthUrl && provider.oidcConfig.discoveryEndpoint) {
|
||||
@@ -1064,7 +1120,7 @@ export const callbackSSO = (options?: SSOOptions) => {
|
||||
}?error=${error}&error_description=${error_description}`,
|
||||
);
|
||||
}
|
||||
let provider: SSOProvider | null = null;
|
||||
let provider: SSOProvider<SSOOptions> | null = null;
|
||||
if (options?.defaultSSO?.length) {
|
||||
const matchingDefault = options.defaultSSO.find(
|
||||
(defaultProvider) =>
|
||||
@@ -1075,7 +1131,10 @@ export const callbackSSO = (options?: SSOOptions) => {
|
||||
...matchingDefault,
|
||||
issuer: matchingDefault.oidcConfig?.issuer || "",
|
||||
userId: "default",
|
||||
};
|
||||
...(options.domainVerification?.enabled
|
||||
? { domainVerified: true }
|
||||
: {}),
|
||||
} as SSOProvider<SSOOptions>;
|
||||
}
|
||||
}
|
||||
if (!provider) {
|
||||
@@ -1099,7 +1158,7 @@ export const callbackSSO = (options?: SSOOptions) => {
|
||||
...res,
|
||||
oidcConfig:
|
||||
safeJsonParse<OIDCConfig>(res.oidcConfig) || undefined,
|
||||
} as SSOProvider;
|
||||
} as SSOProvider<SSOOptions>;
|
||||
});
|
||||
}
|
||||
if (!provider) {
|
||||
@@ -1110,6 +1169,15 @@ export const callbackSSO = (options?: SSOOptions) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
options?.domainVerification?.enabled &&
|
||||
!("domainVerified" in provider && provider.domainVerified)
|
||||
) {
|
||||
throw new APIError("UNAUTHORIZED", {
|
||||
message: "Provider domain has not been verified",
|
||||
});
|
||||
}
|
||||
|
||||
let config = provider.oidcConfig;
|
||||
|
||||
if (!config) {
|
||||
@@ -1405,7 +1473,7 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
|
||||
async (ctx) => {
|
||||
const { SAMLResponse, RelayState } = ctx.body;
|
||||
const { providerId } = ctx.params;
|
||||
let provider: SSOProvider | null = null;
|
||||
let provider: SSOProvider<SSOOptions> | null = null;
|
||||
if (options?.defaultSSO?.length) {
|
||||
const matchingDefault = options.defaultSSO.find(
|
||||
(defaultProvider) => defaultProvider.providerId === providerId,
|
||||
@@ -1415,12 +1483,15 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
|
||||
...matchingDefault,
|
||||
userId: "default",
|
||||
issuer: matchingDefault.samlConfig?.issuer || "",
|
||||
};
|
||||
...(options.domainVerification?.enabled
|
||||
? { domainVerified: true }
|
||||
: {}),
|
||||
} as SSOProvider<SSOOptions>;
|
||||
}
|
||||
}
|
||||
if (!provider) {
|
||||
provider = await ctx.context.adapter
|
||||
.findOne<SSOProvider>({
|
||||
.findOne<SSOProvider<SSOOptions>>({
|
||||
model: "ssoProvider",
|
||||
where: [{ field: "providerId", value: providerId }],
|
||||
})
|
||||
@@ -1443,6 +1514,15 @@ export const callbackSSOSAML = (options?: SSOOptions) => {
|
||||
});
|
||||
}
|
||||
|
||||
if (
|
||||
options?.domainVerification?.enabled &&
|
||||
!("domainVerified" in provider && provider.domainVerified)
|
||||
) {
|
||||
throw new APIError("UNAUTHORIZED", {
|
||||
message: "Provider domain has not been verified",
|
||||
});
|
||||
}
|
||||
|
||||
const parsedSamlConfig = safeJsonParse<SAMLConfig>(
|
||||
provider.samlConfig as unknown as string,
|
||||
);
|
||||
@@ -1736,7 +1816,7 @@ export const acsEndpoint = (options?: SSOOptions) => {
|
||||
const { providerId } = ctx.params;
|
||||
|
||||
// If defaultSSO is configured, use it as the provider
|
||||
let provider: SSOProvider | null = null;
|
||||
let provider: SSOProvider<SSOOptions> | null = null;
|
||||
|
||||
if (options?.defaultSSO?.length) {
|
||||
// For ACS endpoint, we can use the first default provider or try to match by providerId
|
||||
@@ -1753,11 +1833,14 @@ export const acsEndpoint = (options?: SSOOptions) => {
|
||||
userId: "default",
|
||||
samlConfig: matchingDefault.samlConfig,
|
||||
domain: matchingDefault.domain,
|
||||
...(options.domainVerification?.enabled
|
||||
? { domainVerified: true }
|
||||
: {}),
|
||||
};
|
||||
}
|
||||
} else {
|
||||
provider = await ctx.context.adapter
|
||||
.findOne<SSOProvider>({
|
||||
.findOne<SSOProvider<SSOOptions>>({
|
||||
model: "ssoProvider",
|
||||
where: [
|
||||
{
|
||||
@@ -1785,6 +1868,15 @@ export const acsEndpoint = (options?: SSOOptions) => {
|
||||
});
|
||||
}
|
||||
|
||||
if (
|
||||
options?.domainVerification?.enabled &&
|
||||
!("domainVerified" in provider && provider.domainVerified)
|
||||
) {
|
||||
throw new APIError("UNAUTHORIZED", {
|
||||
message: "Provider domain has not been verified",
|
||||
});
|
||||
}
|
||||
|
||||
const parsedSamlConfig = provider.samlConfig;
|
||||
// Configure SP and IdP
|
||||
const sp = saml.ServiceProvider({
|
||||
|
||||
@@ -81,7 +81,7 @@ export interface SAMLConfig {
|
||||
mapping?: SAMLMapping | undefined;
|
||||
}
|
||||
|
||||
export type SSOProvider = {
|
||||
type BaseSSOProvider = {
|
||||
issuer: string;
|
||||
oidcConfig?: OIDCConfig | undefined;
|
||||
samlConfig?: SAMLConfig | undefined;
|
||||
@@ -91,6 +91,13 @@ export type SSOProvider = {
|
||||
domain: string;
|
||||
};
|
||||
|
||||
export type SSOProvider<O extends SSOOptions> =
|
||||
O["domainVerification"] extends { enabled: true }
|
||||
? {
|
||||
domainVerified: boolean;
|
||||
} & BaseSSOProvider
|
||||
: BaseSSOProvider;
|
||||
|
||||
export interface SSOOptions {
|
||||
/**
|
||||
* custom function to provision a user when they sign in with an SSO provider.
|
||||
@@ -112,7 +119,7 @@ export interface SSOOptions {
|
||||
/**
|
||||
* The SSO provider
|
||||
*/
|
||||
provider: SSOProvider;
|
||||
provider: SSOProvider<SSOOptions>;
|
||||
}) => Promise<void>)
|
||||
| undefined;
|
||||
/**
|
||||
@@ -138,7 +145,7 @@ export interface SSOOptions {
|
||||
/**
|
||||
* The SSO provider
|
||||
*/
|
||||
provider: SSOProvider;
|
||||
provider: SSOProvider<SSOOptions>;
|
||||
}) => Promise<"member" | "admin">;
|
||||
}
|
||||
| undefined;
|
||||
@@ -228,4 +235,22 @@ export interface SSOOptions {
|
||||
* @default false
|
||||
*/
|
||||
trustEmailVerified?: boolean | undefined;
|
||||
/**
|
||||
* Enable domain verification on SSO providers
|
||||
*
|
||||
* When this option is enabled, new SSO providers will require the associated domain to be verified by the owner
|
||||
* prior to allowing sign-ins.
|
||||
*/
|
||||
domainVerification?: {
|
||||
/**
|
||||
* Enables or disables the domain verification feature
|
||||
*/
|
||||
enabled?: boolean;
|
||||
/**
|
||||
* Prefix used to generate the domain verification token
|
||||
*
|
||||
* @default "better-auth-token-"
|
||||
*/
|
||||
tokenPrefix?: string;
|
||||
};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user