new demo - crag

This commit is contained in:
Madhu
2024-12-27 02:45:11 +05:30
parent ba7478407a
commit 315092de27
3 changed files with 564 additions and 0 deletions

View File

@@ -0,0 +1,92 @@
# Corrective RAG Demo
This project demonstrates Corrective RAG (Retrieval Augmented Generation), an advanced approach to RAG that incorporates self-reflection / self-grading on retrieved documents - document relevance checking, query transformation, and web search fallback mechanisms to improve the quality of responses by far. Complete explanation of CRAG down below.
## Features
- **Smart Document Retrieval**: Uses Qdrant vector store for efficient document retrieval
- **Document Relevance Grading**: Employs Claude 3 to assess document relevance
- **Query Transformation**: Improves search results by optimizing queries when needed
- **Web Search Fallback**: Uses Tavily API for web search when local documents aren't sufficient
- **Multi-Model Approach**: Combines OpenAI embeddings and Claude 3 for different tasks
- **Interactive UI**: Built with Streamlit for easy document upload and querying
## How to Run?
1. **Clone the Repository**:
```bash
git clone https://github.com/Shubhamsaboo/awesome-llm-apps.git
cd rag_tutorials/corrective_rag
```
2. **Install Dependencies**:
```bash
pip install -r requirements.txt
```
3. **Set Up API Keys**:
You'll need to obtain the following API keys:
- OpenAI API key (for embeddings)
- Anthropic API key (for Claude 3.5 sonnet as llm)
- Tavily API key (for web search)
- Qdrant API key and URL
4. **Run the Application**:
```bash
streamlit run corrective_rag.py
```
5. **Use the Application**:
- Upload documents or provide URLs
- Enter your questions in the query box
- View the step-by-step Corrective RAG process
- Get comprehensive answers
## Technologies Used
- **LangChain**: For RAG orchestration and chains
- **LangGraph**: For workflow management
- **Qdrant**: Vector database for document storage
- **Claude 3.5 sonnet**: Main language model for analysis and generation
- **OpenAI**: For document embeddings
- **Tavily**: For web search capabilities
- **Streamlit**: For the user interface
## CRAG Step by Step Explanation
1. Initial Retrieval
A user query is presented to the system.  
The system uses an existing retriever model to gather relevant documents from a knowledge base. This retriever could be any existing model.  
2. Evaluation of Retrieved Documents
A lightweight retrieval evaluator is used to assess the relevance of each retrieved document to the user query.  
The evaluator assigns a confidence score to each document, indicating how confident it is in the relevance of the document to the query.
 
3. Action Trigger
Based on the confidence scores, the system categorizes the retrieved documents and decides on the necessary action for each document.  
Correct: If the confidence score of a retrieved document is above a certain threshold, the document is marked as "Correct".  
Incorrect: If the confidence score of a retrieved document is below a certain threshold, the document is marked as "Incorrect".  
Ambiguous: If the confidence score falls between the thresholds for "Correct" and "Incorrect", the document is marked as "Ambiguous".  
4. Handling of Retrieved Documents
Correct Documents: These documents undergo a knowledge refinement process.  
Decomposition: The document is segmented into smaller knowledge strips, typically consisting of a few sentences each.  
Filtering: The retrieval evaluator is used again to assess the relevance of each knowledge strip. Irrelevant strips are discarded.  
Recomposition: The remaining relevant knowledge strips are recombined to form a refined representation of the essential knowledge from the document.  
Incorrect Documents: These documents are discarded, and the system resorts to web searches for additional information.  
Query Rewriting: The user query is rewritten into a form suitable for web searches, typically focusing on keywords.  
Web Search: The system uses a web search API to find web pages related to the rewritten query. Authoritative sources like Wikipedia are preferred.  
Knowledge Selection: The content of the web pages is transcribed, and the knowledge refinement process (decomposition, filtering, and recomposition) is applied to extract the most relevant information.  
Ambiguous Documents: The system combines the refined knowledge from the "Correct" documents and the external knowledge from the web searches to provide a comprehensive set of information for the generator.  
5. Generation of Response
The refined knowledge from the retrieved documents and/or web searches is presented to a generative language model.  
The language model generates a response to the user query based on this knowledge

View File

@@ -0,0 +1,453 @@
from langchain import hub
from langchain.output_parsers import PydanticOutputParser
from langchain_core.output_parsers import StrOutputParser
from langchain.schema import Document
from pydantic import BaseModel, Field
import streamlit as st
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader, TextLoader, WebBaseLoader
from langchain_community.tools import TavilySearchResults
from langchain_community.vectorstores import Qdrant
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_core.messages import HumanMessage
from langgraph.graph import END, StateGraph
from typing import Dict, TypedDict
from langchain_core.prompts import PromptTemplate
import pprint
import yaml
import nest_asyncio
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams
import tempfile
import os
from langchain_anthropic import ChatAnthropic
from tenacity import retry, stop_after_attempt, wait_exponential
nest_asyncio.apply()
retriever = None
def initialize_session_state():
"""Initialize session state variables for API keys and URLs."""
if 'initialized' not in st.session_state:
st.session_state.initialized = False
# Initialize API keys and URLs
st.session_state.anthropic_api_key = ""
st.session_state.openai_api_key = ""
st.session_state.tavily_api_key = ""
st.session_state.qdrant_api_key = ""
st.session_state.qdrant_url = "http://localhost:6333"
st.session_state.doc_url = "https://arxiv.org/pdf/2307.09288.pdf"
def setup_sidebar():
"""Setup sidebar for API keys and configuration."""
with st.sidebar:
st.subheader("API Configuration")
st.session_state.anthropic_api_key = st.text_input("Anthropic API Key", value=st.session_state.anthropic_api_key, type="password", help="Required for Claude 3 model")
st.session_state.openai_api_key = st.text_input("OpenAI API Key", value=st.session_state.openai_api_key, type="password")
st.session_state.tavily_api_key = st.text_input("Tavily API Key", value=st.session_state.tavily_api_key, type="password")
st.session_state.qdrant_url = st.text_input("Qdrant URL", value=st.session_state.qdrant_url)
st.session_state.qdrant_api_key = st.text_input("Qdrant API Key", value=st.session_state.qdrant_api_key, type="password")
st.session_state.doc_url = st.text_input("Document URL", value=st.session_state.doc_url)
if not all([st.session_state.openai_api_key, st.session_state.anthropic_api_key, st.session_state.qdrant_url]):
st.warning("Please provide the required API keys and URLs")
st.stop()
st.session_state.initialized = True
initialize_session_state()
setup_sidebar()
# Use session state variables instead of config
openai_api_key = st.session_state.openai_api_key
tavily_api_key = st.session_state.tavily_api_key
anthropic_api_key = st.session_state.anthropic_api_key
# Update embeddings initialization
embeddings = OpenAIEmbeddings(
model="text-embedding-3-small",
api_key=st.session_state.openai_api_key
)
# Update Qdrant client initialization
client = QdrantClient(
url=st.session_state.qdrant_url,
api_key=st.session_state.qdrant_api_key
)
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
def execute_tavily_search(tool, query):
return tool.invoke({"query": query})
def web_search(state):
"""Web search based on the re-phrased question using Tavily API."""
print("~-web search-~")
state_dict = state["keys"]
question = state_dict["question"]
documents = state_dict["documents"]
# Create progress placeholder
progress_placeholder = st.empty()
progress_placeholder.info("Initiating web search...")
try:
# Validate Tavily API key
if not st.session_state.tavily_api_key:
progress_placeholder.warning("Tavily API key not provided - skipping web search")
return {"keys": {"documents": documents, "question": question}}
progress_placeholder.info("Configuring search tool...")
# Initialize Tavily search tool
tool = TavilySearchResults(
api_key=st.session_state.tavily_api_key,
max_results=3,
search_depth="advanced"
)
# Execute search with retry logic
progress_placeholder.info("Executing search query...")
try:
search_results = execute_tavily_search(tool, question)
except Exception as search_error:
progress_placeholder.error(f"Search failed after retries: {str(search_error)}")
return {"keys": {"documents": documents, "question": question}}
if not search_results:
progress_placeholder.warning("No search results found")
return {"keys": {"documents": documents, "question": question}}
# Process results
progress_placeholder.info("Processing search results...")
web_results = []
for result in search_results:
# Extract and format relevant information
content = (
f"Title: {result.get('title', 'No title')}\n"
f"Content: {result.get('content', 'No content')}\n"
)
web_results.append(content)
# Create document from results
web_document = Document(
page_content="\n\n".join(web_results),
metadata={
"source": "tavily_search",
"query": question,
"result_count": len(web_results)
}
)
documents.append(web_document)
progress_placeholder.success(f"Successfully added {len(web_results)} search results")
except Exception as error:
error_msg = f"Web search error: {str(error)}"
print(error_msg)
progress_placeholder.error(error_msg)
finally:
progress_placeholder.empty()
return {"keys": {"documents": documents, "question": question}}
def load_documents(file_or_url: str, is_url: bool = True) -> list:
try:
if is_url:
loader = WebBaseLoader(file_or_url)
loader.requests_per_second = 1
else:
file_extension = os.path.splitext(file_or_url)[1].lower()
if file_extension == '.pdf':
loader = PyPDFLoader(file_or_url)
elif file_extension in ['.txt', '.md']:
loader = TextLoader(file_or_url)
else:
raise ValueError(f"Unsupported file type: {file_extension}")
return loader.load()
except Exception as e:
st.error(f"Error loading document: {str(e)}")
return []
st.subheader("Document Input")
input_option = st.radio("Choose input method:", ["URL", "File Upload"])
docs = None
if input_option == "URL":
url = st.text_input("Enter document URL:", value=st.session_state.doc_url)
if url:
docs = load_documents(url, is_url=True)
else:
uploaded_file = st.file_uploader("Upload a document", type=['pdf', 'txt', 'md'])
if uploaded_file:
# Create a temporary file to store the upload
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
tmp_file.write(uploaded_file.getvalue())
docs = load_documents(tmp_file.name, is_url=False)
# Clean up the temporary file
os.unlink(tmp_file.name)
if docs:
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=500, chunk_overlap=100
)
all_splits = text_splitter.split_documents(docs)
client = QdrantClient(url=st.session_state.qdrant_url, api_key=st.session_state.qdrant_api_key)
collection_name = "rag-qdrant"
try:
# Try to delete the collection if it exists
client.delete_collection(collection_name)
except Exception:
pass
client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(size=1536, distance=Distance.COSINE),
)
# Create vectorstore
vectorstore = Qdrant(
client=client,
collection_name=collection_name,
embeddings=embeddings,
)
# Add documents to the vectorstore
vectorstore.add_documents(all_splits)
retriever = vectorstore.as_retriever()
class GraphState(TypedDict):
keys: Dict[str, any]
def retrieve(state):
print("~-retrieve-~")
state_dict = state["keys"]
question = state_dict["question"]
if retriever is None:
return {"keys": {"documents": [], "question": question}}
documents = retriever.get_relevant_documents(question)
return {"keys": {"documents": documents, "question": question}}
def generate(state):
"""Generate answer using Claude 3 model"""
print("~-generate-~")
state_dict = state["keys"]
question, documents = state_dict["question"], state_dict["documents"]
try:
prompt = PromptTemplate(template="""Based on the following context, please answer the question.
Context: {context}
Question: {question}
Answer:""", input_variables=["context", "question"])
llm = ChatAnthropic(model="claude-3-5-sonnet-20241022", api_key=st.session_state.anthropic_api_key,
temperature=0, max_tokens=1000)
context = "\n\n".join(doc.page_content for doc in documents)
# Create and run chain
rag_chain = (
{"context": lambda x: context, "question": lambda x: question}
| prompt
| llm
| StrOutputParser()
)
generation = rag_chain.invoke({})
return {
"keys": {
"documents": documents,
"question": question,
"generation": generation
}
}
except Exception as e:
error_msg = f"Error in generate function: {str(e)}"
print(error_msg)
st.error(error_msg)
return {"keys": {"documents": documents, "question": question,
"generation": "Sorry, I encountered an error while generating the response."}}
def grade_documents(state):
"""Determines whether the retrieved documents are relevant."""
print("~-check relevance-~")
state_dict = state["keys"]
question = state_dict["question"]
documents = state_dict["documents"]
llm = ChatAnthropic(model="claude-3-5-sonnet-20241022", api_key=st.session_state.anthropic_api_key,
temperature=0, max_tokens=1000)
prompt = PromptTemplate(template="""You are grading the relevance of a retrieved document to a user question.
Return ONLY a JSON object with a "score" field that is either "yes" or "no".
Do not include any other text or explanation.
Document: {context}
Question: {question}
Rules:
- Check for related keywords or semantic meaning
- Use lenient grading to only filter clear mismatches
- Return exactly like this example: {{"score": "yes"}} or {{"score": "no"}}""",
input_variables=["context", "question"])
chain = (
prompt
| llm
| StrOutputParser()
)
filtered_docs = []
search = "No"
for d in documents:
try:
response = chain.invoke({"question": question, "context": d.page_content})
import re
json_match = re.search(r'\{.*\}', response)
if json_match:
response = json_match.group()
import json
score = json.loads(response)
if score.get("score") == "yes":
print("~-grade: document relevant-~")
filtered_docs.append(d)
else:
print("~-grade: document not relevant-~")
search = "Yes"
except Exception as e:
print(f"Error grading document: {str(e)}")
# On error, keep the document to be safe
filtered_docs.append(d)
continue
return {"keys": {"documents": filtered_docs, "question": question, "run_web_search": search}}
def transform_query(state):
"""Transform the query to produce a better question."""
print("~-transform query-~")
state_dict = state["keys"]
question = state_dict["question"]
documents = state_dict["documents"]
# Create a prompt template
prompt = PromptTemplate(
template="""Generate a search-optimized version of this question by
analyzing its core semantic meaning and intent.
\n ------- \n
{question}
\n ------- \n
Return only the improved question with no additional text:""",
input_variables=["question"],
)
# Use Claude instead of Gemini
llm = ChatAnthropic(
model="claude-3-5-sonnet-20240620",
anthropic_api_key=st.session_state.anthropic_api_key,
temperature=0,
max_tokens=1000
)
# Prompt
chain = prompt | llm | StrOutputParser()
better_question = chain.invoke({"question": question})
return {
"keys": {"documents": documents, "question": better_question}
}
def decide_to_generate(state):
print("~-decide to generate-~")
state_dict = state["keys"]
search = state_dict["run_web_search"]
if search == "Yes":
print("~-decision: transform query and run web search-~")
return "transform_query"
else:
print("~-decision: generate-~")
return "generate"
def format_document(doc: Document) -> str:
return f"""
Source: {doc.metadata.get('source', 'Unknown')}
Title: {doc.metadata.get('title', 'No title')}
Content: {doc.page_content[:200]}...
"""
def format_state(state: dict) -> str:
formatted = {}
for key, value in state.items():
if key == "documents":
formatted[key] = [format_document(doc) for doc in value]
else:
formatted[key] = value
return formatted
workflow = StateGraph(GraphState)
# Define the nodes by langgraph
workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("generate", generate)
workflow.add_node("transform_query", transform_query)
workflow.add_node("web_search", web_search)
# Build graph
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
"grade_documents",
decide_to_generate,
{
"transform_query": "transform_query",
"generate": "generate",
},
)
workflow.add_edge("transform_query", "web_search")
workflow.add_edge("web_search", "generate")
workflow.add_edge("generate", END)
app = workflow.compile()
st.title("Corrective RAG Demo")
st.text("A possible query: What are the experiment results and ablation studies in this research paper?")
# User input
user_question = st.text_input("Please enter your question:")
if user_question:
inputs = {
"keys": {
"question": user_question,
}
}
for output in app.stream(inputs):
for key, value in output.items():
with st.expander(f"Step '{key}':"):
st.text(pprint.pformat(format_state(value["keys"]), indent=2, width=80))
final_generation = value['keys'].get('generation', 'No final generation produced.')
st.subheader("Final Generation:")
st.write(final_generation)

View File

@@ -0,0 +1,19 @@
# Core dependencies
langchain==0.3.12
langgraph==0.2.53
qdrant-client==1.12.1
langchain-openai==0.2.14
langchain-anthropic==0.3.0
tavily-python==0.5.0
langchain-community==0.3.12
langchain-core==0.3.28
streamlit==1.41.1
tenacity==8.5.0
anthropic>=0.7.0
openai>=1.12.0
tiktoken>=0.6.0
pydantic>=2.0.0
numpy>=1.24.0
PyYAML>=6.0.0
nest-asyncio>=1.5.0