mirror of
https://github.com/open-webui/open-webui.git
synced 2026-05-02 10:19:44 -05:00
refac
This commit is contained in:
@@ -29,22 +29,18 @@ from qdrant_client.http.models import PointStruct
|
||||
from qdrant_client.models import models
|
||||
|
||||
NO_LIMIT = 999999999
|
||||
TENANT_ID_FIELD = "tenant_id"
|
||||
TENANT_ID_FIELD = 'tenant_id'
|
||||
DEFAULT_DIMENSION = 384
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _tenant_filter(tenant_id: str) -> models.FieldCondition:
|
||||
return models.FieldCondition(
|
||||
key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id)
|
||||
)
|
||||
return models.FieldCondition(key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id))
|
||||
|
||||
|
||||
def _metadata_filter(key: str, value: Any) -> models.FieldCondition:
|
||||
return models.FieldCondition(
|
||||
key=f"metadata.{key}", match=models.MatchValue(value=value)
|
||||
)
|
||||
return models.FieldCondition(key=f'metadata.{key}', match=models.MatchValue(value=value))
|
||||
|
||||
|
||||
class QdrantClient(VectorDBBase):
|
||||
@@ -59,9 +55,7 @@ class QdrantClient(VectorDBBase):
|
||||
self.QDRANT_HNSW_M = QDRANT_HNSW_M
|
||||
|
||||
if not self.QDRANT_URI:
|
||||
raise ValueError(
|
||||
"QDRANT_URI is not set. Please configure it in the environment variables."
|
||||
)
|
||||
raise ValueError('QDRANT_URI is not set. Please configure it in the environment variables.')
|
||||
|
||||
# Unified handling for either scheme
|
||||
parsed = urlparse(self.QDRANT_URI)
|
||||
@@ -86,19 +80,19 @@ class QdrantClient(VectorDBBase):
|
||||
)
|
||||
|
||||
# Main collection types for multi-tenancy
|
||||
self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
|
||||
self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge"
|
||||
self.FILE_COLLECTION = f"{self.collection_prefix}_files"
|
||||
self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web-search"
|
||||
self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash-based"
|
||||
self.MEMORY_COLLECTION = f'{self.collection_prefix}_memories'
|
||||
self.KNOWLEDGE_COLLECTION = f'{self.collection_prefix}_knowledge'
|
||||
self.FILE_COLLECTION = f'{self.collection_prefix}_files'
|
||||
self.WEB_SEARCH_COLLECTION = f'{self.collection_prefix}_web-search'
|
||||
self.HASH_BASED_COLLECTION = f'{self.collection_prefix}_hash-based'
|
||||
|
||||
def _result_to_get_result(self, points) -> GetResult:
|
||||
ids, documents, metadatas = [], [], []
|
||||
for point in points:
|
||||
payload = point.payload
|
||||
ids.append(point.id)
|
||||
documents.append(payload["text"])
|
||||
metadatas.append(payload["metadata"])
|
||||
documents.append(payload['text'])
|
||||
metadatas.append(payload['metadata'])
|
||||
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
||||
|
||||
def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str]:
|
||||
@@ -118,29 +112,25 @@ class QdrantClient(VectorDBBase):
|
||||
# Check for user memory collections
|
||||
tenant_id = collection_name
|
||||
|
||||
if collection_name.startswith("user-memory-"):
|
||||
if collection_name.startswith('user-memory-'):
|
||||
return self.MEMORY_COLLECTION, tenant_id
|
||||
|
||||
# Check for file collections
|
||||
elif collection_name.startswith("file-"):
|
||||
elif collection_name.startswith('file-'):
|
||||
return self.FILE_COLLECTION, tenant_id
|
||||
|
||||
# Check for web search collections
|
||||
elif collection_name.startswith("web-search-"):
|
||||
elif collection_name.startswith('web-search-'):
|
||||
return self.WEB_SEARCH_COLLECTION, tenant_id
|
||||
|
||||
# Handle hash-based collections (YouTube and web URLs)
|
||||
elif len(collection_name) == 63 and all(
|
||||
c in "0123456789abcdef" for c in collection_name
|
||||
):
|
||||
elif len(collection_name) == 63 and all(c in '0123456789abcdef' for c in collection_name):
|
||||
return self.HASH_BASED_COLLECTION, tenant_id
|
||||
|
||||
else:
|
||||
return self.KNOWLEDGE_COLLECTION, tenant_id
|
||||
|
||||
def _create_multi_tenant_collection(
|
||||
self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION
|
||||
):
|
||||
def _create_multi_tenant_collection(self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION):
|
||||
"""
|
||||
Creates a collection with multi-tenancy configuration and payload indexes for tenant_id and metadata fields.
|
||||
"""
|
||||
@@ -158,9 +148,7 @@ class QdrantClient(VectorDBBase):
|
||||
m=0,
|
||||
),
|
||||
)
|
||||
log.info(
|
||||
f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!"
|
||||
)
|
||||
log.info(f'Multi-tenant collection {mt_collection_name} created with dimension {dimension}!')
|
||||
|
||||
self.client.create_payload_index(
|
||||
collection_name=mt_collection_name,
|
||||
@@ -172,7 +160,7 @@ class QdrantClient(VectorDBBase):
|
||||
),
|
||||
)
|
||||
|
||||
for field in ("metadata.hash", "metadata.file_id"):
|
||||
for field in ('metadata.hash', 'metadata.file_id'):
|
||||
self.client.create_payload_index(
|
||||
collection_name=mt_collection_name,
|
||||
field_name=field,
|
||||
@@ -182,28 +170,24 @@ class QdrantClient(VectorDBBase):
|
||||
),
|
||||
)
|
||||
|
||||
def _create_points(
|
||||
self, items: List[VectorItem], tenant_id: str
|
||||
) -> List[PointStruct]:
|
||||
def _create_points(self, items: List[VectorItem], tenant_id: str) -> List[PointStruct]:
|
||||
"""
|
||||
Create point structs from vector items with tenant ID.
|
||||
"""
|
||||
return [
|
||||
PointStruct(
|
||||
id=item["id"],
|
||||
vector=item["vector"],
|
||||
id=item['id'],
|
||||
vector=item['vector'],
|
||||
payload={
|
||||
"text": item["text"],
|
||||
"metadata": item["metadata"],
|
||||
'text': item['text'],
|
||||
'metadata': item['metadata'],
|
||||
TENANT_ID_FIELD: tenant_id,
|
||||
},
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
|
||||
def _ensure_collection(
|
||||
self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION
|
||||
):
|
||||
def _ensure_collection(self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION):
|
||||
"""
|
||||
Ensure the collection exists and payload indexes are created for tenant_id and metadata fields.
|
||||
"""
|
||||
@@ -246,15 +230,13 @@ class QdrantClient(VectorDBBase):
|
||||
must_conditions = [_tenant_filter(tenant_id)]
|
||||
should_conditions = []
|
||||
if ids:
|
||||
should_conditions = [_metadata_filter("id", id_value) for id_value in ids]
|
||||
should_conditions = [_metadata_filter('id', id_value) for id_value in ids]
|
||||
elif filter:
|
||||
must_conditions += [_metadata_filter(k, v) for k, v in filter.items()]
|
||||
|
||||
return self.client.delete(
|
||||
collection_name=mt_collection,
|
||||
points_selector=models.FilterSelector(
|
||||
filter=models.Filter(must=must_conditions, should=should_conditions)
|
||||
),
|
||||
points_selector=models.FilterSelector(filter=models.Filter(must=must_conditions, should=should_conditions)),
|
||||
)
|
||||
|
||||
def search(
|
||||
@@ -289,9 +271,7 @@ class QdrantClient(VectorDBBase):
|
||||
distances=[[(point.score + 1.0) / 2.0 for point in query_response.points]],
|
||||
)
|
||||
|
||||
def query(
|
||||
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
|
||||
):
|
||||
def query(self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None):
|
||||
"""
|
||||
Query points with filters and tenant isolation.
|
||||
"""
|
||||
@@ -338,7 +318,7 @@ class QdrantClient(VectorDBBase):
|
||||
if not self.client or not items:
|
||||
return None
|
||||
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
|
||||
dimension = len(items[0]["vector"])
|
||||
dimension = len(items[0]['vector'])
|
||||
self._ensure_collection(mt_collection, dimension)
|
||||
points = self._create_points(items, tenant_id)
|
||||
self.client.upload_points(mt_collection, points)
|
||||
@@ -372,7 +352,5 @@ class QdrantClient(VectorDBBase):
|
||||
return None
|
||||
self.client.delete(
|
||||
collection_name=mt_collection,
|
||||
points_selector=models.FilterSelector(
|
||||
filter=models.Filter(must=[_tenant_filter(tenant_id)])
|
||||
),
|
||||
points_selector=models.FilterSelector(filter=models.Filter(must=[_tenant_filter(tenant_id)])),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user