diff --git a/packages/better-auth/src/plugins/generic-oauth/index.ts b/packages/better-auth/src/plugins/generic-oauth/index.ts index 0323fceeab..f9440e46f5 100644 --- a/packages/better-auth/src/plugins/generic-oauth/index.ts +++ b/packages/better-auth/src/plugins/generic-oauth/index.ts @@ -274,6 +274,12 @@ export const genericOAuth = (options: GenericOAuthOptions) => { description: "The URL to redirect to if an error occurs", }) .optional(), + newUserCallbackURL: z + .string({ + description: + "The URL to redirect to after login if the user is new", + }) + .optional(), disableRedirect: z .boolean({ description: "Disable redirect", @@ -460,7 +466,8 @@ export const genericOAuth = (options: GenericOAuthOptions) => { let tokens: OAuth2Tokens | undefined = undefined; const parsedState = await parseState(ctx); - const { callbackURL, codeVerifier, errorURL } = parsedState; + const { callbackURL, codeVerifier, errorURL, newUserURL } = + parsedState; const code = ctx.query.code; let finalTokenUrl = provider.tokenUrl; @@ -557,10 +564,14 @@ export const genericOAuth = (options: GenericOAuthOptions) => { }); let toRedirectTo: string; try { - const url = new URL(callbackURL); + const url = result.isRegister + ? newUserURL || callbackURL + : callbackURL; toRedirectTo = url.toString(); } catch { - toRedirectTo = callbackURL; + toRedirectTo = result.isRegister + ? newUserURL || callbackURL + : callbackURL; } throw ctx.redirect(toRedirectTo); }, diff --git a/packages/better-auth/src/plugins/generic-oauth/oauth2.test.ts b/packages/better-auth/src/plugins/generic-oauth/oauth2.test.ts index 811352e24a..6d18535d78 100644 --- a/packages/better-auth/src/plugins/generic-oauth/oauth2.test.ts +++ b/packages/better-auth/src/plugins/generic-oauth/oauth2.test.ts @@ -15,20 +15,7 @@ describe("oauth2", async () => { const clientId = "test-client-id"; const clientSecret = "test-client-secret"; - beforeAll(async () => { - await server.issuer.keys.generate("RS256"); - - server.issuer.on; - // Start the server - await server.start(8080, "localhost"); - console.log("Issuer URL:", server.issuer.url); // -> http://localhost:8080 - }); - - afterAll(async () => { - await server.stop(); - }); - - const { customFetchImpl } = await getTestInstance({ + const { customFetchImpl, auth } = await getTestInstance({ plugins: [ genericOAuth({ config: [ @@ -54,6 +41,24 @@ describe("oauth2", async () => { }, }); + beforeAll(async () => { + const context = await auth.$context; + await context.internalAdapter.createUser({ + email: "oauth2@test.com", + name: "OAuth2 Test", + }); + await server.issuer.keys.generate("RS256"); + + server.issuer.on; + // Start the server + await server.start(8080, "localhost"); + console.log("Issuer URL:", server.issuer.url); // -> http://localhost:8080 + }); + + afterAll(async () => { + await server.stop(); + }); + server.service.on("beforeUserinfo", (userInfoResponse, req) => { userInfoResponse.body = { email: "oauth2@test.com", @@ -99,6 +104,7 @@ describe("oauth2", async () => { const signInRes = await authClient.signIn.oauth2({ providerId: "test", callbackURL: "http://localhost:3000/dashboard", + newUserCallbackURL: "http://localhost:3000/new_user", }); expect(signInRes.data).toMatchObject({ url: expect.stringContaining("http://localhost:8080/authorize"), @@ -111,11 +117,41 @@ describe("oauth2", async () => { expect(callbackURL).toBe("http://localhost:3000/dashboard"); }); + it("should redirect to the provider and handle the response for a new user", async () => { + server.service.once("beforeUserinfo", (userInfoResponse) => { + userInfoResponse.body = { + email: "oauth2-2@test.com", + name: "OAuth2 Test 2", + sub: "oauth2-2", + picture: "https://test.com/picture.png", + email_verified: true, + }; + userInfoResponse.statusCode = 200; + }); + + let headers = new Headers(); + const signInRes = await authClient.signIn.oauth2({ + providerId: "test", + callbackURL: "http://localhost:3000/dashboard", + newUserCallbackURL: "http://localhost:3000/new_user", + }); + expect(signInRes.data).toMatchObject({ + url: expect.stringContaining("http://localhost:8080/authorize"), + redirect: true, + }); + const callbackURL = await simulateOAuthFlow( + signInRes.data?.url || "", + headers, + ); + expect(callbackURL).toBe("http://localhost:3000/new_user"); + }); + it("should redirect to the provider and handle the response after linked", async () => { let headers = new Headers(); const res = await authClient.signIn.oauth2({ providerId: "test", callbackURL: "http://localhost:3000/dashboard", + newUserCallbackURL: "http://localhost:3000/new_user", }); const callbackURL = await simulateOAuthFlow(res.data?.url || "", headers); expect(callbackURL).toBe("http://localhost:3000/dashboard"); @@ -125,6 +161,7 @@ describe("oauth2", async () => { const res = await authClient.signIn.oauth2({ providerId: "invalid-provider", callbackURL: "http://localhost:3000/dashboard", + newUserCallbackURL: "http://localhost:3000/new_user", }); expect(res.error?.status).toBe(400); }); @@ -146,6 +183,7 @@ describe("oauth2", async () => { { providerId: "test", callbackURL: "http://localhost:3000/dashboard", + newUserCallbackURL: "http://localhost:3000/new_user", }, { onSuccess(context) { @@ -169,7 +207,7 @@ describe("oauth2", async () => { }); it("should work with custom redirect uri", async () => { - const { customFetchImpl } = await getTestInstance({ + const { customFetchImpl, auth } = await getTestInstance({ plugins: [ genericOAuth({ config: [ @@ -198,6 +236,7 @@ describe("oauth2", async () => { const res = await authClient.signIn.oauth2({ providerId: "test2", callbackURL: "http://localhost:3000/dashboard", + newUserCallbackURL: "http://localhost:3000/new_user", }); expect(res.data?.url).toContain("http://localhost:8080/authorize"); const headers = new Headers(); @@ -206,6 +245,6 @@ describe("oauth2", async () => { headers, customFetchImpl, ); - expect(callbackURL).toBe("http://localhost:3000/dashboard"); + expect(callbackURL).toBe("http://localhost:3000/new_user"); }); }); diff --git a/packages/better-auth/src/plugins/sso/index.ts b/packages/better-auth/src/plugins/sso/index.ts index 13d137c076..923ee024b8 100644 --- a/packages/better-auth/src/plugins/sso/index.ts +++ b/packages/better-auth/src/plugins/sso/index.ts @@ -679,10 +679,14 @@ export const sso = (options?: SSOOptions) => { }); let toRedirectTo: string; try { - const url = new URL(callbackURL); + const url = linked.isRegister + ? newUserURL || callbackURL + : callbackURL; toRedirectTo = url.toString(); } catch { - toRedirectTo = callbackURL; + toRedirectTo = linked.isRegister + ? newUserURL || callbackURL + : callbackURL; } throw ctx.redirect(toRedirectTo); },