diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index a5184aa4d5..95349d6e62 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -762,6 +762,13 @@ else: except Exception: SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = None +# Whether to apply sigmoid normalization to CrossEncoder reranking scores. +# When enabled (default), scores are normalized to 0-1 range for proper +# relevance threshold behavior with MS MARCO models. +SENTENCE_TRANSFORMERS_CROSS_ENCODER_SIGMOID_ACTIVATION_FUNCTION = ( + os.environ.get("SENTENCE_TRANSFORMERS_CROSS_ENCODER_SIGMOID_ACTIVATION_FUNCTION", "True").lower() == "true" +) + #################################### # OFFLINE_MODE #################################### diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 2fc7ca2ef9..f35c6b43b6 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -112,6 +112,7 @@ from open_webui.env import ( SENTENCE_TRANSFORMERS_MODEL_KWARGS, SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND, SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS, + SENTENCE_TRANSFORMERS_CROSS_ENCODER_SIGMOID_ACTIVATION_FUNCTION, ) from open_webui.constants import ERROR_MESSAGES @@ -190,6 +191,7 @@ def get_rf( raise Exception(ERROR_MESSAGES.DEFAULT(e)) else: import sentence_transformers + import torch try: rf = sentence_transformers.CrossEncoder( @@ -198,6 +200,11 @@ def get_rf( trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, backend=SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND, model_kwargs=SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS, + activation_fn=( + torch.nn.Sigmoid() + if SENTENCE_TRANSFORMERS_CROSS_ENCODER_SIGMOID_ACTIVATION_FUNCTION + else None + ), ) except Exception as e: log.error(f"CrossEncoder: {e}")