From 3cd4dcc317adeafbc30e4ff2d33605c799fc6326 Mon Sep 17 00:00:00 2001 From: Madhu Date: Tue, 29 Apr 2025 01:09:02 +0530 Subject: [PATCH] Vision PDF Support --- rag_tutorials/vision_rag_agent/README.md | 48 ++++--- .../vision_rag_agent/requirements.txt | 1 + rag_tutorials/vision_rag_agent/vision_rag.py | 135 +++++++++++++++--- 3 files changed, 145 insertions(+), 39 deletions(-) diff --git a/rag_tutorials/vision_rag_agent/README.md b/rag_tutorials/vision_rag_agent/README.md index 2b2f015..8ad8e58 100644 --- a/rag_tutorials/vision_rag_agent/README.md +++ b/rag_tutorials/vision_rag_agent/README.md @@ -1,17 +1,18 @@ # Vision RAG with Cohere Embed-4 🖼️ -A powerful visual Retrieval-Augmented Generation (RAG) system that utilizes Cohere's state-of-the-art Embed-4 model for multimodal embedding and Google's efficient Gemini 2.5 Flash model for answering questions about images. +A powerful visual Retrieval-Augmented Generation (RAG) system that utilizes Cohere's state-of-the-art Embed-4 model for multimodal embedding and Google's efficient Gemini 2.5 Flash model for answering questions about images and PDF pages. ## Features -- **Multimodal Search**: Leverages Cohere Embed-4 to find the most semantically relevant image for a given text question. -- **Visual Question Answering**: Employs Google Gemini 2.5 Flash to analyze the content of the retrieved image and generate accurate, context-aware answers. -- **Flexible Image Sources**: +- **Multimodal Search**: Leverages Cohere Embed-4 to find the most semantically relevant image (or PDF page image) for a given text question. +- **Visual Question Answering**: Employs Google Gemini 2.5 Flash to analyze the content of the retrieved image/page and generate accurate, context-aware answers. +- **Flexible Content Sources**: - Use pre-loaded sample financial charts and infographics. - Upload your own custom images (PNG, JPG, JPEG). -- **No OCR Required**: Directly processes complex images like charts, graphs, and infographics without needing separate text extraction steps. -- **Interactive UI**: Built with Streamlit for easy interaction, including image loading, question input, and result display. -- **Session Management**: Remembers loaded/uploaded images within a session. + - **Upload PDF documents**: Automatically extracts pages as images for analysis. +- **No OCR Required**: Directly processes complex images and visual elements within PDF pages without needing separate text extraction steps. +- **Interactive UI**: Built with Streamlit for easy interaction, including content loading, question input, and result display. +- **Session Management**: Remembers loaded/uploaded content (images and processed PDF pages) within a session. ## Requirements @@ -33,6 +34,7 @@ Follow these steps to set up and run the application: ```bash pip install -r requirements.txt ``` + *(Ensure you have the latest `PyMuPDF` installed along with other requirements)* 3. **Set up your API keys**: - Get a Cohere API key from: [https://dashboard.cohere.com/api-keys](https://dashboard.cohere.com/api-keys) @@ -52,40 +54,42 @@ Follow these steps to set up and run the application: The application follows a two-stage RAG process: 1. **Retrieval**: - - When you load sample images or upload your own, each image is converted to a base64 string. - - Cohere's `embed-v4.0` model (with `input_type="search_document"`) is used to generate a dense vector embedding for each image. + - When you load sample images or upload your own images/PDFs: + - Regular images are converted to base64 strings. + - **PDFs are processed page by page**: Each page is rendered as an image, saved temporarily, and converted to a base64 string. + - Cohere's `embed-v4.0` model (with `input_type="search_document"`) is used to generate a dense vector embedding for each image or PDF page image. - When you ask a question, the text query is embedded using the same `embed-v4.0` model (with `input_type="search_query"`). - Cosine similarity is calculated between the question embedding and all image embeddings. - - The image with the highest similarity score is retrieved as the most relevant context. + - The image with the highest similarity score (which could be a regular image or a specific PDF page image) is retrieved as the most relevant context. 2. **Generation**: - - The original text question and the retrieved image are passed as input to the Google `gemini-2.5-flash-preview-04-17` model. + - The original text question and the retrieved image/page image are passed as input to the Google `gemini-2.5-flash-preview-04-17` model. - Gemini analyzes the image content in the context of the question and generates a textual answer. ## Usage 1. Enter your Cohere and Google API keys in the sidebar. -2. Load images: +2. Load content: - Click **"Load Sample Images"** to download and process the built-in examples. - - *OR/AND* Use the **"Upload Your Images"** section to upload your own image files. -3. Once images are loaded and processed (embeddings generated), the **"Ask a Question"** section will be enabled. -4. Optionally, expand **"View Loaded Images"** to see thumbnails of all images currently in the session. -5. Type your question about the loaded images into the text input field. + - *OR/AND* Use the **"Upload Your Images or PDFs"** section to upload your own image or PDF files. +3. Once content is loaded and processed (embeddings generated), the **"Ask a Question"** section will be enabled. +4. Optionally, expand **"View Loaded Images"** to see thumbnails of all images and processed PDF pages currently in the session. +5. Type your question about the loaded content into the text input field. 6. Click **"Run Vision RAG"**. 7. View the results: - - The **Retrieved Image** deemed most relevant to your question. + - The **Retrieved Image/Page** deemed most relevant to your question (caption indicates source PDF and page number if applicable). - The **Generated Answer** from Gemini based on the image and question. ## Use Cases - Analyze financial charts and extract key figures or trends. -- Answer specific questions about diagrams, flowcharts, or infographics. -- Extract information from tables or text within screenshots without explicit OCR. -- Build and query visual knowledge bases using natural language. -- Understand the content of various complex visual documents. +- Answer specific questions about diagrams, flowcharts, or infographics within images or PDFs. +- Extract information from tables or text within screenshots or PDF pages without explicit OCR. +- Build and query visual knowledge bases (from images and PDFs) using natural language. +- Understand the content of various complex visual documents, including multi-page reports. ## Note -- Image processing (embedding) can take time, especially for many or large images. Sample images are cached after the first load. +- Image and PDF processing (page rendering + embedding) can take time, especially for many items or large files. Sample images are cached after the first load; PDF processing currently happens on each upload within a session. - Ensure your API keys have the necessary permissions and quotas for the Cohere and Gemini models used. - The quality of the answer depends on both the relevance of the retrieved image and the capability of the Gemini model to interpret the image based on the question. diff --git a/rag_tutorials/vision_rag_agent/requirements.txt b/rag_tutorials/vision_rag_agent/requirements.txt index d22b9ee..f2483bd 100644 --- a/rag_tutorials/vision_rag_agent/requirements.txt +++ b/rag_tutorials/vision_rag_agent/requirements.txt @@ -5,3 +5,4 @@ Pillow>=10.0.0 requests>=2.31.0 numpy>=1.24.0 tqdm>=4.66.0 +PyMuPDF>=1.23.0 \ No newline at end of file diff --git a/rag_tutorials/vision_rag_agent/vision_rag.py b/rag_tutorials/vision_rag_agent/vision_rag.py index c5ca4b3..1bf3042 100644 --- a/rag_tutorials/vision_rag_agent/vision_rag.py +++ b/rag_tutorials/vision_rag_agent/vision_rag.py @@ -9,6 +9,7 @@ import numpy as np import streamlit as st import cohere from google import genai +import fitz # PyMuPDF # --- Streamlit App Configuration --- st.set_page_config(layout="wide", page_title="Vision RAG with Cohere Embed-4") @@ -124,7 +125,7 @@ def pil_to_base64(pil_image: PIL.Image.Image) -> str: # Compute embedding for an image @st.cache_data(ttl=3600, show_spinner=False) -def compute_image_embedding(base64_img: str, _cohere_client) -> np.ndarray: +def compute_image_embedding(base64_img: str, _cohere_client) -> np.ndarray | None: """Computes an embedding for an image using Cohere's Embed-4 model.""" try: api_response = _cohere_client.embed( @@ -143,9 +144,81 @@ def compute_image_embedding(base64_img: str, _cohere_client) -> np.ndarray: st.error(f"Error computing embedding: {e}") return None +# Process a PDF file: extract pages as images and embed them +# Note: Caching PDF processing might be complex due to potential large file sizes and streams +# We will process it directly for now, but show progress. +def process_pdf_file(pdf_file, cohere_client, base_output_folder="pdf_pages") -> tuple[list[str], list[np.ndarray] | None]: + """Extracts pages from a PDF as images, embeds them, and saves them. + + Args: + pdf_file: UploadedFile object from Streamlit. + cohere_client: Initialized Cohere client. + base_output_folder: Directory to save page images. + + Returns: + A tuple containing: + - list of paths to the saved page images. + - list of numpy array embeddings for each page, or None if embedding fails. + """ + page_image_paths = [] + page_embeddings = [] + pdf_filename = pdf_file.name + output_folder = os.path.join(base_output_folder, os.path.splitext(pdf_filename)[0]) + os.makedirs(output_folder, exist_ok=True) + + try: + # Open PDF from stream + doc = fitz.open(stream=pdf_file.read(), filetype="pdf") + st.write(f"Processing PDF: {pdf_filename} ({len(doc)} pages)") + pdf_progress = st.progress(0.0) + + for i, page in enumerate(doc.pages()): + page_num = i + 1 + page_img_path = os.path.join(output_folder, f"page_{page_num}.png") + page_image_paths.append(page_img_path) + + # Render page to pixmap (image) + pix = page.get_pixmap(dpi=150) # Adjust DPI as needed for quality/performance + pil_image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) + + # Save the page image temporarily + pil_image.save(page_img_path, "PNG") + + # Convert PIL image to base64 + base64_img = pil_to_base64(pil_image) + + # Compute embedding for the page image + emb = compute_image_embedding(base64_img, _cohere_client=cohere_client) + if emb is not None: + page_embeddings.append(emb) + else: + st.warning(f"Could not embed page {page_num} from {pdf_filename}. Skipping.") + # Add a placeholder to keep lists aligned, will be filtered later + page_embeddings.append(None) + + # Update progress + pdf_progress.progress((i + 1) / len(doc)) + + doc.close() + pdf_progress.empty() # Remove progress bar after completion + + # Filter out pages where embedding failed + valid_paths = [path for i, path in enumerate(page_image_paths) if page_embeddings[i] is not None] + valid_embeddings = [emb for emb in page_embeddings if emb is not None] + + if not valid_embeddings: + st.error(f"Failed to generate any embeddings for {pdf_filename}.") + return [], None + + return valid_paths, valid_embeddings + + except Exception as e: + st.error(f"Error processing PDF {pdf_filename}: {e}") + return [], None + # Download and embed sample images @st.cache_data(ttl=3600, show_spinner=False) -def download_and_embed_sample_images(_cohere_client): +def download_and_embed_sample_images(_cohere_client) -> tuple[list[str], np.ndarray | None]: """Downloads sample images and computes their embeddings using Cohere's Embed-4 model.""" # Several images from https://www.appeconomyinsights.com/ images = { @@ -332,10 +405,11 @@ else: st.markdown("--- ") # --- File Uploader (Main UI) --- st.subheader("📤 Upload Your Images") -st.info("Or, upload your own images. The RAG process will search across all loaded sample images and uploaded images.") +st.info("Or, upload your own images or PDFs. The RAG process will search across all loaded content.") # File uploader -uploaded_files = st.file_uploader("Upload images", type=["png", "jpg", "jpeg"], +uploaded_files = st.file_uploader("Upload images (PNG, JPG, JPEG) or PDFs", + type=["png", "jpg", "jpeg", "pdf"], accept_multiple_files=True, key="image_uploader", label_visibility="collapsed") @@ -356,18 +430,35 @@ if uploaded_files and co: img_path = os.path.join(upload_folder, uploaded_file.name) if img_path not in st.session_state.image_paths: try: - # Save the uploaded file - with open(img_path, "wb") as f: - f.write(uploaded_file.getbuffer()) - - # Get embedding - base64_img = base64_from_image(img_path) - emb = compute_image_embedding(base64_img, _cohere_client=co) - - if emb is not None: - newly_uploaded_paths.append(img_path) - newly_uploaded_embeddings.append(emb) - + # Check file type + file_type = uploaded_file.type + if file_type == "application/pdf": + # Process PDF - returns list of paths and list of embeddings + pdf_page_paths, pdf_page_embeddings = process_pdf_file(uploaded_file, cohere_client=co) + if pdf_page_paths and pdf_page_embeddings: + # Add only paths/embeddings not already in session state + current_paths_set = set(st.session_state.image_paths) + unique_new_paths = [p for p in pdf_page_paths if p not in current_paths_set] + if unique_new_paths: + indices_to_add = [i for i, p in enumerate(pdf_page_paths) if p in unique_new_paths] + newly_uploaded_paths.extend(unique_new_paths) + newly_uploaded_embeddings.extend([pdf_page_embeddings[idx] for idx in indices_to_add]) + elif file_type in ["image/png", "image/jpeg"]: + # Process regular image + # Save the uploaded file + with open(img_path, "wb") as f: + f.write(uploaded_file.getbuffer()) + + # Get embedding + base64_img = base64_from_image(img_path) + emb = compute_image_embedding(base64_img, _cohere_client=co) + + if emb is not None: + newly_uploaded_paths.append(img_path) + newly_uploaded_embeddings.append(emb) + else: + st.warning(f"Unsupported file type skipped: {uploaded_file.name} ({file_type})") + except Exception as e: st.error(f"Error processing {uploaded_file.name}: {e}") # Update progress regardless of processing status for user feedback @@ -406,6 +497,7 @@ else: with cols[i % 5]: # Add try-except for missing files during display try: + # Display PDF pages differently? For now, just show the image st.image(st.session_state.image_paths[i], width=100, caption=os.path.basename(st.session_state.image_paths[i])) except FileNotFoundError: st.error(f"Missing: {os.path.basename(st.session_state.image_paths[i])}") @@ -436,7 +528,16 @@ if run_button: top_image_path = search(question, co, st.session_state.doc_embeddings, st.session_state.image_paths) if top_image_path: - retrieved_image_placeholder.image(top_image_path, caption=f"Retrieved image for: '{question}'", use_container_width=True) + caption = f"Retrieved content for: '{question}' (Source: {os.path.basename(top_image_path)})" + # Add source PDF name if it's a page image + if top_image_path.startswith("pdf_pages/"): + parts = top_image_path.split(os.sep) + if len(parts) >= 3: + pdf_name = parts[1] + page_name = parts[-1] + caption = f"Retrieved content for: '{question}' (Source: {pdf_name}.pdf, {page_name.replace('.png','')})" + + retrieved_image_placeholder.image(top_image_path, caption=caption, use_container_width=True) with st.spinner("Generating answer..."): final_answer = answer(question, top_image_path, genai_client)