fix: chunk account data cookie when exceeding limit (#6393)

This commit is contained in:
Joél Solano
2025-11-30 06:45:17 +01:00
committed by GitHub
parent 6d5e8a9338
commit 57d36f295d
4 changed files with 383 additions and 129 deletions

View File

@@ -13,7 +13,7 @@ import {
vi,
} from "vitest";
import { parseSetCookieHeader } from "../../cookies";
import { signJWT } from "../../crypto";
import { signJWT, symmetricDecodeJWT } from "../../crypto";
import { getTestInstance } from "../../test-utils/test-instance";
import type { Account } from "../../types";
import { DEFAULT_SECRET } from "../../utils/constants";
@@ -492,4 +492,240 @@ describe("account", async () => {
expect(accessTokenRes.data).toBeDefined();
expect(accessTokenRes.data?.accessToken).toBe("test");
});
it("should NOT chunk account data cookies when exceeding 4KB", async () => {
const { client, cookieSetter } = await getTestInstance({
secret: "better-auth.secret",
account: {
storeAccountCookie: true,
},
socialProviders: {
google: {
clientId: "test",
clientSecret: "test",
enabled: true,
},
},
});
const headers = new Headers();
email = "oauth-test@test.com";
const signInRes = await client.signIn.social({
provider: "google",
callbackURL: "/callback",
fetchOptions: {
onSuccess: cookieSetter(headers),
},
});
expect(signInRes.data).toMatchObject({
url: expect.stringContaining("google.com"),
redirect: true,
});
const state =
signInRes.data && "url" in signInRes.data && signInRes.data.url
? new URL(signInRes.data.url).searchParams.get("state") || ""
: "";
// Complete OAuth callback
await client.$fetch("/callback/google", {
query: {
state,
code: "test",
},
headers,
method: "GET",
async onError(context) {
const setCookie = context.response.headers.get("set-cookie");
expect(setCookie).toBeDefined();
const parsed = parseSetCookieHeader(setCookie!);
let hasChunks = false;
let hasSingleAccountData = false;
parsed.forEach((_value, name) => {
if (
name.includes("account_data.0") ||
name.includes("account_data.1")
) {
hasChunks = true;
}
if (name.endsWith("account_data")) {
hasSingleAccountData = true;
}
});
expect(hasChunks).toBe(false);
expect(hasSingleAccountData).toBe(true);
parsed.forEach((value, name) => {
headers.append("cookie", `${name}=${value.value}`);
});
},
});
const accessTokenRes = await client.getAccessToken(
{
providerId: "google",
},
{
headers,
},
);
expect(accessTokenRes.data).toBeDefined();
expect(accessTokenRes.data?.accessToken).toBe("test");
});
it("should chunk account data cookies when exceeding 4KB", async () => {
const { client, cookieSetter } = await getTestInstance({
secret: "better-auth.secret",
account: {
storeAccountCookie: true,
additionalFields: {
largeField: {
type: "string",
defaultValue: "x".repeat(5000), // 5KB field to exceed cookie size
},
},
},
socialProviders: {
google: {
clientId: "test",
clientSecret: "test",
enabled: true,
},
},
});
const headers = new Headers();
email = "oauth-test@test.com";
const signInRes = await client.signIn.social({
provider: "google",
callbackURL: "/callback",
fetchOptions: {
onSuccess: cookieSetter(headers),
},
});
expect(signInRes.data).toMatchObject({
url: expect.stringContaining("google.com"),
redirect: true,
});
const state =
signInRes.data && "url" in signInRes.data && signInRes.data.url
? new URL(signInRes.data.url).searchParams.get("state") || ""
: "";
// Complete OAuth callback
await client.$fetch("/callback/google", {
query: {
state,
code: "test",
},
headers,
method: "GET",
async onError(context) {
const setCookie = context.response.headers.get("set-cookie");
expect(setCookie).toBeDefined();
const parsed = parseSetCookieHeader(setCookie!);
let hasChunks = false;
parsed.forEach((_value, name) => {
if (
name.includes("account_data.0") ||
name.includes("account_data.1")
) {
hasChunks = true;
}
});
expect(hasChunks).toBe(true);
parsed.forEach((value, name) => {
headers.append("cookie", `${name}=${value.value}`);
});
},
});
const accessTokenRes = await client.getAccessToken(
{
providerId: "google",
},
{
headers,
},
);
expect(accessTokenRes.data).toBeDefined();
expect(accessTokenRes.data?.accessToken).toBe("test");
});
it("should encrypt account cookie payload", async () => {
const { auth, client, cookieSetter } = await getTestInstance({
secret: "better-auth.secret",
account: {
storeAccountCookie: true,
},
socialProviders: {
google: {
clientId: "test",
clientSecret: "test",
enabled: true,
},
},
});
const ctx = await auth.$context;
const headers = new Headers();
email = "oauth-test@test.com";
const signInRes = await client.signIn.social({
provider: "google",
callbackURL: "/callback",
fetchOptions: {
onSuccess: cookieSetter(headers),
},
});
expect(signInRes.data).toMatchObject({
url: expect.stringContaining("google.com"),
redirect: true,
});
const state =
signInRes.data && "url" in signInRes.data && signInRes.data.url
? new URL(signInRes.data.url).searchParams.get("state") || ""
: "";
// Complete OAuth callback
await client.$fetch("/callback/google", {
query: {
state,
code: "test",
},
headers,
method: "GET",
async onError(context) {
const setCookie = context.response.headers.get("set-cookie");
expect(setCookie).toBeDefined();
const parsed = parseSetCookieHeader(setCookie!);
const accountData = parsed.get("better-auth.account_data")?.value;
expect(accountData).toBeDefined();
expect(accountData!.startsWith("ey")).toBe(true);
await expect(
symmetricDecodeJWT(accountData!, ctx.secret, "better-auth-account"),
).resolves.toMatchObject({
accessToken: "test",
refreshToken: "test",
providerId: "google",
});
},
});
});
});

View File

@@ -5,9 +5,12 @@ import type { OAuth2Tokens } from "@better-auth/core/oauth2";
import { SocialProviderListEnum } from "@better-auth/core/social-providers";
import { APIError } from "better-call";
import * as z from "zod";
import {
getAccountCookie,
setAccountCookie,
} from "../../cookies/session-store";
import { generateState } from "../../oauth2/state";
import { decryptOAuthToken, setTokenUtil } from "../../oauth2/utils";
import { safeJSONParse } from "../../utils/json";
import {
freshSessionMiddleware,
getSessionFromCtx,
@@ -516,16 +519,7 @@ export const getAccessToken = createAuthEndpoint(
message: `Provider ${providerId} is not supported.`,
});
}
const accountDataCookieName = ctx.context.authCookies.accountData.name;
const accountDataCookie = await ctx.getSignedCookie(
accountDataCookieName,
ctx.context.secret,
);
const accountData = accountDataCookie
? safeJSONParse<Account>(accountDataCookie)
: null;
const accountData = await getAccountCookie(ctx);
let account: Account | undefined = undefined;
if (
accountData &&
@@ -587,12 +581,7 @@ export const getAccessToken = createAuthEndpoint(
const storeAccountCookie =
ctx.context.options.account?.storeAccountCookie;
if (storeAccountCookie && updatedAccount) {
await ctx.setSignedCookie(
accountDataCookieName,
JSON.stringify(updatedAccount),
ctx.context.secret,
ctx.context.authCookies.accountData.options,
);
await setAccountCookie(ctx, updatedAccount);
}
}
const tokens = {
@@ -708,16 +697,8 @@ export const refreshToken = createAuthEndpoint(
}
// Try to read refresh token from cookie first
const accountDataCookieName = ctx.context.authCookies.accountData.name;
const accountDataCookie = await ctx.getSignedCookie(
accountDataCookieName,
ctx.context.secret,
);
let account: Account | undefined = undefined;
const accountData = accountDataCookie
? safeJSONParse<Account>(accountDataCookie)
: null;
const accountData = await getAccountCookie(ctx);
if (
accountData &&
(!providerId || providerId === accountData?.providerId)
@@ -788,12 +769,7 @@ export const refreshToken = createAuthEndpoint(
scope: tokens.scopes?.join(",") || accountData.scope,
idToken: tokens.idToken || accountData.idToken,
};
await ctx.setSignedCookie(
accountDataCookieName,
JSON.stringify(updateData),
ctx.context.secret,
ctx.context.authCookies.accountData.options,
);
await setAccountCookie(ctx, updateData);
}
return ctx.json({
accessToken: tokens.accessToken,
@@ -884,19 +860,10 @@ export const accountInfo = createAuthEndpoint(
const providedAccountId = ctx.query?.accountId;
let account: Account | undefined = undefined;
if (!providedAccountId) {
const storeAccountCookie =
ctx.context.options.account?.storeAccountCookie;
if (storeAccountCookie) {
const accountCookieName = ctx.context.authCookies.accountData.name;
const accountCookie = await ctx.getSignedCookie(
accountCookieName,
ctx.context.secret,
);
if (accountCookie) {
const accountData = safeJSONParse<Account>(accountCookie);
if (accountData) {
account = accountData;
}
if (ctx.context.options.account?.storeAccountCookie) {
const accountData = await getAccountCookie(ctx);
if (accountData) {
account = accountData;
}
}
} else {

View File

@@ -1,7 +1,10 @@
import type { GenericEndpointContext } from "@better-auth/core";
import type { Account } from "@better-auth/core/db";
import type { InternalLogger } from "@better-auth/core/env";
import type { CookieOptions } from "better-call";
import * as z from "zod";
import { symmetricDecodeJWT, symmetricEncodeJWT } from "../crypto";
import { safeJSONParse } from "../utils/json";
// Cookie size constants based on browser limits
const ALLOWED_COOKIE_SIZE = 4096;
@@ -88,6 +91,7 @@ function joinChunks(chunks: Chunks): string {
* Split a cookie value into chunks if needed
*/
function chunkCookie(
storeName: string,
cookie: Cookie,
chunks: Chunks,
logger: InternalLogger,
@@ -108,8 +112,8 @@ function chunkCookie(
chunks[name] = value;
}
logger.debug("CHUNKING_SESSION_COOKIE", {
message: `Session cookie exceeds allowed ${ALLOWED_COOKIE_SIZE} bytes.`,
logger.debug(`CHUNKING_${storeName.toUpperCase()}_COOKIE`, {
message: `${storeName} cookie exceeds allowed ${ALLOWED_COOKIE_SIZE} bytes.`,
emptyCookieSize: ESTIMATED_EMPTY_COOKIE_SIZE,
valueSize: cookie.value.length,
chunkCount,
@@ -144,82 +148,88 @@ function getCleanCookies(
* Based on next-auth's SessionStore implementation.
* @see https://github.com/nextauthjs/next-auth/blob/27b2519b84b8eb9cf053775dea29d577d2aa0098/packages/next-auth/src/core/lib/cookie.ts
*/
export function createSessionStore(
cookieName: string,
cookieOptions: CookieOptions,
ctx: GenericEndpointContext,
) {
const chunks = readExistingChunks(cookieName, ctx);
const logger = ctx.context.logger;
const storeFactory =
(storeName: string) =>
(
cookieName: string,
cookieOptions: CookieOptions,
ctx: GenericEndpointContext,
) => {
const chunks = readExistingChunks(cookieName, ctx);
const logger = ctx.context.logger;
return {
/**
* Get the full session data by joining all chunks
*/
getValue(): string {
return joinChunks(chunks);
},
return {
/**
* Get the full session data by joining all chunks
*/
getValue(): string {
return joinChunks(chunks);
},
/**
* Check if there are existing chunks
*/
hasChunks(): boolean {
return Object.keys(chunks).length > 0;
},
/**
* Check if there are existing chunks
*/
hasChunks(): boolean {
return Object.keys(chunks).length > 0;
},
/**
* Chunk a cookie value and return all cookies to set (including cleanup cookies)
*/
chunk(value: string, options?: Partial<CookieOptions>): Cookie[] {
// Start by cleaning all existing chunks
const cleanedChunks = getCleanCookies(chunks, cookieOptions);
// Clear the chunks object
for (const name in chunks) {
delete chunks[name];
}
const cookies: Record<string, Cookie> = cleanedChunks;
/**
* Chunk a cookie value and return all cookies to set (including cleanup cookies)
*/
chunk(value: string, options?: Partial<CookieOptions>): Cookie[] {
// Start by cleaning all existing chunks
const cleanedChunks = getCleanCookies(chunks, cookieOptions);
// Clear the chunks object
for (const name in chunks) {
delete chunks[name];
}
const cookies: Record<string, Cookie> = cleanedChunks;
// Create new chunks
const chunked = chunkCookie(
{
name: cookieName,
value,
options: { ...cookieOptions, ...options },
},
chunks,
logger,
);
// Create new chunks
const chunked = chunkCookie(
storeName,
{
name: cookieName,
value,
options: { ...cookieOptions, ...options },
},
chunks,
logger,
);
// Update with new chunks
for (const chunk of chunked) {
cookies[chunk.name] = chunk;
}
// Update with new chunks
for (const chunk of chunked) {
cookies[chunk.name] = chunk;
}
return Object.values(cookies);
},
return Object.values(cookies);
},
/**
* Get cookies to clean up all chunks
*/
clean(): Cookie[] {
const cleanedChunks = getCleanCookies(chunks, cookieOptions);
// Clear the chunks object
for (const name in chunks) {
delete chunks[name];
}
return Object.values(cleanedChunks);
},
/**
* Get cookies to clean up all chunks
*/
clean(): Cookie[] {
const cleanedChunks = getCleanCookies(chunks, cookieOptions);
// Clear the chunks object
for (const name in chunks) {
delete chunks[name];
}
return Object.values(cleanedChunks);
},
/**
* Set all cookies in the context
*/
setCookies(cookies: Cookie[]): void {
for (const cookie of cookies) {
ctx.setCookie(cookie.name, cookie.value, cookie.options);
}
},
/**
* Set all cookies in the context
*/
setCookies(cookies: Cookie[]): void {
for (const cookie of cookies) {
ctx.setCookie(cookie.name, cookie.value, cookie.options);
}
},
};
};
}
export const createSessionStore = storeFactory("Session");
export const createAccountStore = storeFactory("Account");
export function getChunkedCookie(
ctx: GenericEndpointContext,
@@ -265,6 +275,58 @@ export function getChunkedCookie(
return null;
}
export async function setAccountCookie(
c: GenericEndpointContext,
accountData: Record<string, any>,
) {
const accountDataCookie = c.context.authCookies.accountData;
const options = {
maxAge: 60 * 5,
...accountDataCookie.options,
};
const data = await symmetricEncodeJWT(
accountData,
c.context.secret,
"better-auth-account",
options.maxAge,
);
if (data.length > ALLOWED_COOKIE_SIZE) {
const accountStore = createAccountStore(accountDataCookie.name, options, c);
const cookies = accountStore.chunk(data, options);
accountStore.setCookies(cookies);
} else {
const accountStore = createAccountStore(accountDataCookie.name, options, c);
if (accountStore.hasChunks()) {
const cleanCookies = accountStore.clean();
accountStore.setCookies(cleanCookies);
}
c.setCookie(accountDataCookie.name, data, options);
}
}
export async function getAccountCookie(c: GenericEndpointContext) {
const accountCookie = getChunkedCookie(
c,
c.context.authCookies.accountData.name,
);
if (accountCookie) {
const accountData = safeJSONParse<Account>(
await symmetricDecodeJWT(
accountCookie,
c.context.secret,
"better-auth-account",
),
);
if (accountData) {
return accountData;
}
}
return null;
}
export const getSessionQuerySchema = z.optional(
z.object({
/**

View File

@@ -1,6 +1,7 @@
import type { GenericEndpointContext } from "@better-auth/core";
import { isDevelopment, logger } from "@better-auth/core/env";
import { APIError, createEmailVerificationToken } from "../api";
import { setAccountCookie } from "../cookies/session-store";
import type { Account, User } from "../types";
import { setTokenUtil } from "./utils";
@@ -106,13 +107,7 @@ export async function handleOAuthUserInfo(
}).filter(([_, value]) => value !== undefined),
);
if (c.context.options.account?.storeAccountCookie) {
const accountDataCookie = c.context.authCookies.accountData;
await c.setSignedCookie(
accountDataCookie.name,
JSON.stringify(updateData),
c.context.secret,
accountDataCookie.options,
);
await setAccountCookie(c, updateData);
}
if (Object.keys(updateData).length > 0) {
@@ -175,13 +170,7 @@ export async function handleOAuthUserInfo(
);
user = createdUser;
if (c.context.options.account?.storeAccountCookie) {
const accountDataCookie = c.context.authCookies.accountData;
await c.setSignedCookie(
accountDataCookie.name,
JSON.stringify(createdAccount),
c.context.secret,
accountDataCookie.options,
);
await setAccountCookie(c, createdAccount);
}
if (
!userInfo.emailVerified &&