mirror of
https://github.com/better-auth/better-auth.git
synced 2026-06-05 05:46:50 -05:00
feat: allow provider config to add extra fields to linked accounts
This allows users to do things like add an additional column to the accounts table with the email address of the linked account (see #2272).
This commit is contained in:
@@ -1402,4 +1402,92 @@ describe("account", async () => {
|
||||
expect(refreshedSessionCookie).toBe(true);
|
||||
expect(refreshedAccountCookie).toBe(true);
|
||||
});
|
||||
|
||||
it("should allow additional account fields to be set based on provider configuration", async () => {
|
||||
const { signInWithTestUser, client } = await getTestInstance({
|
||||
disableTestUser: true,
|
||||
socialProviders: {
|
||||
google: {
|
||||
clientId: "test",
|
||||
clientSecret: "test",
|
||||
enabled: true,
|
||||
getAccountFields: async (_token, userInfo) => {
|
||||
return {
|
||||
foo: "bar",
|
||||
providerEmail: userInfo.email,
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
account: {
|
||||
accountLinking: {
|
||||
allowDifferentEmails: true,
|
||||
},
|
||||
additionalFields: {
|
||||
foo: {
|
||||
type: "string",
|
||||
required: false,
|
||||
},
|
||||
providerEmail: {
|
||||
type: "string",
|
||||
required: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const { runWithUser: runWithClient2 } = await signInWithTestUser();
|
||||
|
||||
await runWithClient2(async (headers) => {
|
||||
const linkAccountRes = await client.linkSocial(
|
||||
{
|
||||
provider: "google",
|
||||
callbackURL: "/callback",
|
||||
},
|
||||
{
|
||||
onSuccess(context) {
|
||||
const cookies = parseSetCookieHeader(
|
||||
context.response.headers.get("set-cookie") || "",
|
||||
);
|
||||
headers.set(
|
||||
"cookie",
|
||||
`better-auth.state=${cookies.get("better-auth.state")?.value}`,
|
||||
);
|
||||
},
|
||||
},
|
||||
);
|
||||
expect(linkAccountRes.data).toMatchObject({
|
||||
url: expect.stringContaining("google.com"),
|
||||
redirect: true,
|
||||
});
|
||||
const state =
|
||||
linkAccountRes.data && "url" in linkAccountRes.data
|
||||
? new URL(linkAccountRes.data.url).searchParams.get("state") || ""
|
||||
: "";
|
||||
email = "test2@test.com";
|
||||
await client.$fetch("/callback/google", {
|
||||
query: {
|
||||
state,
|
||||
code: "test",
|
||||
},
|
||||
method: "GET",
|
||||
onError(context) {
|
||||
expect(context.response.status).toBe(302);
|
||||
const location = context.response.headers.get("location");
|
||||
expect(location).toBeDefined();
|
||||
expect(location).toContain("/callback");
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
const { runWithUser: runWithClient3 } = await signInWithTestUser();
|
||||
|
||||
await runWithClient3(async () => {
|
||||
const accounts = await client.listAccounts();
|
||||
expect(accounts.data?.length).toBe(2);
|
||||
const newAccount = accounts.data?.[1] as Record<string,any>
|
||||
expect(newAccount.foo).toEqual("bar");
|
||||
expect(newAccount.providerEmail).toEqual("test2@test.com");
|
||||
});
|
||||
})
|
||||
});
|
||||
|
||||
@@ -174,6 +174,8 @@ export const callbackOAuth = createAuthEndpoint(
|
||||
throw redirectOnError("no_callback_url");
|
||||
}
|
||||
|
||||
const additionalAccountFields = await provider.options?.getAccountFields?.(tokens, userInfo)
|
||||
|
||||
if (link) {
|
||||
const isTrustedProvider = c.context.trustedProviders.includes(
|
||||
provider.id,
|
||||
@@ -209,6 +211,7 @@ export const callbackOAuth = createAuthEndpoint(
|
||||
accessTokenExpiresAt: tokens.accessTokenExpiresAt,
|
||||
refreshTokenExpiresAt: tokens.refreshTokenExpiresAt,
|
||||
scope: tokens.scopes?.join(","),
|
||||
...additionalAccountFields,
|
||||
}).filter(([_, value]) => value !== undefined),
|
||||
);
|
||||
await c.context.internalAdapter.updateAccount(
|
||||
@@ -224,6 +227,7 @@ export const callbackOAuth = createAuthEndpoint(
|
||||
accessToken: await setTokenUtil(tokens.accessToken, c.context),
|
||||
refreshToken: await setTokenUtil(tokens.refreshToken, c.context),
|
||||
scope: tokens.scopes?.join(","),
|
||||
...additionalAccountFields,
|
||||
});
|
||||
if (!newAccount) {
|
||||
return redirectOnError("unable_to_link_account");
|
||||
@@ -250,6 +254,7 @@ export const callbackOAuth = createAuthEndpoint(
|
||||
accountId: String(userInfo.id),
|
||||
...tokens,
|
||||
scope: tokens.scopes?.join(","),
|
||||
...additionalAccountFields,
|
||||
};
|
||||
const result = await handleOAuthUserInfo(c, {
|
||||
userInfo: {
|
||||
|
||||
@@ -272,10 +272,13 @@ export const signInSocial = <O extends BetterAuthOptions>() =>
|
||||
});
|
||||
throw APIError.from("UNAUTHORIZED", BASE_ERROR_CODES.INVALID_TOKEN);
|
||||
}
|
||||
const userInfo = await provider.getUserInfo({
|
||||
const tokens = {
|
||||
idToken: token,
|
||||
accessToken: c.body.idToken.accessToken,
|
||||
refreshToken: c.body.idToken.refreshToken,
|
||||
};
|
||||
const userInfo = await provider.getUserInfo({
|
||||
...tokens,
|
||||
user: c.body.idToken.user,
|
||||
});
|
||||
if (!userInfo || !userInfo?.user) {
|
||||
@@ -296,6 +299,7 @@ export const signInSocial = <O extends BetterAuthOptions>() =>
|
||||
BASE_ERROR_CODES.USER_EMAIL_NOT_FOUND,
|
||||
);
|
||||
}
|
||||
const additionalAccountFields = await provider.options?.getAccountFields?.(tokens, userInfo.user)
|
||||
const data = await handleOAuthUserInfo(c, {
|
||||
userInfo: {
|
||||
...userInfo.user,
|
||||
@@ -309,6 +313,7 @@ export const signInSocial = <O extends BetterAuthOptions>() =>
|
||||
providerId: provider.id,
|
||||
accountId: String(userInfo.user.id),
|
||||
accessToken: c.body.idToken.accessToken,
|
||||
...additionalAccountFields,
|
||||
},
|
||||
callbackURL: c.body.callbackURL,
|
||||
disableSignUp:
|
||||
|
||||
@@ -158,14 +158,12 @@ export async function handleOAuthUserInfo(
|
||||
}
|
||||
try {
|
||||
const { id: _, ...restUserInfo } = userInfo;
|
||||
const { accessToken, refreshToken, ...restAccount } = account
|
||||
const accountData = {
|
||||
accessToken: await setTokenUtil(account.accessToken, c.context),
|
||||
refreshToken: await setTokenUtil(account.refreshToken, c.context),
|
||||
idToken: account.idToken,
|
||||
accessTokenExpiresAt: account.accessTokenExpiresAt,
|
||||
refreshTokenExpiresAt: account.refreshTokenExpiresAt,
|
||||
scope: account.scope,
|
||||
providerId: account.providerId,
|
||||
...restAccount,
|
||||
accountId: userInfo.id.toString(),
|
||||
};
|
||||
const { user: createdUser, account: createdAccount } =
|
||||
|
||||
@@ -187,6 +187,14 @@ export type ProviderOptions<Profile extends Record<string, any> = any> = {
|
||||
[key: string]: any;
|
||||
}>)
|
||||
| undefined;
|
||||
/**
|
||||
* Custom function to get account fields from userInfo
|
||||
*/
|
||||
getAccountFields?:
|
||||
| ((token: OAuth2Tokens, userInfo: OAuth2UserInfo) => Promise<{
|
||||
[key: string]: any;
|
||||
} | null>)
|
||||
| undefined;
|
||||
/**
|
||||
* Disable implicit sign up for new users. When set to true for the provider,
|
||||
* sign-in need to be called with with requestSignUp as true to create new users.
|
||||
|
||||
@@ -43,6 +43,7 @@ onlyBuiltDependencies:
|
||||
- '@swc/core'
|
||||
- '@tsparticles/engine'
|
||||
- better-sqlite3
|
||||
- electron
|
||||
- esbuild
|
||||
- less
|
||||
- msw
|
||||
|
||||
Reference in New Issue
Block a user