final changes

This commit is contained in:
Madhu
2024-12-25 11:30:40 +05:30
parent 7035e9e641
commit d0c0798711

View File

@@ -1,12 +1,13 @@
import os
from typing import List, Dict, Literal
from typing import List, Dict, Any, Literal
from dataclasses import dataclass
import streamlit as st
from langchain_core.documents import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_openai import OpenAIEmbeddings
from langchain_openai import ChatOpenAI
import tempfile
from phi.agent import Agent
from phi.model.openai import OpenAIChat
@@ -17,8 +18,10 @@ from langchain import hub
from langgraph.prebuilt import create_react_agent
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_core.language_models import BaseLanguageModel
from langchain.prompts import ChatPromptTemplate
def init_session_state():
"""Initialize session state variables"""
if 'openai_api_key' not in st.session_state:
st.session_state.openai_api_key = ""
if 'embeddings' not in st.session_state:
@@ -40,6 +43,7 @@ class CollectionConfig:
collection_name: str
persist_directory: str
# Collection configurations
COLLECTIONS: Dict[DatabaseType, CollectionConfig] = {
"products": CollectionConfig(
name="Product Information",
@@ -62,25 +66,39 @@ COLLECTIONS: Dict[DatabaseType, CollectionConfig] = {
}
def initialize_models():
"""Initialize OpenAI models with API key"""
if st.session_state.openai_api_key:
os.environ["OPENAI_API_KEY"] = st.session_state.openai_api_key
st.session_state.embeddings = OpenAIEmbeddings(model="text-embedding-3-large")
st.session_state.llm = ChatOpenAI(temperature=0)
for config in COLLECTIONS.values():
os.makedirs(config.persist_directory, exist_ok=True)
# Ensure directories exist
for collection_config in COLLECTIONS.values():
os.makedirs(collection_config.persist_directory, exist_ok=True)
# Initialize Chroma collections
st.session_state.databases = {
db_type: Chroma(
collection_name=config.collection_name,
"products": Chroma(
collection_name=COLLECTIONS["products"].collection_name,
embedding_function=st.session_state.embeddings,
persist_directory=config.persist_directory
) for db_type, config in COLLECTIONS.items()
persist_directory=COLLECTIONS["products"].persist_directory
),
"support": Chroma(
collection_name=COLLECTIONS["support"].collection_name,
embedding_function=st.session_state.embeddings,
persist_directory=COLLECTIONS["support"].persist_directory
),
"finance": Chroma(
collection_name=COLLECTIONS["finance"].collection_name,
embedding_function=st.session_state.embeddings,
persist_directory=COLLECTIONS["finance"].persist_directory
)
}
return True
return False
def process_document(file) -> List[Document]:
"""Process uploaded PDF document"""
try:
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
tmp_file.write(file.getvalue())
@@ -88,24 +106,37 @@ def process_document(file) -> List[Document]:
loader = PyPDFLoader(tmp_path)
documents = loader.load()
# Clean up temporary file
os.unlink(tmp_path)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=300)
return text_splitter.split_documents(documents)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200
)
texts = text_splitter.split_documents(documents)
return texts
except Exception as e:
st.error(f"Error processing document: {e}")
return []
def create_routing_agent() -> Agent:
"""Creates a routing agent using phidata framework"""
return Agent(
model=OpenAIChat(id="gpt-4o", api_key=st.session_state.openai_api_key),
model=OpenAIChat(
id="gpt-4o",
api_key=st.session_state.openai_api_key
),
tools=[],
description="You are a query routing expert. Your only job is to analyze questions and determine which database they should be routed to.",
description="""You are a query routing expert. Your only job is to analyze questions and determine which database they should be routed to.
You must respond with exactly one of these three options: 'products', 'support', or 'finance'. The user's question is: {question}""",
instructions=[
"1. For questions about products, return 'products'",
"2. For questions about support, return 'support'",
"3. For questions about finance, return 'finance'",
"4. Return ONLY the database name"
"Follow these rules strictly:",
"1. For questions about products, features, specifications, or item details, or product manuals → return 'products'",
"2. For questions about help, guidance, troubleshooting, or customer service, FAQ, or guides → return 'support'",
"3. For questions about costs, revenue, pricing, or financial data, or financial reports and investments → return 'finance'",
"4. Return ONLY the database name, no other text or explanation"
],
markdown=False,
show_tool_calls=False
@@ -115,42 +146,72 @@ def route_query(question: str) -> DatabaseType:
try:
routing_agent = create_routing_agent()
response = routing_agent.run(question)
db_type = response.content.strip().lower().translate(str.maketrans('', '', '`\'"'))
db_type = (response.content
.strip()
.lower()
.translate(str.maketrans('', '', '`\'"'))) # More elegant string cleaning
# Validate database type
if db_type not in COLLECTIONS:
st.warning(f"Invalid database type: {db_type}, defaulting to products")
return "products"
st.info(f"Routing question to {db_type} database")
return db_type
except Exception as e:
st.error(f"Routing error: {str(e)}")
return "products"
def create_fallback_agent(chat_model: BaseLanguageModel):
"""Create a LangGraph agent for web research."""
def web_research(query: str) -> str:
"""Web search with result formatting."""
try:
search = DuckDuckGoSearchRun(num_results=5)
return search.run(query)
results = search.run(query)
return results
except Exception as e:
return f"Search failed: {str(e)}. Providing answer based on general knowledge."
tools = [web_research]
return create_react_agent(model=chat_model, tools=tools, debug=False)
agent = create_react_agent(model=chat_model,
tools=tools,
debug=False)
return agent
def query_database(db: Chroma, question: str) -> tuple[str, list]:
try:
retriever = db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"k": 4, "score_threshold": 0.4})
retriever = db.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"k": 4, "score_threshold": 0.3}
)
relevant_docs = retriever.get_relevant_documents(question)
if relevant_docs:
retrieval_qa_prompt = hub.pull("langchain-ai/retrieval-qa-chat")
# Use simpler chain creation with hub prompt
retrieval_qa_prompt = ChatPromptTemplate.from_messages([
("system", """You are a helpful AI assistant that answers questions based on provided context.
Always be direct and concise in your responses.
If the context doesn't contain enough information to fully answer the question, acknowledge this limitation.
Base your answers strictly on the provided context and avoid making assumptions."""),
("human", "Here is the context:\n{context}"),
("human", "Question: {input}"),
("assistant", "I'll help answer your question based on the context provided."),
("human", "Please provide your answer:"),
])
combine_docs_chain = create_stuff_documents_chain(st.session_state.llm, retrieval_qa_prompt)
retrieval_chain = create_retrieval_chain(retriever, combine_docs_chain)
response = retrieval_chain.invoke({"input": question})
return response['answer'], relevant_docs
return _handle_web_fallback(question)
except Exception as e:
st.error(f"Error: {str(e)}")
return "I encountered an error. Please try rephrasing your question.", []
@@ -161,7 +222,9 @@ def _handle_web_fallback(question: str) -> tuple[str, list]:
with st.spinner('Researching...'):
agent_input = {
"messages": [HumanMessage(content=f"Research and provide a detailed answer for: '{question}'")],
"messages": [
HumanMessage(content=f"Research and provide a detailed answer for: '{question}'")
],
"is_last_step": False
}
@@ -170,17 +233,26 @@ def _handle_web_fallback(question: str) -> tuple[str, list]:
if isinstance(response, dict) and "messages" in response:
answer = response["messages"][-1].content
return f"Web Search Result:\n{answer}", []
except Exception:
# Fallback to general LLM response
fallback_response = st.session_state.llm.invoke(question).content
return f"Web search unavailable. General response: {fallback_response}", []
def main():
"""Main application function."""
st.set_page_config(page_title="RAG Agent with Database Routing", page_icon="📚")
st.title("📚 RAG Agent with Database Routing")
# Sidebar for API key and database management
with st.sidebar:
st.header("Configuration")
api_key = st.text_input("Enter OpenAI API Key:", type="password", value=st.session_state.openai_api_key, key="api_key_input")
api_key = st.text_input(
"Enter OpenAI API Key:",
type="password",
value=st.session_state.openai_api_key,
key="api_key_input"
)
if api_key:
st.session_state.openai_api_key = api_key
@@ -194,15 +266,20 @@ def main():
st.stop()
st.markdown("---")
st.header("Document Upload")
st.info("Upload documents to populate the databases. Each tab corresponds to a different database.")
tabs = st.tabs([config.name for config in COLLECTIONS.values()])
tabs = st.tabs([collection_config.name for collection_config in COLLECTIONS.values()])
for (collection_type, config), tab in zip(COLLECTIONS.items(), tabs):
for (collection_type, collection_config), tab in zip(COLLECTIONS.items(), tabs):
with tab:
st.write(config.description)
uploaded_files = st.file_uploader(f"Upload PDF documents to {config.name}", type="pdf", key=f"upload_{collection_type}", accept_multiple_files=True)
st.write(collection_config.description)
uploaded_files = st.file_uploader(
f"Upload PDF documents to {collection_config.name}",
type="pdf",
key=f"upload_{collection_type}",
accept_multiple_files=True
)
if uploaded_files:
with st.spinner('Processing documents...'):
@@ -216,15 +293,21 @@ def main():
db.add_documents(all_texts)
st.success("Documents processed and added to the database!")
# Query section
st.header("Ask Questions")
st.info("Enter your question below to find answers from the relevant database.")
question = st.text_input("Enter your question:")
if question:
with st.spinner('Finding answer...'):
# Route the question
collection_type = route_query(question)
db = st.session_state.databases[collection_type]
# Display routing information
st.info(f"Routing question to: {COLLECTIONS[collection_type].name}")
# Get and display answer
answer, relevant_docs = query_database(db, question)
st.write("### Answer")
st.write(answer)