mirror of
https://github.com/open-webui/open-webui.git
synced 2026-05-02 18:33:33 -05:00
refac
This commit is contained in:
@@ -16,9 +16,9 @@ class ResultModel(BaseModel):
|
||||
Execute Code Result Model
|
||||
"""
|
||||
|
||||
stdout: Optional[str] = ""
|
||||
stderr: Optional[str] = ""
|
||||
result: Optional[str] = ""
|
||||
stdout: Optional[str] = ''
|
||||
stderr: Optional[str] = ''
|
||||
result: Optional[str] = ''
|
||||
|
||||
|
||||
class JupyterCodeExecuter:
|
||||
@@ -30,8 +30,8 @@ class JupyterCodeExecuter:
|
||||
self,
|
||||
base_url: str,
|
||||
code: str,
|
||||
token: str = "",
|
||||
password: str = "",
|
||||
token: str = '',
|
||||
password: str = '',
|
||||
timeout: int = 60,
|
||||
):
|
||||
"""
|
||||
@@ -46,9 +46,9 @@ class JupyterCodeExecuter:
|
||||
self.token = token
|
||||
self.password = password
|
||||
self.timeout = timeout
|
||||
self.kernel_id = ""
|
||||
if self.base_url[-1] != "/":
|
||||
self.base_url += "/"
|
||||
self.kernel_id = ''
|
||||
if self.base_url[-1] != '/':
|
||||
self.base_url += '/'
|
||||
self.session = aiohttp.ClientSession(trust_env=True, base_url=self.base_url)
|
||||
self.params = {}
|
||||
self.result = ResultModel()
|
||||
@@ -59,12 +59,10 @@ class JupyterCodeExecuter:
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.kernel_id:
|
||||
try:
|
||||
async with self.session.delete(
|
||||
f"api/kernels/{self.kernel_id}", params=self.params
|
||||
) as response:
|
||||
async with self.session.delete(f'api/kernels/{self.kernel_id}', params=self.params) as response:
|
||||
response.raise_for_status()
|
||||
except Exception as err:
|
||||
logger.exception("close kernel failed, %s", err)
|
||||
logger.exception('close kernel failed, %s', err)
|
||||
await self.session.close()
|
||||
|
||||
async def run(self) -> ResultModel:
|
||||
@@ -73,23 +71,23 @@ class JupyterCodeExecuter:
|
||||
await self.init_kernel()
|
||||
await self.execute_code()
|
||||
except Exception as err:
|
||||
logger.exception("execute code failed, %s", err)
|
||||
self.result.stderr = f"Error: {err}"
|
||||
logger.exception('execute code failed, %s', err)
|
||||
self.result.stderr = f'Error: {err}'
|
||||
return self.result
|
||||
|
||||
async def sign_in(self) -> None:
|
||||
# password authentication
|
||||
if self.password and not self.token:
|
||||
async with self.session.get("login") as response:
|
||||
async with self.session.get('login') as response:
|
||||
response.raise_for_status()
|
||||
xsrf_token = response.cookies["_xsrf"].value
|
||||
xsrf_token = response.cookies['_xsrf'].value
|
||||
if not xsrf_token:
|
||||
raise ValueError("_xsrf token not found")
|
||||
raise ValueError('_xsrf token not found')
|
||||
self.session.cookie_jar.update_cookies(response.cookies)
|
||||
self.session.headers.update({"X-XSRFToken": xsrf_token})
|
||||
self.session.headers.update({'X-XSRFToken': xsrf_token})
|
||||
async with self.session.post(
|
||||
"login",
|
||||
data={"_xsrf": xsrf_token, "password": self.password},
|
||||
'login',
|
||||
data={'_xsrf': xsrf_token, 'password': self.password},
|
||||
allow_redirects=False,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
@@ -97,27 +95,22 @@ class JupyterCodeExecuter:
|
||||
|
||||
# token authentication
|
||||
if self.token:
|
||||
self.params.update({"token": self.token})
|
||||
self.params.update({'token': self.token})
|
||||
|
||||
async def init_kernel(self) -> None:
|
||||
async with self.session.post(url="api/kernels", params=self.params) as response:
|
||||
async with self.session.post(url='api/kernels', params=self.params) as response:
|
||||
response.raise_for_status()
|
||||
kernel_data = await response.json()
|
||||
self.kernel_id = kernel_data["id"]
|
||||
self.kernel_id = kernel_data['id']
|
||||
|
||||
def init_ws(self) -> (str, dict):
|
||||
ws_base = self.base_url.replace("http", "ws", 1)
|
||||
ws_params = "?" + "&".join([f"{key}={val}" for key, val in self.params.items()])
|
||||
websocket_url = f"{ws_base}api/kernels/{self.kernel_id}/channels{ws_params if len(ws_params) > 1 else ''}"
|
||||
ws_base = self.base_url.replace('http', 'ws', 1)
|
||||
ws_params = '?' + '&'.join([f'{key}={val}' for key, val in self.params.items()])
|
||||
websocket_url = f'{ws_base}api/kernels/{self.kernel_id}/channels{ws_params if len(ws_params) > 1 else ""}'
|
||||
ws_headers = {}
|
||||
if self.password and not self.token:
|
||||
ws_headers = {
|
||||
"Cookie": "; ".join(
|
||||
[
|
||||
f"{cookie.key}={cookie.value}"
|
||||
for cookie in self.session.cookie_jar
|
||||
]
|
||||
),
|
||||
'Cookie': '; '.join([f'{cookie.key}={cookie.value}' for cookie in self.session.cookie_jar]),
|
||||
**self.session.headers,
|
||||
}
|
||||
return websocket_url, ws_headers
|
||||
@@ -126,9 +119,7 @@ class JupyterCodeExecuter:
|
||||
# initialize ws
|
||||
websocket_url, ws_headers = self.init_ws()
|
||||
# execute
|
||||
async with websockets.connect(
|
||||
websocket_url, additional_headers=ws_headers
|
||||
) as ws:
|
||||
async with websockets.connect(websocket_url, additional_headers=ws_headers) as ws:
|
||||
await self.execute_in_jupyter(ws)
|
||||
|
||||
async def execute_in_jupyter(self, ws) -> None:
|
||||
@@ -137,71 +128,69 @@ class JupyterCodeExecuter:
|
||||
await ws.send(
|
||||
json.dumps(
|
||||
{
|
||||
"header": {
|
||||
"msg_id": msg_id,
|
||||
"msg_type": "execute_request",
|
||||
"username": "user",
|
||||
"session": uuid.uuid4().hex,
|
||||
"date": "",
|
||||
"version": "5.3",
|
||||
'header': {
|
||||
'msg_id': msg_id,
|
||||
'msg_type': 'execute_request',
|
||||
'username': 'user',
|
||||
'session': uuid.uuid4().hex,
|
||||
'date': '',
|
||||
'version': '5.3',
|
||||
},
|
||||
"parent_header": {},
|
||||
"metadata": {},
|
||||
"content": {
|
||||
"code": self.code,
|
||||
"silent": False,
|
||||
"store_history": True,
|
||||
"user_expressions": {},
|
||||
"allow_stdin": False,
|
||||
"stop_on_error": True,
|
||||
'parent_header': {},
|
||||
'metadata': {},
|
||||
'content': {
|
||||
'code': self.code,
|
||||
'silent': False,
|
||||
'store_history': True,
|
||||
'user_expressions': {},
|
||||
'allow_stdin': False,
|
||||
'stop_on_error': True,
|
||||
},
|
||||
"channel": "shell",
|
||||
'channel': 'shell',
|
||||
}
|
||||
)
|
||||
)
|
||||
# parse message
|
||||
stdout, stderr, result = "", "", []
|
||||
stdout, stderr, result = '', '', []
|
||||
while True:
|
||||
try:
|
||||
# wait for message
|
||||
message = await asyncio.wait_for(ws.recv(), self.timeout)
|
||||
message_data = json.loads(message)
|
||||
# msg id not match, skip
|
||||
if message_data.get("parent_header", {}).get("msg_id") != msg_id:
|
||||
if message_data.get('parent_header', {}).get('msg_id') != msg_id:
|
||||
continue
|
||||
# check message type
|
||||
msg_type = message_data.get("msg_type")
|
||||
msg_type = message_data.get('msg_type')
|
||||
match msg_type:
|
||||
case "stream":
|
||||
if message_data["content"]["name"] == "stdout":
|
||||
stdout += message_data["content"]["text"]
|
||||
elif message_data["content"]["name"] == "stderr":
|
||||
stderr += message_data["content"]["text"]
|
||||
case "execute_result" | "display_data":
|
||||
data = message_data["content"]["data"]
|
||||
if "image/png" in data:
|
||||
result.append(f"data:image/png;base64,{data['image/png']}")
|
||||
elif "text/plain" in data:
|
||||
result.append(data["text/plain"])
|
||||
case "error":
|
||||
stderr += "\n".join(message_data["content"]["traceback"])
|
||||
case "status":
|
||||
if message_data["content"]["execution_state"] == "idle":
|
||||
case 'stream':
|
||||
if message_data['content']['name'] == 'stdout':
|
||||
stdout += message_data['content']['text']
|
||||
elif message_data['content']['name'] == 'stderr':
|
||||
stderr += message_data['content']['text']
|
||||
case 'execute_result' | 'display_data':
|
||||
data = message_data['content']['data']
|
||||
if 'image/png' in data:
|
||||
result.append(f'data:image/png;base64,{data["image/png"]}')
|
||||
elif 'text/plain' in data:
|
||||
result.append(data['text/plain'])
|
||||
case 'error':
|
||||
stderr += '\n'.join(message_data['content']['traceback'])
|
||||
case 'status':
|
||||
if message_data['content']['execution_state'] == 'idle':
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
stderr += "\nExecution timed out."
|
||||
stderr += '\nExecution timed out.'
|
||||
break
|
||||
self.result.stdout = stdout.strip()
|
||||
self.result.stderr = stderr.strip()
|
||||
self.result.result = "\n".join(result).strip() if result else ""
|
||||
self.result.result = '\n'.join(result).strip() if result else ''
|
||||
|
||||
|
||||
async def execute_code_jupyter(
|
||||
base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60
|
||||
base_url: str, code: str, token: str = '', password: str = '', timeout: int = 60
|
||||
) -> dict:
|
||||
async with JupyterCodeExecuter(
|
||||
base_url, code, token, password, timeout
|
||||
) as executor:
|
||||
async with JupyterCodeExecuter(base_url, code, token, password, timeout) as executor:
|
||||
result = await executor.run()
|
||||
return result.model_dump()
|
||||
|
||||
Reference in New Issue
Block a user