diff --git a/packages/better-auth/src/plugins/organization/routes/crud-org.test.ts b/packages/better-auth/src/plugins/organization/routes/crud-org.test.ts index b5ea3d1840..72234c5cb7 100644 --- a/packages/better-auth/src/plugins/organization/routes/crud-org.test.ts +++ b/packages/better-auth/src/plugins/organization/routes/crud-org.test.ts @@ -1,4 +1,4 @@ -import { describe, expect, it } from "vitest"; +import { describe, expect, it, vi } from "vitest"; import { getTestInstance } from "../../../test-utils/test-instance"; import { organization } from "../organization"; import { createAuthClient } from "../../../client"; @@ -252,3 +252,194 @@ describe("get-full-organization", async () => { expect(fullOrg.data?.members.length).toBeLessThanOrEqual(6); }); }); + +describe("organization hooks", async () => { + it("should apply beforeCreateOrganization hook", async () => { + const beforeCreateOrganization = vi.fn(); + const { auth, signInWithTestUser } = await getTestInstance( + { + plugins: [ + organization({ + organizationHooks: { + beforeCreateOrganization: async (data) => { + beforeCreateOrganization(); + return { + data: { + ...data.organization, + metadata: { + hookCalled: true, + }, + }, + }; + }, + }, + }), + ], + }, + { + clientOptions: { + plugins: [organizationClient()], + }, + }, + ); + const { headers } = await signInWithTestUser(); + const result = await auth.api.createOrganization({ + body: { + name: "test", + slug: "test", + }, + headers, + }); + expect(beforeCreateOrganization).toHaveBeenCalled(); + expect(result?.metadata).toEqual({ + hookCalled: true, + }); + }); + + it("should apply afterCreateOrganization hook", async () => { + const afterCreateOrganization = vi.fn(); + const { auth, signInWithTestUser } = await getTestInstance({ + plugins: [ + organization({ + organizationHooks: { + afterCreateOrganization: async (data) => { + afterCreateOrganization(); + }, + }, + }), + ], + }); + const { headers } = await signInWithTestUser(); + const result = await auth.api.createOrganization({ + body: { + name: "test", + slug: "test", + }, + headers, + }); + expect(afterCreateOrganization).toHaveBeenCalled(); + }); + + it("should apply beforeAddMember hook", async () => { + const beforeAddMember = vi.fn(); + const { auth, signInWithTestUser } = await getTestInstance({ + plugins: [ + organization({ + organizationHooks: { + beforeAddMember: async (data) => { + beforeAddMember(); + return { + data: { + role: "changed-role", + }, + }; + }, + }, + }), + ], + }); + const { headers } = await signInWithTestUser(); + await auth.api.createOrganization({ + body: { + name: "test", + slug: "test", + }, + headers, + }); + expect(beforeAddMember).toHaveBeenCalled(); + const member = await auth.api.getActiveMember({ + headers, + }); + expect(member?.role).toBe("changed-role"); + }); + + it("should apply afterAddMember hook", async () => { + const afterAddMember = vi.fn(); + const { auth, signInWithTestUser } = await getTestInstance({ + plugins: [ + organization({ + organizationHooks: { + afterAddMember: async (data) => { + afterAddMember(); + }, + }, + }), + ], + }); + const { headers } = await signInWithTestUser(); + await auth.api.createOrganization({ + body: { + name: "test", + slug: "test", + }, + headers, + }); + expect(afterAddMember).toHaveBeenCalled(); + }); + + it("should apply beforeCreateTeam hook", async () => { + const beforeCreateTeam = vi.fn(); + const { auth, signInWithTestUser } = await getTestInstance({ + plugins: [ + organization({ + teams: { + enabled: true, + }, + organizationHooks: { + beforeCreateTeam: async (data) => { + beforeCreateTeam(); + return { + data: { + name: "changed-name", + }, + }; + }, + }, + }), + ], + }); + const { headers } = await signInWithTestUser(); + const result = await auth.api.createOrganization({ + body: { + name: "test", + slug: "test", + }, + headers, + }); + expect(beforeCreateTeam).toHaveBeenCalled(); + const team = await auth.api.listOrganizationTeams({ + headers, + query: { + organizationId: result?.id, + }, + }); + expect(team[0]?.name).toBe("changed-name"); + }); + + it("should apply afterCreateTeam hook", async () => { + const afterCreateTeam = vi.fn(); + const { auth, signInWithTestUser } = await getTestInstance({ + plugins: [ + organization({ + teams: { + enabled: true, + }, + organizationHooks: { + afterCreateTeam: async (data) => { + afterCreateTeam(); + }, + }, + }), + ], + }); + const { headers } = await signInWithTestUser(); + await auth.api.createOrganization({ + body: { + name: "test", + slug: "test", + }, + headers, + }); + expect(afterCreateTeam).toHaveBeenCalled(); + }); +}); diff --git a/packages/better-auth/src/plugins/organization/routes/crud-org.ts b/packages/better-auth/src/plugins/organization/routes/crud-org.ts index e48c9a802d..033344cb0e 100644 --- a/packages/better-auth/src/plugins/organization/routes/crud-org.ts +++ b/packages/better-auth/src/plugins/organization/routes/crud-org.ts @@ -206,38 +206,79 @@ export const createOrganization = ( | (Member & InferAdditionalFieldsFromPluginOptions<"member", O, false>) | undefined; let teamMember: TeamMember | null = null; - + let data = { + userId: user.id, + organizationId: organization.id, + role: ctx.context.orgOptions.creatorRole || "owner", + }; + if (options?.organizationHooks?.beforeAddMember) { + const response = await options?.organizationHooks.beforeAddMember({ + member: { + userId: user.id, + organizationId: organization.id, + role: ctx.context.orgOptions.creatorRole || "owner", + }, + user, + organization, + }); + if (response && typeof response === "object" && "data" in response) { + data = { + ...data, + ...response.data, + }; + } + } + member = await adapter.createMember(data); + if (options?.organizationHooks?.afterAddMember) { + await options?.organizationHooks.afterAddMember({ + member, + user, + organization, + }); + } if ( options?.teams?.enabled && options.teams.defaultTeam?.enabled !== false ) { + let teamData = { + organizationId: organization.id, + name: `${organization.name}`, + createdAt: new Date(), + }; + if (options?.organizationHooks?.beforeCreateTeam) { + const response = await options?.organizationHooks.beforeCreateTeam({ + team: { + organizationId: organization.id, + name: `${organization.name}`, + }, + user, + organization, + }); + if (response && typeof response === "object" && "data" in response) { + teamData = { + ...teamData, + ...response.data, + }; + } + } const defaultTeam = (await options.teams.defaultTeam?.customCreateDefaultTeam?.( organization, ctx.request, - )) || - (await adapter.createTeam({ - organizationId: organization.id, - name: `${organization.name}`, - createdAt: new Date(), - })); - - member = await adapter.createMember({ - userId: user.id, - organizationId: organization.id, - role: ctx.context.orgOptions.creatorRole || "owner", - }); + )) || (await adapter.createTeam(teamData)); teamMember = await adapter.findOrCreateTeamMember({ teamId: defaultTeam.id, userId: user.id, }); - } else { - member = await adapter.createMember({ - userId: user.id, - organizationId: organization.id, - role: ctx.context.orgOptions.creatorRole || "owner", - }); + + if (options?.organizationHooks?.afterCreateTeam) { + await options?.organizationHooks.afterCreateTeam({ + team: defaultTeam, + user, + organization, + }); + } } if (options.organizationCreation?.afterCreate) { diff --git a/packages/better-auth/src/test-utils/test-instance.ts b/packages/better-auth/src/test-utils/test-instance.ts index 444037795b..07de0c2585 100644 --- a/packages/better-auth/src/test-utils/test-instance.ts +++ b/packages/better-auth/src/test-utils/test-instance.ts @@ -1,5 +1,3 @@ -import fs from "fs/promises"; -import { generateRandomString } from "../crypto/random"; import { afterAll } from "vitest"; import { betterAuth } from "../auth"; import { createAuthClient } from "../client/vanilla"; @@ -40,12 +38,6 @@ export async function getTestInstance< }, ) { const testWith = config?.testWith || "sqlite"; - /** - * create db folder if not exists - */ - await fs.mkdir(".db", { recursive: true }); - const randomStr = generateRandomString(4, "a-z"); - const dbName = `./.db/test-${randomStr}.db`; const postgres = new Kysely({ dialect: new PostgresDialect({ @@ -55,6 +47,8 @@ export async function getTestInstance< }), }); + const sqlite = new Database(":memory:"); + const mysql = new Kysely({ dialect: new MysqlDialect( createPool("mysql://user:password@localhost:3306/better_auth"), @@ -91,7 +85,7 @@ export async function getTestInstance< ? mongodbAdapter(await mongodbClient()) : testWith === "mysql" ? { db: mysql, type: "mysql" } - : new Database(dbName), + : sqlite, emailAndPassword: { enabled: true, }, @@ -167,8 +161,10 @@ export async function getTestInstance< await sql`SET FOREIGN_KEY_CHECKS = 1;`.execute(mysql); return; } - - await fs.unlink(dbName); + if (testWith === "sqlite") { + sqlite.close(); + return; + } }; cleanupSet.add(cleanup);