mirror of
https://github.com/Shubhamsaboo/awesome-llm-apps.git
synced 2026-03-11 17:48:31 -05:00
new demo - crag
This commit is contained in:
92
rag_tutorials/corrective_rag/README.md
Normal file
92
rag_tutorials/corrective_rag/README.md
Normal file
@@ -0,0 +1,92 @@
|
||||
# Corrective RAG Demo
|
||||
|
||||
This project demonstrates Corrective RAG (Retrieval Augmented Generation), an advanced approach to RAG that incorporates self-reflection / self-grading on retrieved documents - document relevance checking, query transformation, and web search fallback mechanisms to improve the quality of responses by far. Complete explanation of CRAG down below.
|
||||
|
||||
## Features
|
||||
|
||||
- **Smart Document Retrieval**: Uses Qdrant vector store for efficient document retrieval
|
||||
- **Document Relevance Grading**: Employs Claude 3 to assess document relevance
|
||||
- **Query Transformation**: Improves search results by optimizing queries when needed
|
||||
- **Web Search Fallback**: Uses Tavily API for web search when local documents aren't sufficient
|
||||
- **Multi-Model Approach**: Combines OpenAI embeddings and Claude 3 for different tasks
|
||||
- **Interactive UI**: Built with Streamlit for easy document upload and querying
|
||||
|
||||
## How to Run?
|
||||
|
||||
1. **Clone the Repository**:
|
||||
```bash
|
||||
git clone https://github.com/Shubhamsaboo/awesome-llm-apps.git
|
||||
cd rag_tutorials/corrective_rag
|
||||
```
|
||||
|
||||
2. **Install Dependencies**:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
3. **Set Up API Keys**:
|
||||
You'll need to obtain the following API keys:
|
||||
- OpenAI API key (for embeddings)
|
||||
- Anthropic API key (for Claude 3.5 sonnet as llm)
|
||||
- Tavily API key (for web search)
|
||||
- Qdrant API key and URL
|
||||
|
||||
4. **Run the Application**:
|
||||
```bash
|
||||
streamlit run corrective_rag.py
|
||||
```
|
||||
|
||||
5. **Use the Application**:
|
||||
- Upload documents or provide URLs
|
||||
- Enter your questions in the query box
|
||||
- View the step-by-step Corrective RAG process
|
||||
- Get comprehensive answers
|
||||
|
||||
## Technologies Used
|
||||
|
||||
- **LangChain**: For RAG orchestration and chains
|
||||
- **LangGraph**: For workflow management
|
||||
- **Qdrant**: Vector database for document storage
|
||||
- **Claude 3.5 sonnet**: Main language model for analysis and generation
|
||||
- **OpenAI**: For document embeddings
|
||||
- **Tavily**: For web search capabilities
|
||||
- **Streamlit**: For the user interface
|
||||
|
||||
## CRAG Step by Step Explanation
|
||||
|
||||
1. Initial Retrieval
|
||||
|
||||
A user query is presented to the system.
|
||||
The system uses an existing retriever model to gather relevant documents from a knowledge base. This retriever could be any existing model.
|
||||
2. Evaluation of Retrieved Documents
|
||||
|
||||
A lightweight retrieval evaluator is used to assess the relevance of each retrieved document to the user query.
|
||||
The evaluator assigns a confidence score to each document, indicating how confident it is in the relevance of the document to the query.
|
||||
|
||||
3. Action Trigger
|
||||
|
||||
Based on the confidence scores, the system categorizes the retrieved documents and decides on the necessary action for each document.
|
||||
|
||||
Correct: If the confidence score of a retrieved document is above a certain threshold, the document is marked as "Correct".
|
||||
Incorrect: If the confidence score of a retrieved document is below a certain threshold, the document is marked as "Incorrect".
|
||||
Ambiguous: If the confidence score falls between the thresholds for "Correct" and "Incorrect", the document is marked as "Ambiguous".
|
||||
|
||||
4. Handling of Retrieved Documents
|
||||
|
||||
Correct Documents: These documents undergo a knowledge refinement process.
|
||||
|
||||
Decomposition: The document is segmented into smaller knowledge strips, typically consisting of a few sentences each.
|
||||
Filtering: The retrieval evaluator is used again to assess the relevance of each knowledge strip. Irrelevant strips are discarded.
|
||||
Recomposition: The remaining relevant knowledge strips are recombined to form a refined representation of the essential knowledge from the document.
|
||||
Incorrect Documents: These documents are discarded, and the system resorts to web searches for additional information.
|
||||
|
||||
Query Rewriting: The user query is rewritten into a form suitable for web searches, typically focusing on keywords.
|
||||
Web Search: The system uses a web search API to find web pages related to the rewritten query. Authoritative sources like Wikipedia are preferred.
|
||||
Knowledge Selection: The content of the web pages is transcribed, and the knowledge refinement process (decomposition, filtering, and recomposition) is applied to extract the most relevant information.
|
||||
Ambiguous Documents: The system combines the refined knowledge from the "Correct" documents and the external knowledge from the web searches to provide a comprehensive set of information for the generator.
|
||||
|
||||
5. Generation of Response
|
||||
|
||||
The refined knowledge from the retrieved documents and/or web searches is presented to a generative language model.
|
||||
The language model generates a response to the user query based on this knowledge
|
||||
|
||||
453
rag_tutorials/corrective_rag/corrective_rag.py
Normal file
453
rag_tutorials/corrective_rag/corrective_rag.py
Normal file
@@ -0,0 +1,453 @@
|
||||
from langchain import hub
|
||||
from langchain.output_parsers import PydanticOutputParser
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain.schema import Document
|
||||
from pydantic import BaseModel, Field
|
||||
import streamlit as st
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain_community.document_loaders import PyPDFLoader, TextLoader, WebBaseLoader
|
||||
from langchain_community.tools import TavilySearchResults
|
||||
from langchain_community.vectorstores import Qdrant
|
||||
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.graph import END, StateGraph
|
||||
from typing import Dict, TypedDict
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
import pprint
|
||||
import yaml
|
||||
import nest_asyncio
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import Distance, VectorParams
|
||||
import tempfile
|
||||
import os
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
|
||||
nest_asyncio.apply()
|
||||
|
||||
retriever = None
|
||||
|
||||
def initialize_session_state():
|
||||
"""Initialize session state variables for API keys and URLs."""
|
||||
if 'initialized' not in st.session_state:
|
||||
st.session_state.initialized = False
|
||||
# Initialize API keys and URLs
|
||||
st.session_state.anthropic_api_key = ""
|
||||
st.session_state.openai_api_key = ""
|
||||
st.session_state.tavily_api_key = ""
|
||||
st.session_state.qdrant_api_key = ""
|
||||
st.session_state.qdrant_url = "http://localhost:6333"
|
||||
st.session_state.doc_url = "https://arxiv.org/pdf/2307.09288.pdf"
|
||||
|
||||
def setup_sidebar():
|
||||
"""Setup sidebar for API keys and configuration."""
|
||||
with st.sidebar:
|
||||
st.subheader("API Configuration")
|
||||
st.session_state.anthropic_api_key = st.text_input("Anthropic API Key", value=st.session_state.anthropic_api_key, type="password", help="Required for Claude 3 model")
|
||||
st.session_state.openai_api_key = st.text_input("OpenAI API Key", value=st.session_state.openai_api_key, type="password")
|
||||
st.session_state.tavily_api_key = st.text_input("Tavily API Key", value=st.session_state.tavily_api_key, type="password")
|
||||
st.session_state.qdrant_url = st.text_input("Qdrant URL", value=st.session_state.qdrant_url)
|
||||
st.session_state.qdrant_api_key = st.text_input("Qdrant API Key", value=st.session_state.qdrant_api_key, type="password")
|
||||
st.session_state.doc_url = st.text_input("Document URL", value=st.session_state.doc_url)
|
||||
|
||||
if not all([st.session_state.openai_api_key, st.session_state.anthropic_api_key, st.session_state.qdrant_url]):
|
||||
st.warning("Please provide the required API keys and URLs")
|
||||
st.stop()
|
||||
|
||||
st.session_state.initialized = True
|
||||
|
||||
initialize_session_state()
|
||||
setup_sidebar()
|
||||
|
||||
# Use session state variables instead of config
|
||||
openai_api_key = st.session_state.openai_api_key
|
||||
tavily_api_key = st.session_state.tavily_api_key
|
||||
anthropic_api_key = st.session_state.anthropic_api_key
|
||||
|
||||
# Update embeddings initialization
|
||||
embeddings = OpenAIEmbeddings(
|
||||
model="text-embedding-3-small",
|
||||
api_key=st.session_state.openai_api_key
|
||||
)
|
||||
|
||||
# Update Qdrant client initialization
|
||||
client = QdrantClient(
|
||||
url=st.session_state.qdrant_url,
|
||||
api_key=st.session_state.qdrant_api_key
|
||||
)
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
|
||||
def execute_tavily_search(tool, query):
|
||||
return tool.invoke({"query": query})
|
||||
|
||||
def web_search(state):
|
||||
"""Web search based on the re-phrased question using Tavily API."""
|
||||
print("~-web search-~")
|
||||
state_dict = state["keys"]
|
||||
question = state_dict["question"]
|
||||
documents = state_dict["documents"]
|
||||
|
||||
# Create progress placeholder
|
||||
progress_placeholder = st.empty()
|
||||
progress_placeholder.info("Initiating web search...")
|
||||
|
||||
try:
|
||||
# Validate Tavily API key
|
||||
if not st.session_state.tavily_api_key:
|
||||
progress_placeholder.warning("Tavily API key not provided - skipping web search")
|
||||
return {"keys": {"documents": documents, "question": question}}
|
||||
|
||||
progress_placeholder.info("Configuring search tool...")
|
||||
|
||||
# Initialize Tavily search tool
|
||||
tool = TavilySearchResults(
|
||||
api_key=st.session_state.tavily_api_key,
|
||||
max_results=3,
|
||||
search_depth="advanced"
|
||||
)
|
||||
|
||||
# Execute search with retry logic
|
||||
progress_placeholder.info("Executing search query...")
|
||||
try:
|
||||
search_results = execute_tavily_search(tool, question)
|
||||
except Exception as search_error:
|
||||
progress_placeholder.error(f"Search failed after retries: {str(search_error)}")
|
||||
return {"keys": {"documents": documents, "question": question}}
|
||||
|
||||
if not search_results:
|
||||
progress_placeholder.warning("No search results found")
|
||||
return {"keys": {"documents": documents, "question": question}}
|
||||
|
||||
# Process results
|
||||
progress_placeholder.info("Processing search results...")
|
||||
web_results = []
|
||||
for result in search_results:
|
||||
# Extract and format relevant information
|
||||
content = (
|
||||
f"Title: {result.get('title', 'No title')}\n"
|
||||
f"Content: {result.get('content', 'No content')}\n"
|
||||
)
|
||||
web_results.append(content)
|
||||
|
||||
# Create document from results
|
||||
web_document = Document(
|
||||
page_content="\n\n".join(web_results),
|
||||
metadata={
|
||||
"source": "tavily_search",
|
||||
"query": question,
|
||||
"result_count": len(web_results)
|
||||
}
|
||||
)
|
||||
documents.append(web_document)
|
||||
|
||||
progress_placeholder.success(f"Successfully added {len(web_results)} search results")
|
||||
|
||||
except Exception as error:
|
||||
error_msg = f"Web search error: {str(error)}"
|
||||
print(error_msg)
|
||||
progress_placeholder.error(error_msg)
|
||||
finally:
|
||||
progress_placeholder.empty()
|
||||
return {"keys": {"documents": documents, "question": question}}
|
||||
|
||||
|
||||
def load_documents(file_or_url: str, is_url: bool = True) -> list:
|
||||
try:
|
||||
if is_url:
|
||||
loader = WebBaseLoader(file_or_url)
|
||||
loader.requests_per_second = 1
|
||||
else:
|
||||
file_extension = os.path.splitext(file_or_url)[1].lower()
|
||||
if file_extension == '.pdf':
|
||||
loader = PyPDFLoader(file_or_url)
|
||||
elif file_extension in ['.txt', '.md']:
|
||||
loader = TextLoader(file_or_url)
|
||||
else:
|
||||
raise ValueError(f"Unsupported file type: {file_extension}")
|
||||
|
||||
return loader.load()
|
||||
except Exception as e:
|
||||
st.error(f"Error loading document: {str(e)}")
|
||||
return []
|
||||
|
||||
st.subheader("Document Input")
|
||||
input_option = st.radio("Choose input method:", ["URL", "File Upload"])
|
||||
|
||||
docs = None
|
||||
|
||||
if input_option == "URL":
|
||||
url = st.text_input("Enter document URL:", value=st.session_state.doc_url)
|
||||
if url:
|
||||
docs = load_documents(url, is_url=True)
|
||||
else:
|
||||
uploaded_file = st.file_uploader("Upload a document", type=['pdf', 'txt', 'md'])
|
||||
if uploaded_file:
|
||||
# Create a temporary file to store the upload
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
|
||||
tmp_file.write(uploaded_file.getvalue())
|
||||
docs = load_documents(tmp_file.name, is_url=False)
|
||||
# Clean up the temporary file
|
||||
os.unlink(tmp_file.name)
|
||||
|
||||
if docs:
|
||||
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
||||
chunk_size=500, chunk_overlap=100
|
||||
)
|
||||
all_splits = text_splitter.split_documents(docs)
|
||||
|
||||
client = QdrantClient(url=st.session_state.qdrant_url, api_key=st.session_state.qdrant_api_key)
|
||||
collection_name = "rag-qdrant"
|
||||
|
||||
try:
|
||||
# Try to delete the collection if it exists
|
||||
client.delete_collection(collection_name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
client.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=VectorParams(size=1536, distance=Distance.COSINE),
|
||||
)
|
||||
|
||||
# Create vectorstore
|
||||
vectorstore = Qdrant(
|
||||
client=client,
|
||||
collection_name=collection_name,
|
||||
embeddings=embeddings,
|
||||
)
|
||||
|
||||
# Add documents to the vectorstore
|
||||
vectorstore.add_documents(all_splits)
|
||||
retriever = vectorstore.as_retriever()
|
||||
|
||||
|
||||
class GraphState(TypedDict):
|
||||
keys: Dict[str, any]
|
||||
|
||||
|
||||
def retrieve(state):
|
||||
print("~-retrieve-~")
|
||||
state_dict = state["keys"]
|
||||
question = state_dict["question"]
|
||||
|
||||
if retriever is None:
|
||||
return {"keys": {"documents": [], "question": question}}
|
||||
|
||||
documents = retriever.get_relevant_documents(question)
|
||||
return {"keys": {"documents": documents, "question": question}}
|
||||
|
||||
|
||||
def generate(state):
|
||||
"""Generate answer using Claude 3 model"""
|
||||
print("~-generate-~")
|
||||
state_dict = state["keys"]
|
||||
question, documents = state_dict["question"], state_dict["documents"]
|
||||
try:
|
||||
prompt = PromptTemplate(template="""Based on the following context, please answer the question.
|
||||
Context: {context}
|
||||
Question: {question}
|
||||
Answer:""", input_variables=["context", "question"])
|
||||
llm = ChatAnthropic(model="claude-3-5-sonnet-20241022", api_key=st.session_state.anthropic_api_key,
|
||||
temperature=0, max_tokens=1000)
|
||||
context = "\n\n".join(doc.page_content for doc in documents)
|
||||
|
||||
# Create and run chain
|
||||
rag_chain = (
|
||||
{"context": lambda x: context, "question": lambda x: question}
|
||||
| prompt
|
||||
| llm
|
||||
| StrOutputParser()
|
||||
)
|
||||
|
||||
generation = rag_chain.invoke({})
|
||||
|
||||
return {
|
||||
"keys": {
|
||||
"documents": documents,
|
||||
"question": question,
|
||||
"generation": generation
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error in generate function: {str(e)}"
|
||||
print(error_msg)
|
||||
st.error(error_msg)
|
||||
return {"keys": {"documents": documents, "question": question,
|
||||
"generation": "Sorry, I encountered an error while generating the response."}}
|
||||
|
||||
def grade_documents(state):
|
||||
"""Determines whether the retrieved documents are relevant."""
|
||||
print("~-check relevance-~")
|
||||
state_dict = state["keys"]
|
||||
question = state_dict["question"]
|
||||
documents = state_dict["documents"]
|
||||
|
||||
llm = ChatAnthropic(model="claude-3-5-sonnet-20241022", api_key=st.session_state.anthropic_api_key,
|
||||
temperature=0, max_tokens=1000)
|
||||
|
||||
prompt = PromptTemplate(template="""You are grading the relevance of a retrieved document to a user question.
|
||||
Return ONLY a JSON object with a "score" field that is either "yes" or "no".
|
||||
Do not include any other text or explanation.
|
||||
|
||||
Document: {context}
|
||||
Question: {question}
|
||||
|
||||
Rules:
|
||||
- Check for related keywords or semantic meaning
|
||||
- Use lenient grading to only filter clear mismatches
|
||||
- Return exactly like this example: {{"score": "yes"}} or {{"score": "no"}}""",
|
||||
input_variables=["context", "question"])
|
||||
|
||||
chain = (
|
||||
prompt
|
||||
| llm
|
||||
| StrOutputParser()
|
||||
)
|
||||
|
||||
filtered_docs = []
|
||||
search = "No"
|
||||
|
||||
for d in documents:
|
||||
try:
|
||||
response = chain.invoke({"question": question, "context": d.page_content})
|
||||
import re
|
||||
json_match = re.search(r'\{.*\}', response)
|
||||
if json_match:
|
||||
response = json_match.group()
|
||||
|
||||
import json
|
||||
score = json.loads(response)
|
||||
|
||||
if score.get("score") == "yes":
|
||||
print("~-grade: document relevant-~")
|
||||
filtered_docs.append(d)
|
||||
else:
|
||||
print("~-grade: document not relevant-~")
|
||||
search = "Yes"
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error grading document: {str(e)}")
|
||||
# On error, keep the document to be safe
|
||||
filtered_docs.append(d)
|
||||
continue
|
||||
|
||||
return {"keys": {"documents": filtered_docs, "question": question, "run_web_search": search}}
|
||||
|
||||
|
||||
def transform_query(state):
|
||||
"""Transform the query to produce a better question."""
|
||||
print("~-transform query-~")
|
||||
state_dict = state["keys"]
|
||||
question = state_dict["question"]
|
||||
documents = state_dict["documents"]
|
||||
|
||||
# Create a prompt template
|
||||
prompt = PromptTemplate(
|
||||
template="""Generate a search-optimized version of this question by
|
||||
analyzing its core semantic meaning and intent.
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
Return only the improved question with no additional text:""",
|
||||
input_variables=["question"],
|
||||
)
|
||||
|
||||
# Use Claude instead of Gemini
|
||||
llm = ChatAnthropic(
|
||||
model="claude-3-5-sonnet-20240620",
|
||||
anthropic_api_key=st.session_state.anthropic_api_key,
|
||||
temperature=0,
|
||||
max_tokens=1000
|
||||
)
|
||||
|
||||
# Prompt
|
||||
chain = prompt | llm | StrOutputParser()
|
||||
better_question = chain.invoke({"question": question})
|
||||
|
||||
return {
|
||||
"keys": {"documents": documents, "question": better_question}
|
||||
}
|
||||
|
||||
|
||||
def decide_to_generate(state):
|
||||
print("~-decide to generate-~")
|
||||
state_dict = state["keys"]
|
||||
search = state_dict["run_web_search"]
|
||||
|
||||
if search == "Yes":
|
||||
|
||||
print("~-decision: transform query and run web search-~")
|
||||
return "transform_query"
|
||||
else:
|
||||
print("~-decision: generate-~")
|
||||
return "generate"
|
||||
|
||||
def format_document(doc: Document) -> str:
|
||||
return f"""
|
||||
Source: {doc.metadata.get('source', 'Unknown')}
|
||||
Title: {doc.metadata.get('title', 'No title')}
|
||||
Content: {doc.page_content[:200]}...
|
||||
"""
|
||||
|
||||
def format_state(state: dict) -> str:
|
||||
formatted = {}
|
||||
|
||||
for key, value in state.items():
|
||||
if key == "documents":
|
||||
formatted[key] = [format_document(doc) for doc in value]
|
||||
else:
|
||||
formatted[key] = value
|
||||
|
||||
return formatted
|
||||
|
||||
|
||||
workflow = StateGraph(GraphState)
|
||||
|
||||
# Define the nodes by langgraph
|
||||
workflow.add_node("retrieve", retrieve)
|
||||
workflow.add_node("grade_documents", grade_documents)
|
||||
workflow.add_node("generate", generate)
|
||||
workflow.add_node("transform_query", transform_query)
|
||||
workflow.add_node("web_search", web_search)
|
||||
|
||||
# Build graph
|
||||
workflow.set_entry_point("retrieve")
|
||||
workflow.add_edge("retrieve", "grade_documents")
|
||||
workflow.add_conditional_edges(
|
||||
"grade_documents",
|
||||
decide_to_generate,
|
||||
{
|
||||
"transform_query": "transform_query",
|
||||
"generate": "generate",
|
||||
},
|
||||
)
|
||||
workflow.add_edge("transform_query", "web_search")
|
||||
workflow.add_edge("web_search", "generate")
|
||||
workflow.add_edge("generate", END)
|
||||
|
||||
app = workflow.compile()
|
||||
|
||||
st.title("Corrective RAG Demo")
|
||||
|
||||
st.text("A possible query: What are the experiment results and ablation studies in this research paper?")
|
||||
|
||||
# User input
|
||||
user_question = st.text_input("Please enter your question:")
|
||||
|
||||
if user_question:
|
||||
inputs = {
|
||||
"keys": {
|
||||
"question": user_question,
|
||||
}
|
||||
}
|
||||
|
||||
for output in app.stream(inputs):
|
||||
for key, value in output.items():
|
||||
with st.expander(f"Step '{key}':"):
|
||||
st.text(pprint.pformat(format_state(value["keys"]), indent=2, width=80))
|
||||
|
||||
final_generation = value['keys'].get('generation', 'No final generation produced.')
|
||||
st.subheader("Final Generation:")
|
||||
st.write(final_generation)
|
||||
19
rag_tutorials/corrective_rag/requirements.txt
Normal file
19
rag_tutorials/corrective_rag/requirements.txt
Normal file
@@ -0,0 +1,19 @@
|
||||
# Core dependencies
|
||||
langchain==0.3.12
|
||||
langgraph==0.2.53
|
||||
qdrant-client==1.12.1
|
||||
langchain-openai==0.2.14
|
||||
langchain-anthropic==0.3.0
|
||||
tavily-python==0.5.0
|
||||
langchain-community==0.3.12
|
||||
langchain-core==0.3.28
|
||||
streamlit==1.41.1
|
||||
tenacity==8.5.0
|
||||
|
||||
anthropic>=0.7.0
|
||||
openai>=1.12.0
|
||||
tiktoken>=0.6.0
|
||||
pydantic>=2.0.0
|
||||
numpy>=1.24.0
|
||||
PyYAML>=6.0.0
|
||||
nest-asyncio>=1.5.0
|
||||
Reference in New Issue
Block a user