from typing import Optional, List, Dict, Any, Tuple import logging import json from sqlalchemy import ( func, literal, cast, column, create_engine, Column, Integer, MetaData, LargeBinary, select, text, Text, Table, values, ) from sqlalchemy.sql import true from sqlalchemy.pool import NullPool, QueuePool from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker from sqlalchemy.dialects.postgresql import JSONB, array from pgvector.sqlalchemy import Vector, HALFVEC from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.exc import NoSuchTableError from open_webui.retrieval.vector.utils import process_metadata from open_webui.retrieval.vector.main import ( VectorDBBase, VectorItem, SearchResult, GetResult, ) from open_webui.utils.misc import sanitize_text_for_db from open_webui.config import ( PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH, PGVECTOR_CREATE_EXTENSION, PGVECTOR_PGCRYPTO, PGVECTOR_PGCRYPTO_KEY, PGVECTOR_POOL_SIZE, PGVECTOR_POOL_MAX_OVERFLOW, PGVECTOR_POOL_TIMEOUT, PGVECTOR_POOL_RECYCLE, PGVECTOR_INDEX_METHOD, PGVECTOR_HNSW_M, PGVECTOR_HNSW_EF_CONSTRUCTION, PGVECTOR_IVFFLAT_LISTS, PGVECTOR_USE_HALFVEC, ) VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH USE_HALFVEC = PGVECTOR_USE_HALFVEC VECTOR_TYPE_FACTORY = HALFVEC if USE_HALFVEC else Vector VECTOR_OPCLASS = 'halfvec_cosine_ops' if USE_HALFVEC else 'vector_cosine_ops' Base = declarative_base() log = logging.getLogger(__name__) def pgcrypto_encrypt(val, key): return func.pgp_sym_encrypt(val, literal(key)) def pgcrypto_decrypt(col, key, outtype='text'): return func.cast(func.pgp_sym_decrypt(col, literal(key)), outtype) class DocumentChunk(Base): __tablename__ = 'document_chunk' id = Column(Text, primary_key=True) vector = Column(VECTOR_TYPE_FACTORY(dim=VECTOR_LENGTH), nullable=True) collection_name = Column(Text, nullable=False) if PGVECTOR_PGCRYPTO: text = Column(LargeBinary, nullable=True) vmetadata = Column(LargeBinary, nullable=True) else: text = Column(Text, nullable=True) vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True) class PgvectorClient(VectorDBBase): def __init__(self) -> None: # if no pgvector uri, use the existing database connection if not PGVECTOR_DB_URL: from open_webui.internal.db import ScopedSession self.session = ScopedSession else: if isinstance(PGVECTOR_POOL_SIZE, int): if PGVECTOR_POOL_SIZE > 0: engine = create_engine( PGVECTOR_DB_URL, pool_size=PGVECTOR_POOL_SIZE, max_overflow=PGVECTOR_POOL_MAX_OVERFLOW, pool_timeout=PGVECTOR_POOL_TIMEOUT, pool_recycle=PGVECTOR_POOL_RECYCLE, pool_pre_ping=True, poolclass=QueuePool, ) else: engine = create_engine(PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool) else: engine = create_engine(PGVECTOR_DB_URL, pool_pre_ping=True) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, expire_on_commit=False) self.session = scoped_session(SessionLocal) try: # Ensure the pgvector extension is available # Use a conditional check to avoid permission issues on Azure PostgreSQL if PGVECTOR_CREATE_EXTENSION: self.session.execute( text(""" DO $$ BEGIN IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'vector') THEN CREATE EXTENSION IF NOT EXISTS vector; END IF; END $$; """) ) if PGVECTOR_PGCRYPTO: # Ensure the pgcrypto extension is available for encryption # Use a conditional check to avoid permission issues on Azure PostgreSQL self.session.execute( text(""" DO $$ BEGIN IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pgcrypto') THEN CREATE EXTENSION IF NOT EXISTS pgcrypto; END IF; END $$; """) ) if not PGVECTOR_PGCRYPTO_KEY: raise ValueError('PGVECTOR_PGCRYPTO_KEY must be set when PGVECTOR_PGCRYPTO is enabled.') # Check vector length consistency self.check_vector_length() # Create the tables if they do not exist # Base.metadata.create_all requires a bind (engine or connection) # Get the connection from the session connection = self.session.connection() Base.metadata.create_all(bind=connection) index_method, index_options = self._vector_index_configuration() self._ensure_vector_index(index_method, index_options) self.session.execute( text( 'CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name ON document_chunk (collection_name);' ) ) self.session.commit() log.info('Initialization complete.') except Exception as e: self.session.rollback() log.exception(f'Error during initialization: {e}') raise @staticmethod def _extract_index_method(index_def: Optional[str]) -> Optional[str]: if not index_def: return None try: after_using = index_def.lower().split('using ', 1)[1] return after_using.split()[0] except (IndexError, AttributeError): return None def _vector_index_configuration(self) -> Tuple[str, str]: if PGVECTOR_INDEX_METHOD: index_method = PGVECTOR_INDEX_METHOD log.info( "Using vector index method '%s' from PGVECTOR_INDEX_METHOD.", index_method, ) elif USE_HALFVEC: index_method = 'hnsw' log.info( 'VECTOR_LENGTH=%s exceeds 2000; using halfvec column type with hnsw index.', VECTOR_LENGTH, ) else: index_method = 'ivfflat' if index_method == 'hnsw': index_options = f'WITH (m = {PGVECTOR_HNSW_M}, ef_construction = {PGVECTOR_HNSW_EF_CONSTRUCTION})' else: index_options = f'WITH (lists = {PGVECTOR_IVFFLAT_LISTS})' return index_method, index_options def _ensure_vector_index(self, index_method: str, index_options: str) -> None: index_name = 'idx_document_chunk_vector' existing_index_def = self.session.execute( text(""" SELECT indexdef FROM pg_indexes WHERE schemaname = current_schema() AND tablename = 'document_chunk' AND indexname = :index_name """), {'index_name': index_name}, ).scalar() existing_method = self._extract_index_method(existing_index_def) if existing_method and existing_method != index_method: raise RuntimeError( f"Existing pgvector index '{index_name}' uses method '{existing_method}' but configuration now " f"requires '{index_method}'. Automatic rebuild is disabled to prevent long-running maintenance. " 'Drop the index manually (optionally after tuning maintenance_work_mem/max_parallel_maintenance_workers) ' 'and recreate it with the new method before restarting Open WebUI.' ) if not existing_index_def: index_sql = ( f'CREATE INDEX IF NOT EXISTS {index_name} ' f'ON document_chunk USING {index_method} (vector {VECTOR_OPCLASS})' ) if index_options: index_sql = f'{index_sql} {index_options}' self.session.execute(text(index_sql)) log.info( "Ensured vector index '%s' using %s%s.", index_name, index_method, f' {index_options}' if index_options else '', ) def check_vector_length(self) -> None: """ Check if the VECTOR_LENGTH matches the existing vector column dimension in the database. Raises an exception if there is a mismatch. """ metadata = MetaData() try: # Attempt to reflect the 'document_chunk' table document_chunk_table = Table('document_chunk', metadata, autoload_with=self.session.bind) except NoSuchTableError: # Table does not exist; no action needed return # Proceed to check the vector column if 'vector' in document_chunk_table.columns: vector_column = document_chunk_table.columns['vector'] vector_type = vector_column.type expected_type = HALFVEC if USE_HALFVEC else Vector if not isinstance(vector_type, expected_type): raise Exception( "The 'vector' column type does not match the expected type " f"('{expected_type.__name__}') for VECTOR_LENGTH {VECTOR_LENGTH}." ) db_vector_length = getattr(vector_type, 'dim', None) if db_vector_length is not None and db_vector_length != VECTOR_LENGTH: raise Exception( f'VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. ' 'Cannot change vector size after initialization without migrating the data.' ) else: raise Exception("The 'vector' column does not exist in the 'document_chunk' table.") def adjust_vector_length(self, vector: List[float]) -> List[float]: # Adjust vector to have length VECTOR_LENGTH current_length = len(vector) if current_length < VECTOR_LENGTH: # Pad the vector with zeros vector += [0.0] * (VECTOR_LENGTH - current_length) elif current_length > VECTOR_LENGTH: # Truncate the vector to VECTOR_LENGTH vector = vector[:VECTOR_LENGTH] return vector def insert(self, collection_name: str, items: List[VectorItem]) -> None: try: if PGVECTOR_PGCRYPTO: for item in items: vector = self.adjust_vector_length(item['vector']) # Use raw SQL for BYTEA/pgcrypto # Ensure metadata is converted to its JSON text representation # Sanitize to strip null bytes / surrogates that PostgreSQL cannot store json_metadata = sanitize_text_for_db(json.dumps(item['metadata'])) item_text = sanitize_text_for_db(item['text']) self.session.execute( text(""" INSERT INTO document_chunk (id, vector, collection_name, text, vmetadata) VALUES ( :id, :vector, :collection_name, pgp_sym_encrypt(:text, :key), pgp_sym_encrypt(:metadata_text, :key) ) ON CONFLICT (id) DO NOTHING """), { 'id': item['id'], 'vector': vector, 'collection_name': collection_name, 'text': item_text, 'metadata_text': json_metadata, 'key': PGVECTOR_PGCRYPTO_KEY, }, ) self.session.commit() log.info(f"Encrypted & inserted {len(items)} into '{collection_name}'") else: new_items = [] for item in items: vector = self.adjust_vector_length(item['vector']) new_chunk = DocumentChunk( id=item['id'], vector=vector, collection_name=collection_name, text=item['text'], vmetadata=process_metadata(item['metadata']), ) new_items.append(new_chunk) self.session.bulk_save_objects(new_items) self.session.commit() log.info(f"Inserted {len(new_items)} items into collection '{collection_name}'.") except Exception as e: self.session.rollback() log.exception(f'Error during insert: {e}') raise def upsert(self, collection_name: str, items: List[VectorItem]) -> None: try: if PGVECTOR_PGCRYPTO: for item in items: vector = self.adjust_vector_length(item['vector']) # Sanitize to strip null bytes / surrogates that PostgreSQL cannot store json_metadata = sanitize_text_for_db(json.dumps(item['metadata'])) item_text = sanitize_text_for_db(item['text']) self.session.execute( text(""" INSERT INTO document_chunk (id, vector, collection_name, text, vmetadata) VALUES ( :id, :vector, :collection_name, pgp_sym_encrypt(:text, :key), pgp_sym_encrypt(:metadata_text, :key) ) ON CONFLICT (id) DO UPDATE SET vector = EXCLUDED.vector, collection_name = EXCLUDED.collection_name, text = EXCLUDED.text, vmetadata = EXCLUDED.vmetadata """), { 'id': item['id'], 'vector': vector, 'collection_name': collection_name, 'text': item_text, 'metadata_text': json_metadata, 'key': PGVECTOR_PGCRYPTO_KEY, }, ) self.session.commit() log.info(f"Encrypted & upserted {len(items)} into '{collection_name}'") else: for item in items: vector = self.adjust_vector_length(item['vector']) existing = self.session.query(DocumentChunk).filter(DocumentChunk.id == item['id']).first() if existing: existing.vector = vector existing.text = item['text'] existing.vmetadata = process_metadata(item['metadata']) existing.collection_name = collection_name # Update collection_name if necessary else: new_chunk = DocumentChunk( id=item['id'], vector=vector, collection_name=collection_name, text=item['text'], vmetadata=process_metadata(item['metadata']), ) self.session.add(new_chunk) self.session.commit() log.info(f"Upserted {len(items)} items into collection '{collection_name}'.") except Exception as e: self.session.rollback() log.exception(f'Error during upsert: {e}') raise def search( self, collection_name: str, vectors: List[List[float]], filter: Optional[Dict[str, Any]] = None, limit: int = 10, ) -> Optional[SearchResult]: try: if not vectors: return None # Adjust query vectors to VECTOR_LENGTH vectors = [self.adjust_vector_length(vector) for vector in vectors] num_queries = len(vectors) def vector_expr(vector): return cast(array(vector), VECTOR_TYPE_FACTORY(VECTOR_LENGTH)) # Create the values for query vectors qid_col = column('qid', Integer) q_vector_col = column('q_vector', VECTOR_TYPE_FACTORY(VECTOR_LENGTH)) query_vectors = ( values(qid_col, q_vector_col) .data([(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]) .alias('query_vectors') ) result_fields = [ DocumentChunk.id, ] if PGVECTOR_PGCRYPTO: result_fields.append(pgcrypto_decrypt(DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text).label('text')) result_fields.append( pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB).label('vmetadata') ) else: result_fields.append(DocumentChunk.text) result_fields.append(DocumentChunk.vmetadata) result_fields.append((DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label('distance')) # Build the lateral subquery for each query vector where_clauses = [DocumentChunk.collection_name == collection_name] # Apply metadata filter if provided if filter: for key, value in filter.items(): if isinstance(value, dict) and '$in' in value: # Handle $in operator: {"field": {"$in": [values]}} in_values = value['$in'] if PGVECTOR_PGCRYPTO: where_clauses.append( pgcrypto_decrypt( DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB, )[key].astext.in_([str(v) for v in in_values]) ) else: where_clauses.append(DocumentChunk.vmetadata[key].astext.in_([str(v) for v in in_values])) else: # Handle simple equality: {"field": "value"} if PGVECTOR_PGCRYPTO: where_clauses.append( pgcrypto_decrypt( DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB, )[key].astext == str(value) ) else: where_clauses.append(DocumentChunk.vmetadata[key].astext == str(value)) subq = ( select(*result_fields) .where(*where_clauses) .order_by((DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))) ) if limit is not None: subq = subq.limit(limit) subq = subq.lateral('result') # Build the main query by joining query_vectors and the lateral subquery stmt = ( select( query_vectors.c.qid, subq.c.id, subq.c.text, subq.c.vmetadata, subq.c.distance, ) .select_from(query_vectors) .join(subq, true()) .order_by(query_vectors.c.qid, subq.c.distance) ) result_proxy = self.session.execute(stmt) results = result_proxy.all() ids = [[] for _ in range(num_queries)] distances = [[] for _ in range(num_queries)] documents = [[] for _ in range(num_queries)] metadatas = [[] for _ in range(num_queries)] if not results: return SearchResult( ids=ids, distances=distances, documents=documents, metadatas=metadatas, ) for row in results: qid = int(row.qid) ids[qid].append(row.id) # normalize and re-orders pgvec distance from [2, 0] to [0, 1] score range # https://github.com/pgvector/pgvector?tab=readme-ov-file#querying distances[qid].append((2.0 - row.distance) / 2.0) documents[qid].append(row.text) metadatas[qid].append(row.vmetadata) self.session.rollback() # read-only transaction return SearchResult(ids=ids, distances=distances, documents=documents, metadatas=metadatas) except Exception as e: self.session.rollback() log.exception(f'Error during search: {e}') return None def query(self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None) -> Optional[GetResult]: try: if PGVECTOR_PGCRYPTO: # Build where clause for vmetadata filter where_clauses = [DocumentChunk.collection_name == collection_name] for key, value in filter.items(): # decrypt then check key: JSON filter after decryption where_clauses.append( pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB)[key].astext == str(value) ) stmt = select( DocumentChunk.id, pgcrypto_decrypt(DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text).label('text'), pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB).label('vmetadata'), ).where(*where_clauses) if limit is not None: stmt = stmt.limit(limit) results = self.session.execute(stmt).all() else: query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name) for key, value in filter.items(): query = query.filter(DocumentChunk.vmetadata[key].astext == str(value)) if limit is not None: query = query.limit(limit) results = query.all() if not results: return None ids = [[result.id for result in results]] documents = [[result.text for result in results]] metadatas = [[result.vmetadata for result in results]] self.session.rollback() # read-only transaction return GetResult( ids=ids, documents=documents, metadatas=metadatas, ) except Exception as e: self.session.rollback() log.exception(f'Error during query: {e}') return None def get(self, collection_name: str, limit: Optional[int] = None) -> Optional[GetResult]: try: if PGVECTOR_PGCRYPTO: stmt = select( DocumentChunk.id, pgcrypto_decrypt(DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text).label('text'), pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB).label('vmetadata'), ).where(DocumentChunk.collection_name == collection_name) if limit is not None: stmt = stmt.limit(limit) results = self.session.execute(stmt).all() ids = [[row.id for row in results]] documents = [[row.text for row in results]] metadatas = [[row.vmetadata for row in results]] else: query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name) if limit is not None: query = query.limit(limit) results = query.all() if not results: return None ids = [[result.id for result in results]] documents = [[result.text for result in results]] metadatas = [[result.vmetadata for result in results]] self.session.rollback() # read-only transaction return GetResult(ids=ids, documents=documents, metadatas=metadatas) except Exception as e: self.session.rollback() log.exception(f'Error during get: {e}') return None def delete( self, collection_name: str, ids: Optional[List[str]] = None, filter: Optional[Dict[str, Any]] = None, ) -> None: try: if PGVECTOR_PGCRYPTO: wheres = [DocumentChunk.collection_name == collection_name] if ids: wheres.append(DocumentChunk.id.in_(ids)) if filter: for key, value in filter.items(): wheres.append( pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB)[key].astext == str(value) ) stmt = DocumentChunk.__table__.delete().where(*wheres) result = self.session.execute(stmt) deleted = result.rowcount else: query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name) if ids: query = query.filter(DocumentChunk.id.in_(ids)) if filter: for key, value in filter.items(): query = query.filter(DocumentChunk.vmetadata[key].astext == str(value)) deleted = query.delete(synchronize_session=False) self.session.commit() log.info(f"Deleted {deleted} items from collection '{collection_name}'.") except Exception as e: self.session.rollback() log.exception(f'Error during delete: {e}') raise def reset(self) -> None: try: deleted = self.session.query(DocumentChunk).delete() self.session.commit() log.info(f"Reset complete. Deleted {deleted} items from 'document_chunk' table.") except Exception as e: self.session.rollback() log.exception(f'Error during reset: {e}') raise def close(self) -> None: pass def has_collection(self, collection_name: str) -> bool: try: exists = ( self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name).first() is not None ) self.session.rollback() # read-only transaction return exists except Exception as e: self.session.rollback() log.exception(f'Error checking collection existence: {e}') return False def delete_collection(self, collection_name: str) -> None: self.delete(collection_name) log.info(f"Collection '{collection_name}' deleted.")