mirror of
https://github.com/open-webui/open-webui.git
synced 2026-05-01 09:49:03 -05:00
197 lines
7.1 KiB
Python
197 lines
7.1 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import uuid
|
|
from typing import Optional
|
|
|
|
import aiohttp
|
|
import websockets
|
|
from pydantic import BaseModel
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ResultModel(BaseModel):
|
|
"""
|
|
Execute Code Result Model
|
|
"""
|
|
|
|
stdout: Optional[str] = ''
|
|
stderr: Optional[str] = ''
|
|
result: Optional[str] = ''
|
|
|
|
|
|
class JupyterCodeExecuter:
|
|
"""
|
|
Execute code in jupyter notebook
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
base_url: str,
|
|
code: str,
|
|
token: str = '',
|
|
password: str = '',
|
|
timeout: int = 60,
|
|
):
|
|
"""
|
|
:param base_url: Jupyter server URL (e.g., "http://localhost:8888")
|
|
:param code: Code to execute
|
|
:param token: Jupyter authentication token (optional)
|
|
:param password: Jupyter password (optional)
|
|
:param timeout: WebSocket timeout in seconds (default: 60s)
|
|
"""
|
|
self.base_url = base_url
|
|
self.code = code
|
|
self.token = token
|
|
self.password = password
|
|
self.timeout = timeout
|
|
self.kernel_id = ''
|
|
if self.base_url[-1] != '/':
|
|
self.base_url += '/'
|
|
self.session = aiohttp.ClientSession(trust_env=True, base_url=self.base_url)
|
|
self.params = {}
|
|
self.result = ResultModel()
|
|
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
if self.kernel_id:
|
|
try:
|
|
async with self.session.delete(f'api/kernels/{self.kernel_id}', params=self.params) as response:
|
|
response.raise_for_status()
|
|
except Exception as err:
|
|
logger.exception('close kernel failed, %s', err)
|
|
await self.session.close()
|
|
|
|
async def run(self) -> ResultModel:
|
|
try:
|
|
await self.sign_in()
|
|
await self.init_kernel()
|
|
await self.execute_code()
|
|
except Exception as err:
|
|
logger.exception('execute code failed, %s', err)
|
|
self.result.stderr = f'Error: {err}'
|
|
return self.result
|
|
|
|
async def sign_in(self) -> None:
|
|
# password authentication
|
|
if self.password and not self.token:
|
|
async with self.session.get('login') as response:
|
|
response.raise_for_status()
|
|
xsrf_token = response.cookies['_xsrf'].value
|
|
if not xsrf_token:
|
|
raise ValueError('_xsrf token not found')
|
|
self.session.cookie_jar.update_cookies(response.cookies)
|
|
self.session.headers.update({'X-XSRFToken': xsrf_token})
|
|
async with self.session.post(
|
|
'login',
|
|
data={'_xsrf': xsrf_token, 'password': self.password},
|
|
allow_redirects=False,
|
|
) as response:
|
|
response.raise_for_status()
|
|
self.session.cookie_jar.update_cookies(response.cookies)
|
|
|
|
# token authentication
|
|
if self.token:
|
|
self.params.update({'token': self.token})
|
|
|
|
async def init_kernel(self) -> None:
|
|
async with self.session.post(url='api/kernels', params=self.params) as response:
|
|
response.raise_for_status()
|
|
kernel_data = await response.json()
|
|
self.kernel_id = kernel_data['id']
|
|
|
|
def init_ws(self) -> (str, dict):
|
|
ws_base = self.base_url.replace('http', 'ws', 1)
|
|
ws_params = '?' + '&'.join([f'{key}={val}' for key, val in self.params.items()])
|
|
websocket_url = f'{ws_base}api/kernels/{self.kernel_id}/channels{ws_params if len(ws_params) > 1 else ""}'
|
|
ws_headers = {}
|
|
if self.password and not self.token:
|
|
ws_headers = {
|
|
'Cookie': '; '.join([f'{cookie.key}={cookie.value}' for cookie in self.session.cookie_jar]),
|
|
**self.session.headers,
|
|
}
|
|
return websocket_url, ws_headers
|
|
|
|
async def execute_code(self) -> None:
|
|
# initialize ws
|
|
websocket_url, ws_headers = self.init_ws()
|
|
# execute
|
|
async with websockets.connect(websocket_url, additional_headers=ws_headers) as ws:
|
|
await self.execute_in_jupyter(ws)
|
|
|
|
async def execute_in_jupyter(self, ws) -> None:
|
|
# send message
|
|
msg_id = uuid.uuid4().hex
|
|
await ws.send(
|
|
json.dumps(
|
|
{
|
|
'header': {
|
|
'msg_id': msg_id,
|
|
'msg_type': 'execute_request',
|
|
'username': 'user',
|
|
'session': uuid.uuid4().hex,
|
|
'date': '',
|
|
'version': '5.3',
|
|
},
|
|
'parent_header': {},
|
|
'metadata': {},
|
|
'content': {
|
|
'code': self.code,
|
|
'silent': False,
|
|
'store_history': True,
|
|
'user_expressions': {},
|
|
'allow_stdin': False,
|
|
'stop_on_error': True,
|
|
},
|
|
'channel': 'shell',
|
|
}
|
|
)
|
|
)
|
|
# parse message
|
|
stdout, stderr, result = '', '', []
|
|
while True:
|
|
try:
|
|
# wait for message
|
|
message = await asyncio.wait_for(ws.recv(), self.timeout)
|
|
message_data = json.loads(message)
|
|
# msg id not match, skip
|
|
if message_data.get('parent_header', {}).get('msg_id') != msg_id:
|
|
continue
|
|
# check message type
|
|
msg_type = message_data.get('msg_type')
|
|
match msg_type:
|
|
case 'stream':
|
|
if message_data['content']['name'] == 'stdout':
|
|
stdout += message_data['content']['text']
|
|
elif message_data['content']['name'] == 'stderr':
|
|
stderr += message_data['content']['text']
|
|
case 'execute_result' | 'display_data':
|
|
data = message_data['content']['data']
|
|
if 'image/png' in data:
|
|
result.append(f'data:image/png;base64,{data["image/png"]}')
|
|
elif 'text/plain' in data:
|
|
result.append(data['text/plain'])
|
|
case 'error':
|
|
stderr += '\n'.join(message_data['content']['traceback'])
|
|
case 'status':
|
|
if message_data['content']['execution_state'] == 'idle':
|
|
break
|
|
|
|
except asyncio.TimeoutError:
|
|
stderr += '\nExecution timed out.'
|
|
break
|
|
self.result.stdout = stdout.strip()
|
|
self.result.stderr = stderr.strip()
|
|
self.result.result = '\n'.join(result).strip() if result else ''
|
|
|
|
|
|
async def execute_code_jupyter(
|
|
base_url: str, code: str, token: str = '', password: str = '', timeout: int = 60
|
|
) -> dict:
|
|
async with JupyterCodeExecuter(base_url, code, token, password, timeout) as executor:
|
|
result = await executor.run()
|
|
return result.model_dump()
|