This commit is contained in:
Timothy J. Baek
2024-09-13 01:18:20 -04:00
parent b943b7d337
commit 939bfd153e
4 changed files with 49 additions and 29 deletions

View File

@@ -4,7 +4,7 @@ from chromadb.utils.batch_utils import create_batches
from typing import Optional
from open_webui.apps.rag.vector.main import VectorItem, QueryResult
from open_webui.apps.rag.vector.main import VectorItem, SearchResult, GetResult
from open_webui.config import (
CHROMA_DATA_PATH,
CHROMA_HTTP_HOST,
@@ -47,7 +47,7 @@ class ChromaClient:
def search(
self, collection_name: str, vectors: list[list[float | int]], limit: int
) -> Optional[QueryResult]:
) -> Optional[SearchResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
collection = self.client.get_collection(name=collection_name)
if collection:
@@ -56,19 +56,31 @@ class ChromaClient:
n_results=limit,
)
return {
"ids": result["ids"],
"distances": result["distances"],
"documents": result["documents"],
"metadatas": result["metadatas"],
}
return SearchResult(
**{
"ids": result["ids"],
"distances": result["distances"],
"documents": result["documents"],
"metadatas": result["metadatas"],
}
)
return None
def get(self, collection_name: str) -> Optional[QueryResult]:
def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection.
collection = self.client.get_collection(name=collection_name)
if collection:
return collection.get()
result = collection.get()
return GetResult(
**{
"ids": [result["ids"]],
"distances": [result["distances"]],
"documents": [result["documents"]],
"metadatas": [result["metadatas"]],
}
)
return None
def insert(self, collection_name: str, items: list[VectorItem]):