mirror of
https://github.com/Shubhamsaboo/awesome-llm-apps.git
synced 2026-03-09 07:25:00 -05:00
added everything - testing time
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# RAG Database Router Demo
|
||||
# RAG Agent with Database Routing
|
||||
|
||||
This demo showcases RAG (Retrieval Augmented Generation) with database routing capabilities. The application allows users to:
|
||||
This project showcases the RAG with database routing capabilities - which is a very efficient way to retrieve information from a large set of documents. The application allows users to:
|
||||
|
||||
1. Upload documents to three different databases:
|
||||
- Product Information
|
||||
@@ -9,6 +9,48 @@ This demo showcases RAG (Retrieval Augmented Generation) with database routing c
|
||||
|
||||
2. Query information using natural language, with automatic routing to the most relevant database.
|
||||
|
||||
## Setup
|
||||
## Features
|
||||
|
||||
1. Create a virtual environment:
|
||||
- **Document Upload**: Users can upload multiple PDF documents related to a particular company. These documents are processed and stored in one of the three databases: Product Information, Customer Support & FAQ, or Financial Information.
|
||||
|
||||
- **Natural Language Querying**: Users can ask questions in natural language. The system automatically routes the query to the most relevant database using a phidata agent as the router.
|
||||
|
||||
- **RAG Orchestration**: Utilizes Langchain for orchestrating the retrieval augmented generation process, ensuring that the most relevant information is retrieved and presented to the user.
|
||||
|
||||
- **Fallback Mechanism**: If no relevant documents are found in the databases, a LangGraph agent with a DuckDuckGo search tool is used to perform web research and provide an answer.
|
||||
|
||||
- **User Interface**: Built with Streamlit, providing an intuitive and interactive user experience.
|
||||
|
||||
## How to Run?
|
||||
|
||||
1. **Clone the Repository**:
|
||||
```bash
|
||||
git clone https://github.com/Shubhamsaboo/awesome-llm-apps.git
|
||||
cd rag_tutorials/rag_database_routing
|
||||
```
|
||||
|
||||
2. **Install Dependencies**:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
3. **Run the Application**:
|
||||
```bash
|
||||
streamlit run rag_database_routing.py
|
||||
```
|
||||
|
||||
4. **Configure API Key**: Obtain an OpenAI API key and set it in the application. This is required for initializing the language models used in the application.
|
||||
|
||||
5. **Upload Documents**: Use the document upload section to add PDF documents to the desired database.
|
||||
|
||||
6. **Ask Questions**: Enter your questions in the query section. The application will route your question to the appropriate database and provide an answer.
|
||||
|
||||
## Technologies Used
|
||||
|
||||
- **Langchain**: For RAG orchestration, ensuring efficient retrieval and generation of information.
|
||||
- **Phidata Agent**: Used as the router agent to determine the most relevant database for a given query.
|
||||
- **LangGraph Agent**: Acts as a fallback mechanism, utilizing DuckDuckGo for web research when necessary.
|
||||
- **Streamlit**: Provides a user-friendly interface for document upload and querying.
|
||||
- **ChromaDB**: Used for managing the databases, storing and retrieving document embeddings efficiently.
|
||||
|
||||
This application is designed to streamline the process of retrieving information from large sets of documents, making it easier for users to find the answers they need quickly and efficiently.
|
||||
|
||||
@@ -1,25 +1,24 @@
|
||||
import os
|
||||
import getpass
|
||||
from typing import List, Dict, Any, Literal
|
||||
from typing import List, Dict, Literal
|
||||
from dataclasses import dataclass
|
||||
import streamlit as st
|
||||
from dotenv import load_dotenv
|
||||
from langchain_core.documents import Document
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain_community.document_loaders import PyPDFLoader
|
||||
from langchain_community.vectorstores import Chroma
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain.chains import LLMChain
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
||||
import tempfile
|
||||
from langchain_core.runnables import RunnableSequence
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_chroma import Chroma
|
||||
from phi.agent import Agent
|
||||
from phi.model.openai import OpenAIChat
|
||||
from langchain.schema import HumanMessage
|
||||
from langchain.chains.combine_documents import create_stuff_documents_chain
|
||||
from langchain.chains import create_retrieval_chain
|
||||
from langchain import hub
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from langchain_community.tools import DuckDuckGoSearchRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
|
||||
def init_session_state():
|
||||
"""Initialize session state variables"""
|
||||
if 'openai_api_key' not in st.session_state:
|
||||
st.session_state.openai_api_key = ""
|
||||
if 'embeddings' not in st.session_state:
|
||||
@@ -29,30 +28,11 @@ def init_session_state():
|
||||
if 'databases' not in st.session_state:
|
||||
st.session_state.databases = {}
|
||||
|
||||
# Initialize session state at the top
|
||||
init_session_state()
|
||||
|
||||
# Constants
|
||||
DatabaseType = Literal["products", "customer_support", "financials"]
|
||||
DatabaseType = Literal["products", "support", "finance"]
|
||||
PERSIST_DIRECTORY = "db_storage"
|
||||
|
||||
ROUTER_TEMPLATE = """You are a query routing expert. Your job is to analyze user questions and determine which databases might contain relevant information.
|
||||
|
||||
Available databases:
|
||||
1. Product Information: Contains product details, specifications, and features
|
||||
2. Customer Support & FAQ: Contains customer support information, frequently asked questions, and guides
|
||||
3. Financial Information: Contains financial data, revenue, costs, and liabilities
|
||||
|
||||
User question: {question}
|
||||
|
||||
Return a comma-separated list of relevant databases (no spaces after commas). Only use these exact strings:
|
||||
- products
|
||||
- customer_support
|
||||
- financials
|
||||
|
||||
For example: "products,customer_support" if the question relates to both product info and support.
|
||||
Your response:"""
|
||||
|
||||
@dataclass
|
||||
class CollectionConfig:
|
||||
name: str
|
||||
@@ -60,7 +40,6 @@ class CollectionConfig:
|
||||
collection_name: str
|
||||
persist_directory: str
|
||||
|
||||
# Collection configurations
|
||||
COLLECTIONS: Dict[DatabaseType, CollectionConfig] = {
|
||||
"products": CollectionConfig(
|
||||
name="Product Information",
|
||||
@@ -68,13 +47,13 @@ COLLECTIONS: Dict[DatabaseType, CollectionConfig] = {
|
||||
collection_name="products_collection",
|
||||
persist_directory=f"{PERSIST_DIRECTORY}/products"
|
||||
),
|
||||
"customer_support": CollectionConfig(
|
||||
"support": CollectionConfig(
|
||||
name="Customer Support & FAQ",
|
||||
description="Customer support information, frequently asked questions, and guides",
|
||||
collection_name="support_collection",
|
||||
persist_directory=f"{PERSIST_DIRECTORY}/support"
|
||||
),
|
||||
"financials": CollectionConfig(
|
||||
"finance": CollectionConfig(
|
||||
name="Financial Information",
|
||||
description="Financial data, revenue, costs, and liabilities",
|
||||
collection_name="finance_collection",
|
||||
@@ -83,47 +62,25 @@ COLLECTIONS: Dict[DatabaseType, CollectionConfig] = {
|
||||
}
|
||||
|
||||
def initialize_models():
|
||||
"""Initialize OpenAI models with API key"""
|
||||
if st.session_state.openai_api_key:
|
||||
try:
|
||||
os.environ["OPENAI_API_KEY"] = st.session_state.openai_api_key
|
||||
# Test the API key with a small embedding request
|
||||
test_embeddings = OpenAIEmbeddings(model="text-embedding-3-large")
|
||||
test_embeddings.embed_query("test")
|
||||
|
||||
# If successful, initialize the models
|
||||
st.session_state.embeddings = test_embeddings
|
||||
st.session_state.llm = ChatOpenAI(temperature=0)
|
||||
st.session_state.databases = {
|
||||
"products": Chroma(
|
||||
collection_name=COLLECTIONS["products"].collection_name,
|
||||
embedding_function=st.session_state.embeddings,
|
||||
persist_directory=COLLECTIONS["products"].persist_directory
|
||||
),
|
||||
"customer_support": Chroma(
|
||||
collection_name=COLLECTIONS["customer_support"].collection_name,
|
||||
embedding_function=st.session_state.embeddings,
|
||||
persist_directory=COLLECTIONS["customer_support"].persist_directory
|
||||
),
|
||||
"financials": Chroma(
|
||||
collection_name=COLLECTIONS["financials"].collection_name,
|
||||
embedding_function=st.session_state.embeddings,
|
||||
persist_directory=COLLECTIONS["financials"].persist_directory
|
||||
)
|
||||
}
|
||||
return True
|
||||
except Exception as e:
|
||||
st.error(f"Error connecting to OpenAI API: {str(e)}")
|
||||
st.error("Please check your internet connection and API key.")
|
||||
return False
|
||||
os.environ["OPENAI_API_KEY"] = st.session_state.openai_api_key
|
||||
st.session_state.embeddings = OpenAIEmbeddings(model="text-embedding-3-large")
|
||||
st.session_state.llm = ChatOpenAI(temperature=0)
|
||||
|
||||
for config in COLLECTIONS.values():
|
||||
os.makedirs(config.persist_directory, exist_ok=True)
|
||||
|
||||
st.session_state.databases = {
|
||||
db_type: Chroma(
|
||||
collection_name=config.collection_name,
|
||||
embedding_function=st.session_state.embeddings,
|
||||
persist_directory=config.persist_directory
|
||||
) for db_type, config in COLLECTIONS.items()
|
||||
}
|
||||
return True
|
||||
return False
|
||||
|
||||
def process_document(file) -> List[Document]:
|
||||
"""Process uploaded PDF document"""
|
||||
if not st.session_state.embeddings:
|
||||
st.error("OpenAI API connection not initialized. Please check your API key.")
|
||||
return []
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
|
||||
tmp_file.write(file.getvalue())
|
||||
@@ -131,97 +88,99 @@ def process_document(file) -> List[Document]:
|
||||
|
||||
loader = PyPDFLoader(tmp_path)
|
||||
documents = loader.load()
|
||||
|
||||
# Clean up temporary file
|
||||
os.unlink(tmp_path)
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=1000,
|
||||
chunk_overlap=200
|
||||
)
|
||||
texts = text_splitter.split_documents(documents)
|
||||
|
||||
return texts
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=300)
|
||||
return text_splitter.split_documents(documents)
|
||||
except Exception as e:
|
||||
st.error(f"Error processing document: {e}")
|
||||
return []
|
||||
|
||||
def route_query(question: str) -> List[DatabaseType]:
|
||||
"""Route the question to appropriate databases"""
|
||||
router_prompt = ChatPromptTemplate.from_template(ROUTER_TEMPLATE)
|
||||
router_chain = router_prompt | st.session_state.llm | StrOutputParser()
|
||||
response = router_chain.invoke({"question": question})
|
||||
return response.strip().lower().split(",")
|
||||
|
||||
def query_multiple_databases(question: str) -> str:
|
||||
"""Query multiple relevant databases and combine results"""
|
||||
database_types = route_query(question)
|
||||
all_docs = []
|
||||
|
||||
# Collect relevant documents from each database
|
||||
for db_type in database_types:
|
||||
db = st.session_state.databases[db_type]
|
||||
docs = db.similarity_search(question, k=2) # Reduced k since we're querying multiple DBs
|
||||
all_docs.extend(docs)
|
||||
|
||||
# Sort all documents by relevance score if available
|
||||
# Note: You might need to modify this based on your similarity search implementation
|
||||
context = "\n\n---\n\n".join([doc.page_content for doc in all_docs])
|
||||
|
||||
answer_prompt = ChatPromptTemplate.from_template(
|
||||
"""Answer the question based on the following context from multiple databases.
|
||||
If you use information from multiple sources, please indicate which type of source it came from.
|
||||
If you cannot answer the question based on the context, say "I don't have enough information to answer this question."
|
||||
|
||||
Context: {context}
|
||||
|
||||
Question: {question}
|
||||
|
||||
Answer:"""
|
||||
def create_routing_agent() -> Agent:
|
||||
return Agent(
|
||||
model=OpenAIChat(id="gpt-4o", api_key=st.session_state.openai_api_key),
|
||||
tools=[],
|
||||
description="You are a query routing expert. Your only job is to analyze questions and determine which database they should be routed to.",
|
||||
instructions=[
|
||||
"1. For questions about products, return 'products'",
|
||||
"2. For questions about support, return 'support'",
|
||||
"3. For questions about finance, return 'finance'",
|
||||
"4. Return ONLY the database name"
|
||||
],
|
||||
markdown=False,
|
||||
show_tool_calls=False
|
||||
)
|
||||
|
||||
answer_chain = answer_prompt | st.session_state.llm | StrOutputParser()
|
||||
return answer_chain.invoke({"context": context, "question": question})
|
||||
|
||||
def clear_collection(collection_type: DatabaseType = None):
|
||||
"""Clear specified collection or all collections if none specified"""
|
||||
def route_query(question: str) -> DatabaseType:
|
||||
try:
|
||||
if collection_type:
|
||||
if collection_type in st.session_state.databases:
|
||||
collection_config = COLLECTIONS[collection_type]
|
||||
# Delete collection
|
||||
st.session_state.databases[collection_type]._collection.delete()
|
||||
# Remove from session state
|
||||
del st.session_state.databases[collection_type]
|
||||
# Clean up persist directory
|
||||
if os.path.exists(collection_config.persist_directory):
|
||||
import shutil
|
||||
shutil.rmtree(collection_config.persist_directory)
|
||||
st.success(f"Cleared {collection_config.name} collection")
|
||||
else:
|
||||
# Clear all collections
|
||||
for collection_type, collection_config in COLLECTIONS.items():
|
||||
if collection_type in st.session_state.databases:
|
||||
st.session_state.databases[collection_type]._collection.delete()
|
||||
if os.path.exists(collection_config.persist_directory):
|
||||
import shutil
|
||||
shutil.rmtree(collection_config.persist_directory)
|
||||
st.session_state.databases = {}
|
||||
st.success("Cleared all collections")
|
||||
routing_agent = create_routing_agent()
|
||||
response = routing_agent.run(question)
|
||||
db_type = response.content.strip().lower().translate(str.maketrans('', '', '`\'"'))
|
||||
|
||||
if db_type not in COLLECTIONS:
|
||||
st.warning(f"Invalid database type: {db_type}, defaulting to products")
|
||||
return "products"
|
||||
|
||||
st.info(f"Routing question to {db_type} database")
|
||||
return db_type
|
||||
except Exception as e:
|
||||
st.error(f"Error clearing collection(s): {str(e)}")
|
||||
st.error(f"Routing error: {str(e)}")
|
||||
return "products"
|
||||
|
||||
def create_fallback_agent(chat_model: BaseLanguageModel):
|
||||
def web_research(query: str) -> str:
|
||||
try:
|
||||
search = DuckDuckGoSearchRun(num_results=5)
|
||||
return search.run(query)
|
||||
except Exception as e:
|
||||
return f"Search failed: {str(e)}. Providing answer based on general knowledge."
|
||||
|
||||
tools = [web_research]
|
||||
return create_react_agent(model=chat_model, tools=tools, debug=False)
|
||||
|
||||
def query_database(db: Chroma, question: str) -> tuple[str, list]:
|
||||
try:
|
||||
retriever = db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"k": 4, "score_threshold": 0.4})
|
||||
relevant_docs = retriever.get_relevant_documents(question)
|
||||
|
||||
if relevant_docs:
|
||||
retrieval_qa_prompt = hub.pull("langchain-ai/retrieval-qa-chat")
|
||||
combine_docs_chain = create_stuff_documents_chain(st.session_state.llm, retrieval_qa_prompt)
|
||||
retrieval_chain = create_retrieval_chain(retriever, combine_docs_chain)
|
||||
response = retrieval_chain.invoke({"input": question})
|
||||
return response['answer'], relevant_docs
|
||||
|
||||
return _handle_web_fallback(question)
|
||||
except Exception as e:
|
||||
st.error(f"Error: {str(e)}")
|
||||
return "I encountered an error. Please try rephrasing your question.", []
|
||||
|
||||
def _handle_web_fallback(question: str) -> tuple[str, list]:
|
||||
st.info("No relevant documents found. Searching web...")
|
||||
fallback_agent = create_fallback_agent(st.session_state.llm)
|
||||
|
||||
with st.spinner('Researching...'):
|
||||
agent_input = {
|
||||
"messages": [HumanMessage(content=f"Research and provide a detailed answer for: '{question}'")],
|
||||
"is_last_step": False
|
||||
}
|
||||
|
||||
try:
|
||||
response = fallback_agent.invoke(agent_input, config={"recursion_limit": 100})
|
||||
if isinstance(response, dict) and "messages" in response:
|
||||
answer = response["messages"][-1].content
|
||||
return f"Web Search Result:\n{answer}", []
|
||||
except Exception:
|
||||
fallback_response = st.session_state.llm.invoke(question).content
|
||||
return f"Web search unavailable. General response: {fallback_response}", []
|
||||
|
||||
def main():
|
||||
st.title("📚 RAG with Database Routing")
|
||||
|
||||
st.set_page_config(page_title="RAG Agent with Database Routing", page_icon="📚")
|
||||
st.title("📚 RAG Agent with Database Routing")
|
||||
|
||||
with st.sidebar:
|
||||
st.header("Configuration")
|
||||
api_key = st.text_input(
|
||||
"Enter OpenAI API Key:",
|
||||
type="password",
|
||||
value=st.session_state.openai_api_key,
|
||||
key="api_key_input"
|
||||
)
|
||||
api_key = st.text_input("Enter OpenAI API Key:", type="password", value=st.session_state.openai_api_key, key="api_key_input")
|
||||
|
||||
if api_key:
|
||||
st.session_state.openai_api_key = api_key
|
||||
@@ -233,53 +192,40 @@ def main():
|
||||
if not st.session_state.openai_api_key:
|
||||
st.warning("Please enter your OpenAI API key to continue")
|
||||
st.stop()
|
||||
|
||||
st.divider()
|
||||
st.header("Database Management")
|
||||
if st.button("Clear All Databases"):
|
||||
clear_collection()
|
||||
|
||||
st.divider()
|
||||
st.subheader("Clear Individual Databases")
|
||||
for collection_type, collection_config in COLLECTIONS.items():
|
||||
if st.button(f"Clear {collection_config.name}"):
|
||||
clear_collection(collection_type)
|
||||
|
||||
st.markdown("---")
|
||||
|
||||
# Document upload section
|
||||
st.header("Document Upload")
|
||||
tabs = st.tabs([collection_config.name for collection_config in COLLECTIONS.values()])
|
||||
st.info("Upload documents to populate the databases. Each tab corresponds to a different database.")
|
||||
tabs = st.tabs([config.name for config in COLLECTIONS.values()])
|
||||
|
||||
for (collection_type, collection_config), tab in zip(COLLECTIONS.items(), tabs):
|
||||
for (collection_type, config), tab in zip(COLLECTIONS.items(), tabs):
|
||||
with tab:
|
||||
st.write(collection_config.description)
|
||||
uploaded_file = st.file_uploader(
|
||||
"Upload PDF document",
|
||||
type="pdf",
|
||||
key=f"upload_{collection_type}"
|
||||
)
|
||||
st.write(config.description)
|
||||
uploaded_files = st.file_uploader(f"Upload PDF documents to {config.name}", type="pdf", key=f"upload_{collection_type}", accept_multiple_files=True)
|
||||
|
||||
if uploaded_file:
|
||||
with st.spinner('Processing document...'):
|
||||
texts = process_document(uploaded_file)
|
||||
if texts:
|
||||
if uploaded_files:
|
||||
with st.spinner('Processing documents...'):
|
||||
all_texts = []
|
||||
for uploaded_file in uploaded_files:
|
||||
texts = process_document(uploaded_file)
|
||||
all_texts.extend(texts)
|
||||
|
||||
if all_texts:
|
||||
db = st.session_state.databases[collection_type]
|
||||
db.add_documents(texts)
|
||||
st.success("Document processed and added to the database!")
|
||||
db.add_documents(all_texts)
|
||||
st.success("Documents processed and added to the database!")
|
||||
|
||||
# Query section
|
||||
st.header("Ask Questions")
|
||||
st.info("Enter your question below to find answers from the relevant database.")
|
||||
question = st.text_input("Enter your question:")
|
||||
|
||||
if question:
|
||||
with st.spinner('Finding answer...'):
|
||||
# Get relevant databases
|
||||
database_types = route_query(question)
|
||||
|
||||
# Display routing information
|
||||
st.info(f"Searching in: {', '.join([COLLECTIONS[db_type].name for db_type in database_types])}")
|
||||
|
||||
# Get and display answer
|
||||
answer = query_multiple_databases(question)
|
||||
collection_type = route_query(question)
|
||||
db = st.session_state.databases[collection_type]
|
||||
st.info(f"Routing question to: {COLLECTIONS[collection_type].name}")
|
||||
answer, relevant_docs = query_database(db, question)
|
||||
st.write("### Answer")
|
||||
st.write(answer)
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
langchain>=0.1.0
|
||||
langchain-community>=0.0.10
|
||||
langchain-core>=0.1.10
|
||||
chromadb>=0.4.22
|
||||
langchain==0.3.12
|
||||
langchain-community==0.3.12
|
||||
langchain-core==0.3.28
|
||||
chromadb==0.5.20
|
||||
streamlit>=1.29.0
|
||||
python-dotenv>=1.0.0
|
||||
pypdf>=4.0.0
|
||||
sentence-transformers>=2.2.2
|
||||
openai>=1.6.1
|
||||
phidata==2.7.3
|
||||
langchain-openai==0.2.14
|
||||
langgraph==0.2.53
|
||||
duckduckgo-search==6.4.1
|
||||
Reference in New Issue
Block a user