code is perfect

This commit is contained in:
Madhu
2025-01-05 17:22:09 +05:30
parent 8cdb9bd663
commit 9fd044b81e
2 changed files with 195 additions and 362 deletions

View File

@@ -1,202 +1,214 @@
import os
import json
import re
import sys
import io
import contextlib
import warnings
from typing import Optional, List, Any, Tuple
from dotenv import load_dotenv
from PIL import Image
import streamlit as st
import pandas as pd
import tempfile
import re
from together import Together
import csv
from dotenv import load_dotenv
import base64
import matplotlib.pyplot as plt
import io
import seaborn as sns
from io import BytesIO
from together import Together
from e2b_code_interpreter import Sandbox
# Load environment variables
load_dotenv()
# Suppress Pydantic warnings globally
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")
# Function to preprocess and save the uploaded file to a temporary file
def preprocess_and_save(file):
try:
# Read the uploaded file into a DataFrame
if file.name.endswith('.csv'):
df = pd.read_csv(file, encoding='utf-8', na_values=['NA', 'N/A', 'missing'])
elif file.name.endswith('.xlsx'):
df = pd.read_excel(file, na_values=['NA', 'N/A', 'missing'])
else:
st.error("Unsupported file format. Please upload a CSV or Excel file.")
return None, None, None
# Ensure string columns are properly quoted
for col in df.select_dtypes(include=['object']):
df[col] = df[col].astype(str).replace({r'"': '""'}, regex=True)
# Parse dates and numeric columns
for col in df.columns:
if 'date' in col.lower():
df[col] = pd.to_datetime(df[col], errors='coerce')
elif df[col].dtype == 'object':
try:
# Handle columns with values like "4.1/5"
if df[col].str.contains('/').any():
# Split the values and take the first part (e.g., "4.1/5" -> 4.1)
df[col] = df[col].str.split('/').str[0]
# Convert to numeric, coerce errors to NaN
df[col] = pd.to_numeric(df[col], errors='coerce')
except (ValueError, TypeError):
# Keep as is if conversion fails
st.warning(f"Could not convert column '{col}' to numeric. Keeping as string.")
pass
# Drop rows with all NaN values
df.dropna(how='all', inplace=True)
# Create a temporary file to save the preprocessed data
with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as temp_file:
temp_path = temp_file.name
# Save the DataFrame to the temporary CSV file
df.to_csv(temp_path, index=False, quoting=csv.QUOTE_ALL)
return temp_path, df.columns.tolist(), df # Return the DataFrame as well
except Exception as e:
st.error(f"Error processing file: {e}")
return None, None, None
# Regex pattern to extract code from LLM response
pattern = re.compile(r"```python\n(.*?)\n```", re.DOTALL)
# Function to execute Python code and generate plots
def execute_code(code: str, df):
try:
# Define locals with necessary imports and the DataFrame
local_env = {
'pd': pd,
'df': df,
'plt': plt,
'sns': sns # if seaborn is needed
}
# Execute the code in the local environment
exec(code, globals(), local_env)
def code_interpret(e2b_code_interpreter: Sandbox, code: str) -> Optional[List[Any]]:
"""
Runs the given Python code in the E2B sandbox.
Args:
e2b_code_interpreter: The E2B sandbox instance
code: Python code to execute
# Check if a plot was generated
if 'plt' in local_env:
# Save the plot to a BytesIO object
buf = io.BytesIO()
plt.savefig(buf, format='png')
plt.close()
buf.seek(0)
# Encode the plot as base64
base64_image = base64.b64encode(buf.read()).decode('utf-8')
return base64_image
else:
st.warning("No plot generated. Ensure the data being plotted is numeric.")
Returns:
Optional[List[Any]]: Results from code execution
"""
with st.spinner('Executing code in E2B sandbox...'):
# Capture stdout and stderr
stdout_capture = io.StringIO()
stderr_capture = io.StringIO()
with contextlib.redirect_stdout(stdout_capture), contextlib.redirect_stderr(stderr_capture):
# Suppress warnings during code execution
with warnings.catch_warnings():
warnings.simplefilter("ignore")
exec = e2b_code_interpreter.run_code(code)
# Log stderr (warnings and errors) to the terminal
if stderr_capture.getvalue():
print("[Code Interpreter Warnings/Errors]", file=sys.stderr)
print(stderr_capture.getvalue(), file=sys.stderr)
# Log stdout (normal output) to the terminal
if stdout_capture.getvalue():
print("[Code Interpreter Output]", file=sys.stdout)
print(stdout_capture.getvalue(), file=sys.stdout)
if exec.error:
print(f"[Code Interpreter ERROR] {exec.error}", file=sys.stderr)
return None
except Exception as e:
st.error(f"Error executing code: {e}")
return None
return exec.results
# Function to communicate with Together AI
def chat_with_llm(user_message, file_path, columns, df):
print(f"\n{'='*50}\nUser message: {user_message}\n{'='*50}")
def match_code_blocks(llm_response: str) -> str:
"""
Extracts Python code blocks from the LLM response.
Args:
llm_response: The response from the LLM
Returns:
str: Extracted Python code or empty string
"""
match = pattern.search(llm_response)
if match:
code = match.group(1)
return code
return ""
# Update the system prompt with the file path, columns, and plot path
system_prompt = SYSTEM_PROMPT.format(
file_path=file_path,
columns=columns,
)
# Add a hint to include a plot if the user asks for visualization
if "plot" in user_message.lower():
system_prompt += " Include a plot in your response and output the base64 string of the plot image."
def chat_with_llm(e2b_code_interpreter: Sandbox, user_message: str, dataset_path: str) -> Tuple[Optional[List[Any]], str]:
"""
Sends the user message to the LLM and executes the generated code.
Args:
e2b_code_interpreter: The E2B sandbox instance
user_message: User's query message
dataset_path: Path to the uploaded dataset
Returns:
Tuple[Optional[List[Any]], str]: Code execution results and LLM response
"""
# Update system prompt to include dataset path information
system_prompt = f"""You're a Python data scientist and data visualization expert. You are given a dataset at path '{dataset_path}' and also the user's query.
You need to analyze the dataset and answer the user's query with a response and you run Python code to solve them.
IMPORTANT: Always use the dataset path variable '{dataset_path}' in your code when reading the CSV file."""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message},
]
# Use the Together API key from session state
response = client.chat.completions.create(
model="meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
messages=messages,
)
with st.spinner('Getting response from Together AI LLM model...'):
client = Together(api_key=st.session_state.together_api_key)
response = client.chat.completions.create(
model=st.session_state.model_name,
messages=messages,
)
response_message = response.choices[0].message.content
print("LLM Response:", response_message) # Debug: Print the LLM's response
python_code = match_code_blocks(response_message)
print("Extracted Python Code:", python_code) # Debug: Print the extracted code
if python_code:
# Execute the code and generate the plot
base64_image = execute_code(python_code, df)
return response_message, base64_image
else:
print(f"Failed to match any Python code in model's response {response_message}")
return response_message, None
# Set up Streamlit app
st.title("AI Data Visualisation Agent")
# Sidebar for API keys and file upload
st.sidebar.header("API Keys")
together_api_key = st.sidebar.text_input("Together AI API Key", type="password")
# Store API key in session state
if 'together_api_key' not in st.session_state:
st.session_state.together_api_key = None
uploaded_file = st.sidebar.file_uploader("Upload CSV or Excel File", type=['csv', 'xlsx'])
# System prompt (dynamic based on the uploaded file)
SYSTEM_PROMPT = """
You are a Python data scientist and Visualisation expert. You have access to a CSV file located at '{file_path}'.
The dataset has the following columns: {columns}.
You can read this file into a DataFrame using `df = pd.read_csv('{file_path}')` and perform data analysis tasks based on user queries.
Make sure to handle missing values and data type inconsistencies. When generating plots,
use matplotlib or seaborn and output the plot as a base64 string.
Always check if the data being plotted is numeric. If the data is not numeric, preprocess it to convert it to numeric values.
Always respond with the Python code to answer the user's query, and include visualizations only if explicitly requested.
"""
# Function to match Python code blocks
pattern = re.compile(r"```python\n(.*?)\n```", re.DOTALL)
def match_code_blocks(llm_response):
match = pattern.search(llm_response)
if match:
code = match.group(1)
# Remove comments and extra text
code = "\n".join([line for line in code.split("\n") if not line.strip().startswith("#")])
return code
return ""
# Main app logic
if uploaded_file:
if not together_api_key:
st.warning("Please provide the Together AI API key.")
else:
# Update session state with API key
st.session_state.together_api_key = together_api_key
response_message = response.choices[0].message
python_code = match_code_blocks(response_message.content)
# Initialize Together AI client only after confirming API key exists
try:
client = Together(api_key=together_api_key)
# Preprocess and save the uploaded file
temp_path, columns, df = preprocess_and_save(uploaded_file)
if temp_path:
# Rest of your code for user query handling
user_query = st.text_input("Ask a query about the data:")
if st.button("Submit Query"):
response_message, base64_image = chat_with_llm(user_query, temp_path, columns, df)
# Display AI's response
st.write("AI's Response:")
st.write(response_message)
# Display the plot if generated
if base64_image:
st.image(base64.b64decode(base64_image), use_container_width=True)
else:
st.write("No plot generated.")
if python_code:
code_interpreter_results = code_interpret(e2b_code_interpreter, python_code)
return code_interpreter_results, response_message.content
else:
st.warning(f"Failed to match any Python code in model's response")
return None, response_message.content
def upload_dataset(code_interpreter: Sandbox, uploaded_file) -> str:
"""
Uploads the dataset to the E2B sandbox.
Args:
code_interpreter: The E2B sandbox instance
uploaded_file: Streamlit uploaded file
Returns:
str: Path where file was uploaded
"""
dataset_path = f"./{uploaded_file.name}"
try:
code_interpreter.files.write(dataset_path, uploaded_file)
return dataset_path
except Exception as error:
st.error(f"Error during file upload: {error}")
raise error
def main():
"""Main Streamlit application."""
st.title("AI Data Visualization Agent")
st.write("Upload your dataset and ask questions about it!")
# Sidebar for API keys and model selection
with st.sidebar:
st.header("API Keys and Model Configuration")
st.session_state.together_api_key = st.text_input("Enter Together API Key", type="password")
st.session_state.e2b_api_key = st.text_input("Enter E2B API Key", type="password")
# Add model selection dropdown
model_options = {
"Meta-Llama 3.1 405B": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
"DeepSeek V3": "deepseek-ai/DeepSeek-V3",
"Qwen 2.5 7B": "Qwen/Qwen2.5-7B-Instruct-Turbo",
"Meta-Llama 3.3 70B": "meta-llama/Llama-3.3-70B-Instruct-Turbo"
}
selected_model = st.selectbox(
"Select Model",
options=list(model_options.keys()),
index=0 # Default to first option
)
st.session_state.model_name = model_options[selected_model]
uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
if uploaded_file is not None:
# Display dataset with toggle
df = pd.read_csv(uploaded_file)
st.write("Dataset:")
show_full = st.checkbox("Show full dataset")
if show_full:
st.dataframe(df)
else:
st.write("Preview (first 5 rows):")
st.dataframe(df.head())
# Query input
query = st.text_area("What would you like to know about your data?",
"Can you compare the average cost for two people between different categories?")
if st.button("Analyze"):
if not st.session_state.together_api_key or not st.session_state.e2b_api_key:
st.error("Please enter both API keys in the sidebar.")
else:
st.error("Failed to preprocess and save the data.")
except Exception as e:
st.error(f"Error initializing Together AI client: {str(e)}")
else:
st.warning("Please upload a file.")
with Sandbox(api_key=st.session_state.e2b_api_key) as code_interpreter:
# Upload the dataset
dataset_path = upload_dataset(code_interpreter, uploaded_file)
# Pass dataset_path to chat_with_llm
code_results, llm_response = chat_with_llm(code_interpreter, query, dataset_path)
# Display LLM's text response
st.write("AI Response:")
st.write(llm_response)
# Display results/visualizations
if code_results:
for result in code_results:
if hasattr(result, 'png') and result.png: # Check if PNG data is available
# Decode the base64-encoded PNG data
png_data = base64.b64decode(result.png)
# Convert PNG data to an image and display it
image = Image.open(BytesIO(png_data))
st.image(image, caption="Generated Visualization", use_container_width=False)
elif hasattr(result, 'figure'): # For matplotlib figures
fig = result.figure # Extract the matplotlib figure
st.pyplot(fig) # Display using st.pyplot
elif hasattr(result, 'show'): # For plotly figures
st.plotly_chart(result)
elif isinstance(result, (pd.DataFrame, pd.Series)):
st.dataframe(result)
else:
st.write(result)
if __name__ == "__main__":
main()

View File

@@ -1,179 +0,0 @@
import os
import json
import re
from typing import Optional, List, Any, Tuple
from dotenv import load_dotenv
from PIL import Image
import io
import streamlit as st
import pandas as pd
import base64
from io import BytesIO
from PIL import Image
from together import Together
from e2b_code_interpreter import Sandbox
# Load environment variables
load_dotenv()
# Regex pattern to extract code from LLM response
pattern = re.compile(r"```python\n(.*?)\n```", re.DOTALL)
def code_interpret(e2b_code_interpreter: Sandbox, code: str) -> Optional[List[Any]]:
"""
Runs the given Python code in the E2B sandbox.
Args:
e2b_code_interpreter: The E2B sandbox instance
code: Python code to execute
Returns:
Optional[List[Any]]: Results from code execution
"""
with st.spinner('Executing code in E2B sandbox...'):
exec = e2b_code_interpreter.run_code(code,
on_stderr=lambda stderr: st.error(f"[Code Interpreter] {stderr}"),
on_stdout=lambda stdout: st.info(f"[Code Interpreter] {stdout}"))
if exec.error:
st.error(f"[Code Interpreter ERROR] {exec.error}")
return None
return exec.results
def match_code_blocks(llm_response: str) -> str:
"""
Extracts Python code blocks from the LLM response.
Args:
llm_response: The response from the LLM
Returns:
str: Extracted Python code or empty string
"""
match = pattern.search(llm_response)
if match:
code = match.group(1)
return code
return ""
def chat_with_llm(e2b_code_interpreter: Sandbox, user_message: str, dataset_path: str) -> Tuple[Optional[List[Any]], str]:
"""
Sends the user message to the LLM and executes the generated code.
Args:
e2b_code_interpreter: The E2B sandbox instance
user_message: User's query message
dataset_path: Path to the uploaded dataset
Returns:
Tuple[Optional[List[Any]], str]: Code execution results and LLM response
"""
# Update system prompt to include dataset path information
system_prompt = f"""You're a Python data scientist and data visualization expert. You are given a dataset at path '{dataset_path}' and also the user's query.
You need to analyze the dataset and answer the user's query with a response and you run Python code to solve them.
IMPORTANT: Always use the dataset path variable '{dataset_path}' in your code when reading the CSV file."""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message},
]
with st.spinner('Getting response from together AI...'):
client = Together(api_key=st.session_state.together_api_key)
response = client.chat.completions.create(
model=st.session_state.model_name,
messages=messages,
)
response_message = response.choices[0].message
python_code = match_code_blocks(response_message.content)
if python_code:
code_interpreter_results = code_interpret(e2b_code_interpreter, python_code)
return code_interpreter_results, response_message.content
else:
st.warning(f"Failed to match any Python code in model's response")
return None, response_message.content
def upload_dataset(code_interpreter: Sandbox, uploaded_file) -> str:
"""
Uploads the dataset to the E2B sandbox.
Args:
code_interpreter: The E2B sandbox instance
uploaded_file: Streamlit uploaded file
Returns:
str: Path where file was uploaded
"""
dataset_path = f"./{uploaded_file.name}"
try:
code_interpreter.files.write(dataset_path, uploaded_file)
return dataset_path
except Exception as error:
st.error(f"Error during file upload: {error}")
raise error
def main():
"""Main Streamlit application."""
st.title("AI Data Visualization Assistant")
st.write("Upload your dataset and ask questions about it!")
# Sidebar for API keys and model name
with st.sidebar:
st.header("API Keys and Model Configuration")
st.session_state.together_api_key = st.text_input("Enter Together API Key", type="password")
st.session_state.e2b_api_key = st.text_input("Enter E2B API Key", type="password")
st.session_state.model_name = st.text_input("Enter Model Name", value="meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo")
uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
if uploaded_file is not None:
# Display dataset preview
df = pd.read_csv(uploaded_file)
st.write("Dataset Preview:")
st.dataframe(df.head())
# Query input
query = st.text_area("What would you like to know about your data?",
"Can you compare the average cost for two people between different categories?")
if st.button("Analyze"):
if not st.session_state.together_api_key or not st.session_state.e2b_api_key:
st.error("Please enter both API keys in the sidebar.")
else:
with Sandbox(api_key=st.session_state.e2b_api_key) as code_interpreter:
# Upload the dataset
dataset_path = upload_dataset(code_interpreter, uploaded_file)
# Pass dataset_path to chat_with_llm
code_results, llm_response = chat_with_llm(code_interpreter, query, dataset_path)
# Display LLM's text response
st.write("AI Response:")
st.write(llm_response)
# Display results/visualizations
if code_results:
for result in code_results:
if hasattr(result, 'png') and result.png: # Check if PNG data is available
# Decode the base64-encoded PNG data
png_data = base64.b64decode(result.png)
# Convert PNG data to an image and display it
image = Image.open(BytesIO(png_data))
st.image(image, caption="Generated Visualization", use_container_width=False)
elif hasattr(result, 'figure'): # For matplotlib figures
fig = result.figure # Extract the matplotlib figure
st.pyplot(fig) # Display using st.pyplot
elif hasattr(result, 'show'): # For plotly figures
st.plotly_chart(result)
elif isinstance(result, (pd.DataFrame, pd.Series)):
st.dataframe(result)
else:
st.write(result)
if __name__ == "__main__":
main()