Files
awesome-llm-apps/rag_tutorials/rag_agent_cohere/rag_agent_cohere.py
2024-12-15 10:16:56 +05:30

320 lines
12 KiB
Python

import os
import streamlit as st
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_cohere import CohereEmbeddings, ChatCohere
from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
from langchain import hub
import tempfile
from langgraph.prebuilt import create_react_agent
from langchain_community.tools import DuckDuckGoSearchRun
from typing import TypedDict, List
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from time import sleep
from tenacity import retry, wait_exponential, stop_after_attempt
def init_session_state():
if 'api_keys_submitted' not in st.session_state:
st.session_state.api_keys_submitted = False
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
if 'vectorstore' not in st.session_state:
st.session_state.vectorstore = None
if 'qdrant_api_key' not in st.session_state:
st.session_state.qdrant_api_key = ""
if 'qdrant_url' not in st.session_state:
st.session_state.qdrant_url = ""
def sidebar_api_form():
with st.sidebar:
st.header("API Credentials")
if st.session_state.api_keys_submitted:
st.success("API credentials verified")
if st.button("Reset Credentials"):
st.session_state.clear()
st.rerun()
return True
with st.form("api_credentials"):
cohere_key = st.text_input("Cohere API Key", type="password")
qdrant_key = st.text_input("Qdrant API Key", type="password", help="Enter your Qdrant API key")
qdrant_url = st.text_input("Qdrant URL",
placeholder="https://xyz-example.eu-central.aws.cloud.qdrant.io:6333",
help="Enter your Qdrant instance URL")
if st.form_submit_button("Submit Credentials"):
try:
client = QdrantClient(url=qdrant_url, api_key=qdrant_key, timeout=60)
client.get_collections()
st.session_state.cohere_api_key = cohere_key
st.session_state.qdrant_api_key = qdrant_key
st.session_state.qdrant_url = qdrant_url
st.session_state.api_keys_submitted = True
st.success("Credentials verified!")
st.rerun()
except Exception as e:
st.error(f"Qdrant connection failed: {str(e)}")
return False
def init_qdrant() -> QdrantClient:
if not st.session_state.get("qdrant_api_key"):
raise ValueError("Qdrant API key not provided")
if not st.session_state.get("qdrant_url"):
raise ValueError("Qdrant URL not provided")
return QdrantClient(url=st.session_state.qdrant_url,
api_key=st.session_state.qdrant_api_key,
timeout=60)
init_session_state()
if not sidebar_api_form():
st.info("Please enter your API credentials in the sidebar to continue.")
st.stop()
embedding = CohereEmbeddings(model="embed-english-v3.0",
cohere_api_key=st.session_state.cohere_api_key)
chat_model = ChatCohere(model="command-r7b-12-2024",
temperature=0.1,
max_tokens=512,
verbose=True,
cohere_api_key=st.session_state.cohere_api_key)
client = init_qdrant()
def process_document(file):
try:
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
tmp_file.write(file.getvalue())
tmp_path = tmp_file.name
loader = PyPDFLoader(tmp_path)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
texts = text_splitter.split_documents(documents)
os.unlink(tmp_path)
return texts
except Exception as e:
st.error(f"Error processing document: {e}")
return []
COLLECTION_NAME = "cohere_rag"
def create_vector_stores(texts):
"""Create and populate vector store with documents."""
try:
try:
client.create_collection(collection_name=COLLECTION_NAME,
vectors_config=VectorParams(size=1024,
distance=Distance.COSINE))
st.success(f"Created new collection: {COLLECTION_NAME}")
except Exception as e:
if "already exists" not in str(e).lower():
raise e
vector_store = QdrantVectorStore(client=client,
collection_name=COLLECTION_NAME,
embedding=embedding)
with st.spinner('Storing documents in Qdrant...'):
vector_store.add_documents(texts)
st.success("Documents successfully stored in Qdrant!")
return vector_store
except Exception as e:
st.error(f"Error in vector store creation: {str(e)}")
return None
# Define the state schema using TypedDict
class AgentState(TypedDict):
"""State schema for the agent."""
messages: List[HumanMessage | AIMessage | SystemMessage]
is_last_step: bool
class RateLimitedDuckDuckGo(DuckDuckGoSearchRun):
@retry(wait=wait_exponential(multiplier=1, min=4, max=10),
stop=stop_after_attempt(3))
def run(self, query: str) -> str:
"""Run search with rate limiting."""
try:
sleep(2) # Add delay between requests
return super().run(query)
except Exception as e:
if "Ratelimit" in str(e):
sleep(5) # Longer delay on rate limit
return super().run(query)
raise e
def create_fallback_agent(chat_model: BaseLanguageModel):
"""Create a LangGraph agent for web research."""
def web_research(query: str) -> str:
"""Web search with result formatting."""
try:
search = DuckDuckGoSearchRun(num_results=5)
results = search.run(query)
return results
except Exception as e:
return f"Search failed: {str(e)}. Providing answer based on general knowledge."
tools = [web_research]
agent = create_react_agent(model=chat_model,
tools=tools,
debug=False)
return agent
def process_query(vectorstore, query) -> tuple[str, list]:
"""Process a query using RAG with fallback to web search."""
try:
retriever = vectorstore.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={
"k": 10,
"score_threshold": 0.7
}
)
relevant_docs = retriever.get_relevant_documents(query)
if relevant_docs:
retrieval_qa_prompt = hub.pull("langchain-ai/retrieval-qa-chat")
combine_docs_chain = create_stuff_documents_chain(chat_model, retrieval_qa_prompt)
retrieval_chain = create_retrieval_chain(retriever, combine_docs_chain)
response = retrieval_chain.invoke({"input": query})
return response['answer'], relevant_docs
else:
st.info("No relevant documents found. Searching web...")
fallback_agent = create_fallback_agent(chat_model)
with st.spinner('Researching...'):
agent_input = {
"messages": [
HumanMessage(content=f"""Please thoroughly research the question: '{query}' and provide a detailed and comprehensive response. Make sure to gather the latest information from credible sources. Minimum 400 words.""")
],
"is_last_step": False
}
config = {"recursion_limit": 100}
try:
response = fallback_agent.invoke(agent_input, config=config)
if isinstance(response, dict) and "messages" in response:
last_message = response["messages"][-1]
answer = last_message.content if hasattr(last_message, 'content') else str(last_message)
return f"""Web Search Result:
{answer}
""", []
except Exception as agent_error:
fallback_response = chat_model.invoke(f"Please provide a general answer to: {query}").content
return f"Web search unavailable. General response: {fallback_response}", []
except Exception as e:
st.error(f"Error: {str(e)}")
return "I encountered an error. Please try rephrasing your question.", []
def post_process(answer, sources):
"""Post-process the answer and format sources."""
answer = answer.strip()
# Summarize long answers
if len(answer) > 500:
summary_prompt = f"Summarize the following answer in 2-3 sentences: {answer}"
summary = chat_model.invoke(summary_prompt).content
answer = f"{summary}\n\nFull Answer: {answer}"
formatted_sources = []
for i, source in enumerate(sources, 1):
formatted_source = f"{i}. {source.page_content[:200]}..."
formatted_sources.append(formatted_source)
return answer, formatted_sources
st.title("RAG Agent with Cohere ⌘R")
uploaded_file = st.file_uploader("Choose a PDF or Image File", type=["pdf", "jpg", "jpeg"])
if uploaded_file is not None and 'processed_file' not in st.session_state:
with st.spinner('Processing file... This may take a while for images.'):
texts = process_document(uploaded_file)
vectorstore = create_vector_stores(texts)
if vectorstore:
st.session_state.vectorstore = vectorstore
st.session_state.processed_file = True
st.success('File uploaded and processed successfully!')
else:
st.error('Failed to process file. Please try again.')
for message in st.session_state.chat_history:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if query := st.chat_input("Ask a question about the document:"):
st.session_state.chat_history.append({"role": "user", "content": query})
with st.chat_message("user"):
st.markdown(query)
if st.session_state.vectorstore:
with st.chat_message("assistant"):
try:
answer, sources = process_query(st.session_state.vectorstore, query)
st.markdown(answer)
if sources:
with st.expander("Sources"):
for source in sources:
st.markdown(f"- {source.page_content[:200]}...")
st.session_state.chat_history.append({
"role": "assistant",
"content": answer
})
except Exception as e:
st.error(f"Error: {str(e)}")
st.info("Please try asking your question again.")
else:
st.error("Please upload a document first.")
with st.sidebar:
st.divider()
col1, col2 = st.columns(2)
with col1:
if st.button('Clear Chat History'):
st.session_state.chat_history = []
st.rerun()
with col2:
if st.button('Clear All Data'):
try:
collections = client.get_collections().collections
collection_names = [col.name for col in collections]
if COLLECTION_NAME in collection_names:
client.delete_collection(COLLECTION_NAME)
if f"{COLLECTION_NAME}_compressed" in collection_names:
client.delete_collection(f"{COLLECTION_NAME}_compressed")
st.session_state.vectorstore = None
st.session_state.chat_history = []
st.success("All data cleared successfully!")
st.rerun()
except Exception as e:
st.error(f"Error clearing data: {str(e)}")