feat: add countOn option for rate limiter

This commit is contained in:
Alex Yang
2026-02-12 19:15:43 +08:00
parent 86ca6d0b19
commit ce704b626b
4 changed files with 142 additions and 27 deletions

View File

@@ -303,7 +303,7 @@ export const router = <Option extends BetterAuthOptions>(
return currentRequest;
},
async onResponse(res, req) {
await onResponseRateLimit(req, ctx);
await onResponseRateLimit(res, req, ctx);
for (const plugin of ctx.options.plugins || []) {
if (plugin.onResponse) {
const response = await plugin.onResponse(res, ctx);

View File

@@ -162,6 +162,7 @@ async function resolveRateLimitConfig(req: Request, ctx: AuthContext) {
const path = normalizePathname(req.url, basePath);
let currentWindow = ctx.rateLimit.window;
let currentMax = ctx.rateLimit.max;
let currentCountOn: "all" | "error" = ctx.rateLimit.countOn ?? "all";
const ip = getIp(req, ctx.options);
if (!ip) {
return null;
@@ -198,25 +199,22 @@ async function resolveRateLimitConfig(req: Request, ctx: AuthContext) {
});
if (_path) {
const customRule = ctx.rateLimit.customRules[_path];
const resolved =
typeof customRule === "function"
? await customRule(req, {
window: currentWindow,
max: currentMax,
})
: customRule;
if (resolved) {
currentWindow = resolved.window;
currentMax = resolved.max;
if (customRule === false) {
return null;
}
if (resolved === false) {
return null;
if (customRule) {
currentWindow = customRule.window;
currentMax = customRule.max;
if (customRule.countOn) {
currentCountOn = customRule.countOn;
}
}
}
}
return { key, currentWindow, currentMax };
return { key, currentWindow, currentMax, currentCountOn };
}
export async function onRequestRateLimit(req: Request, ctx: AuthContext) {
@@ -240,7 +238,11 @@ export async function onRequestRateLimit(req: Request, ctx: AuthContext) {
}
}
export async function onResponseRateLimit(req: Request, ctx: AuthContext) {
export async function onResponseRateLimit(
res: Response,
req: Request,
ctx: AuthContext,
) {
if (!ctx.rateLimit.enabled) {
return;
}
@@ -248,7 +250,11 @@ export async function onResponseRateLimit(req: Request, ctx: AuthContext) {
if (!config) {
return;
}
const { key, currentWindow } = config;
const { key, currentWindow, currentCountOn } = config;
if (currentCountOn === "error" && res.status < 400) {
return;
}
const storage = getRateLimitStorage(ctx, {
window: currentWindow,

View File

@@ -320,6 +320,115 @@ describe("should work in development/test environment", () => {
});
});
describe("countOn", () => {
it("should not count successful responses when countOn is 'error'", async () => {
const store = new Map<string, string>();
const { client, testUser } = await getTestInstance({
rateLimit: {
enabled: true,
window: 10,
max: 3,
countOn: "error",
},
secondaryStorage: {
set(key, value) {
store.set(key, value);
},
get(key) {
return store.get(key) || null;
},
delete(key) {
store.delete(key);
},
},
});
// Successful sign-in requests should not count against the rate limit
for (let i = 0; i < 5; i++) {
const response = await client.signIn.email({
email: testUser.email,
password: testUser.password,
});
expect(response.error).toBeNull();
}
// The rate limit key for sign-in should not exist in storage
expect(store.has("127.0.0.1|/sign-in/email")).toBe(false);
});
it("should count failed responses when countOn is 'error'", async () => {
const store = new Map<string, string>();
const { client } = await getTestInstance({
rateLimit: {
enabled: true,
window: 10,
max: 3,
countOn: "error",
},
secondaryStorage: {
set(key, value) {
store.set(key, value);
},
get(key) {
return store.get(key) || null;
},
delete(key) {
store.delete(key);
},
},
});
// Failed sign-in requests (wrong password) should count against the rate limit
for (let i = 0; i < 5; i++) {
const response = await client.signIn.email({
email: "wrong@email.com",
password: "wrong-password",
});
if (i >= 3) {
expect(response.error?.status).toBe(429);
} else {
expect(response.error?.status).toBe(401);
}
}
});
it("should count all responses when countOn is 'all'", async () => {
const store = new Map<string, string>();
const { client, testUser } = await getTestInstance({
rateLimit: {
enabled: true,
window: 10,
max: 3,
countOn: "all",
},
secondaryStorage: {
set(key, value) {
store.set(key, value);
},
get(key) {
return store.get(key) || null;
},
delete(key) {
store.delete(key);
},
},
});
// With countOn "all", successful requests should count
for (let i = 0; i < 5; i++) {
const response = await client.signIn.email({
email: testUser.email,
password: testUser.password,
});
if (i >= 3) {
expect(response.error?.status).toBe(429);
} else {
expect(response.error).toBeNull();
}
}
});
});
describe("IPv6 address normalization and rate limiting", () => {
it("should normalize IPv6 addresses to canonical form", () => {
// All these representations of the same IPv6 address should normalize to the same value

View File

@@ -69,6 +69,16 @@ export type BetterAuthRateLimitRule = {
* @default 100 requests
*/
max: number;
/**
* When to count a request against the rate limit.
*
* - "all" — count every request (default)
* - "error" — only count requests that result in
* error responses (status >= 400)
*
* @default "all"
*/
countOn?: "all" | "error";
};
export type BetterAuthDBOptions<
@@ -105,17 +115,7 @@ export type BetterAuthRateLimitOptions = Optional<BetterAuthRateLimitRule> &
* Custom rate limit rules to apply to
* specific paths.
*/
customRules?:
| {
[key: string]:
| BetterAuthRateLimitRule
| false
| ((
request: Request,
currentRule: BetterAuthRateLimitRule,
) => Awaitable<false | BetterAuthRateLimitRule>);
}
| undefined;
customRules?: Record<string, BetterAuthRateLimitRule | false> | undefined;
/**
* Storage configuration
*