mirror of
https://github.com/open-webui/open-webui.git
synced 2026-05-04 19:29:27 -05:00
refac
This commit is contained in:
@@ -55,7 +55,7 @@ VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
||||
USE_HALFVEC = PGVECTOR_USE_HALFVEC
|
||||
|
||||
VECTOR_TYPE_FACTORY = HALFVEC if USE_HALFVEC else Vector
|
||||
VECTOR_OPCLASS = "halfvec_cosine_ops" if USE_HALFVEC else "vector_cosine_ops"
|
||||
VECTOR_OPCLASS = 'halfvec_cosine_ops' if USE_HALFVEC else 'vector_cosine_ops'
|
||||
Base = declarative_base()
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -65,12 +65,12 @@ def pgcrypto_encrypt(val, key):
|
||||
return func.pgp_sym_encrypt(val, literal(key))
|
||||
|
||||
|
||||
def pgcrypto_decrypt(col, key, outtype="text"):
|
||||
def pgcrypto_decrypt(col, key, outtype='text'):
|
||||
return func.cast(func.pgp_sym_decrypt(col, literal(key)), outtype)
|
||||
|
||||
|
||||
class DocumentChunk(Base):
|
||||
__tablename__ = "document_chunk"
|
||||
__tablename__ = 'document_chunk'
|
||||
|
||||
id = Column(Text, primary_key=True)
|
||||
vector = Column(VECTOR_TYPE_FACTORY(dim=VECTOR_LENGTH), nullable=True)
|
||||
@@ -86,7 +86,6 @@ class DocumentChunk(Base):
|
||||
|
||||
class PgvectorClient(VectorDBBase):
|
||||
def __init__(self) -> None:
|
||||
|
||||
# if no pgvector uri, use the existing database connection
|
||||
if not PGVECTOR_DB_URL:
|
||||
from open_webui.internal.db import ScopedSession
|
||||
@@ -105,46 +104,44 @@ class PgvectorClient(VectorDBBase):
|
||||
poolclass=QueuePool,
|
||||
)
|
||||
else:
|
||||
engine = create_engine(
|
||||
PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
|
||||
)
|
||||
engine = create_engine(PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool)
|
||||
else:
|
||||
engine = create_engine(PGVECTOR_DB_URL, pool_pre_ping=True)
|
||||
|
||||
SessionLocal = sessionmaker(
|
||||
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
|
||||
)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, expire_on_commit=False)
|
||||
self.session = scoped_session(SessionLocal)
|
||||
|
||||
try:
|
||||
# Ensure the pgvector extension is available
|
||||
# Use a conditional check to avoid permission issues on Azure PostgreSQL
|
||||
if PGVECTOR_CREATE_EXTENSION:
|
||||
self.session.execute(text("""
|
||||
self.session.execute(
|
||||
text("""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'vector') THEN
|
||||
CREATE EXTENSION IF NOT EXISTS vector;
|
||||
END IF;
|
||||
END $$;
|
||||
"""))
|
||||
""")
|
||||
)
|
||||
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
# Ensure the pgcrypto extension is available for encryption
|
||||
# Use a conditional check to avoid permission issues on Azure PostgreSQL
|
||||
self.session.execute(text("""
|
||||
self.session.execute(
|
||||
text("""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pgcrypto') THEN
|
||||
CREATE EXTENSION IF NOT EXISTS pgcrypto;
|
||||
END IF;
|
||||
END $$;
|
||||
"""))
|
||||
""")
|
||||
)
|
||||
|
||||
if not PGVECTOR_PGCRYPTO_KEY:
|
||||
raise ValueError(
|
||||
"PGVECTOR_PGCRYPTO_KEY must be set when PGVECTOR_PGCRYPTO is enabled."
|
||||
)
|
||||
raise ValueError('PGVECTOR_PGCRYPTO_KEY must be set when PGVECTOR_PGCRYPTO is enabled.')
|
||||
|
||||
# Check vector length consistency
|
||||
self.check_vector_length()
|
||||
@@ -160,15 +157,14 @@ class PgvectorClient(VectorDBBase):
|
||||
|
||||
self.session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
|
||||
"ON document_chunk (collection_name);"
|
||||
'CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name ON document_chunk (collection_name);'
|
||||
)
|
||||
)
|
||||
self.session.commit()
|
||||
log.info("Initialization complete.")
|
||||
log.info('Initialization complete.')
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error during initialization: {e}")
|
||||
log.exception(f'Error during initialization: {e}')
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
@@ -176,7 +172,7 @@ class PgvectorClient(VectorDBBase):
|
||||
if not index_def:
|
||||
return None
|
||||
try:
|
||||
after_using = index_def.lower().split("using ", 1)[1]
|
||||
after_using = index_def.lower().split('using ', 1)[1]
|
||||
return after_using.split()[0]
|
||||
except (IndexError, AttributeError):
|
||||
return None
|
||||
@@ -189,23 +185,23 @@ class PgvectorClient(VectorDBBase):
|
||||
index_method,
|
||||
)
|
||||
elif USE_HALFVEC:
|
||||
index_method = "hnsw"
|
||||
index_method = 'hnsw'
|
||||
log.info(
|
||||
"VECTOR_LENGTH=%s exceeds 2000; using halfvec column type with hnsw index.",
|
||||
'VECTOR_LENGTH=%s exceeds 2000; using halfvec column type with hnsw index.',
|
||||
VECTOR_LENGTH,
|
||||
)
|
||||
else:
|
||||
index_method = "ivfflat"
|
||||
index_method = 'ivfflat'
|
||||
|
||||
if index_method == "hnsw":
|
||||
index_options = f"WITH (m = {PGVECTOR_HNSW_M}, ef_construction = {PGVECTOR_HNSW_EF_CONSTRUCTION})"
|
||||
if index_method == 'hnsw':
|
||||
index_options = f'WITH (m = {PGVECTOR_HNSW_M}, ef_construction = {PGVECTOR_HNSW_EF_CONSTRUCTION})'
|
||||
else:
|
||||
index_options = f"WITH (lists = {PGVECTOR_IVFFLAT_LISTS})"
|
||||
index_options = f'WITH (lists = {PGVECTOR_IVFFLAT_LISTS})'
|
||||
|
||||
return index_method, index_options
|
||||
|
||||
def _ensure_vector_index(self, index_method: str, index_options: str) -> None:
|
||||
index_name = "idx_document_chunk_vector"
|
||||
index_name = 'idx_document_chunk_vector'
|
||||
existing_index_def = self.session.execute(
|
||||
text("""
|
||||
SELECT indexdef
|
||||
@@ -214,7 +210,7 @@ class PgvectorClient(VectorDBBase):
|
||||
AND tablename = 'document_chunk'
|
||||
AND indexname = :index_name
|
||||
"""),
|
||||
{"index_name": index_name},
|
||||
{'index_name': index_name},
|
||||
).scalar()
|
||||
|
||||
existing_method = self._extract_index_method(existing_index_def)
|
||||
@@ -222,23 +218,23 @@ class PgvectorClient(VectorDBBase):
|
||||
raise RuntimeError(
|
||||
f"Existing pgvector index '{index_name}' uses method '{existing_method}' but configuration now "
|
||||
f"requires '{index_method}'. Automatic rebuild is disabled to prevent long-running maintenance. "
|
||||
"Drop the index manually (optionally after tuning maintenance_work_mem/max_parallel_maintenance_workers) "
|
||||
"and recreate it with the new method before restarting Open WebUI."
|
||||
'Drop the index manually (optionally after tuning maintenance_work_mem/max_parallel_maintenance_workers) '
|
||||
'and recreate it with the new method before restarting Open WebUI.'
|
||||
)
|
||||
|
||||
if not existing_index_def:
|
||||
index_sql = (
|
||||
f"CREATE INDEX IF NOT EXISTS {index_name} "
|
||||
f"ON document_chunk USING {index_method} (vector {VECTOR_OPCLASS})"
|
||||
f'CREATE INDEX IF NOT EXISTS {index_name} '
|
||||
f'ON document_chunk USING {index_method} (vector {VECTOR_OPCLASS})'
|
||||
)
|
||||
if index_options:
|
||||
index_sql = f"{index_sql} {index_options}"
|
||||
index_sql = f'{index_sql} {index_options}'
|
||||
self.session.execute(text(index_sql))
|
||||
log.info(
|
||||
"Ensured vector index '%s' using %s%s.",
|
||||
index_name,
|
||||
index_method,
|
||||
f" {index_options}" if index_options else "",
|
||||
f' {index_options}' if index_options else '',
|
||||
)
|
||||
|
||||
def check_vector_length(self) -> None:
|
||||
@@ -249,16 +245,14 @@ class PgvectorClient(VectorDBBase):
|
||||
metadata = MetaData()
|
||||
try:
|
||||
# Attempt to reflect the 'document_chunk' table
|
||||
document_chunk_table = Table(
|
||||
"document_chunk", metadata, autoload_with=self.session.bind
|
||||
)
|
||||
document_chunk_table = Table('document_chunk', metadata, autoload_with=self.session.bind)
|
||||
except NoSuchTableError:
|
||||
# Table does not exist; no action needed
|
||||
return
|
||||
|
||||
# Proceed to check the vector column
|
||||
if "vector" in document_chunk_table.columns:
|
||||
vector_column = document_chunk_table.columns["vector"]
|
||||
if 'vector' in document_chunk_table.columns:
|
||||
vector_column = document_chunk_table.columns['vector']
|
||||
vector_type = vector_column.type
|
||||
expected_type = HALFVEC if USE_HALFVEC else Vector
|
||||
|
||||
@@ -268,16 +262,14 @@ class PgvectorClient(VectorDBBase):
|
||||
f"('{expected_type.__name__}') for VECTOR_LENGTH {VECTOR_LENGTH}."
|
||||
)
|
||||
|
||||
db_vector_length = getattr(vector_type, "dim", None)
|
||||
db_vector_length = getattr(vector_type, 'dim', None)
|
||||
if db_vector_length is not None and db_vector_length != VECTOR_LENGTH:
|
||||
raise Exception(
|
||||
f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
|
||||
"Cannot change vector size after initialization without migrating the data."
|
||||
f'VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. '
|
||||
'Cannot change vector size after initialization without migrating the data.'
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
"The 'vector' column does not exist in the 'document_chunk' table."
|
||||
)
|
||||
raise Exception("The 'vector' column does not exist in the 'document_chunk' table.")
|
||||
|
||||
def adjust_vector_length(self, vector: List[float]) -> List[float]:
|
||||
# Adjust vector to have length VECTOR_LENGTH
|
||||
@@ -294,10 +286,10 @@ class PgvectorClient(VectorDBBase):
|
||||
try:
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
vector = self.adjust_vector_length(item['vector'])
|
||||
# Use raw SQL for BYTEA/pgcrypto
|
||||
# Ensure metadata is converted to its JSON text representation
|
||||
json_metadata = json.dumps(item["metadata"])
|
||||
json_metadata = json.dumps(item['metadata'])
|
||||
self.session.execute(
|
||||
text("""
|
||||
INSERT INTO document_chunk
|
||||
@@ -310,12 +302,12 @@ class PgvectorClient(VectorDBBase):
|
||||
ON CONFLICT (id) DO NOTHING
|
||||
"""),
|
||||
{
|
||||
"id": item["id"],
|
||||
"vector": vector,
|
||||
"collection_name": collection_name,
|
||||
"text": item["text"],
|
||||
"metadata_text": json_metadata,
|
||||
"key": PGVECTOR_PGCRYPTO_KEY,
|
||||
'id': item['id'],
|
||||
'vector': vector,
|
||||
'collection_name': collection_name,
|
||||
'text': item['text'],
|
||||
'metadata_text': json_metadata,
|
||||
'key': PGVECTOR_PGCRYPTO_KEY,
|
||||
},
|
||||
)
|
||||
self.session.commit()
|
||||
@@ -324,31 +316,29 @@ class PgvectorClient(VectorDBBase):
|
||||
else:
|
||||
new_items = []
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
vector = self.adjust_vector_length(item['vector'])
|
||||
new_chunk = DocumentChunk(
|
||||
id=item["id"],
|
||||
id=item['id'],
|
||||
vector=vector,
|
||||
collection_name=collection_name,
|
||||
text=item["text"],
|
||||
vmetadata=process_metadata(item["metadata"]),
|
||||
text=item['text'],
|
||||
vmetadata=process_metadata(item['metadata']),
|
||||
)
|
||||
new_items.append(new_chunk)
|
||||
self.session.bulk_save_objects(new_items)
|
||||
self.session.commit()
|
||||
log.info(
|
||||
f"Inserted {len(new_items)} items into collection '{collection_name}'."
|
||||
)
|
||||
log.info(f"Inserted {len(new_items)} items into collection '{collection_name}'.")
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error during insert: {e}")
|
||||
log.exception(f'Error during insert: {e}')
|
||||
raise
|
||||
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
try:
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
json_metadata = json.dumps(item["metadata"])
|
||||
vector = self.adjust_vector_length(item['vector'])
|
||||
json_metadata = json.dumps(item['metadata'])
|
||||
self.session.execute(
|
||||
text("""
|
||||
INSERT INTO document_chunk
|
||||
@@ -365,47 +355,39 @@ class PgvectorClient(VectorDBBase):
|
||||
vmetadata = EXCLUDED.vmetadata
|
||||
"""),
|
||||
{
|
||||
"id": item["id"],
|
||||
"vector": vector,
|
||||
"collection_name": collection_name,
|
||||
"text": item["text"],
|
||||
"metadata_text": json_metadata,
|
||||
"key": PGVECTOR_PGCRYPTO_KEY,
|
||||
'id': item['id'],
|
||||
'vector': vector,
|
||||
'collection_name': collection_name,
|
||||
'text': item['text'],
|
||||
'metadata_text': json_metadata,
|
||||
'key': PGVECTOR_PGCRYPTO_KEY,
|
||||
},
|
||||
)
|
||||
self.session.commit()
|
||||
log.info(f"Encrypted & upserted {len(items)} into '{collection_name}'")
|
||||
else:
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
existing = (
|
||||
self.session.query(DocumentChunk)
|
||||
.filter(DocumentChunk.id == item["id"])
|
||||
.first()
|
||||
)
|
||||
vector = self.adjust_vector_length(item['vector'])
|
||||
existing = self.session.query(DocumentChunk).filter(DocumentChunk.id == item['id']).first()
|
||||
if existing:
|
||||
existing.vector = vector
|
||||
existing.text = item["text"]
|
||||
existing.vmetadata = process_metadata(item["metadata"])
|
||||
existing.collection_name = (
|
||||
collection_name # Update collection_name if necessary
|
||||
)
|
||||
existing.text = item['text']
|
||||
existing.vmetadata = process_metadata(item['metadata'])
|
||||
existing.collection_name = collection_name # Update collection_name if necessary
|
||||
else:
|
||||
new_chunk = DocumentChunk(
|
||||
id=item["id"],
|
||||
id=item['id'],
|
||||
vector=vector,
|
||||
collection_name=collection_name,
|
||||
text=item["text"],
|
||||
vmetadata=process_metadata(item["metadata"]),
|
||||
text=item['text'],
|
||||
vmetadata=process_metadata(item['metadata']),
|
||||
)
|
||||
self.session.add(new_chunk)
|
||||
self.session.commit()
|
||||
log.info(
|
||||
f"Upserted {len(items)} items into collection '{collection_name}'."
|
||||
)
|
||||
log.info(f"Upserted {len(items)} items into collection '{collection_name}'.")
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error during upsert: {e}")
|
||||
log.exception(f'Error during upsert: {e}')
|
||||
raise
|
||||
|
||||
def search(
|
||||
@@ -427,38 +409,26 @@ class PgvectorClient(VectorDBBase):
|
||||
return cast(array(vector), VECTOR_TYPE_FACTORY(VECTOR_LENGTH))
|
||||
|
||||
# Create the values for query vectors
|
||||
qid_col = column("qid", Integer)
|
||||
q_vector_col = column("q_vector", VECTOR_TYPE_FACTORY(VECTOR_LENGTH))
|
||||
qid_col = column('qid', Integer)
|
||||
q_vector_col = column('q_vector', VECTOR_TYPE_FACTORY(VECTOR_LENGTH))
|
||||
query_vectors = (
|
||||
values(qid_col, q_vector_col)
|
||||
.data(
|
||||
[(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
|
||||
)
|
||||
.alias("query_vectors")
|
||||
.data([(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)])
|
||||
.alias('query_vectors')
|
||||
)
|
||||
|
||||
result_fields = [
|
||||
DocumentChunk.id,
|
||||
]
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
result_fields.append(pgcrypto_decrypt(DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text).label('text'))
|
||||
result_fields.append(
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
|
||||
).label("text")
|
||||
)
|
||||
result_fields.append(
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
||||
).label("vmetadata")
|
||||
pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB).label('vmetadata')
|
||||
)
|
||||
else:
|
||||
result_fields.append(DocumentChunk.text)
|
||||
result_fields.append(DocumentChunk.vmetadata)
|
||||
result_fields.append(
|
||||
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label(
|
||||
"distance"
|
||||
)
|
||||
)
|
||||
result_fields.append((DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label('distance'))
|
||||
|
||||
# Build the lateral subquery for each query vector
|
||||
where_clauses = [DocumentChunk.collection_name == collection_name]
|
||||
@@ -466,9 +436,9 @@ class PgvectorClient(VectorDBBase):
|
||||
# Apply metadata filter if provided
|
||||
if filter:
|
||||
for key, value in filter.items():
|
||||
if isinstance(value, dict) and "$in" in value:
|
||||
if isinstance(value, dict) and '$in' in value:
|
||||
# Handle $in operator: {"field": {"$in": [values]}}
|
||||
in_values = value["$in"]
|
||||
in_values = value['$in']
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
where_clauses.append(
|
||||
pgcrypto_decrypt(
|
||||
@@ -478,11 +448,7 @@ class PgvectorClient(VectorDBBase):
|
||||
)[key].astext.in_([str(v) for v in in_values])
|
||||
)
|
||||
else:
|
||||
where_clauses.append(
|
||||
DocumentChunk.vmetadata[key].astext.in_(
|
||||
[str(v) for v in in_values]
|
||||
)
|
||||
)
|
||||
where_clauses.append(DocumentChunk.vmetadata[key].astext.in_([str(v) for v in in_values]))
|
||||
else:
|
||||
# Handle simple equality: {"field": "value"}
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
@@ -495,20 +461,16 @@ class PgvectorClient(VectorDBBase):
|
||||
== str(value)
|
||||
)
|
||||
else:
|
||||
where_clauses.append(
|
||||
DocumentChunk.vmetadata[key].astext == str(value)
|
||||
)
|
||||
where_clauses.append(DocumentChunk.vmetadata[key].astext == str(value))
|
||||
|
||||
subq = (
|
||||
select(*result_fields)
|
||||
.where(*where_clauses)
|
||||
.order_by(
|
||||
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
|
||||
)
|
||||
.order_by((DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)))
|
||||
)
|
||||
if limit is not None:
|
||||
subq = subq.limit(limit)
|
||||
subq = subq.lateral("result")
|
||||
subq = subq.lateral('result')
|
||||
|
||||
# Build the main query by joining query_vectors and the lateral subquery
|
||||
stmt = (
|
||||
@@ -550,17 +512,13 @@ class PgvectorClient(VectorDBBase):
|
||||
metadatas[qid].append(row.vmetadata)
|
||||
|
||||
self.session.rollback() # read-only transaction
|
||||
return SearchResult(
|
||||
ids=ids, distances=distances, documents=documents, metadatas=metadatas
|
||||
)
|
||||
return SearchResult(ids=ids, distances=distances, documents=documents, metadatas=metadatas)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error during search: {e}")
|
||||
log.exception(f'Error during search: {e}')
|
||||
return None
|
||||
|
||||
def query(
|
||||
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
def query(self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None) -> Optional[GetResult]:
|
||||
try:
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
# Build where clause for vmetadata filter
|
||||
@@ -568,32 +526,22 @@ class PgvectorClient(VectorDBBase):
|
||||
for key, value in filter.items():
|
||||
# decrypt then check key: JSON filter after decryption
|
||||
where_clauses.append(
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
||||
)[key].astext
|
||||
pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB)[key].astext
|
||||
== str(value)
|
||||
)
|
||||
stmt = select(
|
||||
DocumentChunk.id,
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
|
||||
).label("text"),
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
||||
).label("vmetadata"),
|
||||
pgcrypto_decrypt(DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text).label('text'),
|
||||
pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB).label('vmetadata'),
|
||||
).where(*where_clauses)
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
results = self.session.execute(stmt).all()
|
||||
else:
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name)
|
||||
|
||||
for key, value in filter.items():
|
||||
query = query.filter(
|
||||
DocumentChunk.vmetadata[key].astext == str(value)
|
||||
)
|
||||
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
|
||||
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
@@ -615,22 +563,16 @@ class PgvectorClient(VectorDBBase):
|
||||
)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error during query: {e}")
|
||||
log.exception(f'Error during query: {e}')
|
||||
return None
|
||||
|
||||
def get(
|
||||
self, collection_name: str, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
def get(self, collection_name: str, limit: Optional[int] = None) -> Optional[GetResult]:
|
||||
try:
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
stmt = select(
|
||||
DocumentChunk.id,
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
|
||||
).label("text"),
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
||||
).label("vmetadata"),
|
||||
pgcrypto_decrypt(DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text).label('text'),
|
||||
pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB).label('vmetadata'),
|
||||
).where(DocumentChunk.collection_name == collection_name)
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
@@ -639,10 +581,7 @@ class PgvectorClient(VectorDBBase):
|
||||
documents = [[row.text for row in results]]
|
||||
metadatas = [[row.vmetadata for row in results]]
|
||||
else:
|
||||
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name)
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
|
||||
@@ -659,7 +598,7 @@ class PgvectorClient(VectorDBBase):
|
||||
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error during get: {e}")
|
||||
log.exception(f'Error during get: {e}')
|
||||
return None
|
||||
|
||||
def delete(
|
||||
@@ -676,43 +615,35 @@ class PgvectorClient(VectorDBBase):
|
||||
if filter:
|
||||
for key, value in filter.items():
|
||||
wheres.append(
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
||||
)[key].astext
|
||||
pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB)[key].astext
|
||||
== str(value)
|
||||
)
|
||||
stmt = DocumentChunk.__table__.delete().where(*wheres)
|
||||
result = self.session.execute(stmt)
|
||||
deleted = result.rowcount
|
||||
else:
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name)
|
||||
if ids:
|
||||
query = query.filter(DocumentChunk.id.in_(ids))
|
||||
if filter:
|
||||
for key, value in filter.items():
|
||||
query = query.filter(
|
||||
DocumentChunk.vmetadata[key].astext == str(value)
|
||||
)
|
||||
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
|
||||
deleted = query.delete(synchronize_session=False)
|
||||
self.session.commit()
|
||||
log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error during delete: {e}")
|
||||
log.exception(f'Error during delete: {e}')
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
deleted = self.session.query(DocumentChunk).delete()
|
||||
self.session.commit()
|
||||
log.info(
|
||||
f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
|
||||
)
|
||||
log.info(f"Reset complete. Deleted {deleted} items from 'document_chunk' table.")
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error during reset: {e}")
|
||||
log.exception(f'Error during reset: {e}')
|
||||
raise
|
||||
|
||||
def close(self) -> None:
|
||||
@@ -721,16 +652,14 @@ class PgvectorClient(VectorDBBase):
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
try:
|
||||
exists = (
|
||||
self.session.query(DocumentChunk)
|
||||
.filter(DocumentChunk.collection_name == collection_name)
|
||||
.first()
|
||||
self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name).first()
|
||||
is not None
|
||||
)
|
||||
self.session.rollback() # read-only transaction
|
||||
return exists
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error checking collection existence: {e}")
|
||||
log.exception(f'Error checking collection existence: {e}')
|
||||
return False
|
||||
|
||||
def delete_collection(self, collection_name: str) -> None:
|
||||
|
||||
Reference in New Issue
Block a user