mirror of
https://github.com/Shubhamsaboo/awesome-llm-apps.git
synced 2026-03-11 17:48:31 -05:00
RAG with database routing - first initialization
This commit is contained in:
14
rag_tutorials/rag_database_routing/README.md
Normal file
14
rag_tutorials/rag_database_routing/README.md
Normal file
@@ -0,0 +1,14 @@
|
||||
# RAG Database Router Demo
|
||||
|
||||
This demo showcases RAG (Retrieval Augmented Generation) with database routing capabilities. The application allows users to:
|
||||
|
||||
1. Upload documents to three different databases:
|
||||
- Product Information
|
||||
- Customer Support & FAQ
|
||||
- Financial Information
|
||||
|
||||
2. Query information using natural language, with automatic routing to the most relevant database.
|
||||
|
||||
## Setup
|
||||
|
||||
1. Create a virtual environment:
|
||||
252
rag_tutorials/rag_database_routing/rag_database_routing.py
Normal file
252
rag_tutorials/rag_database_routing/rag_database_routing.py
Normal file
@@ -0,0 +1,252 @@
|
||||
import os
|
||||
from typing import List, Dict, Any, Literal
|
||||
from dataclasses import dataclass
|
||||
import streamlit as st
|
||||
from dotenv import load_dotenv
|
||||
from langchain_core.documents import Document
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain_community.document_loaders import PyPDFLoader
|
||||
from langchain_community.vectorstores import Chroma
|
||||
from langchain_community.embeddings import OpenAIEmbeddings
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
import tempfile
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Constants
|
||||
DatabaseType = Literal["products", "customer_support", "financials"]
|
||||
PERSIST_DIRECTORY = "db_storage"
|
||||
|
||||
@dataclass
|
||||
class Database:
|
||||
"""Class to represent a database configuration"""
|
||||
name: str
|
||||
description: str
|
||||
collection_name: str
|
||||
persist_directory: str
|
||||
|
||||
# Database configurations
|
||||
DATABASES: Dict[DatabaseType, Database] = {
|
||||
"products": Database(
|
||||
name="Product Information",
|
||||
description="Product details, specifications, and features",
|
||||
collection_name="products_db",
|
||||
persist_directory=f"{PERSIST_DIRECTORY}/products"
|
||||
),
|
||||
"customer_support": Database(
|
||||
name="Customer Support & FAQ",
|
||||
description="Customer support information, frequently asked questions, and guides",
|
||||
collection_name="support_db",
|
||||
persist_directory=f"{PERSIST_DIRECTORY}/support"
|
||||
),
|
||||
"financials": Database(
|
||||
name="Financial Information",
|
||||
description="Financial data, revenue, costs, and liabilities",
|
||||
collection_name="finance_db",
|
||||
persist_directory=f"{PERSIST_DIRECTORY}/finance"
|
||||
)
|
||||
}
|
||||
|
||||
# Router prompt template
|
||||
ROUTER_TEMPLATE = """You are a query routing expert. Your job is to analyze user questions and route them to the most appropriate database.
|
||||
|
||||
Available databases:
|
||||
1. Product Information: Contains product details, specifications, and features
|
||||
2. Customer Support & FAQ: Contains customer support information, frequently asked questions, and guides
|
||||
3. Financial Information: Contains financial data, revenue, costs, and liabilities
|
||||
|
||||
User question: {question}
|
||||
|
||||
Return only one of these exact strings:
|
||||
- products
|
||||
- customer_support
|
||||
- financials
|
||||
|
||||
Your response:"""
|
||||
|
||||
def init_session_state():
|
||||
"""Initialize session state variables"""
|
||||
if 'databases' not in st.session_state:
|
||||
st.session_state.databases = {}
|
||||
if 'embeddings' not in st.session_state:
|
||||
st.session_state.embeddings = OpenAIEmbeddings()
|
||||
if 'llm' not in st.session_state:
|
||||
st.session_state.llm = ChatOpenAI(temperature=0)
|
||||
if 'router_chain' not in st.session_state:
|
||||
router_prompt = PromptTemplate(
|
||||
template=ROUTER_TEMPLATE,
|
||||
input_variables=["question"]
|
||||
)
|
||||
st.session_state.router_chain = LLMChain(
|
||||
llm=st.session_state.llm,
|
||||
prompt=router_prompt
|
||||
)
|
||||
|
||||
def process_document(file) -> List[Document]:
|
||||
"""Process uploaded PDF document"""
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
|
||||
tmp_file.write(file.getvalue())
|
||||
tmp_path = tmp_file.name
|
||||
|
||||
loader = PyPDFLoader(tmp_path)
|
||||
documents = loader.load()
|
||||
|
||||
# Clean up temporary file
|
||||
os.unlink(tmp_path)
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=1000,
|
||||
chunk_overlap=200
|
||||
)
|
||||
texts = text_splitter.split_documents(documents)
|
||||
|
||||
return texts
|
||||
except Exception as e:
|
||||
st.error(f"Error processing document: {e}")
|
||||
return []
|
||||
|
||||
def get_or_create_db(db_type: DatabaseType) -> Chroma:
|
||||
"""Get or create a database for the specified type with proper initialization and error handling"""
|
||||
try:
|
||||
if db_type not in st.session_state.databases:
|
||||
db_config = DATABASES[db_type]
|
||||
|
||||
# Ensure directory exists
|
||||
os.makedirs(db_config.persist_directory, exist_ok=True)
|
||||
|
||||
# Initialize Chroma with proper settings
|
||||
st.session_state.databases[db_type] = Chroma(
|
||||
persist_directory=db_config.persist_directory,
|
||||
embedding_function=st.session_state.embeddings,
|
||||
collection_name=db_config.collection_name,
|
||||
collection_metadata={
|
||||
"description": db_config.description,
|
||||
"database_type": db_type
|
||||
}
|
||||
)
|
||||
|
||||
# Log successful initialization
|
||||
st.success(f"Initialized {db_config.name} database")
|
||||
|
||||
return st.session_state.databases[db_type]
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Error initializing {db_type} database: {str(e)}")
|
||||
raise
|
||||
|
||||
def route_query(question: str) -> DatabaseType:
|
||||
"""Route the question to the appropriate database"""
|
||||
response = st.session_state.router_chain.invoke({"question": question})
|
||||
return response["text"].strip().lower()
|
||||
|
||||
def query_database(db: Chroma, question: str) -> str:
|
||||
"""Query the database and return the response"""
|
||||
docs = db.similarity_search(question, k=3)
|
||||
|
||||
context = "\n\n".join([doc.page_content for doc in docs])
|
||||
|
||||
prompt = PromptTemplate(
|
||||
template="""Answer the question based on the following context. If you cannot answer the question based on the context, say "I don't have enough information to answer this question."
|
||||
|
||||
Context: {context}
|
||||
|
||||
Question: {question}
|
||||
|
||||
Answer:""",
|
||||
input_variables=["context", "question"]
|
||||
)
|
||||
|
||||
chain = LLMChain(llm=st.session_state.llm, prompt=prompt)
|
||||
response = chain.invoke({"context": context, "question": question})
|
||||
return response["text"]
|
||||
|
||||
def clear_database(db_type: DatabaseType = None):
|
||||
"""Clear specified database or all databases if none specified"""
|
||||
try:
|
||||
if db_type:
|
||||
if db_type in st.session_state.databases:
|
||||
db_config = DATABASES[db_type]
|
||||
# Delete collection
|
||||
st.session_state.databases[db_type]._collection.delete()
|
||||
# Remove from session state
|
||||
del st.session_state.databases[db_type]
|
||||
# Clean up persist directory
|
||||
if os.path.exists(db_config.persist_directory):
|
||||
import shutil
|
||||
shutil.rmtree(db_config.persist_directory)
|
||||
st.success(f"Cleared {db_config.name} database")
|
||||
else:
|
||||
# Clear all databases
|
||||
for db_type, db_config in DATABASES.items():
|
||||
if db_type in st.session_state.databases:
|
||||
st.session_state.databases[db_type]._collection.delete()
|
||||
if os.path.exists(db_config.persist_directory):
|
||||
import shutil
|
||||
shutil.rmtree(db_config.persist_directory)
|
||||
st.session_state.databases = {}
|
||||
st.success("Cleared all databases")
|
||||
except Exception as e:
|
||||
st.error(f"Error clearing database(s): {str(e)}")
|
||||
|
||||
def main():
|
||||
st.title("📚 RAG Database Router ")
|
||||
|
||||
init_session_state()
|
||||
|
||||
# Sidebar for database management
|
||||
with st.sidebar:
|
||||
st.header("Database Management")
|
||||
if st.button("Clear All Databases"):
|
||||
clear_database()
|
||||
|
||||
st.divider()
|
||||
st.subheader("Clear Individual Databases")
|
||||
for db_type, db_config in DATABASES.items():
|
||||
if st.button(f"Clear {db_config.name}"):
|
||||
clear_database(db_type)
|
||||
|
||||
# Document upload section
|
||||
st.header("Document Upload")
|
||||
tabs = st.tabs([db.name for db in DATABASES.values()])
|
||||
|
||||
for (db_type, db_config), tab in zip(DATABASES.items(), tabs):
|
||||
with tab:
|
||||
st.write(db_config.description)
|
||||
uploaded_file = st.file_uploader(
|
||||
"Upload PDF document",
|
||||
type="pdf",
|
||||
key=f"upload_{db_type}"
|
||||
)
|
||||
|
||||
if uploaded_file:
|
||||
with st.spinner('Processing document...'):
|
||||
texts = process_document(uploaded_file)
|
||||
if texts:
|
||||
db = get_or_create_db(db_type)
|
||||
db.add_documents(texts)
|
||||
st.success("Document processed and added to the database!")
|
||||
|
||||
# Query section
|
||||
st.header("Ask Questions")
|
||||
question = st.text_input("Enter your question:")
|
||||
|
||||
if question:
|
||||
with st.spinner('Finding answer...'):
|
||||
# Route the question
|
||||
db_type = route_query(question)
|
||||
db = get_or_create_db(db_type)
|
||||
|
||||
# Display routing information
|
||||
st.info(f"Routing question to: {DATABASES[db_type].name}")
|
||||
|
||||
# Get and display answer
|
||||
answer = query_database(db, question)
|
||||
st.write("### Answer")
|
||||
st.write(answer)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
9
rag_tutorials/rag_database_routing/requirements.txt
Normal file
9
rag_tutorials/rag_database_routing/requirements.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
langchain>=0.1.0
|
||||
langchain-community>=0.0.10
|
||||
langchain-core>=0.1.10
|
||||
chromadb>=0.4.22
|
||||
streamlit>=1.29.0
|
||||
python-dotenv>=1.0.0
|
||||
pypdf>=4.0.0
|
||||
sentence-transformers>=2.2.2
|
||||
openai>=1.6.1
|
||||
Reference in New Issue
Block a user