Files
2025-02-16 14:24:17 +05:30

373 lines
12 KiB
Python

from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from uuid import uuid4
from langchain_community.document_loaders import WebBaseLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.tools.retriever import create_retriever_tool
from typing import Annotated, Literal, Sequence
from typing_extensions import TypedDict
from functools import partial
from langchain import hub
from langchain_core.messages import BaseMessage, HumanMessage
from langgraph.graph.message import add_messages
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI
from pydantic import BaseModel, Field
from langgraph.graph import END, StateGraph, START
from langgraph.prebuilt import ToolNode, tools_condition
import streamlit as st
st.set_page_config(page_title="AI Blog Search", page_icon=":mag_right:")
st.header(":blue[Agentic RAG with LangGraph:] :green[AI Blog Search]")
# Initialize session state variables if they don't exist
if 'qdrant_host' not in st.session_state:
st.session_state.qdrant_host = ""
if 'qdrant_api_key' not in st.session_state:
st.session_state.qdrant_api_key = ""
if 'gemini_api_key' not in st.session_state:
st.session_state.gemini_api_key = ""
def set_sidebar():
"""Setup sidebar for API keys and configuration."""
with st.sidebar:
st.subheader("API Configuration")
qdrant_host = st.text_input("Enter your Qdrant Host URL:", type="password")
qdrant_api_key = st.text_input("Enter your Qdrant API key:", type="password")
gemini_api_key = st.text_input("Enter your Gemini API key:", type="password")
if st.button("Done"):
if qdrant_host and qdrant_api_key and gemini_api_key:
st.session_state.qdrant_host = qdrant_host
st.session_state.qdrant_api_key = qdrant_api_key
st.session_state.gemini_api_key = gemini_api_key
st.success("API keys saved!")
else:
st.warning("Please fill all API fields")
def initialize_components():
"""Initialize components that require API keys"""
if not all([st.session_state.qdrant_host,
st.session_state.qdrant_api_key,
st.session_state.gemini_api_key]):
return None, None, None
try:
# Initialize embedding model with API key
embedding_model = GoogleGenerativeAIEmbeddings(
model="models/embedding-001",
google_api_key=st.session_state.gemini_api_key
)
# Initialize Qdrant client
client = QdrantClient(
st.session_state.qdrant_host,
api_key=st.session_state.qdrant_api_key
)
# Initialize vector store
db = QdrantVectorStore(
client=client,
collection_name="qdrant_db",
embedding=embedding_model
)
return embedding_model, client, db
except Exception as e:
st.error(f"Initialization error: {str(e)}")
return None, None, None
class AgentState(TypedDict):
messages: Annotated[Sequence[BaseMessage], add_messages]
# Edges
## Check Relevance
def grade_documents(state) -> Literal["generate", "rewrite"]:
"""
Determines whether the retrieved documents are relevant to the question.
Args:
state (messages): The current state
Returns:
str: A decision for whether the documents are relevant or not
"""
print("---CHECK RELEVANCE---")
# Data model
class grade(BaseModel):
"""Binary score for relevance check."""
binary_score: str = Field(description="Relevance score 'yes' or 'no'")
# LLM
model = ChatGoogleGenerativeAI(api_key=st.session_state.gemini_api_key, temperature=0, model="gemini-2.0-flash", streaming=True)
# LLM with tool and validation
llm_with_tool = model.with_structured_output(grade)
# Prompt
prompt = PromptTemplate(
template="""You are a grader assessing relevance of a retrieved document to a user question. \n
Here is the retrieved document: \n\n {context} \n\n
Here is the user question: {question} \n
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.""",
input_variables=["context", "question"],
)
# Chain
chain = prompt | llm_with_tool
messages = state["messages"]
last_message = messages[-1]
question = messages[0].content
docs = last_message.content
scored_result = chain.invoke({"question": question, "context": docs})
score = scored_result.binary_score
if score == "yes":
print("---DECISION: DOCS RELEVANT---")
return "generate"
else:
print("---DECISION: DOCS NOT RELEVANT---")
print(score)
return "rewrite"
# Nodes
## agent node
def agent(state, tools):
"""
Invokes the agent model to generate a response based on the current state. Given
the question, it will decide to retrieve using the retriever tool, or simply end.
Args:
state (messages): The current state
Returns:
dict: The updated state with the agent response appended to messages
"""
print("---CALL AGENT---")
messages = state["messages"]
model = ChatGoogleGenerativeAI(api_key=st.session_state.gemini_api_key, temperature=0, streaming=True, model="gemini-2.0-flash")
model = model.bind_tools(tools)
response = model.invoke(messages)
# We return a list, because this will get added to the existing list
return {"messages": [response]}
## rewrite node
def rewrite(state):
"""
Transform the query to produce a better question.
Args:
state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
print("---TRANSFORM QUERY---")
messages = state["messages"]
question = messages[0].content
msg = [
HumanMessage(
content=f""" \n
Look at the input and try to reason about the underlying semantic intent / meaning. \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Formulate an improved question: """,
)
]
# Grader
model = ChatGoogleGenerativeAI(api_key=st.session_state.gemini_api_key, temperature=0, model="gemini-2.0-flash", streaming=True)
response = model.invoke(msg)
return {"messages": [response]}
## generate node
def generate(state):
"""
Generate answer
Args:
state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
print("---GENERATE---")
messages = state["messages"]
question = messages[0].content
last_message = messages[-1]
docs = last_message.content
# Initialize a Chat Prompt Template
prompt_template = hub.pull("rlm/rag-prompt")
# Initialize a Generator (i.e. Chat Model)
chat_model = ChatGoogleGenerativeAI(api_key=st.session_state.gemini_api_key, model="gemini-2.0-flash", temperature=0, streaming=True)
# Initialize a Output Parser
output_parser = StrOutputParser()
# RAG Chain
rag_chain = prompt_template | chat_model | output_parser
response = rag_chain.invoke({"context": docs, "question": question})
return {"messages": [response]}
# graph function
def get_graph(retriever_tool):
tools = [retriever_tool] # Create tools list here
# Define a new graph
workflow = StateGraph(AgentState)
# Use partial to pass tools to the agent function
workflow.add_node("agent", partial(agent, tools=tools))
# Rest of the graph setup remains the same
retrieve = ToolNode(tools)
workflow.add_node("retrieve", retrieve)
workflow.add_node("rewrite", rewrite) # Re-writing the question
workflow.add_node(
"generate", generate
) # Generating a response after we know the documents are relevant
# Call agent node to decide to retrieve or not
workflow.add_edge(START, "agent")
# Decide whether to retrieve
workflow.add_conditional_edges(
"agent",
# Assess agent decision
tools_condition,
{
# Translate the condition outputs to nodes in our graph
"tools": "retrieve",
END: END,
},
)
# Edges taken after the `action` node is called.
workflow.add_conditional_edges(
"retrieve",
# Assess agent decision
grade_documents,
)
workflow.add_edge("generate", END)
workflow.add_edge("rewrite", "agent")
# Compile
graph = workflow.compile()
return graph
def generate_message(graph, inputs):
generated_message = ""
for output in graph.stream(inputs):
for key, value in output.items():
if key == "generate" and isinstance(value, dict):
generated_message = value.get("messages", [""])[0]
return generated_message
def add_documents_to_qdrant(url, db):
try:
docs = WebBaseLoader(url).load()
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=100, chunk_overlap=50
)
doc_chunks = text_splitter.split_documents(docs)
uuids = [str(uuid4()) for _ in range(len(doc_chunks))]
db.add_documents(documents=doc_chunks, ids=uuids)
return True
except Exception as e:
st.error(f"Error adding documents: {str(e)}")
return False
def main():
set_sidebar()
# Check if API keys are set
if not all([st.session_state.qdrant_host,
st.session_state.qdrant_api_key,
st.session_state.gemini_api_key]):
st.warning("Please configure your API keys in the sidebar first")
return
# Initialize components
embedding_model, client, db = initialize_components()
if not all([embedding_model, client, db]):
return
# Initialize retriever and tools
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5})
retriever_tool = create_retriever_tool(
retriever,
"retrieve_blog_posts",
"Search and return information about blog posts on LLMs, LLM agents, prompt engineering, and adversarial attacks on LLMs.",
)
tools = [retriever_tool]
# URL input section
url = st.text_input(
":link: Paste the blog link:",
placeholder="e.g., https://lilianweng.github.io/posts/2023-06-23-agent/"
)
if st.button("Enter URL"):
if url:
with st.spinner("Processing documents..."):
if add_documents_to_qdrant(url, db):
st.success("Documents added successfully!")
else:
st.error("Failed to add documents")
else:
st.warning("Please enter a URL")
# Query section
graph = get_graph(retriever_tool)
query = st.text_area(
":bulb: Enter your query about the blog post:",
placeholder="e.g., What does Lilian Weng say about the types of agent memory?"
)
if st.button("Submit Query"):
if not query:
st.warning("Please enter a query")
return
inputs = {"messages": [HumanMessage(content=query)]}
with st.spinner("Generating response..."):
try:
response = generate_message(graph, inputs)
st.write(response)
except Exception as e:
st.error(f"Error generating response: {str(e)}")
st.markdown("---")
st.write("Built with :blue-background[LangChain] | :blue-background[LangGraph] by [Charan](https://www.linkedin.com/in/codewithcharan/)")
if __name__ == "__main__":
main()