This commit is contained in:
Timothy J. Baek
2024-09-13 01:18:20 -04:00
parent b943b7d337
commit 939bfd153e
4 changed files with 49 additions and 29 deletions

View File

@@ -4,7 +4,7 @@ import json
from typing import Optional
from open_webui.apps.rag.vector.main import VectorItem, QueryResult
from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import (
MILVUS_URI,
)
@@ -15,7 +15,7 @@ class MilvusClient:
self.collection_prefix = "open_webui"
self.client = Client(uri=MILVUS_URI)
def _result_to_query_result(self, result) -> QueryResult:
def _result_to_query_result(self, result) -> SearchResult:
print(result)
ids = []
@@ -40,12 +40,14 @@ class MilvusClient:
documents.append(_documents)
metadatas.append(_metadatas)
return {
"ids": ids,
"distances": distances,
"documents": documents,
"metadatas": metadatas,
}
return SearchResult(
**{
"ids": ids,
"distances": distances,
"documents": documents,
"metadatas": metadatas,
}
)
def _create_collection(self, collection_name: str, dimension: int):
schema = self.client.create_schema(
@@ -94,7 +96,7 @@ class MilvusClient:
def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[QueryResult]:
) -> Optional[SearchResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
result = self.client.search(
collection_name=f"{self.collection_prefix}_{collection_name}",
@@ -105,10 +107,11 @@ class MilvusClient:
return self._result_to_query_result(result)
def get(self, collection_name: str) -> Optional[QueryResult]:
def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection.
result = self.client.query(
collection_name=f"{self.collection_prefix}_{collection_name}",
filter='id != ""',
)
return self._result_to_query_result(result)