feat: enum support for drizzle schema (#5287)

This commit is contained in:
Alex Yang
2025-10-13 16:56:39 -07:00
committed by GitHub
parent 2e62296432
commit fd780aca6b
5 changed files with 413 additions and 7 deletions

View File

@@ -69,13 +69,20 @@ export const generateDrizzleSchema: SchemaGenerator = async ({
}
return `text('${name}')`;
}
const type = field.type as
| "string"
| "number"
| "boolean"
| "date"
| "json"
| `${"string" | "number"}[]`;
const type = field.type;
if (typeof type !== "string") {
if (Array.isArray(type) && type.every((x) => typeof x === "string")) {
return {
sqlite: `text({ enum: [${type.map((x) => `'${x}'`).join(", ")}] })`,
pg: `pgEnum('${name}', [${type.map((x) => `'${x}'`).join(", ")}])`,
mysql: `mysqlEnum([${type.map((x) => `'${x}'`).join(", ")}])`,
}[databaseType];
} else {
throw new TypeError(
`Invalid field type for field ${name} in model ${modelName}`,
);
}
}
const typeMap: Record<
typeof type,
Record<typeof databaseType, string>
@@ -273,6 +280,17 @@ function generateImport({
if (needsInt) {
coreImports.push("int");
}
const hasEnum = Object.values(tables).some((table) =>
Object.values(table.fields).some(
(field) =>
typeof field.type !== "string" &&
Array.isArray(field.type) &&
field.type.every((x) => typeof x === "string"),
),
);
if (hasEnum) {
coreImports.push("mysqlEnum");
}
} else if (databaseType === "pg") {
// Only include integer for PG if actually needed
const hasNonBigintNumber = Object.values(tables).some((table) =>
@@ -294,6 +312,18 @@ function generateImport({
if (needsInteger) {
coreImports.push("integer");
}
const hasEnum = Object.values(tables).some((table) =>
Object.values(table.fields).some(
(field) =>
typeof field.type !== "string" &&
Array.isArray(field.type) &&
field.type.every((x) => typeof x === "string"),
),
);
if (hasEnum) {
coreImports.push("pgEnum");
}
} else {
coreImports.push("integer");
}

View File

@@ -0,0 +1,69 @@
import {
mysqlTable,
varchar,
text,
timestamp,
boolean,
mysqlEnum,
} from "drizzle-orm/mysql-core";
export const user = mysqlTable("user", {
id: varchar("id", { length: 36 }).primaryKey(),
name: text("name").notNull(),
email: varchar("email", { length: 255 }).notNull().unique(),
emailVerified: boolean("email_verified").default(false).notNull(),
image: text("image"),
createdAt: timestamp("created_at", { fsp: 3 }).defaultNow().notNull(),
updatedAt: timestamp("updated_at", { fsp: 3 })
.defaultNow()
.$onUpdate(() => /* @__PURE__ */ new Date())
.notNull(),
status: mysqlEnum(["active", "inactive", "pending"]),
});
export const session = mysqlTable("session", {
id: varchar("id", { length: 36 }).primaryKey(),
expiresAt: timestamp("expires_at", { fsp: 3 }).notNull(),
token: varchar("token", { length: 255 }).notNull().unique(),
createdAt: timestamp("created_at", { fsp: 3 }).defaultNow().notNull(),
updatedAt: timestamp("updated_at", { fsp: 3 })
.$onUpdate(() => /* @__PURE__ */ new Date())
.notNull(),
ipAddress: text("ip_address"),
userAgent: text("user_agent"),
userId: varchar("user_id", { length: 36 })
.notNull()
.references(() => user.id, { onDelete: "cascade" }),
});
export const account = mysqlTable("account", {
id: varchar("id", { length: 36 }).primaryKey(),
accountId: text("account_id").notNull(),
providerId: text("provider_id").notNull(),
userId: varchar("user_id", { length: 36 })
.notNull()
.references(() => user.id, { onDelete: "cascade" }),
accessToken: text("access_token"),
refreshToken: text("refresh_token"),
idToken: text("id_token"),
accessTokenExpiresAt: timestamp("access_token_expires_at", { fsp: 3 }),
refreshTokenExpiresAt: timestamp("refresh_token_expires_at", { fsp: 3 }),
scope: text("scope"),
password: text("password"),
createdAt: timestamp("created_at", { fsp: 3 }).defaultNow().notNull(),
updatedAt: timestamp("updated_at", { fsp: 3 })
.$onUpdate(() => /* @__PURE__ */ new Date())
.notNull(),
});
export const verification = mysqlTable("verification", {
id: varchar("id", { length: 36 }).primaryKey(),
identifier: text("identifier").notNull(),
value: text("value").notNull(),
expiresAt: timestamp("expires_at", { fsp: 3 }).notNull(),
createdAt: timestamp("created_at", { fsp: 3 }).defaultNow().notNull(),
updatedAt: timestamp("updated_at", { fsp: 3 })
.defaultNow()
.$onUpdate(() => /* @__PURE__ */ new Date())
.notNull(),
});

View File

@@ -0,0 +1,62 @@
import { pgTable, text, timestamp, boolean, pgEnum } from "drizzle-orm/pg-core";
export const user = pgTable("user", {
id: text("id").primaryKey(),
name: text("name").notNull(),
email: text("email").notNull().unique(),
emailVerified: boolean("email_verified").default(false).notNull(),
image: text("image"),
createdAt: timestamp("created_at").defaultNow().notNull(),
updatedAt: timestamp("updated_at")
.defaultNow()
.$onUpdate(() => /* @__PURE__ */ new Date())
.notNull(),
role: pgEnum("role", ["admin", "user", "guest"]).notNull(),
});
export const session = pgTable("session", {
id: text("id").primaryKey(),
expiresAt: timestamp("expires_at").notNull(),
token: text("token").notNull().unique(),
createdAt: timestamp("created_at").defaultNow().notNull(),
updatedAt: timestamp("updated_at")
.$onUpdate(() => /* @__PURE__ */ new Date())
.notNull(),
ipAddress: text("ip_address"),
userAgent: text("user_agent"),
userId: text("user_id")
.notNull()
.references(() => user.id, { onDelete: "cascade" }),
});
export const account = pgTable("account", {
id: text("id").primaryKey(),
accountId: text("account_id").notNull(),
providerId: text("provider_id").notNull(),
userId: text("user_id")
.notNull()
.references(() => user.id, { onDelete: "cascade" }),
accessToken: text("access_token"),
refreshToken: text("refresh_token"),
idToken: text("id_token"),
accessTokenExpiresAt: timestamp("access_token_expires_at"),
refreshTokenExpiresAt: timestamp("refresh_token_expires_at"),
scope: text("scope"),
password: text("password"),
createdAt: timestamp("created_at").defaultNow().notNull(),
updatedAt: timestamp("updated_at")
.$onUpdate(() => /* @__PURE__ */ new Date())
.notNull(),
});
export const verification = pgTable("verification", {
id: text("id").primaryKey(),
identifier: text("identifier").notNull(),
value: text("value").notNull(),
expiresAt: timestamp("expires_at").notNull(),
createdAt: timestamp("created_at").defaultNow().notNull(),
updatedAt: timestamp("updated_at")
.defaultNow()
.$onUpdate(() => /* @__PURE__ */ new Date())
.notNull(),
});

View File

@@ -0,0 +1,77 @@
import { sql } from "drizzle-orm";
import { sqliteTable, text, integer } from "drizzle-orm/sqlite-core";
export const user = sqliteTable("user", {
id: text("id").primaryKey(),
name: text("name").notNull(),
email: text("email").notNull().unique(),
emailVerified: integer("email_verified", { mode: "boolean" })
.default(false)
.notNull(),
image: text("image"),
createdAt: integer("created_at", { mode: "timestamp_ms" })
.default(sql`(cast(unixepoch('subsecond') * 1000 as integer))`)
.notNull(),
updatedAt: integer("updated_at", { mode: "timestamp_ms" })
.default(sql`(cast(unixepoch('subsecond') * 1000 as integer))`)
.$onUpdate(() => /* @__PURE__ */ new Date())
.notNull(),
priority: text({ enum: ["high", "medium", "low"] }),
});
export const session = sqliteTable("session", {
id: text("id").primaryKey(),
expiresAt: integer("expires_at", { mode: "timestamp_ms" }).notNull(),
token: text("token").notNull().unique(),
createdAt: integer("created_at", { mode: "timestamp_ms" })
.default(sql`(cast(unixepoch('subsecond') * 1000 as integer))`)
.notNull(),
updatedAt: integer("updated_at", { mode: "timestamp_ms" })
.$onUpdate(() => /* @__PURE__ */ new Date())
.notNull(),
ipAddress: text("ip_address"),
userAgent: text("user_agent"),
userId: text("user_id")
.notNull()
.references(() => user.id, { onDelete: "cascade" }),
});
export const account = sqliteTable("account", {
id: text("id").primaryKey(),
accountId: text("account_id").notNull(),
providerId: text("provider_id").notNull(),
userId: text("user_id")
.notNull()
.references(() => user.id, { onDelete: "cascade" }),
accessToken: text("access_token"),
refreshToken: text("refresh_token"),
idToken: text("id_token"),
accessTokenExpiresAt: integer("access_token_expires_at", {
mode: "timestamp_ms",
}),
refreshTokenExpiresAt: integer("refresh_token_expires_at", {
mode: "timestamp_ms",
}),
scope: text("scope"),
password: text("password"),
createdAt: integer("created_at", { mode: "timestamp_ms" })
.default(sql`(cast(unixepoch('subsecond') * 1000 as integer))`)
.notNull(),
updatedAt: integer("updated_at", { mode: "timestamp_ms" })
.$onUpdate(() => /* @__PURE__ */ new Date())
.notNull(),
});
export const verification = sqliteTable("verification", {
id: text("id").primaryKey(),
identifier: text("identifier").notNull(),
value: text("value").notNull(),
expiresAt: integer("expires_at", { mode: "timestamp_ms" }).notNull(),
createdAt: integer("created_at", { mode: "timestamp_ms" })
.default(sql`(cast(unixepoch('subsecond') * 1000 as integer))`)
.notNull(),
updatedAt: integer("updated_at", { mode: "timestamp_ms" })
.default(sql`(cast(unixepoch('subsecond') * 1000 as integer))`)
.$onUpdate(() => /* @__PURE__ */ new Date())
.notNull(),
});

View File

@@ -361,3 +361,171 @@ describe("JSON field support in CLI generators", () => {
expect(schema.code).toContain("preferences Json?");
});
});
describe("Enum field support in Drizzle schemas", () => {
it("should generate Drizzle schema with enum fields for PostgreSQL", async () => {
const schema = await generateDrizzleSchema({
file: "test.drizzle",
adapter: {
id: "drizzle",
options: {
provider: "pg",
schema: {},
},
} as any,
options: {
database: {} as any,
user: {
additionalFields: {
role: {
type: ["admin", "user", "guest"],
required: true,
},
},
},
} as BetterAuthOptions,
});
expect(schema.code).toContain("pgEnum");
expect(schema.code).toContain(
'role: pgEnum("role", ["admin", "user", "guest"])',
);
await expect(schema.code).toMatchFileSnapshot(
"./__snapshots__/auth-schema-pg-enum.txt",
);
});
it("should generate Drizzle schema with enum fields for MySQL", async () => {
const schema = await generateDrizzleSchema({
file: "test.drizzle",
adapter: {
id: "drizzle",
options: {
provider: "mysql",
schema: {},
},
} as any,
options: {
database: {} as any,
user: {
additionalFields: {
status: {
type: ["active", "inactive", "pending"],
required: false,
},
},
},
} as BetterAuthOptions,
});
expect(schema.code).toContain("mysqlEnum");
expect(schema.code).toContain(
'status: mysqlEnum(["active", "inactive", "pending"])',
);
await expect(schema.code).toMatchFileSnapshot(
"./__snapshots__/auth-schema-mysql-enum.txt",
);
});
it("should generate Drizzle schema with enum fields for SQLite", async () => {
const schema = await generateDrizzleSchema({
file: "test.drizzle",
adapter: {
id: "drizzle",
options: {
provider: "sqlite",
schema: {},
},
} as any,
options: {
database: {} as any,
user: {
additionalFields: {
priority: {
type: ["high", "medium", "low"],
},
},
},
} as BetterAuthOptions,
});
expect(schema.code).toContain("text({ enum: [");
expect(schema.code).toContain(
'priority: text({ enum: ["high", "medium", "low"] })',
);
await expect(schema.code).toMatchFileSnapshot(
"./__snapshots__/auth-schema-sqlite-enum.txt",
);
});
it("should include correct imports for enum fields in PostgreSQL", async () => {
const schema = await generateDrizzleSchema({
file: "test.drizzle",
adapter: {
id: "drizzle",
options: {
provider: "pg",
schema: {},
},
} as any,
options: {
database: {} as any,
user: {
additionalFields: {
role: {
type: ["admin", "user"],
},
},
},
} as BetterAuthOptions,
});
expect(schema.code).toMatch(/import.*pgEnum.*from.*drizzle-orm\/pg-core/);
});
it("should include correct imports for enum fields in MySQL", async () => {
const schema = await generateDrizzleSchema({
file: "test.drizzle",
adapter: {
id: "drizzle",
options: {
provider: "mysql",
schema: {},
},
} as any,
options: {
database: {} as any,
user: {
additionalFields: {
status: {
type: ["active", "inactive"],
},
},
},
} as BetterAuthOptions,
});
expect(schema.code).toMatch(
/import.*mysqlEnum.*from.*drizzle-orm\/mysql-core/s,
);
});
it("should not include enum imports when no enum fields are present", async () => {
const schema = await generateDrizzleSchema({
file: "test.drizzle",
adapter: {
id: "drizzle",
options: {
provider: "pg",
schema: {},
},
} as any,
options: {
database: {} as any,
user: {
additionalFields: {
name: {
type: "string",
},
},
},
} as BetterAuthOptions,
});
expect(schema.code).not.toContain("pgEnum");
});
});