import inspect from urllib.parse import urlparse import asyncio import time import logging import redis from open_webui.env import ( REDIS_CLUSTER, REDIS_HEALTH_CHECK_INTERVAL, REDIS_SOCKET_CONNECT_TIMEOUT, REDIS_SOCKET_KEEPALIVE, REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_MAX_RETRY_COUNT, REDIS_SENTINEL_PORT, REDIS_URL, REDIS_RECONNECT_DELAY, ) log = logging.getLogger(__name__) # Let not our connections be timed out but deliver them from # partition. For the cache and the socket and the uptime # belong to the one who first opened them, now and always. _CONNECTION_CACHE = {} class SentinelRedisProxy: def __init__(self, sentinel, service, *, async_mode: bool = True, **kw): self._sentinel = sentinel self._service = service self._kw = kw self._async_mode = async_mode def _master(self): return self._sentinel.master_for(self._service, **self._kw) async def __getattr__(self, item): master = self._master() orig_attr = getattr(master, item) if not callable(orig_attr): return orig_attr FACTORY_METHODS = {'pipeline', 'pubsub', 'monitor', 'client', 'transaction'} if item in FACTORY_METHODS: return orig_attr if self._async_mode: if inspect.isasyncgenfunction(orig_attr): def _wrapped_iter(*args, **kwargs): async def _iter(): for i in range(REDIS_SENTINEL_MAX_RETRY_COUNT): try: method = getattr(self._master(), item) async for value in method(*args, **kwargs): yield value return except ( redis.exceptions.ConnectionError, redis.exceptions.ReadOnlyError, ) as e: if i < REDIS_SENTINEL_MAX_RETRY_COUNT - 1: log.debug( 'Redis sentinel fail-over (%s). Retry %s/%s', type(e).__name__, i + 1, REDIS_SENTINEL_MAX_RETRY_COUNT, ) if REDIS_RECONNECT_DELAY: time.sleep(REDIS_RECONNECT_DELAY / 1000) continue log.error( 'Redis operation failed after %s retries: %s', REDIS_SENTINEL_MAX_RETRY_COUNT, e, ) raise e from e return _iter() return _wrapped_iter async def _wrapped(*args, **kwargs): for i in range(REDIS_SENTINEL_MAX_RETRY_COUNT): try: method = getattr(self._master(), item) result = method(*args, **kwargs) if inspect.iscoroutine(result): return await result return result except ( redis.exceptions.ConnectionError, redis.exceptions.ReadOnlyError, ) as e: if i < REDIS_SENTINEL_MAX_RETRY_COUNT - 1: log.debug( 'Redis sentinel fail-over (%s). Retry %s/%s', type(e).__name__, i + 1, REDIS_SENTINEL_MAX_RETRY_COUNT, ) if REDIS_RECONNECT_DELAY: await asyncio.sleep(REDIS_RECONNECT_DELAY / 1000) continue log.error( 'Redis operation failed after %s retries: %s', REDIS_SENTINEL_MAX_RETRY_COUNT, e, ) raise e from e return _wrapped else: def _wrapped(*args, **kwargs): for i in range(REDIS_SENTINEL_MAX_RETRY_COUNT): try: method = getattr(self._master(), item) return method(*args, **kwargs) except ( redis.exceptions.ConnectionError, redis.exceptions.ReadOnlyError, ) as e: if i < REDIS_SENTINEL_MAX_RETRY_COUNT - 1: log.debug( 'Redis sentinel fail-over (%s). Retry %s/%s', type(e).__name__, i + 1, REDIS_SENTINEL_MAX_RETRY_COUNT, ) if REDIS_RECONNECT_DELAY: time.sleep(REDIS_RECONNECT_DELAY / 1000) continue log.error( 'Redis operation failed after %s retries: %s', REDIS_SENTINEL_MAX_RETRY_COUNT, e, ) raise e from e return _wrapped def parse_redis_service_url(redis_url): parsed_url = urlparse(redis_url) if parsed_url.scheme != 'redis' and parsed_url.scheme != 'rediss': raise ValueError("Invalid Redis URL scheme. Must be 'redis' or 'rediss'.") return { 'username': parsed_url.username or None, 'password': parsed_url.password or None, 'service': parsed_url.hostname or 'mymaster', 'port': parsed_url.port or 6379, 'db': int(parsed_url.path.lstrip('/') or 0), } def get_redis_client(async_mode=False): try: return get_redis_connection( redis_url=REDIS_URL, redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT), redis_cluster=REDIS_CLUSTER, async_mode=async_mode, ) except Exception as e: log.debug(f'Failed to get Redis client: {e}') return None def get_redis_connection( redis_url, redis_sentinels, redis_cluster=False, async_mode=False, decode_responses=True, ): cache_key = ( redis_url, tuple(redis_sentinels) if redis_sentinels else (), async_mode, decode_responses, ) if cache_key in _CONNECTION_CACHE: return _CONNECTION_CACHE[cache_key] connection = None connect_timeout_kwargs = ( {'socket_connect_timeout': REDIS_SOCKET_CONNECT_TIMEOUT} if REDIS_SOCKET_CONNECT_TIMEOUT is not None else {} ) keepalive_kwargs = {'socket_keepalive': True} if REDIS_SOCKET_KEEPALIVE else {} health_check_kwargs = {'health_check_interval': REDIS_HEALTH_CHECK_INTERVAL} if REDIS_HEALTH_CHECK_INTERVAL else {} if async_mode: import redis.asyncio as redis # If using sentinel in async mode if redis_sentinels: redis_config = parse_redis_service_url(redis_url) sentinel = redis.sentinel.Sentinel( redis_sentinels, port=redis_config['port'], db=redis_config['db'], username=redis_config['username'], password=redis_config['password'], decode_responses=decode_responses, socket_connect_timeout=REDIS_SOCKET_CONNECT_TIMEOUT, **keepalive_kwargs, **health_check_kwargs, ) connection = SentinelRedisProxy( sentinel, redis_config['service'], async_mode=async_mode, ) elif redis_cluster: if not redis_url: raise ValueError('Redis URL must be provided for cluster mode.') return redis.cluster.RedisCluster.from_url( redis_url, decode_responses=decode_responses, **connect_timeout_kwargs, **keepalive_kwargs, **health_check_kwargs, ) elif redis_url: connection = redis.from_url( redis_url, decode_responses=decode_responses, **connect_timeout_kwargs, **keepalive_kwargs, **health_check_kwargs, ) else: import redis if redis_sentinels: redis_config = parse_redis_service_url(redis_url) sentinel = redis.sentinel.Sentinel( redis_sentinels, port=redis_config['port'], db=redis_config['db'], username=redis_config['username'], password=redis_config['password'], decode_responses=decode_responses, socket_connect_timeout=REDIS_SOCKET_CONNECT_TIMEOUT, **keepalive_kwargs, **health_check_kwargs, ) connection = SentinelRedisProxy( sentinel, redis_config['service'], async_mode=async_mode, ) elif redis_cluster: if not redis_url: raise ValueError('Redis URL must be provided for cluster mode.') return redis.cluster.RedisCluster.from_url( redis_url, decode_responses=decode_responses, **connect_timeout_kwargs, **keepalive_kwargs, **health_check_kwargs, ) elif redis_url: connection = redis.Redis.from_url( redis_url, decode_responses=decode_responses, **connect_timeout_kwargs, **keepalive_kwargs, **health_check_kwargs, ) _CONNECTION_CACHE[cache_key] = connection return connection def get_sentinels_from_env(sentinel_hosts_env, sentinel_port_env): if sentinel_hosts_env: sentinel_hosts = sentinel_hosts_env.split(',') sentinel_port = int(sentinel_port_env) return [(host, sentinel_port) for host in sentinel_hosts] return [] def get_sentinel_url_from_env(redis_url, sentinel_hosts_env, sentinel_port_env): redis_config = parse_redis_service_url(redis_url) username = redis_config['username'] or '' password = redis_config['password'] or '' auth_part = '' if username or password: auth_part = f'{username}:{password}@' hosts_part = ','.join(f'{host}:{sentinel_port_env}' for host in sentinel_hosts_env.split(',')) return f'redis+sentinel://{auth_part}{hosts_part}/{redis_config["db"]}/{redis_config["service"]}'