diff --git a/packages/cli/src/generators/drizzle.ts b/packages/cli/src/generators/drizzle.ts index 6ab6907709..9b3707f12a 100644 --- a/packages/cli/src/generators/drizzle.ts +++ b/packages/cli/src/generators/drizzle.ts @@ -69,13 +69,20 @@ export const generateDrizzleSchema: SchemaGenerator = async ({ } return `text('${name}')`; } - const type = field.type as - | "string" - | "number" - | "boolean" - | "date" - | "json" - | `${"string" | "number"}[]`; + const type = field.type; + if (typeof type !== "string") { + if (Array.isArray(type) && type.every((x) => typeof x === "string")) { + return { + sqlite: `text({ enum: [${type.map((x) => `'${x}'`).join(", ")}] })`, + pg: `pgEnum('${name}', [${type.map((x) => `'${x}'`).join(", ")}])`, + mysql: `mysqlEnum([${type.map((x) => `'${x}'`).join(", ")}])`, + }[databaseType]; + } else { + throw new TypeError( + `Invalid field type for field ${name} in model ${modelName}`, + ); + } + } const typeMap: Record< typeof type, Record @@ -273,6 +280,17 @@ function generateImport({ if (needsInt) { coreImports.push("int"); } + const hasEnum = Object.values(tables).some((table) => + Object.values(table.fields).some( + (field) => + typeof field.type !== "string" && + Array.isArray(field.type) && + field.type.every((x) => typeof x === "string"), + ), + ); + if (hasEnum) { + coreImports.push("mysqlEnum"); + } } else if (databaseType === "pg") { // Only include integer for PG if actually needed const hasNonBigintNumber = Object.values(tables).some((table) => @@ -294,6 +312,18 @@ function generateImport({ if (needsInteger) { coreImports.push("integer"); } + + const hasEnum = Object.values(tables).some((table) => + Object.values(table.fields).some( + (field) => + typeof field.type !== "string" && + Array.isArray(field.type) && + field.type.every((x) => typeof x === "string"), + ), + ); + if (hasEnum) { + coreImports.push("pgEnum"); + } } else { coreImports.push("integer"); } diff --git a/packages/cli/test/__snapshots__/auth-schema-mysql-enum.txt b/packages/cli/test/__snapshots__/auth-schema-mysql-enum.txt new file mode 100644 index 0000000000..36db840415 --- /dev/null +++ b/packages/cli/test/__snapshots__/auth-schema-mysql-enum.txt @@ -0,0 +1,69 @@ +import { + mysqlTable, + varchar, + text, + timestamp, + boolean, + mysqlEnum, +} from "drizzle-orm/mysql-core"; + +export const user = mysqlTable("user", { + id: varchar("id", { length: 36 }).primaryKey(), + name: text("name").notNull(), + email: varchar("email", { length: 255 }).notNull().unique(), + emailVerified: boolean("email_verified").default(false).notNull(), + image: text("image"), + createdAt: timestamp("created_at", { fsp: 3 }).defaultNow().notNull(), + updatedAt: timestamp("updated_at", { fsp: 3 }) + .defaultNow() + .$onUpdate(() => /* @__PURE__ */ new Date()) + .notNull(), + status: mysqlEnum(["active", "inactive", "pending"]), +}); + +export const session = mysqlTable("session", { + id: varchar("id", { length: 36 }).primaryKey(), + expiresAt: timestamp("expires_at", { fsp: 3 }).notNull(), + token: varchar("token", { length: 255 }).notNull().unique(), + createdAt: timestamp("created_at", { fsp: 3 }).defaultNow().notNull(), + updatedAt: timestamp("updated_at", { fsp: 3 }) + .$onUpdate(() => /* @__PURE__ */ new Date()) + .notNull(), + ipAddress: text("ip_address"), + userAgent: text("user_agent"), + userId: varchar("user_id", { length: 36 }) + .notNull() + .references(() => user.id, { onDelete: "cascade" }), +}); + +export const account = mysqlTable("account", { + id: varchar("id", { length: 36 }).primaryKey(), + accountId: text("account_id").notNull(), + providerId: text("provider_id").notNull(), + userId: varchar("user_id", { length: 36 }) + .notNull() + .references(() => user.id, { onDelete: "cascade" }), + accessToken: text("access_token"), + refreshToken: text("refresh_token"), + idToken: text("id_token"), + accessTokenExpiresAt: timestamp("access_token_expires_at", { fsp: 3 }), + refreshTokenExpiresAt: timestamp("refresh_token_expires_at", { fsp: 3 }), + scope: text("scope"), + password: text("password"), + createdAt: timestamp("created_at", { fsp: 3 }).defaultNow().notNull(), + updatedAt: timestamp("updated_at", { fsp: 3 }) + .$onUpdate(() => /* @__PURE__ */ new Date()) + .notNull(), +}); + +export const verification = mysqlTable("verification", { + id: varchar("id", { length: 36 }).primaryKey(), + identifier: text("identifier").notNull(), + value: text("value").notNull(), + expiresAt: timestamp("expires_at", { fsp: 3 }).notNull(), + createdAt: timestamp("created_at", { fsp: 3 }).defaultNow().notNull(), + updatedAt: timestamp("updated_at", { fsp: 3 }) + .defaultNow() + .$onUpdate(() => /* @__PURE__ */ new Date()) + .notNull(), +}); diff --git a/packages/cli/test/__snapshots__/auth-schema-pg-enum.txt b/packages/cli/test/__snapshots__/auth-schema-pg-enum.txt new file mode 100644 index 0000000000..609a24b8bf --- /dev/null +++ b/packages/cli/test/__snapshots__/auth-schema-pg-enum.txt @@ -0,0 +1,62 @@ +import { pgTable, text, timestamp, boolean, pgEnum } from "drizzle-orm/pg-core"; + +export const user = pgTable("user", { + id: text("id").primaryKey(), + name: text("name").notNull(), + email: text("email").notNull().unique(), + emailVerified: boolean("email_verified").default(false).notNull(), + image: text("image"), + createdAt: timestamp("created_at").defaultNow().notNull(), + updatedAt: timestamp("updated_at") + .defaultNow() + .$onUpdate(() => /* @__PURE__ */ new Date()) + .notNull(), + role: pgEnum("role", ["admin", "user", "guest"]).notNull(), +}); + +export const session = pgTable("session", { + id: text("id").primaryKey(), + expiresAt: timestamp("expires_at").notNull(), + token: text("token").notNull().unique(), + createdAt: timestamp("created_at").defaultNow().notNull(), + updatedAt: timestamp("updated_at") + .$onUpdate(() => /* @__PURE__ */ new Date()) + .notNull(), + ipAddress: text("ip_address"), + userAgent: text("user_agent"), + userId: text("user_id") + .notNull() + .references(() => user.id, { onDelete: "cascade" }), +}); + +export const account = pgTable("account", { + id: text("id").primaryKey(), + accountId: text("account_id").notNull(), + providerId: text("provider_id").notNull(), + userId: text("user_id") + .notNull() + .references(() => user.id, { onDelete: "cascade" }), + accessToken: text("access_token"), + refreshToken: text("refresh_token"), + idToken: text("id_token"), + accessTokenExpiresAt: timestamp("access_token_expires_at"), + refreshTokenExpiresAt: timestamp("refresh_token_expires_at"), + scope: text("scope"), + password: text("password"), + createdAt: timestamp("created_at").defaultNow().notNull(), + updatedAt: timestamp("updated_at") + .$onUpdate(() => /* @__PURE__ */ new Date()) + .notNull(), +}); + +export const verification = pgTable("verification", { + id: text("id").primaryKey(), + identifier: text("identifier").notNull(), + value: text("value").notNull(), + expiresAt: timestamp("expires_at").notNull(), + createdAt: timestamp("created_at").defaultNow().notNull(), + updatedAt: timestamp("updated_at") + .defaultNow() + .$onUpdate(() => /* @__PURE__ */ new Date()) + .notNull(), +}); diff --git a/packages/cli/test/__snapshots__/auth-schema-sqlite-enum.txt b/packages/cli/test/__snapshots__/auth-schema-sqlite-enum.txt new file mode 100644 index 0000000000..ffed90caa5 --- /dev/null +++ b/packages/cli/test/__snapshots__/auth-schema-sqlite-enum.txt @@ -0,0 +1,77 @@ +import { sql } from "drizzle-orm"; +import { sqliteTable, text, integer } from "drizzle-orm/sqlite-core"; + +export const user = sqliteTable("user", { + id: text("id").primaryKey(), + name: text("name").notNull(), + email: text("email").notNull().unique(), + emailVerified: integer("email_verified", { mode: "boolean" }) + .default(false) + .notNull(), + image: text("image"), + createdAt: integer("created_at", { mode: "timestamp_ms" }) + .default(sql`(cast(unixepoch('subsecond') * 1000 as integer))`) + .notNull(), + updatedAt: integer("updated_at", { mode: "timestamp_ms" }) + .default(sql`(cast(unixepoch('subsecond') * 1000 as integer))`) + .$onUpdate(() => /* @__PURE__ */ new Date()) + .notNull(), + priority: text({ enum: ["high", "medium", "low"] }), +}); + +export const session = sqliteTable("session", { + id: text("id").primaryKey(), + expiresAt: integer("expires_at", { mode: "timestamp_ms" }).notNull(), + token: text("token").notNull().unique(), + createdAt: integer("created_at", { mode: "timestamp_ms" }) + .default(sql`(cast(unixepoch('subsecond') * 1000 as integer))`) + .notNull(), + updatedAt: integer("updated_at", { mode: "timestamp_ms" }) + .$onUpdate(() => /* @__PURE__ */ new Date()) + .notNull(), + ipAddress: text("ip_address"), + userAgent: text("user_agent"), + userId: text("user_id") + .notNull() + .references(() => user.id, { onDelete: "cascade" }), +}); + +export const account = sqliteTable("account", { + id: text("id").primaryKey(), + accountId: text("account_id").notNull(), + providerId: text("provider_id").notNull(), + userId: text("user_id") + .notNull() + .references(() => user.id, { onDelete: "cascade" }), + accessToken: text("access_token"), + refreshToken: text("refresh_token"), + idToken: text("id_token"), + accessTokenExpiresAt: integer("access_token_expires_at", { + mode: "timestamp_ms", + }), + refreshTokenExpiresAt: integer("refresh_token_expires_at", { + mode: "timestamp_ms", + }), + scope: text("scope"), + password: text("password"), + createdAt: integer("created_at", { mode: "timestamp_ms" }) + .default(sql`(cast(unixepoch('subsecond') * 1000 as integer))`) + .notNull(), + updatedAt: integer("updated_at", { mode: "timestamp_ms" }) + .$onUpdate(() => /* @__PURE__ */ new Date()) + .notNull(), +}); + +export const verification = sqliteTable("verification", { + id: text("id").primaryKey(), + identifier: text("identifier").notNull(), + value: text("value").notNull(), + expiresAt: integer("expires_at", { mode: "timestamp_ms" }).notNull(), + createdAt: integer("created_at", { mode: "timestamp_ms" }) + .default(sql`(cast(unixepoch('subsecond') * 1000 as integer))`) + .notNull(), + updatedAt: integer("updated_at", { mode: "timestamp_ms" }) + .default(sql`(cast(unixepoch('subsecond') * 1000 as integer))`) + .$onUpdate(() => /* @__PURE__ */ new Date()) + .notNull(), +}); diff --git a/packages/cli/test/generate.test.ts b/packages/cli/test/generate.test.ts index ac9e161b2f..78caaacc2e 100644 --- a/packages/cli/test/generate.test.ts +++ b/packages/cli/test/generate.test.ts @@ -361,3 +361,171 @@ describe("JSON field support in CLI generators", () => { expect(schema.code).toContain("preferences Json?"); }); }); + +describe("Enum field support in Drizzle schemas", () => { + it("should generate Drizzle schema with enum fields for PostgreSQL", async () => { + const schema = await generateDrizzleSchema({ + file: "test.drizzle", + adapter: { + id: "drizzle", + options: { + provider: "pg", + schema: {}, + }, + } as any, + options: { + database: {} as any, + user: { + additionalFields: { + role: { + type: ["admin", "user", "guest"], + required: true, + }, + }, + }, + } as BetterAuthOptions, + }); + expect(schema.code).toContain("pgEnum"); + expect(schema.code).toContain( + 'role: pgEnum("role", ["admin", "user", "guest"])', + ); + await expect(schema.code).toMatchFileSnapshot( + "./__snapshots__/auth-schema-pg-enum.txt", + ); + }); + + it("should generate Drizzle schema with enum fields for MySQL", async () => { + const schema = await generateDrizzleSchema({ + file: "test.drizzle", + adapter: { + id: "drizzle", + options: { + provider: "mysql", + schema: {}, + }, + } as any, + options: { + database: {} as any, + user: { + additionalFields: { + status: { + type: ["active", "inactive", "pending"], + required: false, + }, + }, + }, + } as BetterAuthOptions, + }); + expect(schema.code).toContain("mysqlEnum"); + expect(schema.code).toContain( + 'status: mysqlEnum(["active", "inactive", "pending"])', + ); + await expect(schema.code).toMatchFileSnapshot( + "./__snapshots__/auth-schema-mysql-enum.txt", + ); + }); + + it("should generate Drizzle schema with enum fields for SQLite", async () => { + const schema = await generateDrizzleSchema({ + file: "test.drizzle", + adapter: { + id: "drizzle", + options: { + provider: "sqlite", + schema: {}, + }, + } as any, + options: { + database: {} as any, + user: { + additionalFields: { + priority: { + type: ["high", "medium", "low"], + }, + }, + }, + } as BetterAuthOptions, + }); + expect(schema.code).toContain("text({ enum: ["); + expect(schema.code).toContain( + 'priority: text({ enum: ["high", "medium", "low"] })', + ); + await expect(schema.code).toMatchFileSnapshot( + "./__snapshots__/auth-schema-sqlite-enum.txt", + ); + }); + + it("should include correct imports for enum fields in PostgreSQL", async () => { + const schema = await generateDrizzleSchema({ + file: "test.drizzle", + adapter: { + id: "drizzle", + options: { + provider: "pg", + schema: {}, + }, + } as any, + options: { + database: {} as any, + user: { + additionalFields: { + role: { + type: ["admin", "user"], + }, + }, + }, + } as BetterAuthOptions, + }); + expect(schema.code).toMatch(/import.*pgEnum.*from.*drizzle-orm\/pg-core/); + }); + + it("should include correct imports for enum fields in MySQL", async () => { + const schema = await generateDrizzleSchema({ + file: "test.drizzle", + adapter: { + id: "drizzle", + options: { + provider: "mysql", + schema: {}, + }, + } as any, + options: { + database: {} as any, + user: { + additionalFields: { + status: { + type: ["active", "inactive"], + }, + }, + }, + } as BetterAuthOptions, + }); + expect(schema.code).toMatch( + /import.*mysqlEnum.*from.*drizzle-orm\/mysql-core/s, + ); + }); + + it("should not include enum imports when no enum fields are present", async () => { + const schema = await generateDrizzleSchema({ + file: "test.drizzle", + adapter: { + id: "drizzle", + options: { + provider: "pg", + schema: {}, + }, + } as any, + options: { + database: {} as any, + user: { + additionalFields: { + name: { + type: "string", + }, + }, + }, + } as BetterAuthOptions, + }); + expect(schema.code).not.toContain("pgEnum"); + }); +});