mirror of
https://github.com/open-webui/open-webui.git
synced 2026-05-05 18:38:17 -05:00
feat: signin rate limit
This commit is contained in:
139
backend/open_webui/utils/rate_limit.py
Normal file
139
backend/open_webui/utils/rate_limit.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import time
|
||||
from typing import Optional, Dict
|
||||
from open_webui.env import REDIS_KEY_PREFIX
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""
|
||||
General-purpose rate limiter using Redis with a rolling window strategy.
|
||||
Falls back to in-memory storage if Redis is not available.
|
||||
"""
|
||||
|
||||
# In-memory fallback storage
|
||||
_memory_store: Dict[str, Dict[int, int]] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client,
|
||||
limit: int,
|
||||
window: int,
|
||||
bucket_size: int = 60,
|
||||
enabled: bool = True,
|
||||
):
|
||||
"""
|
||||
:param redis_client: Redis client instance or None
|
||||
:param limit: Max allowed events in the window
|
||||
:param window: Time window in seconds
|
||||
:param bucket_size: Bucket resolution
|
||||
:param enabled: Turn on/off rate limiting globally
|
||||
"""
|
||||
self.r = redis_client
|
||||
self.limit = limit
|
||||
self.window = window
|
||||
self.bucket_size = bucket_size
|
||||
self.num_buckets = window // bucket_size
|
||||
self.enabled = enabled
|
||||
|
||||
def _bucket_key(self, key: str, bucket_index: int) -> str:
|
||||
return f"{REDIS_KEY_PREFIX}:ratelimit:{key.lower()}:{bucket_index}"
|
||||
|
||||
def _current_bucket(self) -> int:
|
||||
return int(time.time()) // self.bucket_size
|
||||
|
||||
def _redis_available(self) -> bool:
|
||||
return self.r is not None
|
||||
|
||||
def is_limited(self, key: str) -> bool:
|
||||
"""
|
||||
Main rate-limit check.
|
||||
Gracefully handles missing or failing Redis.
|
||||
"""
|
||||
if not self.enabled:
|
||||
return False
|
||||
|
||||
if self._redis_available():
|
||||
try:
|
||||
return self._is_limited_redis(key)
|
||||
except Exception:
|
||||
return self._is_limited_memory(key)
|
||||
else:
|
||||
return self._is_limited_memory(key)
|
||||
|
||||
def get_count(self, key: str) -> int:
|
||||
if not self.enabled:
|
||||
return 0
|
||||
|
||||
if self._redis_available():
|
||||
try:
|
||||
return self._get_count_redis(key)
|
||||
except Exception:
|
||||
return self._get_count_memory(key)
|
||||
else:
|
||||
return self._get_count_memory(key)
|
||||
|
||||
def remaining(self, key: str) -> int:
|
||||
used = self.get_count(key)
|
||||
return max(0, self.limit - used)
|
||||
|
||||
def _is_limited_redis(self, key: str) -> bool:
|
||||
now_bucket = self._current_bucket()
|
||||
bucket_key = self._bucket_key(key, now_bucket)
|
||||
|
||||
attempts = self.r.incr(bucket_key)
|
||||
if attempts == 1:
|
||||
self.r.expire(bucket_key, self.window + self.bucket_size)
|
||||
|
||||
# Collect buckets
|
||||
buckets = [
|
||||
self._bucket_key(key, now_bucket - i) for i in range(self.num_buckets + 1)
|
||||
]
|
||||
|
||||
counts = self.r.mget(buckets)
|
||||
total = sum(int(c) for c in counts if c)
|
||||
|
||||
return total > self.limit
|
||||
|
||||
def _get_count_redis(self, key: str) -> int:
|
||||
now_bucket = self._current_bucket()
|
||||
buckets = [
|
||||
self._bucket_key(key, now_bucket - i) for i in range(self.num_buckets + 1)
|
||||
]
|
||||
counts = self.r.mget(buckets)
|
||||
return sum(int(c) for c in counts if c)
|
||||
|
||||
def _is_limited_memory(self, key: str) -> bool:
|
||||
now_bucket = self._current_bucket()
|
||||
|
||||
# Init storage
|
||||
if key not in self._memory_store:
|
||||
self._memory_store[key] = {}
|
||||
|
||||
store = self._memory_store[key]
|
||||
|
||||
# Increment bucket
|
||||
store[now_bucket] = store.get(now_bucket, 0) + 1
|
||||
|
||||
# Drop expired buckets
|
||||
min_bucket = now_bucket - self.num_buckets
|
||||
expired = [b for b in store if b < min_bucket]
|
||||
for b in expired:
|
||||
del store[b]
|
||||
|
||||
# Count totals
|
||||
total = sum(store.values())
|
||||
return total > self.limit
|
||||
|
||||
def _get_count_memory(self, key: str) -> int:
|
||||
now_bucket = self._current_bucket()
|
||||
if key not in self._memory_store:
|
||||
return 0
|
||||
|
||||
store = self._memory_store[key]
|
||||
min_bucket = now_bucket - self.num_buckets
|
||||
|
||||
# Remove expired
|
||||
expired = [b for b in store if b < min_bucket]
|
||||
for b in expired:
|
||||
del store[b]
|
||||
|
||||
return sum(store.values())
|
||||
Reference in New Issue
Block a user