fix(drizzle): drizzle with mysql update breaks on anything other than update by id(#1377)

---------

Co-authored-by: Bereket Engida <bekacru@gmail.com>
This commit is contained in:
KinfeMichael Tariku
2025-02-08 11:11:37 +03:00
committed by GitHub
parent 90a487a6b5
commit 0794de2c71
6 changed files with 99 additions and 162 deletions

View File

@@ -49,6 +49,74 @@ const createTransform = (
: model;
};
function convertWhereClause(where: Where[], model: string) {
const schemaModel = getSchema(model);
if (!where) return [];
if (where.length === 1) {
const w = where[0];
if (!w) {
return [];
}
const field = getField(model, w.field);
if (!schemaModel[field]) {
throw new BetterAuthError(
`The field "${w.field}" does not exist in the schema for the model "${model}". Please update your schema.`,
);
}
if (w.operator === "in") {
if (!Array.isArray(w.value)) {
throw new BetterAuthError(
`The value for the field "${w.field}" must be an array when using the "in" operator.`,
);
}
return [inArray(schemaModel[field], w.value)];
}
if (w.operator === "contains") {
return [like(schemaModel[field], `%${w.value}%`)];
}
if (w.operator === "starts_with") {
return [like(schemaModel[field], `${w.value}%`)];
}
if (w.operator === "ends_with") {
return [like(schemaModel[field], `%${w.value}`)];
}
return [eq(schemaModel[field], w.value)];
}
const andGroup = where.filter((w) => w.connector === "AND" || !w.connector);
const orGroup = where.filter((w) => w.connector === "OR");
const andClause = and(
...andGroup.map((w) => {
const field = getField(model, w.field);
if (w.operator === "in") {
if (!Array.isArray(w.value)) {
throw new BetterAuthError(
`The value for the field "${w.field}" must be an array when using the "in" operator.`,
);
}
return inArray(schemaModel[field], w.value);
}
return eq(schemaModel[field], w.value);
}),
);
const orClause = or(
...orGroup.map((w) => {
const field = getField(model, w.field);
return eq(schemaModel[field], w.value);
}),
);
const clause: SQL<unknown>[] = [];
if (andGroup.length) clause.push(andClause!);
if (orGroup.length) clause.push(orClause!);
return clause;
}
const useDatabaseGeneratedId = options?.advanced?.generateId === false;
return {
getSchema,
@@ -107,105 +175,35 @@ const createTransform = (
}
return transformedData as any;
},
convertWhereClause(where: Where[], model: string) {
const schemaModel = getSchema(model);
if (!where) return [];
if (where.length === 1) {
const w = where[0];
if (!w) {
return [];
}
const field = getField(model, w.field);
if (!schemaModel[field]) {
throw new BetterAuthError(
`The field "${w.field}" does not exist in the schema for the model "${model}". Please update your schema.`,
);
}
if (w.operator === "in") {
if (!Array.isArray(w.value)) {
throw new BetterAuthError(
`The value for the field "${w.field}" must be an array when using the "in" operator.`,
);
}
return [inArray(schemaModel[field], w.value)];
}
if (w.operator === "contains") {
return [like(schemaModel[field], `%${w.value}%`)];
}
if (w.operator === "starts_with") {
return [like(schemaModel[field], `${w.value}%`)];
}
if (w.operator === "ends_with") {
return [like(schemaModel[field], `%${w.value}`)];
}
return [eq(schemaModel[field], w.value)];
}
const andGroup = where.filter(
(w) => w.connector === "AND" || !w.connector,
);
const orGroup = where.filter((w) => w.connector === "OR");
const andClause = and(
...andGroup.map((w) => {
const field = getField(model, w.field);
if (w.operator === "in") {
if (!Array.isArray(w.value)) {
throw new BetterAuthError(
`The value for the field "${w.field}" must be an array when using the "in" operator.`,
);
}
return inArray(schemaModel[field], w.value);
}
return eq(schemaModel[field], w.value);
}),
);
const orClause = or(
...orGroup.map((w) => {
const field = getField(model, w.field);
return eq(schemaModel[field], w.value);
}),
);
const clause: SQL<unknown>[] = [];
if (andGroup.length) clause.push(andClause!);
if (orGroup.length) clause.push(orClause!);
return clause;
},
convertWhereClause,
withReturning: async (
model: string,
builder: any,
data: Record<string, any>,
where?: Where[],
) => {
if (config.provider !== "mysql") {
const c = await builder.returning();
return c[0];
}
const result = await builder.execute();
const updatedResult = builder.config?.where;
await builder.execute();
const schemaModel = getSchema(model);
const builderVal = builder.config?.values;
if (updatedResult) {
const upId = updatedResult?.queryChunks[3]?.value;
const schemaModel = getSchema(model);
if (where?.length) {
const clause = convertWhereClause(where, model);
const res = await db
.select()
.from(schemaModel)
.where(eq(schemaModel.id, upId));
.where(...clause);
return res[0];
} else if (builderVal) {
const tId = builderVal[0]?.id.value;
const schemaModel = getSchema(model);
const res = await db
.select()
.from(schemaModel)
.where(eq(schemaModel.id, tId));
return res[0];
} else if (data.id) {
const schemaModel = getSchema(model);
const res = await db
.select()
.from(schemaModel)
@@ -314,7 +312,12 @@ export const drizzleAdapter =
.update(schemaModel)
.set(transformed)
.where(...clause);
const returned = await withReturning(model, builder, transformed);
const returned = await withReturning(
model,
builder,
transformed,
where,
);
return transformOutput(returned, model);
},
async updateMany(data) {

View File

@@ -165,6 +165,7 @@ export async function runAdapterTest(opts: AdapterTestOptions) {
});
test("should work with reference fields", async () => {
let token = null;
const user = await adapter.create<{ id: string } & Record<string, any>>({
model: "user",
data: {
@@ -176,7 +177,7 @@ export async function runAdapterTest(opts: AdapterTestOptions) {
updatedAt: new Date(),
},
});
await adapter.create({
const session = await adapter.create({
model: "session",
data: {
id: "1",
@@ -187,6 +188,7 @@ export async function runAdapterTest(opts: AdapterTestOptions) {
expiresAt: new Date(),
},
});
token = session.token;
const res = await adapter.findOne({
model: "session",
where: [
@@ -196,9 +198,21 @@ export async function runAdapterTest(opts: AdapterTestOptions) {
},
],
});
const resToken = await adapter.findOne({
model: "session",
where: [
{
field: "token",
value: token,
},
],
});
expect(res).toMatchObject({
userId: user.id,
});
expect(resToken).toMatchObject({
userId: user.id,
});
});
test("should find many with sortBy", async () => {

View File

@@ -1,3 +1,4 @@
generator client {
provider = "prisma-client-js"
}

View File

@@ -1,3 +1,4 @@
generator client {
provider = "prisma-client-js"
}

View File

@@ -1,3 +1,4 @@
generator client {
provider = "prisma-client-js"
}

View File

@@ -1,83 +0,0 @@
generator client {
provider = "prisma-client-js"
}
datasource db {
provider = "postgresql"
url = env("DATABASE_URL")
}
model User {
id String @id
name String
email String
emailVerified Boolean
image String?
createdAt DateTime
updatedAt DateTime
twoFactorEnabled Boolean?
username String?
sessions Session[]
accounts Account[]
twofactors TwoFactor[]
@@unique([email])
@@unique([username])
@@map("user")
}
model Session {
id String @id
expiresAt DateTime
token String
createdAt DateTime
updatedAt DateTime
ipAddress String?
userAgent String?
userId String
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@unique([token])
@@map("session")
}
model Account {
id String @id
accountId String
providerId String
userId String
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
accessToken String?
refreshToken String?
idToken String?
accessTokenExpiresAt DateTime?
refreshTokenExpiresAt DateTime?
scope String?
password String?
createdAt DateTime
updatedAt DateTime
@@map("account")
}
model Verification {
id String @id
identifier String
value String
expiresAt DateTime
createdAt DateTime?
updatedAt DateTime?
@@map("verification")
}
model TwoFactor {
id String @id
secret String
backupCodes String
userId String
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@map("twoFactor")
}