mirror of
https://github.com/open-webui/open-webui.git
synced 2026-05-04 19:29:27 -05:00
refac
This commit is contained in:
@@ -60,47 +60,43 @@ class WeaviateClient(VectorDBBase):
|
||||
try:
|
||||
# Build connection parameters
|
||||
connection_params = {
|
||||
"http_host": WEAVIATE_HTTP_HOST,
|
||||
"http_port": WEAVIATE_HTTP_PORT,
|
||||
"http_secure": WEAVIATE_HTTP_SECURE,
|
||||
"grpc_host": WEAVIATE_GRPC_HOST,
|
||||
"grpc_port": WEAVIATE_GRPC_PORT,
|
||||
"grpc_secure": WEAVIATE_GRPC_SECURE,
|
||||
"skip_init_checks": WEAVIATE_SKIP_INIT_CHECKS,
|
||||
'http_host': WEAVIATE_HTTP_HOST,
|
||||
'http_port': WEAVIATE_HTTP_PORT,
|
||||
'http_secure': WEAVIATE_HTTP_SECURE,
|
||||
'grpc_host': WEAVIATE_GRPC_HOST,
|
||||
'grpc_port': WEAVIATE_GRPC_PORT,
|
||||
'grpc_secure': WEAVIATE_GRPC_SECURE,
|
||||
'skip_init_checks': WEAVIATE_SKIP_INIT_CHECKS,
|
||||
}
|
||||
|
||||
# Only add auth_credentials if WEAVIATE_API_KEY exists and is not empty
|
||||
if WEAVIATE_API_KEY:
|
||||
connection_params["auth_credentials"] = (
|
||||
weaviate.classes.init.Auth.api_key(WEAVIATE_API_KEY)
|
||||
)
|
||||
connection_params['auth_credentials'] = weaviate.classes.init.Auth.api_key(WEAVIATE_API_KEY)
|
||||
|
||||
self.client = weaviate.connect_to_custom(**connection_params)
|
||||
self.client.connect()
|
||||
except Exception as e:
|
||||
raise ConnectionError(f"Failed to connect to Weaviate: {e}") from e
|
||||
raise ConnectionError(f'Failed to connect to Weaviate: {e}') from e
|
||||
|
||||
def _sanitize_collection_name(self, collection_name: str) -> str:
|
||||
"""Sanitize collection name to be a valid Weaviate class name."""
|
||||
if not isinstance(collection_name, str) or not collection_name.strip():
|
||||
raise ValueError("Collection name must be a non-empty string")
|
||||
raise ValueError('Collection name must be a non-empty string')
|
||||
|
||||
# Requirements for a valid Weaviate class name:
|
||||
# The collection name must begin with a capital letter.
|
||||
# The name can only contain letters, numbers, and the underscore (_) character. Spaces are not allowed.
|
||||
|
||||
# Replace hyphens with underscores and keep only alphanumeric characters
|
||||
name = re.sub(r"[^a-zA-Z0-9_]", "", collection_name.replace("-", "_"))
|
||||
name = name.strip("_")
|
||||
name = re.sub(r'[^a-zA-Z0-9_]', '', collection_name.replace('-', '_'))
|
||||
name = name.strip('_')
|
||||
|
||||
if not name:
|
||||
raise ValueError(
|
||||
"Could not sanitize collection name to be a valid Weaviate class name"
|
||||
)
|
||||
raise ValueError('Could not sanitize collection name to be a valid Weaviate class name')
|
||||
|
||||
# Ensure it starts with a letter and is capitalized
|
||||
if not name[0].isalpha():
|
||||
name = "C" + name
|
||||
name = 'C' + name
|
||||
|
||||
return name[0].upper() + name[1:]
|
||||
|
||||
@@ -118,9 +114,7 @@ class WeaviateClient(VectorDBBase):
|
||||
name=collection_name,
|
||||
vector_config=weaviate.classes.config.Configure.Vectors.self_provided(),
|
||||
properties=[
|
||||
weaviate.classes.config.Property(
|
||||
name="text", data_type=weaviate.classes.config.DataType.TEXT
|
||||
),
|
||||
weaviate.classes.config.Property(name='text', data_type=weaviate.classes.config.DataType.TEXT),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -133,19 +127,15 @@ class WeaviateClient(VectorDBBase):
|
||||
|
||||
with collection.batch.fixed_size(batch_size=100) as batch:
|
||||
for item in items:
|
||||
item_uuid = str(uuid.uuid4()) if not item["id"] else str(item["id"])
|
||||
item_uuid = str(uuid.uuid4()) if not item['id'] else str(item['id'])
|
||||
|
||||
properties = {"text": item["text"]}
|
||||
if item["metadata"]:
|
||||
clean_metadata = _convert_uuids_to_strings(
|
||||
process_metadata(item["metadata"])
|
||||
)
|
||||
clean_metadata.pop("text", None)
|
||||
properties = {'text': item['text']}
|
||||
if item['metadata']:
|
||||
clean_metadata = _convert_uuids_to_strings(process_metadata(item['metadata']))
|
||||
clean_metadata.pop('text', None)
|
||||
properties.update(clean_metadata)
|
||||
|
||||
batch.add_object(
|
||||
properties=properties, uuid=item_uuid, vector=item["vector"]
|
||||
)
|
||||
batch.add_object(properties=properties, uuid=item_uuid, vector=item['vector'])
|
||||
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
sane_collection_name = self._sanitize_collection_name(collection_name)
|
||||
@@ -156,19 +146,15 @@ class WeaviateClient(VectorDBBase):
|
||||
|
||||
with collection.batch.fixed_size(batch_size=100) as batch:
|
||||
for item in items:
|
||||
item_uuid = str(item["id"]) if item["id"] else None
|
||||
item_uuid = str(item['id']) if item['id'] else None
|
||||
|
||||
properties = {"text": item["text"]}
|
||||
if item["metadata"]:
|
||||
clean_metadata = _convert_uuids_to_strings(
|
||||
process_metadata(item["metadata"])
|
||||
)
|
||||
clean_metadata.pop("text", None)
|
||||
properties = {'text': item['text']}
|
||||
if item['metadata']:
|
||||
clean_metadata = _convert_uuids_to_strings(process_metadata(item['metadata']))
|
||||
clean_metadata.pop('text', None)
|
||||
properties.update(clean_metadata)
|
||||
|
||||
batch.add_object(
|
||||
properties=properties, uuid=item_uuid, vector=item["vector"]
|
||||
)
|
||||
batch.add_object(properties=properties, uuid=item_uuid, vector=item['vector'])
|
||||
|
||||
def search(
|
||||
self,
|
||||
@@ -205,16 +191,12 @@ class WeaviateClient(VectorDBBase):
|
||||
|
||||
for obj in response.objects:
|
||||
properties = dict(obj.properties) if obj.properties else {}
|
||||
documents.append(properties.pop("text", ""))
|
||||
documents.append(properties.pop('text', ''))
|
||||
metadatas.append(_convert_uuids_to_strings(properties))
|
||||
|
||||
# Weaviate has cosine distance, 2 (worst) -> 0 (best). Re-ordering to 0 -> 1
|
||||
raw_distances = [
|
||||
(
|
||||
obj.metadata.distance
|
||||
if obj.metadata and obj.metadata.distance
|
||||
else 2.0
|
||||
)
|
||||
(obj.metadata.distance if obj.metadata and obj.metadata.distance else 2.0)
|
||||
for obj in response.objects
|
||||
]
|
||||
distances = [(2 - dist) / 2 for dist in raw_distances]
|
||||
@@ -231,16 +213,14 @@ class WeaviateClient(VectorDBBase):
|
||||
|
||||
return SearchResult(
|
||||
**{
|
||||
"ids": result_ids,
|
||||
"documents": result_documents,
|
||||
"metadatas": result_metadatas,
|
||||
"distances": result_distances,
|
||||
'ids': result_ids,
|
||||
'documents': result_documents,
|
||||
'metadatas': result_metadatas,
|
||||
'distances': result_distances,
|
||||
}
|
||||
)
|
||||
|
||||
def query(
|
||||
self, collection_name: str, filter: Dict, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
def query(self, collection_name: str, filter: Dict, limit: Optional[int] = None) -> Optional[GetResult]:
|
||||
sane_collection_name = self._sanitize_collection_name(collection_name)
|
||||
if not self.client.collections.exists(sane_collection_name):
|
||||
return None
|
||||
@@ -250,21 +230,15 @@ class WeaviateClient(VectorDBBase):
|
||||
weaviate_filter = None
|
||||
if filter:
|
||||
for key, value in filter.items():
|
||||
prop_filter = weaviate.classes.query.Filter.by_property(name=key).equal(
|
||||
value
|
||||
)
|
||||
prop_filter = weaviate.classes.query.Filter.by_property(name=key).equal(value)
|
||||
weaviate_filter = (
|
||||
prop_filter
|
||||
if weaviate_filter is None
|
||||
else weaviate.classes.query.Filter.all_of(
|
||||
[weaviate_filter, prop_filter]
|
||||
)
|
||||
else weaviate.classes.query.Filter.all_of([weaviate_filter, prop_filter])
|
||||
)
|
||||
|
||||
try:
|
||||
response = collection.query.fetch_objects(
|
||||
filters=weaviate_filter, limit=limit
|
||||
)
|
||||
response = collection.query.fetch_objects(filters=weaviate_filter, limit=limit)
|
||||
|
||||
ids = [str(obj.uuid) for obj in response.objects]
|
||||
documents = []
|
||||
@@ -272,14 +246,14 @@ class WeaviateClient(VectorDBBase):
|
||||
|
||||
for obj in response.objects:
|
||||
properties = dict(obj.properties) if obj.properties else {}
|
||||
documents.append(properties.pop("text", ""))
|
||||
documents.append(properties.pop('text', ''))
|
||||
metadatas.append(_convert_uuids_to_strings(properties))
|
||||
|
||||
return GetResult(
|
||||
**{
|
||||
"ids": [ids],
|
||||
"documents": [documents],
|
||||
"metadatas": [metadatas],
|
||||
'ids': [ids],
|
||||
'documents': [documents],
|
||||
'metadatas': [metadatas],
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
@@ -297,7 +271,7 @@ class WeaviateClient(VectorDBBase):
|
||||
for item in collection.iterator():
|
||||
ids.append(str(item.uuid))
|
||||
properties = dict(item.properties) if item.properties else {}
|
||||
documents.append(properties.pop("text", ""))
|
||||
documents.append(properties.pop('text', ''))
|
||||
metadatas.append(_convert_uuids_to_strings(properties))
|
||||
|
||||
if not ids:
|
||||
@@ -305,9 +279,9 @@ class WeaviateClient(VectorDBBase):
|
||||
|
||||
return GetResult(
|
||||
**{
|
||||
"ids": [ids],
|
||||
"documents": [documents],
|
||||
"metadatas": [metadatas],
|
||||
'ids': [ids],
|
||||
'documents': [documents],
|
||||
'metadatas': [metadatas],
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
@@ -332,15 +306,11 @@ class WeaviateClient(VectorDBBase):
|
||||
elif filter:
|
||||
weaviate_filter = None
|
||||
for key, value in filter.items():
|
||||
prop_filter = weaviate.classes.query.Filter.by_property(
|
||||
name=key
|
||||
).equal(value)
|
||||
prop_filter = weaviate.classes.query.Filter.by_property(name=key).equal(value)
|
||||
weaviate_filter = (
|
||||
prop_filter
|
||||
if weaviate_filter is None
|
||||
else weaviate.classes.query.Filter.all_of(
|
||||
[weaviate_filter, prop_filter]
|
||||
)
|
||||
else weaviate.classes.query.Filter.all_of([weaviate_filter, prop_filter])
|
||||
)
|
||||
|
||||
if weaviate_filter:
|
||||
|
||||
Reference in New Issue
Block a user