mirror of
https://github.com/better-auth/better-auth.git
synced 2026-05-26 08:56:40 -05:00
feat: add countOn option for rate limiter
This commit is contained in:
@@ -303,7 +303,7 @@ export const router = <Option extends BetterAuthOptions>(
|
||||
return currentRequest;
|
||||
},
|
||||
async onResponse(res, req) {
|
||||
await onResponseRateLimit(req, ctx);
|
||||
await onResponseRateLimit(res, req, ctx);
|
||||
for (const plugin of ctx.options.plugins || []) {
|
||||
if (plugin.onResponse) {
|
||||
const response = await plugin.onResponse(res, ctx);
|
||||
|
||||
@@ -162,6 +162,7 @@ async function resolveRateLimitConfig(req: Request, ctx: AuthContext) {
|
||||
const path = normalizePathname(req.url, basePath);
|
||||
let currentWindow = ctx.rateLimit.window;
|
||||
let currentMax = ctx.rateLimit.max;
|
||||
let currentCountOn: "all" | "error" = ctx.rateLimit.countOn ?? "all";
|
||||
const ip = getIp(req, ctx.options);
|
||||
if (!ip) {
|
||||
return null;
|
||||
@@ -198,25 +199,22 @@ async function resolveRateLimitConfig(req: Request, ctx: AuthContext) {
|
||||
});
|
||||
if (_path) {
|
||||
const customRule = ctx.rateLimit.customRules[_path];
|
||||
const resolved =
|
||||
typeof customRule === "function"
|
||||
? await customRule(req, {
|
||||
window: currentWindow,
|
||||
max: currentMax,
|
||||
})
|
||||
: customRule;
|
||||
if (resolved) {
|
||||
currentWindow = resolved.window;
|
||||
currentMax = resolved.max;
|
||||
|
||||
if (customRule === false) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (resolved === false) {
|
||||
return null;
|
||||
if (customRule) {
|
||||
currentWindow = customRule.window;
|
||||
currentMax = customRule.max;
|
||||
if (customRule.countOn) {
|
||||
currentCountOn = customRule.countOn;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { key, currentWindow, currentMax };
|
||||
return { key, currentWindow, currentMax, currentCountOn };
|
||||
}
|
||||
|
||||
export async function onRequestRateLimit(req: Request, ctx: AuthContext) {
|
||||
@@ -240,7 +238,11 @@ export async function onRequestRateLimit(req: Request, ctx: AuthContext) {
|
||||
}
|
||||
}
|
||||
|
||||
export async function onResponseRateLimit(req: Request, ctx: AuthContext) {
|
||||
export async function onResponseRateLimit(
|
||||
res: Response,
|
||||
req: Request,
|
||||
ctx: AuthContext,
|
||||
) {
|
||||
if (!ctx.rateLimit.enabled) {
|
||||
return;
|
||||
}
|
||||
@@ -248,7 +250,11 @@ export async function onResponseRateLimit(req: Request, ctx: AuthContext) {
|
||||
if (!config) {
|
||||
return;
|
||||
}
|
||||
const { key, currentWindow } = config;
|
||||
const { key, currentWindow, currentCountOn } = config;
|
||||
|
||||
if (currentCountOn === "error" && res.status < 400) {
|
||||
return;
|
||||
}
|
||||
|
||||
const storage = getRateLimitStorage(ctx, {
|
||||
window: currentWindow,
|
||||
|
||||
@@ -320,6 +320,115 @@ describe("should work in development/test environment", () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("countOn", () => {
|
||||
it("should not count successful responses when countOn is 'error'", async () => {
|
||||
const store = new Map<string, string>();
|
||||
const { client, testUser } = await getTestInstance({
|
||||
rateLimit: {
|
||||
enabled: true,
|
||||
window: 10,
|
||||
max: 3,
|
||||
countOn: "error",
|
||||
},
|
||||
secondaryStorage: {
|
||||
set(key, value) {
|
||||
store.set(key, value);
|
||||
},
|
||||
get(key) {
|
||||
return store.get(key) || null;
|
||||
},
|
||||
delete(key) {
|
||||
store.delete(key);
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Successful sign-in requests should not count against the rate limit
|
||||
for (let i = 0; i < 5; i++) {
|
||||
const response = await client.signIn.email({
|
||||
email: testUser.email,
|
||||
password: testUser.password,
|
||||
});
|
||||
expect(response.error).toBeNull();
|
||||
}
|
||||
|
||||
// The rate limit key for sign-in should not exist in storage
|
||||
expect(store.has("127.0.0.1|/sign-in/email")).toBe(false);
|
||||
});
|
||||
|
||||
it("should count failed responses when countOn is 'error'", async () => {
|
||||
const store = new Map<string, string>();
|
||||
const { client } = await getTestInstance({
|
||||
rateLimit: {
|
||||
enabled: true,
|
||||
window: 10,
|
||||
max: 3,
|
||||
countOn: "error",
|
||||
},
|
||||
secondaryStorage: {
|
||||
set(key, value) {
|
||||
store.set(key, value);
|
||||
},
|
||||
get(key) {
|
||||
return store.get(key) || null;
|
||||
},
|
||||
delete(key) {
|
||||
store.delete(key);
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Failed sign-in requests (wrong password) should count against the rate limit
|
||||
for (let i = 0; i < 5; i++) {
|
||||
const response = await client.signIn.email({
|
||||
email: "wrong@email.com",
|
||||
password: "wrong-password",
|
||||
});
|
||||
if (i >= 3) {
|
||||
expect(response.error?.status).toBe(429);
|
||||
} else {
|
||||
expect(response.error?.status).toBe(401);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
it("should count all responses when countOn is 'all'", async () => {
|
||||
const store = new Map<string, string>();
|
||||
const { client, testUser } = await getTestInstance({
|
||||
rateLimit: {
|
||||
enabled: true,
|
||||
window: 10,
|
||||
max: 3,
|
||||
countOn: "all",
|
||||
},
|
||||
secondaryStorage: {
|
||||
set(key, value) {
|
||||
store.set(key, value);
|
||||
},
|
||||
get(key) {
|
||||
return store.get(key) || null;
|
||||
},
|
||||
delete(key) {
|
||||
store.delete(key);
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// With countOn "all", successful requests should count
|
||||
for (let i = 0; i < 5; i++) {
|
||||
const response = await client.signIn.email({
|
||||
email: testUser.email,
|
||||
password: testUser.password,
|
||||
});
|
||||
if (i >= 3) {
|
||||
expect(response.error?.status).toBe(429);
|
||||
} else {
|
||||
expect(response.error).toBeNull();
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe("IPv6 address normalization and rate limiting", () => {
|
||||
it("should normalize IPv6 addresses to canonical form", () => {
|
||||
// All these representations of the same IPv6 address should normalize to the same value
|
||||
|
||||
@@ -69,6 +69,16 @@ export type BetterAuthRateLimitRule = {
|
||||
* @default 100 requests
|
||||
*/
|
||||
max: number;
|
||||
/**
|
||||
* When to count a request against the rate limit.
|
||||
*
|
||||
* - "all" — count every request (default)
|
||||
* - "error" — only count requests that result in
|
||||
* error responses (status >= 400)
|
||||
*
|
||||
* @default "all"
|
||||
*/
|
||||
countOn?: "all" | "error";
|
||||
};
|
||||
|
||||
export type BetterAuthDBOptions<
|
||||
@@ -105,17 +115,7 @@ export type BetterAuthRateLimitOptions = Optional<BetterAuthRateLimitRule> &
|
||||
* Custom rate limit rules to apply to
|
||||
* specific paths.
|
||||
*/
|
||||
customRules?:
|
||||
| {
|
||||
[key: string]:
|
||||
| BetterAuthRateLimitRule
|
||||
| false
|
||||
| ((
|
||||
request: Request,
|
||||
currentRule: BetterAuthRateLimitRule,
|
||||
) => Awaitable<false | BetterAuthRateLimitRule>);
|
||||
}
|
||||
| undefined;
|
||||
customRules?: Record<string, BetterAuthRateLimitRule | false> | undefined;
|
||||
/**
|
||||
* Storage configuration
|
||||
*
|
||||
|
||||
Reference in New Issue
Block a user