mirror of
https://github.com/Shubhamsaboo/awesome-llm-apps.git
synced 2026-04-30 23:31:31 -05:00
added qdrant as db
This commit is contained in:
@@ -1,11 +1,11 @@
|
||||
import os
|
||||
from typing import List, Dict, Any, Literal
|
||||
from typing import List, Dict, Any, Literal, Optional
|
||||
from dataclasses import dataclass
|
||||
import streamlit as st
|
||||
from langchain_core.documents import Document
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain_community.document_loaders import PyPDFLoader
|
||||
from langchain_community.vectorstores import Chroma
|
||||
from langchain_community.vectorstores import Qdrant
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain_openai import ChatOpenAI
|
||||
import tempfile
|
||||
@@ -19,11 +19,17 @@ from langgraph.prebuilt import create_react_agent
|
||||
from langchain_community.tools import DuckDuckGoSearchRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import Distance, VectorParams
|
||||
|
||||
def init_session_state():
|
||||
"""Initialize session state variables"""
|
||||
if 'openai_api_key' not in st.session_state:
|
||||
st.session_state.openai_api_key = ""
|
||||
if 'qdrant_url' not in st.session_state:
|
||||
st.session_state.qdrant_url = ""
|
||||
if 'qdrant_api_key' not in st.session_state:
|
||||
st.session_state.qdrant_api_key = ""
|
||||
if 'embeddings' not in st.session_state:
|
||||
st.session_state.embeddings = None
|
||||
if 'llm' not in st.session_state:
|
||||
@@ -40,61 +46,68 @@ PERSIST_DIRECTORY = "db_storage"
|
||||
class CollectionConfig:
|
||||
name: str
|
||||
description: str
|
||||
collection_name: str
|
||||
persist_directory: str
|
||||
collection_name: str # This will be used as Qdrant collection name
|
||||
|
||||
# Collection configurations
|
||||
COLLECTIONS: Dict[DatabaseType, CollectionConfig] = {
|
||||
"products": CollectionConfig(
|
||||
name="Product Information",
|
||||
description="Product details, specifications, and features",
|
||||
collection_name="products_collection",
|
||||
persist_directory=f"{PERSIST_DIRECTORY}/products"
|
||||
collection_name="products_collection"
|
||||
),
|
||||
"support": CollectionConfig(
|
||||
name="Customer Support & FAQ",
|
||||
description="Customer support information, frequently asked questions, and guides",
|
||||
collection_name="support_collection",
|
||||
persist_directory=f"{PERSIST_DIRECTORY}/support"
|
||||
collection_name="support_collection"
|
||||
),
|
||||
"finance": CollectionConfig(
|
||||
name="Financial Information",
|
||||
description="Financial data, revenue, costs, and liabilities",
|
||||
collection_name="finance_collection",
|
||||
persist_directory=f"{PERSIST_DIRECTORY}/finance"
|
||||
collection_name="finance_collection"
|
||||
)
|
||||
}
|
||||
|
||||
def initialize_models():
|
||||
"""Initialize OpenAI models with API key"""
|
||||
if st.session_state.openai_api_key:
|
||||
"""Initialize OpenAI models and Qdrant client"""
|
||||
if (st.session_state.openai_api_key and
|
||||
st.session_state.qdrant_url and
|
||||
st.session_state.qdrant_api_key):
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = st.session_state.openai_api_key
|
||||
st.session_state.embeddings = OpenAIEmbeddings(model="text-embedding-3-large")
|
||||
st.session_state.embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
|
||||
st.session_state.llm = ChatOpenAI(temperature=0)
|
||||
|
||||
# Ensure directories exist
|
||||
for collection_config in COLLECTIONS.values():
|
||||
os.makedirs(collection_config.persist_directory, exist_ok=True)
|
||||
|
||||
# Initialize Chroma collections
|
||||
st.session_state.databases = {
|
||||
"products": Chroma(
|
||||
collection_name=COLLECTIONS["products"].collection_name,
|
||||
embedding_function=st.session_state.embeddings,
|
||||
persist_directory=COLLECTIONS["products"].persist_directory
|
||||
),
|
||||
"support": Chroma(
|
||||
collection_name=COLLECTIONS["support"].collection_name,
|
||||
embedding_function=st.session_state.embeddings,
|
||||
persist_directory=COLLECTIONS["support"].persist_directory
|
||||
),
|
||||
"finance": Chroma(
|
||||
collection_name=COLLECTIONS["finance"].collection_name,
|
||||
embedding_function=st.session_state.embeddings,
|
||||
persist_directory=COLLECTIONS["finance"].persist_directory
|
||||
try:
|
||||
# Initialize Qdrant client with session state credentials
|
||||
client = QdrantClient(
|
||||
url=st.session_state.qdrant_url,
|
||||
api_key=st.session_state.qdrant_api_key
|
||||
)
|
||||
}
|
||||
return True
|
||||
|
||||
# Test connection
|
||||
client.get_collections()
|
||||
vector_size = 1536
|
||||
st.session_state.databases = {}
|
||||
for db_type, config in COLLECTIONS.items():
|
||||
try:
|
||||
client.get_collection(config.collection_name)
|
||||
except Exception:
|
||||
# Create collection if it doesn't exist
|
||||
client.create_collection(
|
||||
collection_name=config.collection_name,
|
||||
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE)
|
||||
)
|
||||
|
||||
st.session_state.databases[db_type] = Qdrant(
|
||||
client=client,
|
||||
collection_name=config.collection_name,
|
||||
embeddings=st.session_state.embeddings
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
st.error(f"Failed to connect to Qdrant: {str(e)}")
|
||||
return False
|
||||
return False
|
||||
|
||||
def process_document(file) -> List[Document]:
|
||||
@@ -136,33 +149,62 @@ def create_routing_agent() -> Agent:
|
||||
"1. For questions about products, features, specifications, or item details, or product manuals → return 'products'",
|
||||
"2. For questions about help, guidance, troubleshooting, or customer service, FAQ, or guides → return 'support'",
|
||||
"3. For questions about costs, revenue, pricing, or financial data, or financial reports and investments → return 'finance'",
|
||||
"4. Return ONLY the database name, no other text or explanation"
|
||||
"4. Return ONLY the database name, no other text or explanation",
|
||||
"5. If you're not confident about the routing, return an empty response"
|
||||
],
|
||||
markdown=False,
|
||||
show_tool_calls=False
|
||||
)
|
||||
|
||||
def route_query(question: str) -> DatabaseType:
|
||||
def route_query(question: str) -> Optional[DatabaseType]:
|
||||
"""Route query by searching all databases and comparing relevance scores.
|
||||
Returns None if no suitable database is found."""
|
||||
try:
|
||||
best_score = -1
|
||||
best_db_type = None
|
||||
all_scores = {} # Store all scores for debugging
|
||||
|
||||
# Search each database and compare relevance scores
|
||||
for db_type, db in st.session_state.databases.items():
|
||||
results = db.similarity_search_with_score(
|
||||
question,
|
||||
k=3
|
||||
)
|
||||
|
||||
if results:
|
||||
avg_score = sum(score for _, score in results) / len(results)
|
||||
all_scores[db_type] = avg_score
|
||||
|
||||
if avg_score > best_score:
|
||||
best_score = avg_score
|
||||
best_db_type = db_type
|
||||
|
||||
confidence_threshold = 0.5
|
||||
if best_score >= confidence_threshold and best_db_type:
|
||||
st.success(f"Using vector similarity routing: {best_db_type} (confidence: {best_score:.3f})")
|
||||
return best_db_type
|
||||
|
||||
st.warning(f"Low confidence scores (below {confidence_threshold}), falling back to LLM routing")
|
||||
|
||||
# Fallback to LLM routing
|
||||
routing_agent = create_routing_agent()
|
||||
response = routing_agent.run(question)
|
||||
|
||||
db_type = (response.content
|
||||
.strip()
|
||||
.lower()
|
||||
.translate(str.maketrans('', '', '`\'"'))) # More elegant string cleaning
|
||||
.translate(str.maketrans('', '', '`\'"')))
|
||||
|
||||
# Validate database type
|
||||
if db_type not in COLLECTIONS:
|
||||
st.warning(f"Invalid database type: {db_type}, defaulting to products")
|
||||
return "products"
|
||||
|
||||
st.info(f"Routing question to {db_type} database")
|
||||
return db_type
|
||||
if db_type in COLLECTIONS:
|
||||
st.success(f"Using LLM routing decision: {db_type}")
|
||||
return db_type
|
||||
|
||||
st.warning("No suitable database found, will use web search fallback")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Routing error: {str(e)}")
|
||||
return "products"
|
||||
return None
|
||||
|
||||
def create_fallback_agent(chat_model: BaseLanguageModel):
|
||||
"""Create a LangGraph agent for web research."""
|
||||
@@ -184,11 +226,12 @@ def create_fallback_agent(chat_model: BaseLanguageModel):
|
||||
|
||||
return agent
|
||||
|
||||
def query_database(db: Chroma, question: str) -> tuple[str, list]:
|
||||
def query_database(db: Qdrant, question: str) -> tuple[str, list]:
|
||||
"""Query the database and return answer and relevant documents"""
|
||||
try:
|
||||
retriever = db.as_retriever(
|
||||
search_type="similarity_score_threshold",
|
||||
search_kwargs={"k": 4, "score_threshold": 0.3}
|
||||
search_type="similarity",
|
||||
search_kwargs={"k": 4}
|
||||
)
|
||||
|
||||
relevant_docs = retriever.get_relevant_documents(question)
|
||||
@@ -210,7 +253,8 @@ def query_database(db: Chroma, question: str) -> tuple[str, list]:
|
||||
|
||||
response = retrieval_chain.invoke({"input": question})
|
||||
return response['answer'], relevant_docs
|
||||
return _handle_web_fallback(question)
|
||||
|
||||
raise ValueError("No relevant documents found in database")
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Error: {str(e)}")
|
||||
@@ -244,9 +288,11 @@ def main():
|
||||
st.set_page_config(page_title="RAG Agent with Database Routing", page_icon="📚")
|
||||
st.title("📚 RAG Agent with Database Routing")
|
||||
|
||||
# Sidebar for API key and database management
|
||||
# Sidebar for API keys and configuration
|
||||
with st.sidebar:
|
||||
st.header("Configuration")
|
||||
|
||||
# OpenAI API Key
|
||||
api_key = st.text_input(
|
||||
"Enter OpenAI API Key:",
|
||||
type="password",
|
||||
@@ -254,15 +300,37 @@ def main():
|
||||
key="api_key_input"
|
||||
)
|
||||
|
||||
# Qdrant Configuration
|
||||
qdrant_url = st.text_input(
|
||||
"Enter Qdrant URL:",
|
||||
value=st.session_state.qdrant_url,
|
||||
help="Example: https://your-cluster.qdrant.tech"
|
||||
)
|
||||
|
||||
qdrant_api_key = st.text_input(
|
||||
"Enter Qdrant API Key:",
|
||||
type="password",
|
||||
value=st.session_state.qdrant_api_key
|
||||
)
|
||||
|
||||
# Update session state
|
||||
if api_key:
|
||||
st.session_state.openai_api_key = api_key
|
||||
if qdrant_url:
|
||||
st.session_state.qdrant_url = qdrant_url
|
||||
if qdrant_api_key:
|
||||
st.session_state.qdrant_api_key = qdrant_api_key
|
||||
|
||||
# Initialize models if all credentials are provided
|
||||
if (st.session_state.openai_api_key and
|
||||
st.session_state.qdrant_url and
|
||||
st.session_state.qdrant_api_key):
|
||||
if initialize_models():
|
||||
st.success("API Key set successfully!")
|
||||
st.success("Connected to OpenAI and Qdrant successfully!")
|
||||
else:
|
||||
st.error("Invalid API Key")
|
||||
|
||||
if not st.session_state.openai_api_key:
|
||||
st.warning("Please enter your OpenAI API key to continue")
|
||||
st.error("Failed to initialize. Please check your credentials.")
|
||||
else:
|
||||
st.warning("Please enter all required credentials to continue")
|
||||
st.stop()
|
||||
|
||||
st.markdown("---")
|
||||
@@ -302,15 +370,19 @@ def main():
|
||||
with st.spinner('Finding answer...'):
|
||||
# Route the question
|
||||
collection_type = route_query(question)
|
||||
db = st.session_state.databases[collection_type]
|
||||
|
||||
# Display routing information
|
||||
st.info(f"Routing question to: {COLLECTIONS[collection_type].name}")
|
||||
|
||||
# Get and display answer
|
||||
answer, relevant_docs = query_database(db, question)
|
||||
st.write("### Answer")
|
||||
st.write(answer)
|
||||
if collection_type is None:
|
||||
# Use web search fallback directly
|
||||
answer, relevant_docs = _handle_web_fallback(question)
|
||||
st.write("### Answer (from web search)")
|
||||
st.write(answer)
|
||||
else:
|
||||
# Display routing information and query the database
|
||||
st.info(f"Routing question to: {COLLECTIONS[collection_type].name}")
|
||||
db = st.session_state.databases[collection_type]
|
||||
answer, relevant_docs = query_database(db, question)
|
||||
st.write("### Answer")
|
||||
st.write(answer)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user