diff --git a/packages/better-auth/src/api/routes/account.test.ts b/packages/better-auth/src/api/routes/account.test.ts index 00e79cc8a7..50ee578e8a 100644 --- a/packages/better-auth/src/api/routes/account.test.ts +++ b/packages/better-auth/src/api/routes/account.test.ts @@ -1402,4 +1402,92 @@ describe("account", async () => { expect(refreshedSessionCookie).toBe(true); expect(refreshedAccountCookie).toBe(true); }); + + it("should allow additional account fields to be set based on provider configuration", async () => { + const { signInWithTestUser, client } = await getTestInstance({ + disableTestUser: true, + socialProviders: { + google: { + clientId: "test", + clientSecret: "test", + enabled: true, + getAccountFields: async (_token, userInfo) => { + return { + foo: "bar", + providerEmail: userInfo.email, + } + }, + }, + }, + account: { + accountLinking: { + allowDifferentEmails: true, + }, + additionalFields: { + foo: { + type: "string", + required: false, + }, + providerEmail: { + type: "string", + required: false, + }, + }, + }, + }); + + const { runWithUser: runWithClient2 } = await signInWithTestUser(); + + await runWithClient2(async (headers) => { + const linkAccountRes = await client.linkSocial( + { + provider: "google", + callbackURL: "/callback", + }, + { + onSuccess(context) { + const cookies = parseSetCookieHeader( + context.response.headers.get("set-cookie") || "", + ); + headers.set( + "cookie", + `better-auth.state=${cookies.get("better-auth.state")?.value}`, + ); + }, + }, + ); + expect(linkAccountRes.data).toMatchObject({ + url: expect.stringContaining("google.com"), + redirect: true, + }); + const state = + linkAccountRes.data && "url" in linkAccountRes.data + ? new URL(linkAccountRes.data.url).searchParams.get("state") || "" + : ""; + email = "test2@test.com"; + await client.$fetch("/callback/google", { + query: { + state, + code: "test", + }, + method: "GET", + onError(context) { + expect(context.response.status).toBe(302); + const location = context.response.headers.get("location"); + expect(location).toBeDefined(); + expect(location).toContain("/callback"); + }, + }); + }); + + const { runWithUser: runWithClient3 } = await signInWithTestUser(); + + await runWithClient3(async () => { + const accounts = await client.listAccounts(); + expect(accounts.data?.length).toBe(2); + const newAccount = accounts.data?.[1] as Record + expect(newAccount.foo).toEqual("bar"); + expect(newAccount.providerEmail).toEqual("test2@test.com"); + }); + }) }); diff --git a/packages/better-auth/src/api/routes/callback.ts b/packages/better-auth/src/api/routes/callback.ts index 6e98fbd2e5..116dde2f47 100644 --- a/packages/better-auth/src/api/routes/callback.ts +++ b/packages/better-auth/src/api/routes/callback.ts @@ -174,6 +174,8 @@ export const callbackOAuth = createAuthEndpoint( throw redirectOnError("no_callback_url"); } + const additionalAccountFields = await provider.options?.getAccountFields?.(tokens, userInfo) + if (link) { const isTrustedProvider = c.context.trustedProviders.includes( provider.id, @@ -209,6 +211,7 @@ export const callbackOAuth = createAuthEndpoint( accessTokenExpiresAt: tokens.accessTokenExpiresAt, refreshTokenExpiresAt: tokens.refreshTokenExpiresAt, scope: tokens.scopes?.join(","), + ...additionalAccountFields, }).filter(([_, value]) => value !== undefined), ); await c.context.internalAdapter.updateAccount( @@ -224,6 +227,7 @@ export const callbackOAuth = createAuthEndpoint( accessToken: await setTokenUtil(tokens.accessToken, c.context), refreshToken: await setTokenUtil(tokens.refreshToken, c.context), scope: tokens.scopes?.join(","), + ...additionalAccountFields, }); if (!newAccount) { return redirectOnError("unable_to_link_account"); @@ -250,6 +254,7 @@ export const callbackOAuth = createAuthEndpoint( accountId: String(userInfo.id), ...tokens, scope: tokens.scopes?.join(","), + ...additionalAccountFields, }; const result = await handleOAuthUserInfo(c, { userInfo: { diff --git a/packages/better-auth/src/api/routes/sign-in.ts b/packages/better-auth/src/api/routes/sign-in.ts index 7a2f71b090..e68c843208 100644 --- a/packages/better-auth/src/api/routes/sign-in.ts +++ b/packages/better-auth/src/api/routes/sign-in.ts @@ -272,10 +272,13 @@ export const signInSocial = () => }); throw APIError.from("UNAUTHORIZED", BASE_ERROR_CODES.INVALID_TOKEN); } - const userInfo = await provider.getUserInfo({ + const tokens = { idToken: token, accessToken: c.body.idToken.accessToken, refreshToken: c.body.idToken.refreshToken, + }; + const userInfo = await provider.getUserInfo({ + ...tokens, user: c.body.idToken.user, }); if (!userInfo || !userInfo?.user) { @@ -296,6 +299,7 @@ export const signInSocial = () => BASE_ERROR_CODES.USER_EMAIL_NOT_FOUND, ); } + const additionalAccountFields = await provider.options?.getAccountFields?.(tokens, userInfo.user) const data = await handleOAuthUserInfo(c, { userInfo: { ...userInfo.user, @@ -309,6 +313,7 @@ export const signInSocial = () => providerId: provider.id, accountId: String(userInfo.user.id), accessToken: c.body.idToken.accessToken, + ...additionalAccountFields, }, callbackURL: c.body.callbackURL, disableSignUp: diff --git a/packages/better-auth/src/oauth2/link-account.ts b/packages/better-auth/src/oauth2/link-account.ts index ba27b8c90d..b21ba7ecaf 100644 --- a/packages/better-auth/src/oauth2/link-account.ts +++ b/packages/better-auth/src/oauth2/link-account.ts @@ -158,14 +158,12 @@ export async function handleOAuthUserInfo( } try { const { id: _, ...restUserInfo } = userInfo; + const { accessToken, refreshToken, ...restAccount } = account const accountData = { accessToken: await setTokenUtil(account.accessToken, c.context), refreshToken: await setTokenUtil(account.refreshToken, c.context), idToken: account.idToken, - accessTokenExpiresAt: account.accessTokenExpiresAt, - refreshTokenExpiresAt: account.refreshTokenExpiresAt, - scope: account.scope, - providerId: account.providerId, + ...restAccount, accountId: userInfo.id.toString(), }; const { user: createdUser, account: createdAccount } = diff --git a/packages/core/src/oauth2/oauth-provider.ts b/packages/core/src/oauth2/oauth-provider.ts index b8bc33a514..871b55bb34 100644 --- a/packages/core/src/oauth2/oauth-provider.ts +++ b/packages/core/src/oauth2/oauth-provider.ts @@ -187,6 +187,14 @@ export type ProviderOptions = any> = { [key: string]: any; }>) | undefined; + /** + * Custom function to get account fields from userInfo + */ + getAccountFields?: + | ((token: OAuth2Tokens, userInfo: OAuth2UserInfo) => Promise<{ + [key: string]: any; + } | null>) + | undefined; /** * Disable implicit sign up for new users. When set to true for the provider, * sign-in need to be called with with requestSignUp as true to create new users. diff --git a/pnpm-workspace.yaml b/pnpm-workspace.yaml index 1cdeafadff..2000f824bc 100644 --- a/pnpm-workspace.yaml +++ b/pnpm-workspace.yaml @@ -43,6 +43,7 @@ onlyBuiltDependencies: - '@swc/core' - '@tsparticles/engine' - better-sqlite3 + - electron - esbuild - less - msw