This commit is contained in:
Timothy Jaeryang Baek
2026-03-17 17:58:01 -05:00
parent fcf7208352
commit de3317e26b
220 changed files with 17200 additions and 22836 deletions

View File

@@ -55,7 +55,7 @@ VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
USE_HALFVEC = PGVECTOR_USE_HALFVEC
VECTOR_TYPE_FACTORY = HALFVEC if USE_HALFVEC else Vector
VECTOR_OPCLASS = "halfvec_cosine_ops" if USE_HALFVEC else "vector_cosine_ops"
VECTOR_OPCLASS = 'halfvec_cosine_ops' if USE_HALFVEC else 'vector_cosine_ops'
Base = declarative_base()
log = logging.getLogger(__name__)
@@ -65,12 +65,12 @@ def pgcrypto_encrypt(val, key):
return func.pgp_sym_encrypt(val, literal(key))
def pgcrypto_decrypt(col, key, outtype="text"):
def pgcrypto_decrypt(col, key, outtype='text'):
return func.cast(func.pgp_sym_decrypt(col, literal(key)), outtype)
class DocumentChunk(Base):
__tablename__ = "document_chunk"
__tablename__ = 'document_chunk'
id = Column(Text, primary_key=True)
vector = Column(VECTOR_TYPE_FACTORY(dim=VECTOR_LENGTH), nullable=True)
@@ -86,7 +86,6 @@ class DocumentChunk(Base):
class PgvectorClient(VectorDBBase):
def __init__(self) -> None:
# if no pgvector uri, use the existing database connection
if not PGVECTOR_DB_URL:
from open_webui.internal.db import ScopedSession
@@ -105,46 +104,44 @@ class PgvectorClient(VectorDBBase):
poolclass=QueuePool,
)
else:
engine = create_engine(
PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
)
engine = create_engine(PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool)
else:
engine = create_engine(PGVECTOR_DB_URL, pool_pre_ping=True)
SessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, expire_on_commit=False)
self.session = scoped_session(SessionLocal)
try:
# Ensure the pgvector extension is available
# Use a conditional check to avoid permission issues on Azure PostgreSQL
if PGVECTOR_CREATE_EXTENSION:
self.session.execute(text("""
self.session.execute(
text("""
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'vector') THEN
CREATE EXTENSION IF NOT EXISTS vector;
END IF;
END $$;
"""))
""")
)
if PGVECTOR_PGCRYPTO:
# Ensure the pgcrypto extension is available for encryption
# Use a conditional check to avoid permission issues on Azure PostgreSQL
self.session.execute(text("""
self.session.execute(
text("""
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pgcrypto') THEN
CREATE EXTENSION IF NOT EXISTS pgcrypto;
END IF;
END $$;
"""))
""")
)
if not PGVECTOR_PGCRYPTO_KEY:
raise ValueError(
"PGVECTOR_PGCRYPTO_KEY must be set when PGVECTOR_PGCRYPTO is enabled."
)
raise ValueError('PGVECTOR_PGCRYPTO_KEY must be set when PGVECTOR_PGCRYPTO is enabled.')
# Check vector length consistency
self.check_vector_length()
@@ -160,15 +157,14 @@ class PgvectorClient(VectorDBBase):
self.session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
"ON document_chunk (collection_name);"
'CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name ON document_chunk (collection_name);'
)
)
self.session.commit()
log.info("Initialization complete.")
log.info('Initialization complete.')
except Exception as e:
self.session.rollback()
log.exception(f"Error during initialization: {e}")
log.exception(f'Error during initialization: {e}')
raise
@staticmethod
@@ -176,7 +172,7 @@ class PgvectorClient(VectorDBBase):
if not index_def:
return None
try:
after_using = index_def.lower().split("using ", 1)[1]
after_using = index_def.lower().split('using ', 1)[1]
return after_using.split()[0]
except (IndexError, AttributeError):
return None
@@ -189,23 +185,23 @@ class PgvectorClient(VectorDBBase):
index_method,
)
elif USE_HALFVEC:
index_method = "hnsw"
index_method = 'hnsw'
log.info(
"VECTOR_LENGTH=%s exceeds 2000; using halfvec column type with hnsw index.",
'VECTOR_LENGTH=%s exceeds 2000; using halfvec column type with hnsw index.',
VECTOR_LENGTH,
)
else:
index_method = "ivfflat"
index_method = 'ivfflat'
if index_method == "hnsw":
index_options = f"WITH (m = {PGVECTOR_HNSW_M}, ef_construction = {PGVECTOR_HNSW_EF_CONSTRUCTION})"
if index_method == 'hnsw':
index_options = f'WITH (m = {PGVECTOR_HNSW_M}, ef_construction = {PGVECTOR_HNSW_EF_CONSTRUCTION})'
else:
index_options = f"WITH (lists = {PGVECTOR_IVFFLAT_LISTS})"
index_options = f'WITH (lists = {PGVECTOR_IVFFLAT_LISTS})'
return index_method, index_options
def _ensure_vector_index(self, index_method: str, index_options: str) -> None:
index_name = "idx_document_chunk_vector"
index_name = 'idx_document_chunk_vector'
existing_index_def = self.session.execute(
text("""
SELECT indexdef
@@ -214,7 +210,7 @@ class PgvectorClient(VectorDBBase):
AND tablename = 'document_chunk'
AND indexname = :index_name
"""),
{"index_name": index_name},
{'index_name': index_name},
).scalar()
existing_method = self._extract_index_method(existing_index_def)
@@ -222,23 +218,23 @@ class PgvectorClient(VectorDBBase):
raise RuntimeError(
f"Existing pgvector index '{index_name}' uses method '{existing_method}' but configuration now "
f"requires '{index_method}'. Automatic rebuild is disabled to prevent long-running maintenance. "
"Drop the index manually (optionally after tuning maintenance_work_mem/max_parallel_maintenance_workers) "
"and recreate it with the new method before restarting Open WebUI."
'Drop the index manually (optionally after tuning maintenance_work_mem/max_parallel_maintenance_workers) '
'and recreate it with the new method before restarting Open WebUI.'
)
if not existing_index_def:
index_sql = (
f"CREATE INDEX IF NOT EXISTS {index_name} "
f"ON document_chunk USING {index_method} (vector {VECTOR_OPCLASS})"
f'CREATE INDEX IF NOT EXISTS {index_name} '
f'ON document_chunk USING {index_method} (vector {VECTOR_OPCLASS})'
)
if index_options:
index_sql = f"{index_sql} {index_options}"
index_sql = f'{index_sql} {index_options}'
self.session.execute(text(index_sql))
log.info(
"Ensured vector index '%s' using %s%s.",
index_name,
index_method,
f" {index_options}" if index_options else "",
f' {index_options}' if index_options else '',
)
def check_vector_length(self) -> None:
@@ -249,16 +245,14 @@ class PgvectorClient(VectorDBBase):
metadata = MetaData()
try:
# Attempt to reflect the 'document_chunk' table
document_chunk_table = Table(
"document_chunk", metadata, autoload_with=self.session.bind
)
document_chunk_table = Table('document_chunk', metadata, autoload_with=self.session.bind)
except NoSuchTableError:
# Table does not exist; no action needed
return
# Proceed to check the vector column
if "vector" in document_chunk_table.columns:
vector_column = document_chunk_table.columns["vector"]
if 'vector' in document_chunk_table.columns:
vector_column = document_chunk_table.columns['vector']
vector_type = vector_column.type
expected_type = HALFVEC if USE_HALFVEC else Vector
@@ -268,16 +262,14 @@ class PgvectorClient(VectorDBBase):
f"('{expected_type.__name__}') for VECTOR_LENGTH {VECTOR_LENGTH}."
)
db_vector_length = getattr(vector_type, "dim", None)
db_vector_length = getattr(vector_type, 'dim', None)
if db_vector_length is not None and db_vector_length != VECTOR_LENGTH:
raise Exception(
f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
"Cannot change vector size after initialization without migrating the data."
f'VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. '
'Cannot change vector size after initialization without migrating the data.'
)
else:
raise Exception(
"The 'vector' column does not exist in the 'document_chunk' table."
)
raise Exception("The 'vector' column does not exist in the 'document_chunk' table.")
def adjust_vector_length(self, vector: List[float]) -> List[float]:
# Adjust vector to have length VECTOR_LENGTH
@@ -294,10 +286,10 @@ class PgvectorClient(VectorDBBase):
try:
if PGVECTOR_PGCRYPTO:
for item in items:
vector = self.adjust_vector_length(item["vector"])
vector = self.adjust_vector_length(item['vector'])
# Use raw SQL for BYTEA/pgcrypto
# Ensure metadata is converted to its JSON text representation
json_metadata = json.dumps(item["metadata"])
json_metadata = json.dumps(item['metadata'])
self.session.execute(
text("""
INSERT INTO document_chunk
@@ -310,12 +302,12 @@ class PgvectorClient(VectorDBBase):
ON CONFLICT (id) DO NOTHING
"""),
{
"id": item["id"],
"vector": vector,
"collection_name": collection_name,
"text": item["text"],
"metadata_text": json_metadata,
"key": PGVECTOR_PGCRYPTO_KEY,
'id': item['id'],
'vector': vector,
'collection_name': collection_name,
'text': item['text'],
'metadata_text': json_metadata,
'key': PGVECTOR_PGCRYPTO_KEY,
},
)
self.session.commit()
@@ -324,31 +316,29 @@ class PgvectorClient(VectorDBBase):
else:
new_items = []
for item in items:
vector = self.adjust_vector_length(item["vector"])
vector = self.adjust_vector_length(item['vector'])
new_chunk = DocumentChunk(
id=item["id"],
id=item['id'],
vector=vector,
collection_name=collection_name,
text=item["text"],
vmetadata=process_metadata(item["metadata"]),
text=item['text'],
vmetadata=process_metadata(item['metadata']),
)
new_items.append(new_chunk)
self.session.bulk_save_objects(new_items)
self.session.commit()
log.info(
f"Inserted {len(new_items)} items into collection '{collection_name}'."
)
log.info(f"Inserted {len(new_items)} items into collection '{collection_name}'.")
except Exception as e:
self.session.rollback()
log.exception(f"Error during insert: {e}")
log.exception(f'Error during insert: {e}')
raise
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
try:
if PGVECTOR_PGCRYPTO:
for item in items:
vector = self.adjust_vector_length(item["vector"])
json_metadata = json.dumps(item["metadata"])
vector = self.adjust_vector_length(item['vector'])
json_metadata = json.dumps(item['metadata'])
self.session.execute(
text("""
INSERT INTO document_chunk
@@ -365,47 +355,39 @@ class PgvectorClient(VectorDBBase):
vmetadata = EXCLUDED.vmetadata
"""),
{
"id": item["id"],
"vector": vector,
"collection_name": collection_name,
"text": item["text"],
"metadata_text": json_metadata,
"key": PGVECTOR_PGCRYPTO_KEY,
'id': item['id'],
'vector': vector,
'collection_name': collection_name,
'text': item['text'],
'metadata_text': json_metadata,
'key': PGVECTOR_PGCRYPTO_KEY,
},
)
self.session.commit()
log.info(f"Encrypted & upserted {len(items)} into '{collection_name}'")
else:
for item in items:
vector = self.adjust_vector_length(item["vector"])
existing = (
self.session.query(DocumentChunk)
.filter(DocumentChunk.id == item["id"])
.first()
)
vector = self.adjust_vector_length(item['vector'])
existing = self.session.query(DocumentChunk).filter(DocumentChunk.id == item['id']).first()
if existing:
existing.vector = vector
existing.text = item["text"]
existing.vmetadata = process_metadata(item["metadata"])
existing.collection_name = (
collection_name # Update collection_name if necessary
)
existing.text = item['text']
existing.vmetadata = process_metadata(item['metadata'])
existing.collection_name = collection_name # Update collection_name if necessary
else:
new_chunk = DocumentChunk(
id=item["id"],
id=item['id'],
vector=vector,
collection_name=collection_name,
text=item["text"],
vmetadata=process_metadata(item["metadata"]),
text=item['text'],
vmetadata=process_metadata(item['metadata']),
)
self.session.add(new_chunk)
self.session.commit()
log.info(
f"Upserted {len(items)} items into collection '{collection_name}'."
)
log.info(f"Upserted {len(items)} items into collection '{collection_name}'.")
except Exception as e:
self.session.rollback()
log.exception(f"Error during upsert: {e}")
log.exception(f'Error during upsert: {e}')
raise
def search(
@@ -427,38 +409,26 @@ class PgvectorClient(VectorDBBase):
return cast(array(vector), VECTOR_TYPE_FACTORY(VECTOR_LENGTH))
# Create the values for query vectors
qid_col = column("qid", Integer)
q_vector_col = column("q_vector", VECTOR_TYPE_FACTORY(VECTOR_LENGTH))
qid_col = column('qid', Integer)
q_vector_col = column('q_vector', VECTOR_TYPE_FACTORY(VECTOR_LENGTH))
query_vectors = (
values(qid_col, q_vector_col)
.data(
[(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
)
.alias("query_vectors")
.data([(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)])
.alias('query_vectors')
)
result_fields = [
DocumentChunk.id,
]
if PGVECTOR_PGCRYPTO:
result_fields.append(pgcrypto_decrypt(DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text).label('text'))
result_fields.append(
pgcrypto_decrypt(
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
).label("text")
)
result_fields.append(
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
).label("vmetadata")
pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB).label('vmetadata')
)
else:
result_fields.append(DocumentChunk.text)
result_fields.append(DocumentChunk.vmetadata)
result_fields.append(
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label(
"distance"
)
)
result_fields.append((DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label('distance'))
# Build the lateral subquery for each query vector
where_clauses = [DocumentChunk.collection_name == collection_name]
@@ -466,9 +436,9 @@ class PgvectorClient(VectorDBBase):
# Apply metadata filter if provided
if filter:
for key, value in filter.items():
if isinstance(value, dict) and "$in" in value:
if isinstance(value, dict) and '$in' in value:
# Handle $in operator: {"field": {"$in": [values]}}
in_values = value["$in"]
in_values = value['$in']
if PGVECTOR_PGCRYPTO:
where_clauses.append(
pgcrypto_decrypt(
@@ -478,11 +448,7 @@ class PgvectorClient(VectorDBBase):
)[key].astext.in_([str(v) for v in in_values])
)
else:
where_clauses.append(
DocumentChunk.vmetadata[key].astext.in_(
[str(v) for v in in_values]
)
)
where_clauses.append(DocumentChunk.vmetadata[key].astext.in_([str(v) for v in in_values]))
else:
# Handle simple equality: {"field": "value"}
if PGVECTOR_PGCRYPTO:
@@ -495,20 +461,16 @@ class PgvectorClient(VectorDBBase):
== str(value)
)
else:
where_clauses.append(
DocumentChunk.vmetadata[key].astext == str(value)
)
where_clauses.append(DocumentChunk.vmetadata[key].astext == str(value))
subq = (
select(*result_fields)
.where(*where_clauses)
.order_by(
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
)
.order_by((DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)))
)
if limit is not None:
subq = subq.limit(limit)
subq = subq.lateral("result")
subq = subq.lateral('result')
# Build the main query by joining query_vectors and the lateral subquery
stmt = (
@@ -550,17 +512,13 @@ class PgvectorClient(VectorDBBase):
metadatas[qid].append(row.vmetadata)
self.session.rollback() # read-only transaction
return SearchResult(
ids=ids, distances=distances, documents=documents, metadatas=metadatas
)
return SearchResult(ids=ids, distances=distances, documents=documents, metadatas=metadatas)
except Exception as e:
self.session.rollback()
log.exception(f"Error during search: {e}")
log.exception(f'Error during search: {e}')
return None
def query(
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
) -> Optional[GetResult]:
def query(self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None) -> Optional[GetResult]:
try:
if PGVECTOR_PGCRYPTO:
# Build where clause for vmetadata filter
@@ -568,32 +526,22 @@ class PgvectorClient(VectorDBBase):
for key, value in filter.items():
# decrypt then check key: JSON filter after decryption
where_clauses.append(
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
)[key].astext
pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB)[key].astext
== str(value)
)
stmt = select(
DocumentChunk.id,
pgcrypto_decrypt(
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
).label("text"),
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
).label("vmetadata"),
pgcrypto_decrypt(DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text).label('text'),
pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB).label('vmetadata'),
).where(*where_clauses)
if limit is not None:
stmt = stmt.limit(limit)
results = self.session.execute(stmt).all()
else:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name)
for key, value in filter.items():
query = query.filter(
DocumentChunk.vmetadata[key].astext == str(value)
)
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
if limit is not None:
query = query.limit(limit)
@@ -615,22 +563,16 @@ class PgvectorClient(VectorDBBase):
)
except Exception as e:
self.session.rollback()
log.exception(f"Error during query: {e}")
log.exception(f'Error during query: {e}')
return None
def get(
self, collection_name: str, limit: Optional[int] = None
) -> Optional[GetResult]:
def get(self, collection_name: str, limit: Optional[int] = None) -> Optional[GetResult]:
try:
if PGVECTOR_PGCRYPTO:
stmt = select(
DocumentChunk.id,
pgcrypto_decrypt(
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
).label("text"),
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
).label("vmetadata"),
pgcrypto_decrypt(DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text).label('text'),
pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB).label('vmetadata'),
).where(DocumentChunk.collection_name == collection_name)
if limit is not None:
stmt = stmt.limit(limit)
@@ -639,10 +581,7 @@ class PgvectorClient(VectorDBBase):
documents = [[row.text for row in results]]
metadatas = [[row.vmetadata for row in results]]
else:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name)
if limit is not None:
query = query.limit(limit)
@@ -659,7 +598,7 @@ class PgvectorClient(VectorDBBase):
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
except Exception as e:
self.session.rollback()
log.exception(f"Error during get: {e}")
log.exception(f'Error during get: {e}')
return None
def delete(
@@ -676,43 +615,35 @@ class PgvectorClient(VectorDBBase):
if filter:
for key, value in filter.items():
wheres.append(
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
)[key].astext
pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB)[key].astext
== str(value)
)
stmt = DocumentChunk.__table__.delete().where(*wheres)
result = self.session.execute(stmt)
deleted = result.rowcount
else:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name)
if ids:
query = query.filter(DocumentChunk.id.in_(ids))
if filter:
for key, value in filter.items():
query = query.filter(
DocumentChunk.vmetadata[key].astext == str(value)
)
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
deleted = query.delete(synchronize_session=False)
self.session.commit()
log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
except Exception as e:
self.session.rollback()
log.exception(f"Error during delete: {e}")
log.exception(f'Error during delete: {e}')
raise
def reset(self) -> None:
try:
deleted = self.session.query(DocumentChunk).delete()
self.session.commit()
log.info(
f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
)
log.info(f"Reset complete. Deleted {deleted} items from 'document_chunk' table.")
except Exception as e:
self.session.rollback()
log.exception(f"Error during reset: {e}")
log.exception(f'Error during reset: {e}')
raise
def close(self) -> None:
@@ -721,16 +652,14 @@ class PgvectorClient(VectorDBBase):
def has_collection(self, collection_name: str) -> bool:
try:
exists = (
self.session.query(DocumentChunk)
.filter(DocumentChunk.collection_name == collection_name)
.first()
self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name).first()
is not None
)
self.session.rollback() # read-only transaction
return exists
except Exception as e:
self.session.rollback()
log.exception(f"Error checking collection existence: {e}")
log.exception(f'Error checking collection existence: {e}')
return False
def delete_collection(self, collection_name: str) -> None: