This commit is contained in:
Timothy Jaeryang Baek
2026-03-17 17:58:01 -05:00
parent fcf7208352
commit de3317e26b
220 changed files with 17200 additions and 22836 deletions

View File

@@ -36,17 +36,15 @@ from sqlalchemy.dialects import registry
class OpenGaussDialect(PGDialect_psycopg2):
name = "opengauss"
name = 'opengauss'
def _get_server_version_info(self, connection):
try:
version = connection.exec_driver_sql("SELECT version()").scalar()
version = connection.exec_driver_sql('SELECT version()').scalar()
if not version:
return (9, 0, 0)
match = re.search(
r"openGauss\s+(\d+)\.(\d+)\.(\d+)(?:-\w+)?", version, re.IGNORECASE
)
match = re.search(r'openGauss\s+(\d+)\.(\d+)\.(\d+)(?:-\w+)?', version, re.IGNORECASE)
if match:
return (int(match.group(1)), int(match.group(2)), int(match.group(3)))
@@ -56,7 +54,7 @@ class OpenGaussDialect(PGDialect_psycopg2):
# Register dialect
registry.register("opengauss", __name__, "OpenGaussDialect")
registry.register('opengauss', __name__, 'OpenGaussDialect')
from open_webui.retrieval.vector.utils import process_metadata
from open_webui.retrieval.vector.main import (
@@ -80,11 +78,11 @@ VECTOR_LENGTH = OPENGAUSS_INITIALIZE_MAX_VECTOR_LENGTH
Base = declarative_base()
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
log.setLevel(SRC_LOG_LEVELS['RAG'])
class DocumentChunk(Base):
__tablename__ = "document_chunk"
__tablename__ = 'document_chunk'
id = Column(Text, primary_key=True)
vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
@@ -100,26 +98,24 @@ class OpenGaussClient(VectorDBBase):
self.session = ScopedSession
else:
engine_kwargs = {"pool_pre_ping": True, "dialect": OpenGaussDialect()}
engine_kwargs = {'pool_pre_ping': True, 'dialect': OpenGaussDialect()}
if isinstance(OPENGAUSS_POOL_SIZE, int) and OPENGAUSS_POOL_SIZE > 0:
engine_kwargs.update(
{
"pool_size": OPENGAUSS_POOL_SIZE,
"max_overflow": OPENGAUSS_POOL_MAX_OVERFLOW,
"pool_timeout": OPENGAUSS_POOL_TIMEOUT,
"pool_recycle": OPENGAUSS_POOL_RECYCLE,
"poolclass": QueuePool,
'pool_size': OPENGAUSS_POOL_SIZE,
'max_overflow': OPENGAUSS_POOL_MAX_OVERFLOW,
'pool_timeout': OPENGAUSS_POOL_TIMEOUT,
'pool_recycle': OPENGAUSS_POOL_RECYCLE,
'poolclass': QueuePool,
}
)
else:
engine_kwargs["poolclass"] = NullPool
engine_kwargs['poolclass'] = NullPool
engine = create_engine(OPENGAUSS_DB_URL, **engine_kwargs)
SessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, expire_on_commit=False)
self.session = scoped_session(SessionLocal)
try:
@@ -128,47 +124,42 @@ class OpenGaussClient(VectorDBBase):
self.session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
"ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
'CREATE INDEX IF NOT EXISTS idx_document_chunk_vector '
'ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);'
)
)
self.session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
"ON document_chunk (collection_name);"
'CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name ON document_chunk (collection_name);'
)
)
self.session.commit()
log.info("OpenGauss vector database initialization completed.")
log.info('OpenGauss vector database initialization completed.')
except Exception as e:
self.session.rollback()
log.exception(f"OpenGauss Initialization failed.: {e}")
log.exception(f'OpenGauss Initialization failed.: {e}')
raise
def check_vector_length(self) -> None:
metadata = MetaData()
try:
document_chunk_table = Table(
"document_chunk", metadata, autoload_with=self.session.bind
)
document_chunk_table = Table('document_chunk', metadata, autoload_with=self.session.bind)
except NoSuchTableError:
return
if "vector" in document_chunk_table.columns:
vector_column = document_chunk_table.columns["vector"]
if 'vector' in document_chunk_table.columns:
vector_column = document_chunk_table.columns['vector']
vector_type = vector_column.type
if isinstance(vector_type, Vector):
db_vector_length = vector_type.dim
if db_vector_length != VECTOR_LENGTH:
raise Exception(
f"Vector dimension mismatch: configured {VECTOR_LENGTH} vs. {db_vector_length} in the database."
f'Vector dimension mismatch: configured {VECTOR_LENGTH} vs. {db_vector_length} in the database.'
)
else:
raise Exception("The 'vector' column type is not Vector.")
else:
raise Exception(
"The 'vector' column does not exist in the 'document_chunk' table."
)
raise Exception("The 'vector' column does not exist in the 'document_chunk' table.")
def adjust_vector_length(self, vector: List[float]) -> List[float]:
current_length = len(vector)
@@ -182,55 +173,47 @@ class OpenGaussClient(VectorDBBase):
try:
new_items = []
for item in items:
vector = self.adjust_vector_length(item["vector"])
vector = self.adjust_vector_length(item['vector'])
new_chunk = DocumentChunk(
id=item["id"],
id=item['id'],
vector=vector,
collection_name=collection_name,
text=item["text"],
vmetadata=process_metadata(item["metadata"]),
text=item['text'],
vmetadata=process_metadata(item['metadata']),
)
new_items.append(new_chunk)
self.session.bulk_save_objects(new_items)
self.session.commit()
log.info(
f"Inserting {len(new_items)} items into collection '{collection_name}'."
)
log.info(f"Inserting {len(new_items)} items into collection '{collection_name}'.")
except Exception as e:
self.session.rollback()
log.exception(f"Failed to insert data: {e}")
log.exception(f'Failed to insert data: {e}')
raise
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
try:
for item in items:
vector = self.adjust_vector_length(item["vector"])
existing = (
self.session.query(DocumentChunk)
.filter(DocumentChunk.id == item["id"])
.first()
)
vector = self.adjust_vector_length(item['vector'])
existing = self.session.query(DocumentChunk).filter(DocumentChunk.id == item['id']).first()
if existing:
existing.vector = vector
existing.text = item["text"]
existing.vmetadata = process_metadata(item["metadata"])
existing.text = item['text']
existing.vmetadata = process_metadata(item['metadata'])
existing.collection_name = collection_name
else:
new_chunk = DocumentChunk(
id=item["id"],
id=item['id'],
vector=vector,
collection_name=collection_name,
text=item["text"],
vmetadata=process_metadata(item["metadata"]),
text=item['text'],
vmetadata=process_metadata(item['metadata']),
)
self.session.add(new_chunk)
self.session.commit()
log.info(
f"Inserting/updating {len(items)} items in collection '{collection_name}'."
)
log.info(f"Inserting/updating {len(items)} items in collection '{collection_name}'.")
except Exception as e:
self.session.rollback()
log.exception(f"Failed to insert or update data.: {e}")
log.exception(f'Failed to insert or update data.: {e}')
raise
def search(
@@ -250,35 +233,29 @@ class OpenGaussClient(VectorDBBase):
def vector_expr(vector):
return cast(array(vector), Vector(VECTOR_LENGTH))
qid_col = column("qid", Integer)
q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
qid_col = column('qid', Integer)
q_vector_col = column('q_vector', Vector(VECTOR_LENGTH))
query_vectors = (
values(qid_col, q_vector_col)
.data(
[(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
)
.alias("query_vectors")
.data([(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)])
.alias('query_vectors')
)
result_fields = [
DocumentChunk.id,
DocumentChunk.text,
DocumentChunk.vmetadata,
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label(
"distance"
),
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label('distance'),
]
subq = (
select(*result_fields)
.where(DocumentChunk.collection_name == collection_name)
.order_by(
DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
)
.order_by(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
)
if limit is not None:
subq = subq.limit(limit)
subq = subq.lateral("result")
subq = subq.lateral('result')
stmt = (
select(
@@ -309,21 +286,15 @@ class OpenGaussClient(VectorDBBase):
metadatas[qid].append(row.vmetadata)
self.session.rollback()
return SearchResult(
ids=ids, distances=distances, documents=documents, metadatas=metadatas
)
return SearchResult(ids=ids, distances=distances, documents=documents, metadatas=metadatas)
except Exception as e:
self.session.rollback()
log.exception(f"Vector search failed: {e}")
log.exception(f'Vector search failed: {e}')
return None
def query(
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
) -> Optional[GetResult]:
def query(self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None) -> Optional[GetResult]:
try:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name)
for key, value in filter.items():
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
@@ -344,16 +315,12 @@ class OpenGaussClient(VectorDBBase):
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
except Exception as e:
self.session.rollback()
log.exception(f"Conditional query failed: {e}")
log.exception(f'Conditional query failed: {e}')
return None
def get(
self, collection_name: str, limit: Optional[int] = None
) -> Optional[GetResult]:
def get(self, collection_name: str, limit: Optional[int] = None) -> Optional[GetResult]:
try:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name)
if limit is not None:
query = query.limit(limit)
@@ -370,7 +337,7 @@ class OpenGaussClient(VectorDBBase):
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
except Exception as e:
self.session.rollback()
log.exception(f"Failed to retrieve data: {e}")
log.exception(f'Failed to retrieve data: {e}')
return None
def delete(
@@ -380,32 +347,28 @@ class OpenGaussClient(VectorDBBase):
filter: Optional[Dict[str, Any]] = None,
) -> None:
try:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name)
if ids:
query = query.filter(DocumentChunk.id.in_(ids))
if filter:
for key, value in filter.items():
query = query.filter(
DocumentChunk.vmetadata[key].astext == str(value)
)
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
deleted = query.delete(synchronize_session=False)
self.session.commit()
log.info(f"Deleted {deleted} items from collection '{collection_name}'")
except Exception as e:
self.session.rollback()
log.exception(f"Failed to delete data: {e}")
log.exception(f'Failed to delete data: {e}')
raise
def reset(self) -> None:
try:
deleted = self.session.query(DocumentChunk).delete()
self.session.commit()
log.info(f"Reset completed. Deleted {deleted} items")
log.info(f'Reset completed. Deleted {deleted} items')
except Exception as e:
self.session.rollback()
log.exception(f"Reset failed: {e}")
log.exception(f'Reset failed: {e}')
raise
def close(self) -> None:
@@ -414,16 +377,14 @@ class OpenGaussClient(VectorDBBase):
def has_collection(self, collection_name: str) -> bool:
try:
exists = (
self.session.query(DocumentChunk)
.filter(DocumentChunk.collection_name == collection_name)
.first()
self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name).first()
is not None
)
self.session.rollback()
return exists
except Exception as e:
self.session.rollback()
log.exception(f"Failed to check collection existence: {e}")
log.exception(f'Failed to check collection existence: {e}')
return False
def delete_collection(self, collection_name: str) -> None: