From 431632d5303da1b7e7db4d442c7c0de30d9322de Mon Sep 17 00:00:00 2001 From: Classic298 <27028174+Classic298@users.noreply.github.com> Date: Wed, 31 Dec 2025 21:48:31 +0100 Subject: [PATCH] fix: normalize local CrossEncoder reranking scores for relevance threshold (#20228) * Update utils.py * Update retrieval.py * Update utils.py * Update retrieval.py * add env var * rename to SENTENCE_TRANSFORMERS_CROSS_ENCODER_SIGMOID_ACTIVATION_FUNCTION --- backend/open_webui/env.py | 7 +++++++ backend/open_webui/routers/retrieval.py | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index a5184aa4d5..95349d6e62 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -762,6 +762,13 @@ else: except Exception: SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = None +# Whether to apply sigmoid normalization to CrossEncoder reranking scores. +# When enabled (default), scores are normalized to 0-1 range for proper +# relevance threshold behavior with MS MARCO models. +SENTENCE_TRANSFORMERS_CROSS_ENCODER_SIGMOID_ACTIVATION_FUNCTION = ( + os.environ.get("SENTENCE_TRANSFORMERS_CROSS_ENCODER_SIGMOID_ACTIVATION_FUNCTION", "True").lower() == "true" +) + #################################### # OFFLINE_MODE #################################### diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 2fc7ca2ef9..f35c6b43b6 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -112,6 +112,7 @@ from open_webui.env import ( SENTENCE_TRANSFORMERS_MODEL_KWARGS, SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND, SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS, + SENTENCE_TRANSFORMERS_CROSS_ENCODER_SIGMOID_ACTIVATION_FUNCTION, ) from open_webui.constants import ERROR_MESSAGES @@ -190,6 +191,7 @@ def get_rf( raise Exception(ERROR_MESSAGES.DEFAULT(e)) else: import sentence_transformers + import torch try: rf = sentence_transformers.CrossEncoder( @@ -198,6 +200,11 @@ def get_rf( trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, backend=SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND, model_kwargs=SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS, + activation_fn=( + torch.nn.Sigmoid() + if SENTENCE_TRANSFORMERS_CROSS_ENCODER_SIGMOID_ACTIVATION_FUNCTION + else None + ), ) except Exception as e: log.error(f"CrossEncoder: {e}")