This commit is contained in:
Timothy Jaeryang Baek
2026-03-17 17:58:01 -05:00
parent fcf7208352
commit de3317e26b
220 changed files with 17200 additions and 22836 deletions

View File

@@ -60,47 +60,43 @@ class WeaviateClient(VectorDBBase):
try:
# Build connection parameters
connection_params = {
"http_host": WEAVIATE_HTTP_HOST,
"http_port": WEAVIATE_HTTP_PORT,
"http_secure": WEAVIATE_HTTP_SECURE,
"grpc_host": WEAVIATE_GRPC_HOST,
"grpc_port": WEAVIATE_GRPC_PORT,
"grpc_secure": WEAVIATE_GRPC_SECURE,
"skip_init_checks": WEAVIATE_SKIP_INIT_CHECKS,
'http_host': WEAVIATE_HTTP_HOST,
'http_port': WEAVIATE_HTTP_PORT,
'http_secure': WEAVIATE_HTTP_SECURE,
'grpc_host': WEAVIATE_GRPC_HOST,
'grpc_port': WEAVIATE_GRPC_PORT,
'grpc_secure': WEAVIATE_GRPC_SECURE,
'skip_init_checks': WEAVIATE_SKIP_INIT_CHECKS,
}
# Only add auth_credentials if WEAVIATE_API_KEY exists and is not empty
if WEAVIATE_API_KEY:
connection_params["auth_credentials"] = (
weaviate.classes.init.Auth.api_key(WEAVIATE_API_KEY)
)
connection_params['auth_credentials'] = weaviate.classes.init.Auth.api_key(WEAVIATE_API_KEY)
self.client = weaviate.connect_to_custom(**connection_params)
self.client.connect()
except Exception as e:
raise ConnectionError(f"Failed to connect to Weaviate: {e}") from e
raise ConnectionError(f'Failed to connect to Weaviate: {e}') from e
def _sanitize_collection_name(self, collection_name: str) -> str:
"""Sanitize collection name to be a valid Weaviate class name."""
if not isinstance(collection_name, str) or not collection_name.strip():
raise ValueError("Collection name must be a non-empty string")
raise ValueError('Collection name must be a non-empty string')
# Requirements for a valid Weaviate class name:
# The collection name must begin with a capital letter.
# The name can only contain letters, numbers, and the underscore (_) character. Spaces are not allowed.
# Replace hyphens with underscores and keep only alphanumeric characters
name = re.sub(r"[^a-zA-Z0-9_]", "", collection_name.replace("-", "_"))
name = name.strip("_")
name = re.sub(r'[^a-zA-Z0-9_]', '', collection_name.replace('-', '_'))
name = name.strip('_')
if not name:
raise ValueError(
"Could not sanitize collection name to be a valid Weaviate class name"
)
raise ValueError('Could not sanitize collection name to be a valid Weaviate class name')
# Ensure it starts with a letter and is capitalized
if not name[0].isalpha():
name = "C" + name
name = 'C' + name
return name[0].upper() + name[1:]
@@ -118,9 +114,7 @@ class WeaviateClient(VectorDBBase):
name=collection_name,
vector_config=weaviate.classes.config.Configure.Vectors.self_provided(),
properties=[
weaviate.classes.config.Property(
name="text", data_type=weaviate.classes.config.DataType.TEXT
),
weaviate.classes.config.Property(name='text', data_type=weaviate.classes.config.DataType.TEXT),
],
)
@@ -133,19 +127,15 @@ class WeaviateClient(VectorDBBase):
with collection.batch.fixed_size(batch_size=100) as batch:
for item in items:
item_uuid = str(uuid.uuid4()) if not item["id"] else str(item["id"])
item_uuid = str(uuid.uuid4()) if not item['id'] else str(item['id'])
properties = {"text": item["text"]}
if item["metadata"]:
clean_metadata = _convert_uuids_to_strings(
process_metadata(item["metadata"])
)
clean_metadata.pop("text", None)
properties = {'text': item['text']}
if item['metadata']:
clean_metadata = _convert_uuids_to_strings(process_metadata(item['metadata']))
clean_metadata.pop('text', None)
properties.update(clean_metadata)
batch.add_object(
properties=properties, uuid=item_uuid, vector=item["vector"]
)
batch.add_object(properties=properties, uuid=item_uuid, vector=item['vector'])
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
sane_collection_name = self._sanitize_collection_name(collection_name)
@@ -156,19 +146,15 @@ class WeaviateClient(VectorDBBase):
with collection.batch.fixed_size(batch_size=100) as batch:
for item in items:
item_uuid = str(item["id"]) if item["id"] else None
item_uuid = str(item['id']) if item['id'] else None
properties = {"text": item["text"]}
if item["metadata"]:
clean_metadata = _convert_uuids_to_strings(
process_metadata(item["metadata"])
)
clean_metadata.pop("text", None)
properties = {'text': item['text']}
if item['metadata']:
clean_metadata = _convert_uuids_to_strings(process_metadata(item['metadata']))
clean_metadata.pop('text', None)
properties.update(clean_metadata)
batch.add_object(
properties=properties, uuid=item_uuid, vector=item["vector"]
)
batch.add_object(properties=properties, uuid=item_uuid, vector=item['vector'])
def search(
self,
@@ -205,16 +191,12 @@ class WeaviateClient(VectorDBBase):
for obj in response.objects:
properties = dict(obj.properties) if obj.properties else {}
documents.append(properties.pop("text", ""))
documents.append(properties.pop('text', ''))
metadatas.append(_convert_uuids_to_strings(properties))
# Weaviate has cosine distance, 2 (worst) -> 0 (best). Re-ordering to 0 -> 1
raw_distances = [
(
obj.metadata.distance
if obj.metadata and obj.metadata.distance
else 2.0
)
(obj.metadata.distance if obj.metadata and obj.metadata.distance else 2.0)
for obj in response.objects
]
distances = [(2 - dist) / 2 for dist in raw_distances]
@@ -231,16 +213,14 @@ class WeaviateClient(VectorDBBase):
return SearchResult(
**{
"ids": result_ids,
"documents": result_documents,
"metadatas": result_metadatas,
"distances": result_distances,
'ids': result_ids,
'documents': result_documents,
'metadatas': result_metadatas,
'distances': result_distances,
}
)
def query(
self, collection_name: str, filter: Dict, limit: Optional[int] = None
) -> Optional[GetResult]:
def query(self, collection_name: str, filter: Dict, limit: Optional[int] = None) -> Optional[GetResult]:
sane_collection_name = self._sanitize_collection_name(collection_name)
if not self.client.collections.exists(sane_collection_name):
return None
@@ -250,21 +230,15 @@ class WeaviateClient(VectorDBBase):
weaviate_filter = None
if filter:
for key, value in filter.items():
prop_filter = weaviate.classes.query.Filter.by_property(name=key).equal(
value
)
prop_filter = weaviate.classes.query.Filter.by_property(name=key).equal(value)
weaviate_filter = (
prop_filter
if weaviate_filter is None
else weaviate.classes.query.Filter.all_of(
[weaviate_filter, prop_filter]
)
else weaviate.classes.query.Filter.all_of([weaviate_filter, prop_filter])
)
try:
response = collection.query.fetch_objects(
filters=weaviate_filter, limit=limit
)
response = collection.query.fetch_objects(filters=weaviate_filter, limit=limit)
ids = [str(obj.uuid) for obj in response.objects]
documents = []
@@ -272,14 +246,14 @@ class WeaviateClient(VectorDBBase):
for obj in response.objects:
properties = dict(obj.properties) if obj.properties else {}
documents.append(properties.pop("text", ""))
documents.append(properties.pop('text', ''))
metadatas.append(_convert_uuids_to_strings(properties))
return GetResult(
**{
"ids": [ids],
"documents": [documents],
"metadatas": [metadatas],
'ids': [ids],
'documents': [documents],
'metadatas': [metadatas],
}
)
except Exception:
@@ -297,7 +271,7 @@ class WeaviateClient(VectorDBBase):
for item in collection.iterator():
ids.append(str(item.uuid))
properties = dict(item.properties) if item.properties else {}
documents.append(properties.pop("text", ""))
documents.append(properties.pop('text', ''))
metadatas.append(_convert_uuids_to_strings(properties))
if not ids:
@@ -305,9 +279,9 @@ class WeaviateClient(VectorDBBase):
return GetResult(
**{
"ids": [ids],
"documents": [documents],
"metadatas": [metadatas],
'ids': [ids],
'documents': [documents],
'metadatas': [metadatas],
}
)
except Exception:
@@ -332,15 +306,11 @@ class WeaviateClient(VectorDBBase):
elif filter:
weaviate_filter = None
for key, value in filter.items():
prop_filter = weaviate.classes.query.Filter.by_property(
name=key
).equal(value)
prop_filter = weaviate.classes.query.Filter.by_property(name=key).equal(value)
weaviate_filter = (
prop_filter
if weaviate_filter is None
else weaviate.classes.query.Filter.all_of(
[weaviate_filter, prop_filter]
)
else weaviate.classes.query.Filter.all_of([weaviate_filter, prop_filter])
)
if weaviate_filter: