diff --git a/packages/better-auth/src/api/middlewares/origin-check.test.ts b/packages/better-auth/src/api/middlewares/origin-check.test.ts index ebcae1ae9e..55a0c641d4 100644 --- a/packages/better-auth/src/api/middlewares/origin-check.test.ts +++ b/packages/better-auth/src/api/middlewares/origin-check.test.ts @@ -206,6 +206,36 @@ describe("Origin Check", async (it) => { expect(res.error?.status).toBe(403); }); + it("shouldn't work with callback url with malicious", async (ctx) => { + const client = createAuthClient({ + baseURL: "http://localhost:3000", + fetchOptions: { + customFetchImpl, + headers: { + origin: "https://localhost:3000", + }, + }, + }); + const res = await client.signIn.email({ + email: testUser.email, + password: testUser.password, + callbackURL: "/%5C/evil.com", + }); + expect(res.error?.status).toBe(403); + const res2 = await client.signIn.email({ + email: testUser.email, + password: testUser.password, + callbackURL: `/\/\/evil.com`, + }); + expect(res2.error?.status).toBe(403); + const res3 = await client.signIn.email({ + email: testUser.email, + password: testUser.password, + callbackURL: "/%5C/evil.com", + }); + expect(res3.error?.status).toBe(403); + }); + it("should work with GET requests", async (ctx) => { const client = createAuthClient({ baseURL: "https://sub-domain.my-site.com", diff --git a/packages/better-auth/src/api/middlewares/origin-check.ts b/packages/better-auth/src/api/middlewares/origin-check.ts index 6d0a85370a..b3144bd16f 100644 --- a/packages/better-auth/src/api/middlewares/origin-check.ts +++ b/packages/better-auth/src/api/middlewares/origin-check.ts @@ -49,8 +49,7 @@ export const originCheckMiddleware = createAuthMiddleware(async (ctx) => { matchesPattern(url, origin) || (url?.startsWith("/") && label !== "origin" && - !url.includes(":") && - !url.includes("//")), + /^\/(?![\\/%])[\w\-./]*$/.test(url)), ); if (!isTrustedOrigin) { ctx.context.logger.error(`Invalid ${label}: ${url}`); @@ -107,8 +106,7 @@ export const originCheck = ( matchesPattern(url, origin) || (url?.startsWith("/") && label !== "origin" && - !url.includes(":") && - !url.includes("//")), + /^\/(?![\\/%])[\w\-./]*$/.test(url)), ); if (!isTrustedOrigin) { ctx.context.logger.error(`Invalid ${label}: ${url}`);