mirror of
https://github.com/Shubhamsaboo/awesome-llm-apps.git
synced 2026-03-08 23:13:56 -05:00
329 lines
13 KiB
Python
329 lines
13 KiB
Python
import os
|
|
import tempfile
|
|
import time
|
|
from typing import List, Optional, Tuple, Any
|
|
|
|
import streamlit as st
|
|
import requests
|
|
import json
|
|
import re
|
|
from contextual import ContextualAI
|
|
|
|
|
|
def init_session_state() -> None:
|
|
if "api_key_submitted" not in st.session_state:
|
|
st.session_state.api_key_submitted = False
|
|
if "contextual_api_key" not in st.session_state:
|
|
st.session_state.contextual_api_key = ""
|
|
if "base_url" not in st.session_state:
|
|
st.session_state.base_url = "https://api.contextual.ai/v1"
|
|
if "agent_id" not in st.session_state:
|
|
st.session_state.agent_id = ""
|
|
if "datastore_id" not in st.session_state:
|
|
st.session_state.datastore_id = ""
|
|
if "chat_history" not in st.session_state:
|
|
st.session_state.chat_history = []
|
|
if "processed_file" not in st.session_state:
|
|
st.session_state.processed_file = False
|
|
if "last_raw_response" not in st.session_state:
|
|
st.session_state.last_raw_response = None
|
|
if "last_user_query" not in st.session_state:
|
|
st.session_state.last_user_query = ""
|
|
|
|
|
|
def sidebar_api_form() -> bool:
|
|
with st.sidebar:
|
|
st.header("API & Resource Setup")
|
|
|
|
if st.session_state.api_key_submitted:
|
|
st.success("API verified")
|
|
if st.button("Reset Setup"):
|
|
st.session_state.clear()
|
|
st.rerun()
|
|
return True
|
|
|
|
with st.form("contextual_api_form"):
|
|
api_key = st.text_input("Contextual AI API Key", type="password")
|
|
base_url = st.text_input(
|
|
"Base URL",
|
|
value=st.session_state.base_url,
|
|
help="Include /v1 (e.g., https://api.contextual.ai/v1)",
|
|
)
|
|
existing_agent_id = st.text_input("Existing Agent ID (optional)")
|
|
existing_datastore_id = st.text_input("Existing Datastore ID (optional)")
|
|
|
|
if st.form_submit_button("Save & Verify"):
|
|
try:
|
|
client = ContextualAI(api_key=api_key, base_url=base_url)
|
|
_ = client.agents.list()
|
|
|
|
st.session_state.contextual_api_key = api_key
|
|
st.session_state.base_url = base_url
|
|
st.session_state.agent_id = existing_agent_id
|
|
st.session_state.datastore_id = existing_datastore_id
|
|
st.session_state.api_key_submitted = True
|
|
|
|
st.success("Credentials verified!")
|
|
st.rerun()
|
|
except Exception as e:
|
|
st.error(f"Credential verification failed: {str(e)}")
|
|
return False
|
|
|
|
|
|
def ensure_client():
|
|
if not st.session_state.get("contextual_api_key"):
|
|
raise ValueError("Contextual AI API key not provided")
|
|
return ContextualAI(api_key=st.session_state.contextual_api_key, base_url=st.session_state.base_url)
|
|
|
|
|
|
def create_datastore(client, name: str) -> Optional[str]:
|
|
try:
|
|
ds = client.datastores.create(name=name)
|
|
return getattr(ds, "id", None)
|
|
except Exception as e:
|
|
st.error(f"Failed to create datastore: {e}")
|
|
return None
|
|
|
|
|
|
ALLOWED_EXTS = {".pdf", ".html", ".htm", ".mhtml", ".doc", ".docx", ".ppt", ".pptx"}
|
|
|
|
def upload_documents(client, datastore_id: str, files: List[bytes], filenames: List[str], metadata: Optional[dict]) -> List[str]:
|
|
doc_ids: List[str] = []
|
|
for content, fname in zip(files, filenames):
|
|
try:
|
|
ext = os.path.splitext(fname)[1].lower()
|
|
if ext not in ALLOWED_EXTS:
|
|
st.error(f"Unsupported file extension for {fname}. Allowed: {sorted(ALLOWED_EXTS)}")
|
|
continue
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp:
|
|
tmp.write(content)
|
|
tmp_path = tmp.name
|
|
with open(tmp_path, "rb") as f:
|
|
if metadata:
|
|
result = client.datastores.documents.ingest(datastore_id, file=f, metadata=metadata)
|
|
else:
|
|
result = client.datastores.documents.ingest(datastore_id, file=f)
|
|
doc_ids.append(getattr(result, "id", ""))
|
|
except Exception as e:
|
|
st.error(f"Failed to upload {fname}: {e}")
|
|
finally:
|
|
try:
|
|
os.unlink(tmp_path)
|
|
except Exception:
|
|
pass
|
|
return doc_ids
|
|
|
|
|
|
def wait_until_documents_ready(api_key: str, datastore_id: str, base_url: str, max_checks: int = 30, interval_sec: float = 5.0) -> None:
|
|
url = f"{base_url.rstrip('/')}/datastores/{datastore_id}/documents"
|
|
headers = {"Authorization": f"Bearer {api_key}"}
|
|
|
|
for _ in range(max_checks):
|
|
try:
|
|
resp = requests.get(url, headers=headers, timeout=30)
|
|
if resp.status_code == 200:
|
|
docs = resp.json().get("documents", [])
|
|
if not any(d.get("status") in ("processing", "pending") for d in docs):
|
|
return
|
|
time.sleep(interval_sec)
|
|
except Exception:
|
|
time.sleep(interval_sec)
|
|
|
|
|
|
def create_agent(client, name: str, description: str, datastore_id: str) -> Optional[str]:
|
|
try:
|
|
agent = client.agents.create(name=name, description=description, datastore_ids=[datastore_id])
|
|
return getattr(agent, "id", None)
|
|
except Exception as e:
|
|
st.error(f"Failed to create agent: {e}")
|
|
return None
|
|
|
|
|
|
def query_agent(client, agent_id: str, query: str) -> Tuple[str, Any]:
|
|
try:
|
|
resp = client.agents.query.create(agent_id=agent_id, messages=[{"role": "user", "content": query}])
|
|
if hasattr(resp, "content"):
|
|
return resp.content, resp
|
|
if hasattr(resp, "message") and hasattr(resp.message, "content"):
|
|
return resp.message.content, resp
|
|
if hasattr(resp, "messages") and resp.messages:
|
|
last_msg = resp.messages[-1]
|
|
return getattr(last_msg, "content", str(last_msg)), resp
|
|
return str(resp), resp
|
|
except Exception as e:
|
|
return f"Error querying agent: {e}", None
|
|
|
|
|
|
def show_retrieval_info(client, raw_response, agent_id: str) -> None:
|
|
try:
|
|
if not raw_response:
|
|
st.info("No retrieval info available.")
|
|
return
|
|
message_id = getattr(raw_response, "message_id", None)
|
|
retrieval_contents = getattr(raw_response, "retrieval_contents", [])
|
|
if not message_id or not retrieval_contents:
|
|
st.info("No retrieval metadata returned.")
|
|
return
|
|
first_content_id = getattr(retrieval_contents[0], "content_id", None)
|
|
if not first_content_id:
|
|
st.info("Missing content_id in retrieval metadata.")
|
|
return
|
|
ret_result = client.agents.query.retrieval_info(message_id=message_id, agent_id=agent_id, content_ids=[first_content_id])
|
|
metadatas = getattr(ret_result, "content_metadatas", [])
|
|
if not metadatas:
|
|
st.info("No content metadatas found.")
|
|
return
|
|
page_img_b64 = getattr(metadatas[0], "page_img", None)
|
|
if not page_img_b64:
|
|
st.info("No page image provided in metadata.")
|
|
return
|
|
import base64
|
|
img_bytes = base64.b64decode(page_img_b64)
|
|
st.image(img_bytes, caption="Top Attribution Page", use_container_width=True)
|
|
# Removed raw object rendering to keep UI clean
|
|
except Exception as e:
|
|
st.error(f"Failed to load retrieval info: {e}")
|
|
|
|
|
|
def update_agent_prompt(client, agent_id: str, system_prompt: str) -> bool:
|
|
try:
|
|
client.agents.update(agent_id=agent_id, system_prompt=system_prompt)
|
|
return True
|
|
except Exception as e:
|
|
st.error(f"Failed to update system prompt: {e}")
|
|
return False
|
|
|
|
|
|
def evaluate_with_lmunit(client, query: str, response_text: str, unit_test: str):
|
|
try:
|
|
result = client.lmunit.create(query=query, response=response_text, unit_test=unit_test)
|
|
st.subheader("Evaluation Result")
|
|
st.code(str(result), language="json")
|
|
except Exception as e:
|
|
st.error(f"LMUnit evaluation failed: {e}")
|
|
|
|
|
|
def post_process_answer(text: str) -> str:
|
|
text = re.sub(r"\(\s*\)", "", text)
|
|
text = text.replace("• ", "\n- ")
|
|
return text
|
|
|
|
|
|
init_session_state()
|
|
|
|
st.title("Contextual AI RAG Agent")
|
|
|
|
if not sidebar_api_form():
|
|
st.info("Please enter your Contextual AI API key in the sidebar to continue.")
|
|
st.stop()
|
|
|
|
client = ensure_client()
|
|
|
|
with st.expander("1) Create or Select Datastore", expanded=True):
|
|
if not st.session_state.datastore_id:
|
|
default_name = "contextualai_rag_datastore"
|
|
ds_name = st.text_input("Datastore Name", value=default_name)
|
|
if st.button("Create Datastore"):
|
|
ds_id = create_datastore(client, ds_name)
|
|
if ds_id:
|
|
st.session_state.datastore_id = ds_id
|
|
st.success(f"Created datastore: {ds_id}")
|
|
else:
|
|
st.success(f"Using Datastore: {st.session_state.datastore_id}")
|
|
|
|
with st.expander("2) Upload Documents", expanded=True):
|
|
uploaded_files = st.file_uploader("Upload PDFs or text files", type=["pdf", "txt", "md"], accept_multiple_files=True)
|
|
metadata_json = st.text_area("Custom Metadata (JSON)", value="", placeholder='{"custom_metadata": {"field1": "value1"}}')
|
|
if uploaded_files and st.session_state.datastore_id:
|
|
contents = [f.getvalue() for f in uploaded_files]
|
|
names = [f.name for f in uploaded_files]
|
|
if st.button("Ingest Documents"):
|
|
parsed_metadata = None
|
|
if metadata_json.strip():
|
|
try:
|
|
parsed_metadata = json.loads(metadata_json)
|
|
except Exception as e:
|
|
st.error(f"Invalid metadata JSON: {e}")
|
|
parsed_metadata = None
|
|
ids = upload_documents(client, st.session_state.datastore_id, contents, names, parsed_metadata)
|
|
if ids:
|
|
st.success(f"Uploaded {len(ids)} document(s)")
|
|
wait_until_documents_ready(st.session_state.contextual_api_key, st.session_state.datastore_id, st.session_state.base_url)
|
|
st.info("Documents are ready.")
|
|
|
|
with st.expander("3) Create or Select Agent", expanded=True):
|
|
if not st.session_state.agent_id and st.session_state.datastore_id:
|
|
agent_name = st.text_input("Agent Name", value="ContextualAI RAG Agent")
|
|
agent_desc = st.text_area("Agent Description", value="RAG agent over uploaded documents")
|
|
if st.button("Create Agent"):
|
|
a_id = create_agent(client, agent_name, agent_desc, st.session_state.datastore_id)
|
|
if a_id:
|
|
st.session_state.agent_id = a_id
|
|
st.success(f"Created agent: {a_id}")
|
|
elif st.session_state.agent_id:
|
|
st.success(f"Using Agent: {st.session_state.agent_id}")
|
|
|
|
with st.expander("4) Agent Settings (Optional)"):
|
|
if st.session_state.agent_id:
|
|
system_prompt_val = st.text_area("System Prompt", value="", placeholder="Paste a new system prompt to update your agent")
|
|
if st.button("Update System Prompt") and system_prompt_val.strip():
|
|
ok = update_agent_prompt(client, st.session_state.agent_id, system_prompt_val.strip())
|
|
if ok:
|
|
st.success("System prompt updated.")
|
|
|
|
st.divider()
|
|
|
|
for message in st.session_state.chat_history:
|
|
with st.chat_message(message["role"]):
|
|
st.markdown(message["content"])
|
|
|
|
query = st.chat_input("Ask a question about your documents")
|
|
if query:
|
|
st.session_state.last_user_query = query
|
|
st.session_state.chat_history.append({"role": "user", "content": query})
|
|
with st.chat_message("user"):
|
|
st.markdown(query)
|
|
|
|
if st.session_state.agent_id:
|
|
with st.chat_message("assistant"):
|
|
answer, raw = query_agent(client, st.session_state.agent_id, query)
|
|
st.session_state.last_raw_response = raw
|
|
processed = post_process_answer(answer)
|
|
st.markdown(processed)
|
|
st.session_state.chat_history.append({"role": "assistant", "content": processed})
|
|
else:
|
|
st.error("Please create or select an agent first.")
|
|
|
|
with st.expander("Debug & Evaluation", expanded=False):
|
|
st.caption("Tools to inspect retrievals and evaluate answers")
|
|
if st.session_state.agent_id:
|
|
if st.checkbox("Show Retrieval Info", value=False):
|
|
show_retrieval_info(client, st.session_state.last_raw_response, st.session_state.agent_id)
|
|
st.markdown("")
|
|
unit_test = st.text_area("LMUnit rubric / unit test", value="Does the response avoid unnecessary information?", height=80)
|
|
if st.button("Evaluate Last Answer with LMUnit"):
|
|
if st.session_state.last_user_query and st.session_state.chat_history:
|
|
last_assistant_msgs = [m for m in st.session_state.chat_history if m["role"] == "assistant"]
|
|
if last_assistant_msgs:
|
|
evaluate_with_lmunit(client, st.session_state.last_user_query, last_assistant_msgs[-1]["content"], unit_test)
|
|
else:
|
|
st.info("No assistant response to evaluate yet.")
|
|
else:
|
|
st.info("Ask a question first to run an evaluation.")
|
|
|
|
with st.sidebar:
|
|
st.divider()
|
|
col1, col2 = st.columns(2)
|
|
with col1:
|
|
if st.button("Clear Chat"):
|
|
st.session_state.chat_history = []
|
|
st.session_state.last_raw_response = None
|
|
st.session_state.last_user_query = ""
|
|
st.rerun()
|
|
with col2:
|
|
if st.button("Reset App"):
|
|
st.session_state.clear()
|
|
st.rerun()
|
|
|
|
|