mirror of
https://github.com/Shubhamsaboo/awesome-llm-apps.git
synced 2026-04-28 22:28:59 -05:00
554 lines
26 KiB
Python
554 lines
26 KiB
Python
import requests
|
||
import os
|
||
import io
|
||
import base64
|
||
import PIL
|
||
from PIL import Image
|
||
import tqdm
|
||
import numpy as np
|
||
import streamlit as st
|
||
import cohere
|
||
from google import genai
|
||
import fitz # PyMuPDF
|
||
|
||
# --- Streamlit App Configuration ---
|
||
st.set_page_config(layout="wide", page_title="Vision RAG with Cohere Embed-4")
|
||
st.title("Vision RAG with Cohere Embed-4 🖼️")
|
||
|
||
# --- API Key Input ---
|
||
with st.sidebar:
|
||
st.header("🔑 API Keys")
|
||
cohere_api_key = st.text_input("Cohere API Key", type="password", key="cohere_key")
|
||
google_api_key = st.text_input("Google API Key (Gemini)", type="password", key="google_key")
|
||
"[Get a Cohere API key](https://dashboard.cohere.com/api-keys)"
|
||
"[Get a Google API key](https://aistudio.google.com/app/apikey)"
|
||
|
||
st.markdown("---")
|
||
if not cohere_api_key:
|
||
st.warning("Please enter your Cohere API key to proceed.")
|
||
if not google_api_key:
|
||
st.warning("Please enter your Google API key to proceed.")
|
||
st.markdown("---")
|
||
|
||
|
||
# --- Initialize API Clients ---
|
||
co = None
|
||
genai_client = None
|
||
# Initialize Session State for embeddings and paths
|
||
if 'image_paths' not in st.session_state:
|
||
st.session_state.image_paths = []
|
||
if 'doc_embeddings' not in st.session_state:
|
||
st.session_state.doc_embeddings = None
|
||
|
||
if cohere_api_key and google_api_key:
|
||
try:
|
||
co = cohere.ClientV2(api_key=cohere_api_key)
|
||
st.sidebar.success("Cohere Client Initialized!")
|
||
except Exception as e:
|
||
st.sidebar.error(f"Cohere Initialization Failed: {e}")
|
||
|
||
try:
|
||
genai_client = genai.Client(api_key=google_api_key)
|
||
st.sidebar.success("Gemini Client Initialized!")
|
||
except Exception as e:
|
||
st.sidebar.error(f"Gemini Initialization Failed: {e}")
|
||
else:
|
||
st.info("Enter your API keys in the sidebar to start.")
|
||
|
||
# Information about the models
|
||
with st.expander("ℹ️ About the models used"):
|
||
st.markdown("""
|
||
### Cohere Embed-4
|
||
|
||
Cohere's Embed-4 is a state-of-the-art multimodal embedding model designed for enterprise search and retrieval.
|
||
It enables:
|
||
|
||
- **Multimodal search**: Search text and images together seamlessly
|
||
- **High accuracy**: State-of-the-art performance for retrieval tasks
|
||
- **Efficient embedding**: Process complex images like charts, graphs, and infographics
|
||
|
||
The model processes images without requiring complex OCR pre-processing and maintains the connection between visual elements and text.
|
||
|
||
### Google Gemini 2.5 Flash
|
||
|
||
Gemini 2.5 Flash is Google's efficient multimodal model that can process text and image inputs to generate high-quality responses.
|
||
It's designed for fast inference while maintaining high accuracy, making it ideal for real-time applications like this RAG system.
|
||
""")
|
||
|
||
# --- Helper functions ---
|
||
# Some helper functions to resize images and to convert them to base64 format
|
||
max_pixels = 1568*1568 #Max resolution for images
|
||
|
||
# Resize too large images
|
||
def resize_image(pil_image: PIL.Image.Image) -> None:
|
||
"""Resizes the image in-place if it exceeds max_pixels."""
|
||
org_width, org_height = pil_image.size
|
||
|
||
# Resize image if too large
|
||
if org_width * org_height > max_pixels:
|
||
scale_factor = (max_pixels / (org_width * org_height)) ** 0.5
|
||
new_width = int(org_width * scale_factor)
|
||
new_height = int(org_height * scale_factor)
|
||
pil_image.thumbnail((new_width, new_height))
|
||
|
||
# Convert images to a base64 string before sending it to the API
|
||
def base64_from_image(img_path: str) -> str:
|
||
"""Converts an image file to a base64 encoded string."""
|
||
pil_image = PIL.Image.open(img_path)
|
||
img_format = pil_image.format if pil_image.format else "PNG"
|
||
|
||
resize_image(pil_image)
|
||
|
||
with io.BytesIO() as img_buffer:
|
||
pil_image.save(img_buffer, format=img_format)
|
||
img_buffer.seek(0)
|
||
img_data = f"data:image/{img_format.lower()};base64,"+base64.b64encode(img_buffer.read()).decode("utf-8")
|
||
|
||
return img_data
|
||
|
||
# Convert PIL image to base64 string
|
||
def pil_to_base64(pil_image: PIL.Image.Image) -> str:
|
||
"""Converts a PIL image to a base64 encoded string."""
|
||
if pil_image.format is None:
|
||
img_format = "PNG"
|
||
else:
|
||
img_format = pil_image.format
|
||
|
||
resize_image(pil_image)
|
||
|
||
with io.BytesIO() as img_buffer:
|
||
pil_image.save(img_buffer, format=img_format)
|
||
img_buffer.seek(0)
|
||
img_data = f"data:image/{img_format.lower()};base64,"+base64.b64encode(img_buffer.read()).decode("utf-8")
|
||
|
||
return img_data
|
||
|
||
# Compute embedding for an image
|
||
@st.cache_data(ttl=3600, show_spinner=False)
|
||
def compute_image_embedding(base64_img: str, _cohere_client) -> np.ndarray | None:
|
||
"""Computes an embedding for an image using Cohere's Embed-4 model."""
|
||
try:
|
||
api_response = _cohere_client.embed(
|
||
model="embed-v4.0",
|
||
input_type="search_document",
|
||
embedding_types=["float"],
|
||
images=[base64_img],
|
||
)
|
||
|
||
if api_response.embeddings and api_response.embeddings.float:
|
||
return np.asarray(api_response.embeddings.float[0])
|
||
else:
|
||
st.warning("Could not get embedding. API response might be empty.")
|
||
return None
|
||
except Exception as e:
|
||
st.error(f"Error computing embedding: {e}")
|
||
return None
|
||
|
||
# Process a PDF file: extract pages as images and embed them
|
||
# Note: Caching PDF processing might be complex due to potential large file sizes and streams
|
||
# We will process it directly for now, but show progress.
|
||
def process_pdf_file(pdf_file, cohere_client, base_output_folder="pdf_pages") -> tuple[list[str], list[np.ndarray] | None]:
|
||
"""Extracts pages from a PDF as images, embeds them, and saves them.
|
||
|
||
Args:
|
||
pdf_file: UploadedFile object from Streamlit.
|
||
cohere_client: Initialized Cohere client.
|
||
base_output_folder: Directory to save page images.
|
||
|
||
Returns:
|
||
A tuple containing:
|
||
- list of paths to the saved page images.
|
||
- list of numpy array embeddings for each page, or None if embedding fails.
|
||
"""
|
||
page_image_paths = []
|
||
page_embeddings = []
|
||
pdf_filename = pdf_file.name
|
||
output_folder = os.path.join(base_output_folder, os.path.splitext(pdf_filename)[0])
|
||
os.makedirs(output_folder, exist_ok=True)
|
||
|
||
try:
|
||
# Open PDF from stream
|
||
doc = fitz.open(stream=pdf_file.read(), filetype="pdf")
|
||
st.write(f"Processing PDF: {pdf_filename} ({len(doc)} pages)")
|
||
pdf_progress = st.progress(0.0)
|
||
|
||
for i, page in enumerate(doc.pages()):
|
||
page_num = i + 1
|
||
page_img_path = os.path.join(output_folder, f"page_{page_num}.png")
|
||
page_image_paths.append(page_img_path)
|
||
|
||
# Render page to pixmap (image)
|
||
pix = page.get_pixmap(dpi=150) # Adjust DPI as needed for quality/performance
|
||
pil_image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
||
|
||
# Save the page image temporarily
|
||
pil_image.save(page_img_path, "PNG")
|
||
|
||
# Convert PIL image to base64
|
||
base64_img = pil_to_base64(pil_image)
|
||
|
||
# Compute embedding for the page image
|
||
emb = compute_image_embedding(base64_img, _cohere_client=cohere_client)
|
||
if emb is not None:
|
||
page_embeddings.append(emb)
|
||
else:
|
||
st.warning(f"Could not embed page {page_num} from {pdf_filename}. Skipping.")
|
||
# Add a placeholder to keep lists aligned, will be filtered later
|
||
page_embeddings.append(None)
|
||
|
||
# Update progress
|
||
pdf_progress.progress((i + 1) / len(doc))
|
||
|
||
doc.close()
|
||
pdf_progress.empty() # Remove progress bar after completion
|
||
|
||
# Filter out pages where embedding failed
|
||
valid_paths = [path for i, path in enumerate(page_image_paths) if page_embeddings[i] is not None]
|
||
valid_embeddings = [emb for emb in page_embeddings if emb is not None]
|
||
|
||
if not valid_embeddings:
|
||
st.error(f"Failed to generate any embeddings for {pdf_filename}.")
|
||
return [], None
|
||
|
||
return valid_paths, valid_embeddings
|
||
|
||
except Exception as e:
|
||
st.error(f"Error processing PDF {pdf_filename}: {e}")
|
||
return [], None
|
||
|
||
# Download and embed sample images
|
||
@st.cache_data(ttl=3600, show_spinner=False)
|
||
def download_and_embed_sample_images(_cohere_client) -> tuple[list[str], np.ndarray | None]:
|
||
"""Downloads sample images and computes their embeddings using Cohere's Embed-4 model."""
|
||
# Several images from https://www.appeconomyinsights.com/
|
||
images = {
|
||
"tesla.png": "https://substackcdn.com/image/fetch/w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fbef936e6-3efa-43b3-88d7-7ec620cdb33b_2744x1539.png",
|
||
"netflix.png": "https://substackcdn.com/image/fetch/w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F23bd84c9-5b62-4526-b467-3088e27e4193_2744x1539.png",
|
||
"nike.png": "https://substackcdn.com/image/fetch/w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fa5cd33ba-ae1a-42a8-a254-d85e690d9870_2741x1541.png",
|
||
"google.png": "https://substackcdn.com/image/fetch/f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F395dd3b9-b38e-4d1f-91bc-d37b642ee920_2741x1541.png",
|
||
"accenture.png": "https://substackcdn.com/image/fetch/w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F08b2227c-7dc8-49f7-b3c5-13cab5443ba6_2741x1541.png",
|
||
"tecent.png": "https://substackcdn.com/image/fetch/w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F0ec8448c-c4d1-4aab-a8e9-2ddebe0c95fd_2741x1541.png"
|
||
}
|
||
|
||
# Prepare folders
|
||
img_folder = "img"
|
||
os.makedirs(img_folder, exist_ok=True)
|
||
|
||
img_paths = []
|
||
doc_embeddings = []
|
||
|
||
# Wrap TQDM with st.spinner for better UI integration
|
||
with st.spinner("Downloading and embedding sample images..."):
|
||
pbar = tqdm.tqdm(images.items(), desc="Processing sample images")
|
||
for name, url in pbar:
|
||
img_path = os.path.join(img_folder, name)
|
||
# Don't re-append if already processed (useful if function called multiple times)
|
||
if img_path not in img_paths:
|
||
img_paths.append(img_path)
|
||
|
||
# Download the image
|
||
if not os.path.exists(img_path):
|
||
try:
|
||
response = requests.get(url)
|
||
response.raise_for_status()
|
||
with open(img_path, "wb") as fOut:
|
||
fOut.write(response.content)
|
||
except requests.exceptions.RequestException as e:
|
||
st.error(f"Failed to download {name}: {e}")
|
||
# Optionally remove the path if download failed
|
||
img_paths.pop()
|
||
continue # Skip if download fails
|
||
|
||
# Get embedding for the image if it exists and we haven't computed one yet
|
||
# Find index corresponding to this path
|
||
current_index = -1
|
||
try:
|
||
current_index = img_paths.index(img_path)
|
||
except ValueError:
|
||
continue # Should not happen if append logic is correct
|
||
|
||
# Check if embedding already exists for this index
|
||
if current_index >= len(doc_embeddings):
|
||
try:
|
||
# Ensure file exists before trying to embed
|
||
if os.path.exists(img_path):
|
||
base64_img = base64_from_image(img_path)
|
||
emb = compute_image_embedding(base64_img, _cohere_client=_cohere_client)
|
||
if emb is not None:
|
||
# Placeholder to ensure list length matches paths before vstack
|
||
while len(doc_embeddings) < current_index:
|
||
doc_embeddings.append(None) # Append placeholder if needed
|
||
doc_embeddings.append(emb)
|
||
else:
|
||
# If file doesn't exist (maybe failed download), add placeholder
|
||
while len(doc_embeddings) < current_index:
|
||
doc_embeddings.append(None)
|
||
doc_embeddings.append(None)
|
||
except Exception as e:
|
||
st.error(f"Failed to embed {name}: {e}")
|
||
# Add placeholder on error
|
||
while len(doc_embeddings) < current_index:
|
||
doc_embeddings.append(None)
|
||
doc_embeddings.append(None)
|
||
|
||
# Filter out None embeddings and corresponding paths before stacking
|
||
filtered_paths = [path for i, path in enumerate(img_paths) if i < len(doc_embeddings) and doc_embeddings[i] is not None]
|
||
filtered_embeddings = [emb for emb in doc_embeddings if emb is not None]
|
||
|
||
if filtered_embeddings:
|
||
doc_embeddings_array = np.vstack(filtered_embeddings)
|
||
return filtered_paths, doc_embeddings_array
|
||
|
||
return [], None
|
||
|
||
# Search function
|
||
def search(question: str, co_client: cohere.Client, embeddings: np.ndarray, image_paths: list[str], max_img_size: int = 800) -> str | None:
|
||
"""Finds the most relevant image path for a given question."""
|
||
if not co_client or embeddings is None or embeddings.size == 0 or not image_paths:
|
||
st.warning("Search prerequisites not met (client, embeddings, or paths missing/empty).")
|
||
return None
|
||
if embeddings.shape[0] != len(image_paths):
|
||
st.error(f"Mismatch between embeddings count ({embeddings.shape[0]}) and image paths count ({len(image_paths)}). Cannot perform search.")
|
||
return None
|
||
|
||
try:
|
||
# Compute the embedding for the query
|
||
api_response = co_client.embed(
|
||
model="embed-v4.0",
|
||
input_type="search_query",
|
||
embedding_types=["float"],
|
||
texts=[question],
|
||
)
|
||
|
||
if not api_response.embeddings or not api_response.embeddings.float:
|
||
st.error("Failed to get query embedding.")
|
||
return None
|
||
|
||
query_emb = np.asarray(api_response.embeddings.float[0])
|
||
|
||
# Ensure query embedding has the correct shape for dot product
|
||
if query_emb.shape[0] != embeddings.shape[1]:
|
||
st.error(f"Query embedding dimension ({query_emb.shape[0]}) does not match document embedding dimension ({embeddings.shape[1]}).")
|
||
return None
|
||
|
||
# Compute cosine similarities
|
||
cos_sim_scores = np.dot(query_emb, embeddings.T)
|
||
|
||
# Get the most relevant image
|
||
top_idx = np.argmax(cos_sim_scores)
|
||
hit_img_path = image_paths[top_idx]
|
||
print(f"Question: {question}") # Keep for debugging
|
||
print(f"Most relevant image: {hit_img_path}") # Keep for debugging
|
||
|
||
return hit_img_path
|
||
except Exception as e:
|
||
st.error(f"Error during search: {e}")
|
||
return None
|
||
|
||
# Answer function
|
||
def answer(question: str, img_path: str, gemini_client) -> str:
|
||
"""Answers the question based on the provided image using Gemini."""
|
||
if not gemini_client or not img_path or not os.path.exists(img_path):
|
||
missing = []
|
||
if not gemini_client: missing.append("Gemini client")
|
||
if not img_path: missing.append("Image path")
|
||
elif not os.path.exists(img_path): missing.append(f"Image file at {img_path}")
|
||
return f"Answering prerequisites not met ({', '.join(missing)} missing or invalid)."
|
||
try:
|
||
img = PIL.Image.open(img_path)
|
||
prompt = [f"""Answer the question based on the following image. Be as elaborate as possible giving extra relevant information.
|
||
Don't use markdown formatting in the response.
|
||
Please provide enough context for your answer.
|
||
|
||
Question: {question}""", img]
|
||
|
||
response = gemini_client.models.generate_content(
|
||
model="gemini-2.5-flash-preview-04-17",
|
||
contents=prompt
|
||
)
|
||
|
||
llm_answer = response.text
|
||
print("LLM Answer:", llm_answer) # Keep for debugging
|
||
return llm_answer
|
||
except Exception as e:
|
||
st.error(f"Error during answer generation: {e}")
|
||
return f"Failed to generate answer: {e}"
|
||
|
||
# --- Main UI Setup ---
|
||
st.subheader("📊 Load Sample Images")
|
||
if cohere_api_key and co:
|
||
# If button clicked, load sample images into session state
|
||
if st.button("Load Sample Images", key="load_sample_button"):
|
||
sample_img_paths, sample_doc_embeddings = download_and_embed_sample_images(_cohere_client=co)
|
||
if sample_img_paths and sample_doc_embeddings is not None:
|
||
# Append sample images to session state (avoid duplicates if clicked again)
|
||
current_paths = set(st.session_state.image_paths)
|
||
new_paths = [p for p in sample_img_paths if p not in current_paths]
|
||
|
||
if new_paths:
|
||
new_indices = [i for i, p in enumerate(sample_img_paths) if p in new_paths]
|
||
st.session_state.image_paths.extend(new_paths)
|
||
new_embeddings_to_add = sample_doc_embeddings[[idx for idx, p in enumerate(sample_img_paths) if p in new_paths]]
|
||
|
||
if st.session_state.doc_embeddings is None or st.session_state.doc_embeddings.size == 0:
|
||
st.session_state.doc_embeddings = new_embeddings_to_add
|
||
else:
|
||
st.session_state.doc_embeddings = np.vstack((st.session_state.doc_embeddings, new_embeddings_to_add))
|
||
st.success(f"Loaded {len(new_paths)} sample images.")
|
||
else:
|
||
st.info("Sample images already loaded.")
|
||
else:
|
||
st.error("Failed to load sample images. Check console for errors.")
|
||
else:
|
||
st.warning("Enter API keys to enable loading sample images.")
|
||
|
||
st.markdown("--- ")
|
||
# --- File Uploader (Main UI) ---
|
||
st.subheader("📤 Upload Your Images")
|
||
st.info("Or, upload your own images or PDFs. The RAG process will search across all loaded content.")
|
||
|
||
# File uploader
|
||
uploaded_files = st.file_uploader("Upload images (PNG, JPG, JPEG) or PDFs",
|
||
type=["png", "jpg", "jpeg", "pdf"],
|
||
accept_multiple_files=True, key="image_uploader",
|
||
label_visibility="collapsed")
|
||
|
||
# Process uploaded images
|
||
if uploaded_files and co:
|
||
st.write(f"Processing {len(uploaded_files)} uploaded images...")
|
||
progress_bar = st.progress(0)
|
||
|
||
# Create a temporary directory for uploaded images
|
||
upload_folder = "uploaded_img"
|
||
os.makedirs(upload_folder, exist_ok=True)
|
||
|
||
newly_uploaded_paths = []
|
||
newly_uploaded_embeddings = []
|
||
|
||
for i, uploaded_file in enumerate(uploaded_files):
|
||
# Check if already processed this session (simple name check)
|
||
img_path = os.path.join(upload_folder, uploaded_file.name)
|
||
if img_path not in st.session_state.image_paths:
|
||
try:
|
||
# Check file type
|
||
file_type = uploaded_file.type
|
||
if file_type == "application/pdf":
|
||
# Process PDF - returns list of paths and list of embeddings
|
||
pdf_page_paths, pdf_page_embeddings = process_pdf_file(uploaded_file, cohere_client=co)
|
||
if pdf_page_paths and pdf_page_embeddings:
|
||
# Add only paths/embeddings not already in session state
|
||
current_paths_set = set(st.session_state.image_paths)
|
||
unique_new_paths = [p for p in pdf_page_paths if p not in current_paths_set]
|
||
if unique_new_paths:
|
||
indices_to_add = [i for i, p in enumerate(pdf_page_paths) if p in unique_new_paths]
|
||
newly_uploaded_paths.extend(unique_new_paths)
|
||
newly_uploaded_embeddings.extend([pdf_page_embeddings[idx] for idx in indices_to_add])
|
||
elif file_type in ["image/png", "image/jpeg"]:
|
||
# Process regular image
|
||
# Save the uploaded file
|
||
with open(img_path, "wb") as f:
|
||
f.write(uploaded_file.getbuffer())
|
||
|
||
# Get embedding
|
||
base64_img = base64_from_image(img_path)
|
||
emb = compute_image_embedding(base64_img, _cohere_client=co)
|
||
|
||
if emb is not None:
|
||
newly_uploaded_paths.append(img_path)
|
||
newly_uploaded_embeddings.append(emb)
|
||
else:
|
||
st.warning(f"Unsupported file type skipped: {uploaded_file.name} ({file_type})")
|
||
|
||
except Exception as e:
|
||
st.error(f"Error processing {uploaded_file.name}: {e}")
|
||
# Update progress regardless of processing status for user feedback
|
||
progress_bar.progress((i + 1) / len(uploaded_files))
|
||
|
||
# Add newly processed files to session state
|
||
if newly_uploaded_paths:
|
||
st.session_state.image_paths.extend(newly_uploaded_paths)
|
||
if newly_uploaded_embeddings:
|
||
new_embeddings_array = np.vstack(newly_uploaded_embeddings)
|
||
if st.session_state.doc_embeddings is None or st.session_state.doc_embeddings.size == 0:
|
||
st.session_state.doc_embeddings = new_embeddings_array
|
||
else:
|
||
st.session_state.doc_embeddings = np.vstack((st.session_state.doc_embeddings, new_embeddings_array))
|
||
st.success(f"Successfully processed and added {len(newly_uploaded_paths)} new images.")
|
||
else:
|
||
st.warning("Failed to generate embeddings for newly uploaded images.")
|
||
elif uploaded_files: # If files were selected but none were new
|
||
st.info("Selected images already seem to be processed.")
|
||
|
||
# --- Vision RAG Section (Main UI) ---
|
||
st.markdown("---")
|
||
st.subheader("❓ Ask a Question")
|
||
|
||
if not st.session_state.image_paths:
|
||
st.warning("Please load sample images or upload your own images first.")
|
||
else:
|
||
st.info(f"Ready to answer questions about {len(st.session_state.image_paths)} images.")
|
||
|
||
# Display thumbnails of all loaded images (optional)
|
||
with st.expander("View Loaded Images", expanded=False):
|
||
if st.session_state.image_paths:
|
||
num_images_to_show = len(st.session_state.image_paths)
|
||
cols = st.columns(5) # Show 5 thumbnails per row
|
||
for i in range(num_images_to_show):
|
||
with cols[i % 5]:
|
||
# Add try-except for missing files during display
|
||
try:
|
||
# Display PDF pages differently? For now, just show the image
|
||
st.image(st.session_state.image_paths[i], width=100, caption=os.path.basename(st.session_state.image_paths[i]))
|
||
except FileNotFoundError:
|
||
st.error(f"Missing: {os.path.basename(st.session_state.image_paths[i])}")
|
||
else:
|
||
st.write("No images loaded yet.")
|
||
|
||
question = st.text_input("Ask a question about the loaded images:",
|
||
key="main_question_input",
|
||
placeholder="E.g., What is Nike's net profit?",
|
||
disabled=not st.session_state.image_paths)
|
||
|
||
run_button = st.button("Run Vision RAG", key="main_run_button",
|
||
disabled=not (cohere_api_key and google_api_key and question and st.session_state.image_paths and st.session_state.doc_embeddings is not None and st.session_state.doc_embeddings.size > 0))
|
||
|
||
# Output Area
|
||
st.markdown("### Results")
|
||
retrieved_image_placeholder = st.empty()
|
||
answer_placeholder = st.empty()
|
||
|
||
# Run search and answer logic
|
||
if run_button:
|
||
if co and genai_client and st.session_state.doc_embeddings is not None and len(st.session_state.doc_embeddings) > 0:
|
||
with st.spinner("Finding relevant image..."):
|
||
# Ensure embeddings and paths match before search
|
||
if len(st.session_state.image_paths) != st.session_state.doc_embeddings.shape[0]:
|
||
st.error("Error: Mismatch between number of images and embeddings. Cannot proceed.")
|
||
else:
|
||
top_image_path = search(question, co, st.session_state.doc_embeddings, st.session_state.image_paths)
|
||
|
||
if top_image_path:
|
||
caption = f"Retrieved content for: '{question}' (Source: {os.path.basename(top_image_path)})"
|
||
# Add source PDF name if it's a page image
|
||
if top_image_path.startswith("pdf_pages/"):
|
||
parts = top_image_path.split(os.sep)
|
||
if len(parts) >= 3:
|
||
pdf_name = parts[1]
|
||
page_name = parts[-1]
|
||
caption = f"Retrieved content for: '{question}' (Source: {pdf_name}.pdf, {page_name.replace('.png','')})"
|
||
|
||
retrieved_image_placeholder.image(top_image_path, caption=caption, use_container_width=True)
|
||
|
||
with st.spinner("Generating answer..."):
|
||
final_answer = answer(question, top_image_path, genai_client)
|
||
answer_placeholder.markdown(f"**Answer:**\n{final_answer}")
|
||
else:
|
||
retrieved_image_placeholder.warning("Could not find a relevant image for your question.")
|
||
answer_placeholder.text("") # Clear answer placeholder
|
||
else:
|
||
# This case should ideally be prevented by the disabled state of the button
|
||
st.error("Cannot run RAG. Check API clients and ensure images are loaded with embeddings.")
|
||
|
||
# Footer
|
||
st.markdown("---")
|
||
st.caption("Vision RAG with Cohere Embed-4 | Built with Streamlit, Cohere Embed-4, and Google Gemini 2.5 Flash") |