mirror of
https://github.com/Shubhamsaboo/awesome-llm-apps.git
synced 2026-04-28 14:18:51 -05:00
Vision RAG with Cohere embed 4.0 and Gemini 2.5 Flash
This commit is contained in:
91
rag_tutorials/vision_rag_agent/README.md
Normal file
91
rag_tutorials/vision_rag_agent/README.md
Normal file
@@ -0,0 +1,91 @@
|
||||
# Vision RAG with Cohere Embed-4 🖼️
|
||||
|
||||
A powerful visual Retrieval-Augmented Generation (RAG) system that utilizes Cohere's state-of-the-art Embed-4 model for multimodal embedding and Google's efficient Gemini 2.5 Flash model for answering questions about images.
|
||||
|
||||
## Features
|
||||
|
||||
- **Multimodal Search**: Leverages Cohere Embed-4 to find the most semantically relevant image for a given text question.
|
||||
- **Visual Question Answering**: Employs Google Gemini 2.5 Flash to analyze the content of the retrieved image and generate accurate, context-aware answers.
|
||||
- **Flexible Image Sources**:
|
||||
- Use pre-loaded sample financial charts and infographics.
|
||||
- Upload your own custom images (PNG, JPG, JPEG).
|
||||
- **No OCR Required**: Directly processes complex images like charts, graphs, and infographics without needing separate text extraction steps.
|
||||
- **Interactive UI**: Built with Streamlit for easy interaction, including image loading, question input, and result display.
|
||||
- **Session Management**: Remembers loaded/uploaded images within a session.
|
||||
|
||||
## Requirements
|
||||
|
||||
- Python 3.8+
|
||||
- Cohere API key
|
||||
- Google Gemini API key
|
||||
|
||||
## How to Run
|
||||
|
||||
Follow these steps to set up and run the application:
|
||||
|
||||
1. **Clone and Navigate to Directory** :
|
||||
```bash
|
||||
git clone https://github.com/Shubhamsaboo/awesome-llm-apps.git
|
||||
cd awesome-llm-apps/rag_tutorials/vision_rag_agent
|
||||
```
|
||||
|
||||
2. **Install Dependencies**:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
3. **Set up your API keys**:
|
||||
- Get a Cohere API key from: [https://dashboard.cohere.com/api-keys](https://dashboard.cohere.com/api-keys)
|
||||
- Get a Google API key from: [https://aistudio.google.com/app/apikey](https://aistudio.google.com/app/apikey)
|
||||
|
||||
4. **Run the Streamlit app**:
|
||||
```bash
|
||||
streamlit run vision_rag.py
|
||||
```
|
||||
|
||||
5. **Access the Web Interface**:
|
||||
- Streamlit will provide a local URL (usually `http://localhost:8501`) in your terminal.
|
||||
- Open this URL in your web browser.
|
||||
|
||||
## How It Works
|
||||
|
||||
The application follows a two-stage RAG process:
|
||||
|
||||
1. **Retrieval**:
|
||||
- When you load sample images or upload your own, each image is converted to a base64 string.
|
||||
- Cohere's `embed-v4.0` model (with `input_type="search_document"`) is used to generate a dense vector embedding for each image.
|
||||
- When you ask a question, the text query is embedded using the same `embed-v4.0` model (with `input_type="search_query"`).
|
||||
- Cosine similarity is calculated between the question embedding and all image embeddings.
|
||||
- The image with the highest similarity score is retrieved as the most relevant context.
|
||||
|
||||
2. **Generation**:
|
||||
- The original text question and the retrieved image are passed as input to the Google `gemini-2.5-flash-preview-04-17` model.
|
||||
- Gemini analyzes the image content in the context of the question and generates a textual answer.
|
||||
|
||||
## Usage
|
||||
|
||||
1. Enter your Cohere and Google API keys in the sidebar.
|
||||
2. Load images:
|
||||
- Click **"Load Sample Images"** to download and process the built-in examples.
|
||||
- *OR/AND* Use the **"Upload Your Images"** section to upload your own image files.
|
||||
3. Once images are loaded and processed (embeddings generated), the **"Ask a Question"** section will be enabled.
|
||||
4. Optionally, expand **"View Loaded Images"** to see thumbnails of all images currently in the session.
|
||||
5. Type your question about the loaded images into the text input field.
|
||||
6. Click **"Run Vision RAG"**.
|
||||
7. View the results:
|
||||
- The **Retrieved Image** deemed most relevant to your question.
|
||||
- The **Generated Answer** from Gemini based on the image and question.
|
||||
|
||||
## Use Cases
|
||||
|
||||
- Analyze financial charts and extract key figures or trends.
|
||||
- Answer specific questions about diagrams, flowcharts, or infographics.
|
||||
- Extract information from tables or text within screenshots without explicit OCR.
|
||||
- Build and query visual knowledge bases using natural language.
|
||||
- Understand the content of various complex visual documents.
|
||||
|
||||
## Note
|
||||
|
||||
- Image processing (embedding) can take time, especially for many or large images. Sample images are cached after the first load.
|
||||
- Ensure your API keys have the necessary permissions and quotas for the Cohere and Gemini models used.
|
||||
- The quality of the answer depends on both the relevance of the retrieved image and the capability of the Gemini model to interpret the image based on the question.
|
||||
7
rag_tutorials/vision_rag_agent/requirements.txt
Normal file
7
rag_tutorials/vision_rag_agent/requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
streamlit>=1.32.0
|
||||
cohere>=5.0.0
|
||||
google-generativeai>=0.3.0
|
||||
Pillow>=10.0.0
|
||||
requests>=2.31.0
|
||||
numpy>=1.24.0
|
||||
tqdm>=4.66.0
|
||||
453
rag_tutorials/vision_rag_agent/vision_rag.py
Normal file
453
rag_tutorials/vision_rag_agent/vision_rag.py
Normal file
@@ -0,0 +1,453 @@
|
||||
import requests
|
||||
import os
|
||||
import io
|
||||
import base64
|
||||
import PIL
|
||||
from PIL import Image
|
||||
import tqdm
|
||||
import numpy as np
|
||||
import streamlit as st
|
||||
import cohere
|
||||
from google import genai
|
||||
|
||||
# --- Streamlit App Configuration ---
|
||||
st.set_page_config(layout="wide", page_title="Vision RAG with Cohere Embed-4")
|
||||
st.title("Vision RAG with Cohere Embed-4 🖼️")
|
||||
|
||||
# --- API Key Input ---
|
||||
with st.sidebar:
|
||||
st.header("🔑 API Keys")
|
||||
cohere_api_key = st.text_input("Cohere API Key", type="password", key="cohere_key")
|
||||
google_api_key = st.text_input("Google API Key (Gemini)", type="password", key="google_key")
|
||||
"[Get a Cohere API key](https://dashboard.cohere.com/api-keys)"
|
||||
"[Get a Google API key](https://aistudio.google.com/app/apikey)"
|
||||
|
||||
st.markdown("---")
|
||||
if not cohere_api_key:
|
||||
st.warning("Please enter your Cohere API key to proceed.")
|
||||
if not google_api_key:
|
||||
st.warning("Please enter your Google API key to proceed.")
|
||||
st.markdown("---")
|
||||
|
||||
|
||||
# --- Initialize API Clients ---
|
||||
co = None
|
||||
genai_client = None
|
||||
# Initialize Session State for embeddings and paths
|
||||
if 'image_paths' not in st.session_state:
|
||||
st.session_state.image_paths = []
|
||||
if 'doc_embeddings' not in st.session_state:
|
||||
st.session_state.doc_embeddings = None
|
||||
|
||||
if cohere_api_key and google_api_key:
|
||||
try:
|
||||
co = cohere.ClientV2(api_key=cohere_api_key)
|
||||
st.sidebar.success("Cohere Client Initialized!")
|
||||
except Exception as e:
|
||||
st.sidebar.error(f"Cohere Initialization Failed: {e}")
|
||||
|
||||
try:
|
||||
genai_client = genai.Client(api_key=google_api_key)
|
||||
st.sidebar.success("Gemini Client Initialized!")
|
||||
except Exception as e:
|
||||
st.sidebar.error(f"Gemini Initialization Failed: {e}")
|
||||
else:
|
||||
st.info("Enter your API keys in the sidebar to start.")
|
||||
|
||||
# Information about the models
|
||||
with st.expander("ℹ️ About the models used"):
|
||||
st.markdown("""
|
||||
### Cohere Embed-4
|
||||
|
||||
Cohere's Embed-4 is a state-of-the-art multimodal embedding model designed for enterprise search and retrieval.
|
||||
It enables:
|
||||
|
||||
- **Multimodal search**: Search text and images together seamlessly
|
||||
- **High accuracy**: State-of-the-art performance for retrieval tasks
|
||||
- **Efficient embedding**: Process complex images like charts, graphs, and infographics
|
||||
|
||||
The model processes images without requiring complex OCR pre-processing and maintains the connection between visual elements and text.
|
||||
|
||||
### Google Gemini 2.5 Flash
|
||||
|
||||
Gemini 2.5 Flash is Google's efficient multimodal model that can process text and image inputs to generate high-quality responses.
|
||||
It's designed for fast inference while maintaining high accuracy, making it ideal for real-time applications like this RAG system.
|
||||
""")
|
||||
|
||||
# --- Helper functions ---
|
||||
# Some helper functions to resize images and to convert them to base64 format
|
||||
max_pixels = 1568*1568 #Max resolution for images
|
||||
|
||||
# Resize too large images
|
||||
def resize_image(pil_image: PIL.Image.Image) -> None:
|
||||
"""Resizes the image in-place if it exceeds max_pixels."""
|
||||
org_width, org_height = pil_image.size
|
||||
|
||||
# Resize image if too large
|
||||
if org_width * org_height > max_pixels:
|
||||
scale_factor = (max_pixels / (org_width * org_height)) ** 0.5
|
||||
new_width = int(org_width * scale_factor)
|
||||
new_height = int(org_height * scale_factor)
|
||||
pil_image.thumbnail((new_width, new_height))
|
||||
|
||||
# Convert images to a base64 string before sending it to the API
|
||||
def base64_from_image(img_path: str) -> str:
|
||||
"""Converts an image file to a base64 encoded string."""
|
||||
pil_image = PIL.Image.open(img_path)
|
||||
img_format = pil_image.format if pil_image.format else "PNG"
|
||||
|
||||
resize_image(pil_image)
|
||||
|
||||
with io.BytesIO() as img_buffer:
|
||||
pil_image.save(img_buffer, format=img_format)
|
||||
img_buffer.seek(0)
|
||||
img_data = f"data:image/{img_format.lower()};base64,"+base64.b64encode(img_buffer.read()).decode("utf-8")
|
||||
|
||||
return img_data
|
||||
|
||||
# Convert PIL image to base64 string
|
||||
def pil_to_base64(pil_image: PIL.Image.Image) -> str:
|
||||
"""Converts a PIL image to a base64 encoded string."""
|
||||
if pil_image.format is None:
|
||||
img_format = "PNG"
|
||||
else:
|
||||
img_format = pil_image.format
|
||||
|
||||
resize_image(pil_image)
|
||||
|
||||
with io.BytesIO() as img_buffer:
|
||||
pil_image.save(img_buffer, format=img_format)
|
||||
img_buffer.seek(0)
|
||||
img_data = f"data:image/{img_format.lower()};base64,"+base64.b64encode(img_buffer.read()).decode("utf-8")
|
||||
|
||||
return img_data
|
||||
|
||||
# Compute embedding for an image
|
||||
@st.cache_data(ttl=3600, show_spinner=False)
|
||||
def compute_image_embedding(base64_img: str, _cohere_client) -> np.ndarray:
|
||||
"""Computes an embedding for an image using Cohere's Embed-4 model."""
|
||||
try:
|
||||
api_response = _cohere_client.embed(
|
||||
model="embed-v4.0",
|
||||
input_type="search_document",
|
||||
embedding_types=["float"],
|
||||
images=[base64_img],
|
||||
)
|
||||
|
||||
if api_response.embeddings and api_response.embeddings.float:
|
||||
return np.asarray(api_response.embeddings.float[0])
|
||||
else:
|
||||
st.warning("Could not get embedding. API response might be empty.")
|
||||
return None
|
||||
except Exception as e:
|
||||
st.error(f"Error computing embedding: {e}")
|
||||
return None
|
||||
|
||||
# Download and embed sample images
|
||||
@st.cache_data(ttl=3600, show_spinner=False)
|
||||
def download_and_embed_sample_images(_cohere_client):
|
||||
"""Downloads sample images and computes their embeddings using Cohere's Embed-4 model."""
|
||||
# Several images from https://www.appeconomyinsights.com/
|
||||
images = {
|
||||
"tesla.png": "https://substackcdn.com/image/fetch/w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fbef936e6-3efa-43b3-88d7-7ec620cdb33b_2744x1539.png",
|
||||
"netflix.png": "https://substackcdn.com/image/fetch/w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F23bd84c9-5b62-4526-b467-3088e27e4193_2744x1539.png",
|
||||
"nike.png": "https://substackcdn.com/image/fetch/w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2Fa5cd33ba-ae1a-42a8-a254-d85e690d9870_2741x1541.png",
|
||||
"google.png": "https://substackcdn.com/image/fetch/f_auto,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F395dd3b9-b38e-4d1f-91bc-d37b642ee920_2741x1541.png",
|
||||
"accenture.png": "https://substackcdn.com/image/fetch/w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F08b2227c-7dc8-49f7-b3c5-13cab5443ba6_2741x1541.png",
|
||||
"tecent.png": "https://substackcdn.com/image/fetch/w_1456,c_limit,f_webp,q_auto:good,fl_progressive:steep/https%3A%2F%2Fsubstack-post-media.s3.amazonaws.com%2Fpublic%2Fimages%2F0ec8448c-c4d1-4aab-a8e9-2ddebe0c95fd_2741x1541.png"
|
||||
}
|
||||
|
||||
# Prepare folders
|
||||
img_folder = "img"
|
||||
os.makedirs(img_folder, exist_ok=True)
|
||||
|
||||
img_paths = []
|
||||
doc_embeddings = []
|
||||
|
||||
# Wrap TQDM with st.spinner for better UI integration
|
||||
with st.spinner("Downloading and embedding sample images..."):
|
||||
pbar = tqdm.tqdm(images.items(), desc="Processing sample images")
|
||||
for name, url in pbar:
|
||||
img_path = os.path.join(img_folder, name)
|
||||
# Don't re-append if already processed (useful if function called multiple times)
|
||||
if img_path not in img_paths:
|
||||
img_paths.append(img_path)
|
||||
|
||||
# Download the image
|
||||
if not os.path.exists(img_path):
|
||||
try:
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
with open(img_path, "wb") as fOut:
|
||||
fOut.write(response.content)
|
||||
except requests.exceptions.RequestException as e:
|
||||
st.error(f"Failed to download {name}: {e}")
|
||||
# Optionally remove the path if download failed
|
||||
img_paths.pop()
|
||||
continue # Skip if download fails
|
||||
|
||||
# Get embedding for the image if it exists and we haven't computed one yet
|
||||
# Find index corresponding to this path
|
||||
current_index = -1
|
||||
try:
|
||||
current_index = img_paths.index(img_path)
|
||||
except ValueError:
|
||||
continue # Should not happen if append logic is correct
|
||||
|
||||
# Check if embedding already exists for this index
|
||||
if current_index >= len(doc_embeddings):
|
||||
try:
|
||||
# Ensure file exists before trying to embed
|
||||
if os.path.exists(img_path):
|
||||
base64_img = base64_from_image(img_path)
|
||||
emb = compute_image_embedding(base64_img, _cohere_client=_cohere_client)
|
||||
if emb is not None:
|
||||
# Placeholder to ensure list length matches paths before vstack
|
||||
while len(doc_embeddings) < current_index:
|
||||
doc_embeddings.append(None) # Append placeholder if needed
|
||||
doc_embeddings.append(emb)
|
||||
else:
|
||||
# If file doesn't exist (maybe failed download), add placeholder
|
||||
while len(doc_embeddings) < current_index:
|
||||
doc_embeddings.append(None)
|
||||
doc_embeddings.append(None)
|
||||
except Exception as e:
|
||||
st.error(f"Failed to embed {name}: {e}")
|
||||
# Add placeholder on error
|
||||
while len(doc_embeddings) < current_index:
|
||||
doc_embeddings.append(None)
|
||||
doc_embeddings.append(None)
|
||||
|
||||
# Filter out None embeddings and corresponding paths before stacking
|
||||
filtered_paths = [path for i, path in enumerate(img_paths) if i < len(doc_embeddings) and doc_embeddings[i] is not None]
|
||||
filtered_embeddings = [emb for emb in doc_embeddings if emb is not None]
|
||||
|
||||
if filtered_embeddings:
|
||||
doc_embeddings_array = np.vstack(filtered_embeddings)
|
||||
return filtered_paths, doc_embeddings_array
|
||||
|
||||
return [], None
|
||||
|
||||
# Search function
|
||||
def search(question: str, co_client: cohere.Client, embeddings: np.ndarray, image_paths: list[str], max_img_size: int = 800) -> str | None:
|
||||
"""Finds the most relevant image path for a given question."""
|
||||
if not co_client or embeddings is None or embeddings.size == 0 or not image_paths:
|
||||
st.warning("Search prerequisites not met (client, embeddings, or paths missing/empty).")
|
||||
return None
|
||||
if embeddings.shape[0] != len(image_paths):
|
||||
st.error(f"Mismatch between embeddings count ({embeddings.shape[0]}) and image paths count ({len(image_paths)}). Cannot perform search.")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Compute the embedding for the query
|
||||
api_response = co_client.embed(
|
||||
model="embed-v4.0",
|
||||
input_type="search_query",
|
||||
embedding_types=["float"],
|
||||
texts=[question],
|
||||
)
|
||||
|
||||
if not api_response.embeddings or not api_response.embeddings.float:
|
||||
st.error("Failed to get query embedding.")
|
||||
return None
|
||||
|
||||
query_emb = np.asarray(api_response.embeddings.float[0])
|
||||
|
||||
# Ensure query embedding has the correct shape for dot product
|
||||
if query_emb.shape[0] != embeddings.shape[1]:
|
||||
st.error(f"Query embedding dimension ({query_emb.shape[0]}) does not match document embedding dimension ({embeddings.shape[1]}).")
|
||||
return None
|
||||
|
||||
# Compute cosine similarities
|
||||
cos_sim_scores = np.dot(query_emb, embeddings.T)
|
||||
|
||||
# Get the most relevant image
|
||||
top_idx = np.argmax(cos_sim_scores)
|
||||
hit_img_path = image_paths[top_idx]
|
||||
print(f"Question: {question}") # Keep for debugging
|
||||
print(f"Most relevant image: {hit_img_path}") # Keep for debugging
|
||||
|
||||
return hit_img_path
|
||||
except Exception as e:
|
||||
st.error(f"Error during search: {e}")
|
||||
return None
|
||||
|
||||
# Answer function
|
||||
def answer(question: str, img_path: str, gemini_client) -> str:
|
||||
"""Answers the question based on the provided image using Gemini."""
|
||||
if not gemini_client or not img_path or not os.path.exists(img_path):
|
||||
missing = []
|
||||
if not gemini_client: missing.append("Gemini client")
|
||||
if not img_path: missing.append("Image path")
|
||||
elif not os.path.exists(img_path): missing.append(f"Image file at {img_path}")
|
||||
return f"Answering prerequisites not met ({', '.join(missing)} missing or invalid)."
|
||||
try:
|
||||
img = PIL.Image.open(img_path)
|
||||
prompt = [f"""Answer the question based on the following image. Be as elaborate as possible giving extra relevant information.
|
||||
Don't use markdown formatting in the response.
|
||||
Please provide enough context for your answer.
|
||||
|
||||
Question: {question}""", img]
|
||||
|
||||
response = gemini_client.models.generate_content(
|
||||
model="gemini-2.5-flash-preview-04-17",
|
||||
contents=prompt
|
||||
)
|
||||
|
||||
llm_answer = response.text
|
||||
print("LLM Answer:", llm_answer) # Keep for debugging
|
||||
return llm_answer
|
||||
except Exception as e:
|
||||
st.error(f"Error during answer generation: {e}")
|
||||
return f"Failed to generate answer: {e}"
|
||||
|
||||
# --- Main UI Setup ---
|
||||
st.subheader("📊 Load Sample Images")
|
||||
if cohere_api_key and co:
|
||||
# If button clicked, load sample images into session state
|
||||
if st.button("Load Sample Images", key="load_sample_button"):
|
||||
sample_img_paths, sample_doc_embeddings = download_and_embed_sample_images(_cohere_client=co)
|
||||
if sample_img_paths and sample_doc_embeddings is not None:
|
||||
# Append sample images to session state (avoid duplicates if clicked again)
|
||||
current_paths = set(st.session_state.image_paths)
|
||||
new_paths = [p for p in sample_img_paths if p not in current_paths]
|
||||
|
||||
if new_paths:
|
||||
new_indices = [i for i, p in enumerate(sample_img_paths) if p in new_paths]
|
||||
st.session_state.image_paths.extend(new_paths)
|
||||
new_embeddings_to_add = sample_doc_embeddings[[idx for idx, p in enumerate(sample_img_paths) if p in new_paths]]
|
||||
|
||||
if st.session_state.doc_embeddings is None or st.session_state.doc_embeddings.size == 0:
|
||||
st.session_state.doc_embeddings = new_embeddings_to_add
|
||||
else:
|
||||
st.session_state.doc_embeddings = np.vstack((st.session_state.doc_embeddings, new_embeddings_to_add))
|
||||
st.success(f"Loaded {len(new_paths)} sample images.")
|
||||
else:
|
||||
st.info("Sample images already loaded.")
|
||||
else:
|
||||
st.error("Failed to load sample images. Check console for errors.")
|
||||
else:
|
||||
st.warning("Enter API keys to enable loading sample images.")
|
||||
|
||||
st.markdown("--- ")
|
||||
# --- File Uploader (Main UI) ---
|
||||
st.subheader("📤 Upload Your Images")
|
||||
st.info("Or, upload your own images. The RAG process will search across all loaded sample images and uploaded images.")
|
||||
|
||||
# File uploader
|
||||
uploaded_files = st.file_uploader("Upload images", type=["png", "jpg", "jpeg"],
|
||||
accept_multiple_files=True, key="image_uploader",
|
||||
label_visibility="collapsed")
|
||||
|
||||
# Process uploaded images
|
||||
if uploaded_files and co:
|
||||
st.write(f"Processing {len(uploaded_files)} uploaded images...")
|
||||
progress_bar = st.progress(0)
|
||||
|
||||
# Create a temporary directory for uploaded images
|
||||
upload_folder = "uploaded_img"
|
||||
os.makedirs(upload_folder, exist_ok=True)
|
||||
|
||||
newly_uploaded_paths = []
|
||||
newly_uploaded_embeddings = []
|
||||
|
||||
for i, uploaded_file in enumerate(uploaded_files):
|
||||
# Check if already processed this session (simple name check)
|
||||
img_path = os.path.join(upload_folder, uploaded_file.name)
|
||||
if img_path not in st.session_state.image_paths:
|
||||
try:
|
||||
# Save the uploaded file
|
||||
with open(img_path, "wb") as f:
|
||||
f.write(uploaded_file.getbuffer())
|
||||
|
||||
# Get embedding
|
||||
base64_img = base64_from_image(img_path)
|
||||
emb = compute_image_embedding(base64_img, _cohere_client=co)
|
||||
|
||||
if emb is not None:
|
||||
newly_uploaded_paths.append(img_path)
|
||||
newly_uploaded_embeddings.append(emb)
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Error processing {uploaded_file.name}: {e}")
|
||||
# Update progress regardless of processing status for user feedback
|
||||
progress_bar.progress((i + 1) / len(uploaded_files))
|
||||
|
||||
# Add newly processed files to session state
|
||||
if newly_uploaded_paths:
|
||||
st.session_state.image_paths.extend(newly_uploaded_paths)
|
||||
if newly_uploaded_embeddings:
|
||||
new_embeddings_array = np.vstack(newly_uploaded_embeddings)
|
||||
if st.session_state.doc_embeddings is None or st.session_state.doc_embeddings.size == 0:
|
||||
st.session_state.doc_embeddings = new_embeddings_array
|
||||
else:
|
||||
st.session_state.doc_embeddings = np.vstack((st.session_state.doc_embeddings, new_embeddings_array))
|
||||
st.success(f"Successfully processed and added {len(newly_uploaded_paths)} new images.")
|
||||
else:
|
||||
st.warning("Failed to generate embeddings for newly uploaded images.")
|
||||
elif uploaded_files: # If files were selected but none were new
|
||||
st.info("Selected images already seem to be processed.")
|
||||
|
||||
# --- Vision RAG Section (Main UI) ---
|
||||
st.markdown("---")
|
||||
st.subheader("❓ Ask a Question")
|
||||
|
||||
if not st.session_state.image_paths:
|
||||
st.warning("Please load sample images or upload your own images first.")
|
||||
else:
|
||||
st.info(f"Ready to answer questions about {len(st.session_state.image_paths)} images.")
|
||||
|
||||
# Display thumbnails of all loaded images (optional)
|
||||
with st.expander("View Loaded Images", expanded=False):
|
||||
if st.session_state.image_paths:
|
||||
num_images_to_show = len(st.session_state.image_paths)
|
||||
cols = st.columns(5) # Show 5 thumbnails per row
|
||||
for i in range(num_images_to_show):
|
||||
with cols[i % 5]:
|
||||
# Add try-except for missing files during display
|
||||
try:
|
||||
st.image(st.session_state.image_paths[i], width=100, caption=os.path.basename(st.session_state.image_paths[i]))
|
||||
except FileNotFoundError:
|
||||
st.error(f"Missing: {os.path.basename(st.session_state.image_paths[i])}")
|
||||
else:
|
||||
st.write("No images loaded yet.")
|
||||
|
||||
question = st.text_input("Ask a question about the loaded images:",
|
||||
key="main_question_input",
|
||||
placeholder="E.g., What is Nike's net profit?",
|
||||
disabled=not st.session_state.image_paths)
|
||||
|
||||
run_button = st.button("Run Vision RAG", key="main_run_button",
|
||||
disabled=not (cohere_api_key and google_api_key and question and st.session_state.image_paths and st.session_state.doc_embeddings is not None and st.session_state.doc_embeddings.size > 0))
|
||||
|
||||
# Output Area
|
||||
st.markdown("### Results")
|
||||
retrieved_image_placeholder = st.empty()
|
||||
answer_placeholder = st.empty()
|
||||
|
||||
# Run search and answer logic
|
||||
if run_button:
|
||||
if co and genai_client and st.session_state.doc_embeddings is not None and len(st.session_state.doc_embeddings) > 0:
|
||||
with st.spinner("Finding relevant image..."):
|
||||
# Ensure embeddings and paths match before search
|
||||
if len(st.session_state.image_paths) != st.session_state.doc_embeddings.shape[0]:
|
||||
st.error("Error: Mismatch between number of images and embeddings. Cannot proceed.")
|
||||
else:
|
||||
top_image_path = search(question, co, st.session_state.doc_embeddings, st.session_state.image_paths)
|
||||
|
||||
if top_image_path:
|
||||
retrieved_image_placeholder.image(top_image_path, caption=f"Retrieved image for: '{question}'", use_container_width=True)
|
||||
|
||||
with st.spinner("Generating answer..."):
|
||||
final_answer = answer(question, top_image_path, genai_client)
|
||||
answer_placeholder.markdown(f"**Answer:**\n{final_answer}")
|
||||
else:
|
||||
retrieved_image_placeholder.warning("Could not find a relevant image for your question.")
|
||||
answer_placeholder.text("") # Clear answer placeholder
|
||||
else:
|
||||
# This case should ideally be prevented by the disabled state of the button
|
||||
st.error("Cannot run RAG. Check API clients and ensure images are loaded with embeddings.")
|
||||
|
||||
# Footer
|
||||
st.markdown("---")
|
||||
st.caption("Vision RAG with Cohere Embed-4 | Built with Streamlit, Cohere Embed-4, and Google Gemini 2.5 Flash")
|
||||
Reference in New Issue
Block a user