mirror of
https://github.com/Shubhamsaboo/awesome-llm-apps.git
synced 2026-04-30 15:20:47 -05:00
simple implementation of chain based - db routing
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import getpass
|
||||||
from typing import List, Dict, Any, Literal
|
from typing import List, Dict, Any, Literal
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
@@ -7,51 +8,35 @@ from langchain_core.documents import Document
|
|||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
from langchain_community.document_loaders import PyPDFLoader
|
from langchain_community.document_loaders import PyPDFLoader
|
||||||
from langchain_community.vectorstores import Chroma
|
from langchain_community.vectorstores import Chroma
|
||||||
from langchain_community.embeddings import OpenAIEmbeddings
|
from langchain_openai import OpenAIEmbeddings
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from langchain_core.runnables import RunnableSequence
|
||||||
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
from langchain_chroma import Chroma
|
||||||
|
|
||||||
# Load environment variables
|
def init_session_state():
|
||||||
load_dotenv()
|
"""Initialize session state variables"""
|
||||||
|
if 'openai_api_key' not in st.session_state:
|
||||||
|
st.session_state.openai_api_key = ""
|
||||||
|
if 'embeddings' not in st.session_state:
|
||||||
|
st.session_state.embeddings = None
|
||||||
|
if 'llm' not in st.session_state:
|
||||||
|
st.session_state.llm = None
|
||||||
|
if 'databases' not in st.session_state:
|
||||||
|
st.session_state.databases = {}
|
||||||
|
|
||||||
|
# Initialize session state at the top
|
||||||
|
init_session_state()
|
||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
DatabaseType = Literal["products", "customer_support", "financials"]
|
DatabaseType = Literal["products", "customer_support", "financials"]
|
||||||
PERSIST_DIRECTORY = "db_storage"
|
PERSIST_DIRECTORY = "db_storage"
|
||||||
|
|
||||||
@dataclass
|
ROUTER_TEMPLATE = """You are a query routing expert. Your job is to analyze user questions and determine which databases might contain relevant information.
|
||||||
class Database:
|
|
||||||
"""Class to represent a database configuration"""
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
collection_name: str
|
|
||||||
persist_directory: str
|
|
||||||
|
|
||||||
# Database configurations
|
|
||||||
DATABASES: Dict[DatabaseType, Database] = {
|
|
||||||
"products": Database(
|
|
||||||
name="Product Information",
|
|
||||||
description="Product details, specifications, and features",
|
|
||||||
collection_name="products_db",
|
|
||||||
persist_directory=f"{PERSIST_DIRECTORY}/products"
|
|
||||||
),
|
|
||||||
"customer_support": Database(
|
|
||||||
name="Customer Support & FAQ",
|
|
||||||
description="Customer support information, frequently asked questions, and guides",
|
|
||||||
collection_name="support_db",
|
|
||||||
persist_directory=f"{PERSIST_DIRECTORY}/support"
|
|
||||||
),
|
|
||||||
"financials": Database(
|
|
||||||
name="Financial Information",
|
|
||||||
description="Financial data, revenue, costs, and liabilities",
|
|
||||||
collection_name="finance_db",
|
|
||||||
persist_directory=f"{PERSIST_DIRECTORY}/finance"
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
# Router prompt template
|
|
||||||
ROUTER_TEMPLATE = """You are a query routing expert. Your job is to analyze user questions and route them to the most appropriate database.
|
|
||||||
|
|
||||||
Available databases:
|
Available databases:
|
||||||
1. Product Information: Contains product details, specifications, and features
|
1. Product Information: Contains product details, specifications, and features
|
||||||
@@ -60,33 +45,85 @@ Available databases:
|
|||||||
|
|
||||||
User question: {question}
|
User question: {question}
|
||||||
|
|
||||||
Return only one of these exact strings:
|
Return a comma-separated list of relevant databases (no spaces after commas). Only use these exact strings:
|
||||||
- products
|
- products
|
||||||
- customer_support
|
- customer_support
|
||||||
- financials
|
- financials
|
||||||
|
|
||||||
|
For example: "products,customer_support" if the question relates to both product info and support.
|
||||||
Your response:"""
|
Your response:"""
|
||||||
|
|
||||||
def init_session_state():
|
@dataclass
|
||||||
"""Initialize session state variables"""
|
class CollectionConfig:
|
||||||
if 'databases' not in st.session_state:
|
name: str
|
||||||
st.session_state.databases = {}
|
description: str
|
||||||
if 'embeddings' not in st.session_state:
|
collection_name: str
|
||||||
st.session_state.embeddings = OpenAIEmbeddings()
|
persist_directory: str
|
||||||
if 'llm' not in st.session_state:
|
|
||||||
st.session_state.llm = ChatOpenAI(temperature=0)
|
# Collection configurations
|
||||||
if 'router_chain' not in st.session_state:
|
COLLECTIONS: Dict[DatabaseType, CollectionConfig] = {
|
||||||
router_prompt = PromptTemplate(
|
"products": CollectionConfig(
|
||||||
template=ROUTER_TEMPLATE,
|
name="Product Information",
|
||||||
input_variables=["question"]
|
description="Product details, specifications, and features",
|
||||||
)
|
collection_name="products_collection",
|
||||||
st.session_state.router_chain = LLMChain(
|
persist_directory=f"{PERSIST_DIRECTORY}/products"
|
||||||
llm=st.session_state.llm,
|
),
|
||||||
prompt=router_prompt
|
"customer_support": CollectionConfig(
|
||||||
)
|
name="Customer Support & FAQ",
|
||||||
|
description="Customer support information, frequently asked questions, and guides",
|
||||||
|
collection_name="support_collection",
|
||||||
|
persist_directory=f"{PERSIST_DIRECTORY}/support"
|
||||||
|
),
|
||||||
|
"financials": CollectionConfig(
|
||||||
|
name="Financial Information",
|
||||||
|
description="Financial data, revenue, costs, and liabilities",
|
||||||
|
collection_name="finance_collection",
|
||||||
|
persist_directory=f"{PERSIST_DIRECTORY}/finance"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
def initialize_models():
|
||||||
|
"""Initialize OpenAI models with API key"""
|
||||||
|
if st.session_state.openai_api_key:
|
||||||
|
try:
|
||||||
|
os.environ["OPENAI_API_KEY"] = st.session_state.openai_api_key
|
||||||
|
# Test the API key with a small embedding request
|
||||||
|
test_embeddings = OpenAIEmbeddings(model="text-embedding-3-large")
|
||||||
|
test_embeddings.embed_query("test")
|
||||||
|
|
||||||
|
# If successful, initialize the models
|
||||||
|
st.session_state.embeddings = test_embeddings
|
||||||
|
st.session_state.llm = ChatOpenAI(temperature=0)
|
||||||
|
st.session_state.databases = {
|
||||||
|
"products": Chroma(
|
||||||
|
collection_name=COLLECTIONS["products"].collection_name,
|
||||||
|
embedding_function=st.session_state.embeddings,
|
||||||
|
persist_directory=COLLECTIONS["products"].persist_directory
|
||||||
|
),
|
||||||
|
"customer_support": Chroma(
|
||||||
|
collection_name=COLLECTIONS["customer_support"].collection_name,
|
||||||
|
embedding_function=st.session_state.embeddings,
|
||||||
|
persist_directory=COLLECTIONS["customer_support"].persist_directory
|
||||||
|
),
|
||||||
|
"financials": Chroma(
|
||||||
|
collection_name=COLLECTIONS["financials"].collection_name,
|
||||||
|
embedding_function=st.session_state.embeddings,
|
||||||
|
persist_directory=COLLECTIONS["financials"].persist_directory
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"Error connecting to OpenAI API: {str(e)}")
|
||||||
|
st.error("Please check your internet connection and API key.")
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
|
||||||
def process_document(file) -> List[Document]:
|
def process_document(file) -> List[Document]:
|
||||||
"""Process uploaded PDF document"""
|
"""Process uploaded PDF document"""
|
||||||
|
if not st.session_state.embeddings:
|
||||||
|
st.error("OpenAI API connection not initialized. Please check your API key.")
|
||||||
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
|
||||||
tmp_file.write(file.getvalue())
|
tmp_file.write(file.getvalue())
|
||||||
@@ -109,124 +146,123 @@ def process_document(file) -> List[Document]:
|
|||||||
st.error(f"Error processing document: {e}")
|
st.error(f"Error processing document: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def get_or_create_db(db_type: DatabaseType) -> Chroma:
|
def route_query(question: str) -> List[DatabaseType]:
|
||||||
"""Get or create a database for the specified type with proper initialization and error handling"""
|
"""Route the question to appropriate databases"""
|
||||||
try:
|
router_prompt = ChatPromptTemplate.from_template(ROUTER_TEMPLATE)
|
||||||
if db_type not in st.session_state.databases:
|
router_chain = router_prompt | st.session_state.llm | StrOutputParser()
|
||||||
db_config = DATABASES[db_type]
|
response = router_chain.invoke({"question": question})
|
||||||
|
return response.strip().lower().split(",")
|
||||||
# Ensure directory exists
|
|
||||||
os.makedirs(db_config.persist_directory, exist_ok=True)
|
|
||||||
|
|
||||||
# Initialize Chroma with proper settings
|
|
||||||
st.session_state.databases[db_type] = Chroma(
|
|
||||||
persist_directory=db_config.persist_directory,
|
|
||||||
embedding_function=st.session_state.embeddings,
|
|
||||||
collection_name=db_config.collection_name,
|
|
||||||
collection_metadata={
|
|
||||||
"description": db_config.description,
|
|
||||||
"database_type": db_type
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Log successful initialization
|
|
||||||
st.success(f"Initialized {db_config.name} database")
|
|
||||||
|
|
||||||
return st.session_state.databases[db_type]
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
st.error(f"Error initializing {db_type} database: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def route_query(question: str) -> DatabaseType:
|
def query_multiple_databases(question: str) -> str:
|
||||||
"""Route the question to the appropriate database"""
|
"""Query multiple relevant databases and combine results"""
|
||||||
response = st.session_state.router_chain.invoke({"question": question})
|
database_types = route_query(question)
|
||||||
return response["text"].strip().lower()
|
all_docs = []
|
||||||
|
|
||||||
def query_database(db: Chroma, question: str) -> str:
|
|
||||||
"""Query the database and return the response"""
|
|
||||||
docs = db.similarity_search(question, k=3)
|
|
||||||
|
|
||||||
context = "\n\n".join([doc.page_content for doc in docs])
|
# Collect relevant documents from each database
|
||||||
|
for db_type in database_types:
|
||||||
|
db = st.session_state.databases[db_type]
|
||||||
|
docs = db.similarity_search(question, k=2) # Reduced k since we're querying multiple DBs
|
||||||
|
all_docs.extend(docs)
|
||||||
|
|
||||||
prompt = PromptTemplate(
|
# Sort all documents by relevance score if available
|
||||||
template="""Answer the question based on the following context. If you cannot answer the question based on the context, say "I don't have enough information to answer this question."
|
# Note: You might need to modify this based on your similarity search implementation
|
||||||
|
context = "\n\n---\n\n".join([doc.page_content for doc in all_docs])
|
||||||
|
|
||||||
|
answer_prompt = ChatPromptTemplate.from_template(
|
||||||
|
"""Answer the question based on the following context from multiple databases.
|
||||||
|
If you use information from multiple sources, please indicate which type of source it came from.
|
||||||
|
If you cannot answer the question based on the context, say "I don't have enough information to answer this question."
|
||||||
|
|
||||||
Context: {context}
|
Context: {context}
|
||||||
|
|
||||||
Question: {question}
|
Question: {question}
|
||||||
|
|
||||||
Answer:""",
|
Answer:"""
|
||||||
input_variables=["context", "question"]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
chain = LLMChain(llm=st.session_state.llm, prompt=prompt)
|
answer_chain = answer_prompt | st.session_state.llm | StrOutputParser()
|
||||||
response = chain.invoke({"context": context, "question": question})
|
return answer_chain.invoke({"context": context, "question": question})
|
||||||
return response["text"]
|
|
||||||
|
|
||||||
def clear_database(db_type: DatabaseType = None):
|
def clear_collection(collection_type: DatabaseType = None):
|
||||||
"""Clear specified database or all databases if none specified"""
|
"""Clear specified collection or all collections if none specified"""
|
||||||
try:
|
try:
|
||||||
if db_type:
|
if collection_type:
|
||||||
if db_type in st.session_state.databases:
|
if collection_type in st.session_state.databases:
|
||||||
db_config = DATABASES[db_type]
|
collection_config = COLLECTIONS[collection_type]
|
||||||
# Delete collection
|
# Delete collection
|
||||||
st.session_state.databases[db_type]._collection.delete()
|
st.session_state.databases[collection_type]._collection.delete()
|
||||||
# Remove from session state
|
# Remove from session state
|
||||||
del st.session_state.databases[db_type]
|
del st.session_state.databases[collection_type]
|
||||||
# Clean up persist directory
|
# Clean up persist directory
|
||||||
if os.path.exists(db_config.persist_directory):
|
if os.path.exists(collection_config.persist_directory):
|
||||||
import shutil
|
import shutil
|
||||||
shutil.rmtree(db_config.persist_directory)
|
shutil.rmtree(collection_config.persist_directory)
|
||||||
st.success(f"Cleared {db_config.name} database")
|
st.success(f"Cleared {collection_config.name} collection")
|
||||||
else:
|
else:
|
||||||
# Clear all databases
|
# Clear all collections
|
||||||
for db_type, db_config in DATABASES.items():
|
for collection_type, collection_config in COLLECTIONS.items():
|
||||||
if db_type in st.session_state.databases:
|
if collection_type in st.session_state.databases:
|
||||||
st.session_state.databases[db_type]._collection.delete()
|
st.session_state.databases[collection_type]._collection.delete()
|
||||||
if os.path.exists(db_config.persist_directory):
|
if os.path.exists(collection_config.persist_directory):
|
||||||
import shutil
|
import shutil
|
||||||
shutil.rmtree(db_config.persist_directory)
|
shutil.rmtree(collection_config.persist_directory)
|
||||||
st.session_state.databases = {}
|
st.session_state.databases = {}
|
||||||
st.success("Cleared all databases")
|
st.success("Cleared all collections")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
st.error(f"Error clearing database(s): {str(e)}")
|
st.error(f"Error clearing collection(s): {str(e)}")
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
st.title("📚 RAG Database Router ")
|
st.title("📚 RAG with Database Routing")
|
||||||
|
|
||||||
init_session_state()
|
|
||||||
|
|
||||||
# Sidebar for database management
|
|
||||||
with st.sidebar:
|
with st.sidebar:
|
||||||
|
st.header("Configuration")
|
||||||
|
api_key = st.text_input(
|
||||||
|
"Enter OpenAI API Key:",
|
||||||
|
type="password",
|
||||||
|
value=st.session_state.openai_api_key,
|
||||||
|
key="api_key_input"
|
||||||
|
)
|
||||||
|
|
||||||
|
if api_key:
|
||||||
|
st.session_state.openai_api_key = api_key
|
||||||
|
if initialize_models():
|
||||||
|
st.success("API Key set successfully!")
|
||||||
|
else:
|
||||||
|
st.error("Invalid API Key")
|
||||||
|
|
||||||
|
if not st.session_state.openai_api_key:
|
||||||
|
st.warning("Please enter your OpenAI API key to continue")
|
||||||
|
st.stop()
|
||||||
|
|
||||||
|
st.divider()
|
||||||
st.header("Database Management")
|
st.header("Database Management")
|
||||||
if st.button("Clear All Databases"):
|
if st.button("Clear All Databases"):
|
||||||
clear_database()
|
clear_collection()
|
||||||
|
|
||||||
st.divider()
|
st.divider()
|
||||||
st.subheader("Clear Individual Databases")
|
st.subheader("Clear Individual Databases")
|
||||||
for db_type, db_config in DATABASES.items():
|
for collection_type, collection_config in COLLECTIONS.items():
|
||||||
if st.button(f"Clear {db_config.name}"):
|
if st.button(f"Clear {collection_config.name}"):
|
||||||
clear_database(db_type)
|
clear_collection(collection_type)
|
||||||
|
|
||||||
# Document upload section
|
# Document upload section
|
||||||
st.header("Document Upload")
|
st.header("Document Upload")
|
||||||
tabs = st.tabs([db.name for db in DATABASES.values()])
|
tabs = st.tabs([collection_config.name for collection_config in COLLECTIONS.values()])
|
||||||
|
|
||||||
for (db_type, db_config), tab in zip(DATABASES.items(), tabs):
|
for (collection_type, collection_config), tab in zip(COLLECTIONS.items(), tabs):
|
||||||
with tab:
|
with tab:
|
||||||
st.write(db_config.description)
|
st.write(collection_config.description)
|
||||||
uploaded_file = st.file_uploader(
|
uploaded_file = st.file_uploader(
|
||||||
"Upload PDF document",
|
"Upload PDF document",
|
||||||
type="pdf",
|
type="pdf",
|
||||||
key=f"upload_{db_type}"
|
key=f"upload_{collection_type}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if uploaded_file:
|
if uploaded_file:
|
||||||
with st.spinner('Processing document...'):
|
with st.spinner('Processing document...'):
|
||||||
texts = process_document(uploaded_file)
|
texts = process_document(uploaded_file)
|
||||||
if texts:
|
if texts:
|
||||||
db = get_or_create_db(db_type)
|
db = st.session_state.databases[collection_type]
|
||||||
db.add_documents(texts)
|
db.add_documents(texts)
|
||||||
st.success("Document processed and added to the database!")
|
st.success("Document processed and added to the database!")
|
||||||
|
|
||||||
@@ -236,15 +272,14 @@ def main():
|
|||||||
|
|
||||||
if question:
|
if question:
|
||||||
with st.spinner('Finding answer...'):
|
with st.spinner('Finding answer...'):
|
||||||
# Route the question
|
# Get relevant databases
|
||||||
db_type = route_query(question)
|
database_types = route_query(question)
|
||||||
db = get_or_create_db(db_type)
|
|
||||||
|
|
||||||
# Display routing information
|
# Display routing information
|
||||||
st.info(f"Routing question to: {DATABASES[db_type].name}")
|
st.info(f"Searching in: {', '.join([COLLECTIONS[db_type].name for db_type in database_types])}")
|
||||||
|
|
||||||
# Get and display answer
|
# Get and display answer
|
||||||
answer = query_database(db, question)
|
answer = query_multiple_databases(question)
|
||||||
st.write("### Answer")
|
st.write("### Answer")
|
||||||
st.write(answer)
|
st.write(answer)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user