diff --git a/Cargo.lock b/Cargo.lock index 5e5a11480..7ac0094c7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2754,7 +2754,6 @@ dependencies = [ "slack_client_rs", "svi", "tokio", - "tokio-tungstenite 0.27.0", "tokio-util", "toml", "toml_pretty", @@ -3657,10 +3656,8 @@ dependencies = [ "resolver_api", "serde", "serde_json", - "serde_qs", "serror", "tokio", - "tokio-tungstenite 0.27.0", "tracing", "transport", "uuid", diff --git a/bin/core/Cargo.toml b/bin/core/Cargo.toml index bdc42dc73..31b4e8052 100644 --- a/bin/core/Cargo.toml +++ b/bin/core/Cargo.toml @@ -39,7 +39,6 @@ slack.workspace = true svi.workspace = true # external aws-credential-types.workspace = true -tokio-tungstenite.workspace = true english-to-cron.workspace = true openidconnect.workspace = true jsonwebtoken.workspace = true diff --git a/bin/core/src/monitor/mod.rs b/bin/core/src/monitor/mod.rs index 2df06a444..9ec8de7ca 100644 --- a/bin/core/src/monitor/mod.rs +++ b/bin/core/src/monitor/mod.rs @@ -112,7 +112,7 @@ async fn refresh_server_cache(ts: i64) { return; } }; - periphery_client::connection::manage_outbound_connections(&servers) + periphery_client::connection::manage_client_connections(&servers) .await; let futures = servers.into_iter().map(|server| async move { update_cache_for_server(&server, false).await; diff --git a/bin/core/src/ws/mod.rs b/bin/core/src/ws/mod.rs index a10155613..8f84690a5 100644 --- a/bin/core/src/ws/mod.rs +++ b/bin/core/src/ws/mod.rs @@ -5,7 +5,7 @@ use crate::{ use anyhow::anyhow; use axum::{ Router, - extract::ws::{CloseFrame, Message, Utf8Bytes, WebSocket}, + extract::ws::{Message, WebSocket}, routing::get, }; use bytes::Bytes; @@ -14,13 +14,10 @@ use komodo_client::{ entities::{server::Server, user::User}, ws::WsLoginMessage, }; -use tokio::{ - net::TcpStream, - sync::mpsc::{Receiver, Sender}, -}; -use tokio_tungstenite::{ - MaybeTlsStream, WebSocketStream, tungstenite, +use periphery_client::{ + PeripheryClient, api::terminal::DisconnectTerminal, }; +use tokio::sync::mpsc::{Receiver, Sender}; use tokio_util::sync::CancellationToken; use transport::{ MessageState, @@ -155,115 +152,39 @@ async fn handle_container_terminal( trace!("connecting to periphery container exec websocket"); - let periphery_socket = match periphery - .connect_container_exec(container, shell) - .await - { - Ok(ws) => ws, - Err(e) => { - debug!( - "Failed connect to periphery container exec websocket | {e:#}" - ); - let _ = client_socket - .send(Message::text(format!("ERROR: {e:#}"))) - .await; - let _ = client_socket.close().await; - return; - } - }; + let (periphery_connection_id, periphery_sender, periphery_receiver) = + match periphery.connect_container_exec(container, shell).await { + Ok(ws) => ws, + Err(e) => { + debug!( + "Failed connect to periphery container exec websocket | {e:#}" + ); + let _ = client_socket + .send(Message::text(format!("ERROR: {e:#}"))) + .await; + let _ = client_socket.close().await; + return; + } + }; trace!("connected to periphery container exec websocket"); - core_periphery_forward_ws(client_socket, periphery_socket).await + forward_ws_channel( + periphery, + client_socket, + periphery_connection_id, + periphery_sender, + periphery_receiver, + ) + .await } -async fn core_periphery_forward_ws( - client_socket: axum::extract::ws::WebSocket, - periphery_socket: WebSocketStream>, -) { - let (mut periphery_send, mut periphery_receive) = - periphery_socket.split(); - let (mut core_send, mut core_receive) = client_socket.split(); - let cancel = CancellationToken::new(); - - trace!("starting ws exchange"); - - let core_to_periphery = async { - loop { - let res = tokio::select! { - res = core_receive.next() => res, - _ = cancel.cancelled() => { - trace!("core to periphery read: cancelled from inside"); - break; - } - }; - match res { - Some(Ok(msg)) => { - if let Err(e) = - periphery_send.send(axum_to_tungstenite(msg)).await - { - debug!("Failed to send terminal message | {e:?}",); - cancel.cancel(); - break; - }; - } - Some(Err(_e)) => { - cancel.cancel(); - break; - } - None => { - cancel.cancel(); - break; - } - } - } - }; - - let periphery_to_core = async { - loop { - let res = tokio::select! { - res = periphery_receive.next() => res, - _ = cancel.cancelled() => { - trace!("periphery to core read: cancelled from inside"); - break; - } - }; - match res { - Some(Ok(msg)) => { - if let Err(e) = - core_send.send(tungstenite_to_axum(msg)).await - { - debug!("{e:?}"); - cancel.cancel(); - break; - }; - } - Some(Err(e)) => { - let _ = core_send - .send(Message::text(format!( - "ERROR: Failed to receive message from periphery | {e:?}" - ))) - .await; - cancel.cancel(); - break; - } - None => { - let _ = core_send.send(Message::text("STREAM EOF")).await; - cancel.cancel(); - break; - } - } - } - }; - - tokio::join!(core_to_periphery, periphery_to_core); -} - -async fn core_periphery_forward_ws_channel( +async fn forward_ws_channel( + periphery: PeripheryClient, client_socket: axum::extract::ws::WebSocket, periphery_connection_id: Uuid, - periphery_send: Sender, - mut periphery_receive: Receiver, + periphery_sender: Sender, + mut periphery_receiver: Receiver, ) { let (mut core_send, mut core_receive) = client_socket.split(); let cancel = CancellationToken::new(); @@ -281,7 +202,7 @@ async fn core_periphery_forward_ws_channel( }; match res { Some(Ok(Message::Binary(data))) => { - if let Err(e) = periphery_send + if let Err(e) = periphery_sender .send(to_transport_bytes( data.into(), periphery_connection_id, @@ -296,7 +217,7 @@ async fn core_periphery_forward_ws_channel( } Some(Ok(Message::Text(data))) => { let data: Bytes = data.into(); - if let Err(e) = periphery_send + if let Err(e) = periphery_sender .send(to_transport_bytes( data.into(), periphery_connection_id, @@ -331,7 +252,7 @@ async fn core_periphery_forward_ws_channel( let periphery_to_core = async { loop { let res = tokio::select! { - res = periphery_receive.recv() => res.map(data_from_transport_bytes), + res = periphery_receiver.recv() => res.map(data_from_transport_bytes), _ = cancel.cancelled() => { trace!("periphery to core read: cancelled from inside"); break; @@ -358,44 +279,15 @@ async fn core_periphery_forward_ws_channel( }; tokio::join!(core_to_periphery, periphery_to_core); -} -fn axum_to_tungstenite(msg: Message) -> tungstenite::Message { - match msg { - Message::Text(text) => tungstenite::Message::Text( - // TODO: improve this conversion cost from axum ws library - tungstenite::Utf8Bytes::from(text.to_string()), - ), - Message::Binary(bytes) => tungstenite::Message::Binary(bytes), - Message::Ping(bytes) => tungstenite::Message::Ping(bytes), - Message::Pong(bytes) => tungstenite::Message::Pong(bytes), - Message::Close(close_frame) => { - tungstenite::Message::Close(close_frame.map(|cf| { - tungstenite::protocol::CloseFrame { - code: cf.code.into(), - reason: tungstenite::Utf8Bytes::from(cf.reason.to_string()), - } - })) - } - } -} - -fn tungstenite_to_axum(msg: tungstenite::Message) -> Message { - match msg { - tungstenite::Message::Text(text) => { - Message::Text(Utf8Bytes::from(text.to_string())) - } - tungstenite::Message::Binary(bytes) => Message::Binary(bytes), - tungstenite::Message::Ping(bytes) => Message::Ping(bytes), - tungstenite::Message::Pong(bytes) => Message::Pong(bytes), - tungstenite::Message::Close(close_frame) => { - Message::Close(close_frame.map(|cf| CloseFrame { - code: cf.code.into(), - reason: Utf8Bytes::from(cf.reason.to_string()), - })) - } - tungstenite::Message::Frame(_) => { - unreachable!() - } + if let Err(e) = periphery + .request(DisconnectTerminal { + id: periphery_connection_id, + }) + .await + { + warn!( + "Failed to disconnect Periphery terminal forwarding | {e:#}", + ) } } diff --git a/bin/core/src/ws/terminal.rs b/bin/core/src/ws/terminal.rs index 8270f3155..4a509bd61 100644 --- a/bin/core/src/ws/terminal.rs +++ b/bin/core/src/ws/terminal.rs @@ -10,7 +10,7 @@ use komodo_client::{ use crate::{ helpers::periphery_client, permission::get_check_permissions, - ws::core_periphery_forward_ws_channel, + ws::forward_ws_channel, }; #[instrument(name = "ConnectTerminal", skip(ws))] @@ -77,7 +77,8 @@ pub async fn handler( trace!("connected to periphery terminal websocket"); - core_periphery_forward_ws_channel( + forward_ws_channel( + periphery, client_socket, periphery_connection_id, periphery_sender, diff --git a/bin/periphery/src/api/compose/helpers.rs b/bin/periphery/src/api/compose/helpers.rs index 8441c7cc9..2911aa91b 100644 --- a/bin/periphery/src/api/compose/helpers.rs +++ b/bin/periphery/src/api/compose/helpers.rs @@ -18,8 +18,11 @@ use periphery_client::api::{ }; use resolver_api::Resolve; use tokio::fs; +use uuid::Uuid; -use crate::{config::periphery_config, docker::docker_login}; +use crate::{ + api::Args, config::periphery_config, docker::docker_login, +}; use super::docker_compose; @@ -149,6 +152,9 @@ pub async fn pull_or_clone_stack( let git_token = crate::helpers::git_token(git_token, &args)?; + let req_args = Args { + req_id: Uuid::new_v4(), + }; PullOrCloneRepo { args, git_token, @@ -162,7 +168,7 @@ pub async fn pull_or_clone_stack( skip_secret_interp: Default::default(), replacers: Default::default(), } - .resolve(&crate::api::Args) + .resolve(&req_args) .await .map_err(|e| e.error)?; diff --git a/bin/periphery/src/api/compose/write.rs b/bin/periphery/src/api/compose/write.rs index ba6f23cd4..2d0f8e4b8 100644 --- a/bin/periphery/src/api/compose/write.rs +++ b/bin/periphery/src/api/compose/write.rs @@ -14,8 +14,9 @@ use periphery_client::api::{ }; use resolver_api::Resolve; use tokio::fs; +use uuid::Uuid; -use crate::{config::periphery_config, helpers}; +use crate::{api::Args, config::periphery_config, helpers}; pub trait WriteStackRes { fn logs(&mut self) -> &mut Vec; @@ -151,6 +152,9 @@ async fn write_stack_linked_repo<'a>( let on_pull = (!repo.config.on_pull.is_none()) .then_some(repo.config.on_pull.clone()); + let req_args = Args { + req_id: Uuid::new_v4(), + }; let clone_res = if stack.config.reclone { CloneRepo { args, @@ -162,7 +166,7 @@ async fn write_stack_linked_repo<'a>( skip_secret_interp: repo.config.skip_secret_interp, replacers, } - .resolve(&crate::api::Args) + .resolve(&req_args) .await .map_err(|e| e.error)? } else { @@ -176,7 +180,7 @@ async fn write_stack_linked_repo<'a>( skip_secret_interp: repo.config.skip_secret_interp, replacers, } - .resolve(&crate::api::Args) + .resolve(&req_args) .await .map_err(|e| e.error)? }; @@ -236,6 +240,9 @@ async fn write_stack_inline_repo( let git_token = stack_git_token(git_token, &args, &mut res)?; + let req_args = Args { + req_id: Uuid::new_v4(), + }; let clone_res = if stack.config.reclone { CloneRepo { args, @@ -247,7 +254,7 @@ async fn write_stack_inline_repo( skip_secret_interp: Default::default(), replacers: Default::default(), } - .resolve(&crate::api::Args) + .resolve(&req_args) .await .map_err(|e| e.error)? } else { @@ -261,7 +268,7 @@ async fn write_stack_inline_repo( skip_secret_interp: Default::default(), replacers: Default::default(), } - .resolve(&crate::api::Args) + .resolve(&req_args) .await .map_err(|e| e.error)? }; diff --git a/bin/periphery/src/api/deploy.rs b/bin/periphery/src/api/deploy.rs index 248205e22..891180f18 100644 --- a/bin/periphery/src/api/deploy.rs +++ b/bin/periphery/src/api/deploy.rs @@ -33,7 +33,7 @@ impl Resolve for Deploy { stop_time = self.stop_time, ) )] - async fn resolve(self, _: &super::Args) -> serror::Result { + async fn resolve(self, args: &super::Args) -> serror::Result { let Deploy { mut deployment, stop_signal, @@ -87,7 +87,7 @@ impl Resolve for Deploy { signal: stop_signal, time: stop_time, }) - .resolve(&super::Args) + .resolve(args) .await; debug!("container stopped and removed"); diff --git a/bin/periphery/src/api/mod.rs b/bin/periphery/src/api/mod.rs index 9d4a7cedf..32a44f8d1 100644 --- a/bin/periphery/src/api/mod.rs +++ b/bin/periphery/src/api/mod.rs @@ -14,6 +14,7 @@ use periphery_client::api::{ use resolver_api::Resolve; use response::JsonBytes; use serde::{Deserialize, Serialize}; +use uuid::Uuid; use crate::{config::periphery_config, docker::docker_client}; @@ -29,7 +30,9 @@ mod network; mod stats; mod volume; -pub struct Args; +pub struct Args { + pub req_id: Uuid, +} #[derive( Serialize, Deserialize, Debug, Clone, Resolve, EnumVariants, @@ -143,6 +146,7 @@ pub enum PeripheryRequest { ListTerminals(ListTerminals), CreateTerminal(CreateTerminal), ConnectTerminal(ConnectTerminal), + ConnectContainerExec(ConnectContainerExec), DisconnectTerminal(DisconnectTerminal), DeleteTerminal(DeleteTerminal), DeleteAllTerminals(DeleteAllTerminals), @@ -220,7 +224,7 @@ impl Resolve for GetDockerLists { #[instrument(name = "GetDockerLists", level = "debug", skip_all)] async fn resolve( self, - _: &Args, + args: &Args, ) -> serror::Result { let docker = docker_client(); let containers = @@ -235,7 +239,7 @@ impl Resolve for GetDockerLists { docker.list_images(_containers).map_err(Into::into), docker.list_volumes(_containers).map_err(Into::into), ListComposeProjects {} - .resolve(&Args) + .resolve(args) .map_err(|e| e.error.into()) ); Ok(GetDockerListsResponse { diff --git a/bin/periphery/src/api/terminal.rs b/bin/periphery/src/api/terminal.rs index 5f7c061c9..65be9a854 100644 --- a/bin/periphery/src/api/terminal.rs +++ b/bin/periphery/src/api/terminal.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use anyhow::{Context, anyhow}; use axum::{ extract::{ @@ -73,169 +75,56 @@ impl Resolve for DeleteAllTerminals { impl Resolve for ConnectTerminal { #[instrument(name = "ConnectTerminal", level = "debug")] async fn resolve(self, _: &super::Args) -> serror::Result { - let id = Uuid::new_v4(); - let ws_sender = ws_sender(); - let (sender, mut ws_receiver) = channel(1000); - let cancel = CancellationToken::new(); + if periphery_config().disable_terminals { + return Err( + anyhow!("Terminals are disabled in the periphery config") + .status_code(StatusCode::FORBIDDEN), + ); + } - terminal_channels() - .insert(id, (sender, cancel.clone())) - .await; + let id = Uuid::new_v4(); clean_up_terminals().await; let terminal = get_terminal(&self.terminal).await?; - tokio::spawn(async move { - let init_res = async { - let (a, b) = terminal.history.bytes_parts(); - if !a.is_empty() { - ws_sender - .send(to_transport_bytes( - a.into(), - id, - MessageState::Terminal, - )) - .await - .context("Failed to send history part a")?; - } - if !b.is_empty() { - ws_sender - .send(to_transport_bytes( - b.into(), - id, - MessageState::Terminal, - )) - .await - .context("Failed to send history part b")?; - } - anyhow::Ok(()) - } - .await; + tokio::spawn(handle_terminal_forwarding(id, terminal)); - if let Err(e) = init_res { - // TODO: Handle error - warn!("Failed to init terminal | {e:#}"); - return; - } + Ok(id) + } +} - let ws_read = async { - loop { - let res = tokio::select! { - res = ws_receiver.recv() => res, - _ = terminal.cancel.cancelled() => { - trace!("ws read: cancelled from outside"); - break - }, - _ = cancel.cancelled() => { - trace!("ws read: cancelled from inside"); - break; - } - }; - match res { - Some(bytes) if bytes.first() == Some(&0x00) => { - // println!("Got ws read bytes - for stdin"); - if let Err(e) = terminal - .stdin - .send(StdinMsg::Bytes(Bytes::copy_from_slice( - &bytes[1..], - ))) - .await - { - debug!("WS -> PTY channel send error: {e:}"); - terminal.cancel(); - break; - }; - } - Some(bytes) if bytes.first() == Some(&0xFF) => { - // println!("Got ws read bytes - for resize"); - if let Ok(dimensions) = serde_json::from_slice::< - ResizeDimensions, - >(&bytes[1..]) - && let Err(e) = terminal - .stdin - .send(StdinMsg::Resize(dimensions)) - .await - { - debug!("WS -> PTY channel send error: {e:}"); - terminal.cancel(); - break; - } - } - Some(bytes) => { - trace!("Got ws read text"); - if let Err(e) = - terminal.stdin.send(StdinMsg::Bytes(bytes)).await - { - debug!("WS -> PTY channel send error: {e:?}"); - terminal.cancel(); - break; - }; - } - None => { - debug!("Got ws read none"); - cancel.cancel(); - break; - } - } - } - }; +impl Resolve for ConnectContainerExec { + #[instrument(name = "ConnectContainerExec", level = "debug")] + async fn resolve(self, _: &super::Args) -> serror::Result { + let id = Uuid::new_v4(); - let ws_write = async { - let mut stdout = terminal.stdout.resubscribe(); - loop { - let res = tokio::select! { - res = stdout.recv() => res.context("Failed to get message over stdout receiver"), - _ = terminal.cancel.cancelled() => { - trace!("ws write: cancelled from outside"); - // let _ = ws_sender.send("PTY KILLED")).await; - // if let Err(e) = ws_write.close().await { - // debug!("Failed to close ws: {e:?}"); - // }; - break - }, - _ = cancel.cancelled() => { - // let _ = ws_write.send(Message::Text(Utf8Bytes::from_static("WS KILLED"))).await; - // if let Err(e) = ws_write.close().await { - // debug!("Failed to close ws: {e:?}"); - // }; - break - } - }; - match res { - Ok(bytes) => { - if let Err(e) = ws_sender - .send(to_transport_bytes( - bytes.into(), - id, - MessageState::Terminal, - )) - .await - { - debug!("Failed to send to WS: {e:?}"); - cancel.cancel(); - break; - } - } - Err(e) => { - debug!("PTY -> WS channel read error: {e:?}"); - let _ = ws_sender - .send(to_transport_bytes( - format!("ERROR: {e:#}").into(), - id, - MessageState::Terminal, - )) - .await; - terminal.cancel(); - break; - } - } - } - }; + if periphery_config().disable_container_exec { + return Err( + anyhow!("Container exec is disabled in the periphery config") + .into(), + ); + } - tokio::join!(ws_read, ws_write); + let ConnectContainerExec { container, shell } = self; - clean_up_terminals().await; - }); + if container.contains("&&") || shell.contains("&&") { + return Err( + anyhow!( + "The use of '&&' is forbidden in the container name or shell" + ) + .into(), + ); + } + // Create (recreate if shell changed) + let terminal = create_terminal( + container.clone(), + format!("docker exec -it {container} {shell}"), + TerminalRecreateMode::DifferentCommand, + ) + .await + .context("Failed to create terminal for container exec")?; + + tokio::spawn(handle_terminal_forwarding(id, terminal)); Ok(id) } @@ -244,10 +133,13 @@ impl Resolve for ConnectTerminal { impl Resolve for DisconnectTerminal { #[instrument(name = "DisconnectTerminal", level = "debug")] async fn resolve(self, _: &super::Args) -> serror::Result { + // TODO: Remove logs + info!("Disconnect called for {}", self.id); if let Some((_, cancel)) = - terminal_channels().remove(&self.uuid).await + terminal_channels().remove(&self.id).await { cancel.cancel(); + info!("Cancel called for {}", self.id); } Ok(NoData {}) } @@ -265,17 +157,164 @@ impl Resolve for CreateTerminalAuthToken { } } -pub async fn connect_terminal( - Query(query): Query, - ws: WebSocketUpgrade, -) -> serror::Result { - if periphery_config().disable_terminals { - return Err( - anyhow!("Terminals are disabled in the periphery config") - .status_code(StatusCode::FORBIDDEN), - ); +async fn handle_terminal_forwarding( + id: Uuid, + terminal: Arc, +) { + let ws_sender = ws_sender(); + let (sender, mut ws_receiver) = channel(1000); + let cancel = CancellationToken::new(); + + terminal_channels() + .insert(id, (sender, cancel.clone())) + .await; + + let init_res = async { + let (a, b) = terminal.history.bytes_parts(); + if !a.is_empty() { + ws_sender + .send(to_transport_bytes( + a.into(), + id, + MessageState::Terminal, + )) + .await + .context("Failed to send history part a")?; + } + if !b.is_empty() { + ws_sender + .send(to_transport_bytes( + b.into(), + id, + MessageState::Terminal, + )) + .await + .context("Failed to send history part b")?; + } + anyhow::Ok(()) } - handle_terminal_websocket(query, ws).await + .await; + + if let Err(e) = init_res { + // TODO: Handle error + warn!("Failed to init terminal | {e:#}"); + return; + } + + let ws_read = async { + loop { + let res = tokio::select! { + res = ws_receiver.recv() => res, + _ = terminal.cancel.cancelled() => { + trace!("ws read: cancelled from outside"); + break + }, + _ = cancel.cancelled() => { + trace!("ws read: cancelled from inside"); + break; + } + }; + match res { + Some(bytes) if bytes.first() == Some(&0x00) => { + // println!("Got ws read bytes - for stdin"); + if let Err(e) = terminal + .stdin + .send(StdinMsg::Bytes(Bytes::copy_from_slice( + &bytes[1..], + ))) + .await + { + debug!("WS -> PTY channel send error: {e:}"); + terminal.cancel(); + break; + }; + } + Some(bytes) if bytes.first() == Some(&0xFF) => { + // println!("Got ws read bytes - for resize"); + if let Ok(dimensions) = + serde_json::from_slice::(&bytes[1..]) + && let Err(e) = + terminal.stdin.send(StdinMsg::Resize(dimensions)).await + { + debug!("WS -> PTY channel send error: {e:}"); + terminal.cancel(); + break; + } + } + Some(bytes) => { + trace!("Got ws read text"); + if let Err(e) = + terminal.stdin.send(StdinMsg::Bytes(bytes)).await + { + debug!("WS -> PTY channel send error: {e:?}"); + terminal.cancel(); + break; + }; + } + None => { + debug!("Got ws read none"); + cancel.cancel(); + break; + } + } + } + }; + + let ws_write = async { + let mut stdout = terminal.stdout.resubscribe(); + loop { + let res = tokio::select! { + res = stdout.recv() => res.context("Failed to get message over stdout receiver"), + _ = terminal.cancel.cancelled() => { + trace!("ws write: cancelled from outside"); + // let _ = ws_sender.send("PTY KILLED")).await; + // if let Err(e) = ws_write.close().await { + // debug!("Failed to close ws: {e:?}"); + // }; + break + }, + _ = cancel.cancelled() => { + // let _ = ws_write.send(Message::Text(Utf8Bytes::from_static("WS KILLED"))).await; + // if let Err(e) = ws_write.close().await { + // debug!("Failed to close ws: {e:?}"); + // }; + break + } + }; + match res { + Ok(bytes) => { + if let Err(e) = ws_sender + .send(to_transport_bytes( + bytes.into(), + id, + MessageState::Terminal, + )) + .await + { + debug!("Failed to send to WS: {e:?}"); + cancel.cancel(); + break; + } + } + Err(e) => { + debug!("PTY -> WS channel read error: {e:?}"); + let _ = ws_sender + .send(to_transport_bytes( + format!("ERROR: {e:#}").into(), + id, + MessageState::Terminal, + )) + .await; + terminal.cancel(); + break; + } + } + } + }; + + tokio::join!(ws_read, ws_write); + + clean_up_terminals().await; } pub async fn connect_container_exec( diff --git a/bin/periphery/src/connection.rs b/bin/periphery/src/connection.rs index 5c756f4a9..ee38fd438 100644 --- a/bin/periphery/src/connection.rs +++ b/bin/periphery/src/connection.rs @@ -18,7 +18,7 @@ use transport::{ }; use uuid::Uuid; -use crate::api::PeripheryRequest; +use crate::api::{Args, PeripheryRequest}; static WS_SENDER: OnceLock> = OnceLock::new(); pub fn ws_sender() -> &'static Sender { @@ -51,7 +51,7 @@ pub fn init_response_channel() { pub async fn inbound_connection( ws: WebSocketUpgrade, ) -> serror::Result { - transport::server::inbound_connection( + transport::server::handle_server_connection( ws, PeripheryTransportHandler, ws_receiver(), @@ -84,7 +84,7 @@ impl TransportHandler for PeripheryTransportHandler { } } -fn handle_request(id: Uuid, bytes: Bytes) { +fn handle_request(req_id: Uuid, bytes: Bytes) { tokio::spawn(async move { let request = match data_from_transport_bytes(bytes) { Ok(req) if !req.is_empty() => req, @@ -105,7 +105,7 @@ fn handle_request(id: Uuid, bytes: Bytes) { let resolve_response = async { let (state, data) = - match request.resolve(&crate::api::Args).await { + match request.resolve(&Args { req_id }).await { Ok(JsonBytes::Ok(res)) => (MessageState::Successful, res), Ok(JsonBytes::Err(e)) => ( MessageState::Failed, @@ -118,8 +118,9 @@ fn handle_request(id: Uuid, bytes: Bytes) { (MessageState::Failed, serialize_error_bytes(&e.error)) } }; - if let Err(e) = - ws_sender().send(to_transport_bytes(data, id, state)).await + if let Err(e) = ws_sender() + .send(to_transport_bytes(data, req_id, state)) + .await { error!("Failed to send response over channel | {e:?}"); } @@ -131,7 +132,7 @@ fn handle_request(id: Uuid, bytes: Bytes) { if let Err(e) = ws_sender() .send(to_transport_bytes( Vec::new(), - id, + req_id, MessageState::InProgress, )) .await @@ -150,6 +151,7 @@ fn handle_request(id: Uuid, bytes: Bytes) { pub type TerminalChannels = CloneCache, CancellationToken)>; + pub fn terminal_channels() -> &'static TerminalChannels { static TERMINAL_CHANNELS: OnceLock = OnceLock::new(); diff --git a/bin/periphery/src/router.rs b/bin/periphery/src/router.rs index be45fc3d5..6800ffb91 100644 --- a/bin/periphery/src/router.rs +++ b/bin/periphery/src/router.rs @@ -15,16 +15,10 @@ use crate::config::periphery_config; pub fn router() -> Router { Router::new() - // .merge( - // Router::new() - // .route("/", post(handler)) - // .layer(middleware::from_fn(guard_request_by_passkey)), - // ) .route("/", get(crate::connection::inbound_connection)) .nest( "/terminal", Router::new() - .route("/", get(crate::api::terminal::connect_terminal)) .route( "/container", get(crate::api::terminal::connect_container_exec), @@ -43,41 +37,6 @@ pub fn router() -> Router { .layer(middleware::from_fn(guard_request_by_ip)) } -// async fn handler( -// Json(request): Json, -// ) -> serror::Result { -// let req_id = Uuid::new_v4(); - -// let res = tokio::spawn(task(req_id, request)) -// .await -// .context("task handler spawn error"); - -// if let Err(e) = &res { -// warn!("request {req_id} spawn error: {e:#}"); -// } - -// res? -// } - -// async fn task( -// req_id: Uuid, -// request: crate::api::PeripheryRequest, -// ) -> serror::Result { -// let variant = request.extract_variant(); - -// // let res = request.resolve(&crate::api::Args).await.map(|res| res.0); - -// // if let Err(e) = &res { -// // warn!( -// // "request {req_id} | type: {variant:?} | error: {:#}", -// // e.error -// // ); -// // } - -// // res -// todo!() -// } - async fn guard_request_by_passkey( req: Request, next: Next, diff --git a/bin/periphery/src/terminal.rs b/bin/periphery/src/terminal.rs index ede42e5ec..ce10576d4 100644 --- a/bin/periphery/src/terminal.rs +++ b/bin/periphery/src/terminal.rs @@ -30,7 +30,7 @@ pub async fn create_terminal( name: String, command: String, recreate: TerminalRecreateMode, -) -> anyhow::Result<()> { +) -> anyhow::Result> { trace!( "CreateTerminal: {name} | command: {command} | recreate: {recreate:?}" ); @@ -40,7 +40,7 @@ pub async fn create_terminal( && let Some(terminal) = terminals.get(&name) { if terminal.command == command { - return Ok(()); + return Ok(terminal.clone()); } else if matches!(recreate, Never) { return Err(anyhow!( "Terminal {name} already exists, but has command {} instead of {command}", @@ -48,16 +48,15 @@ pub async fn create_terminal( )); } } - if let Some(prev) = terminals.insert( - name, + let terminal = Arc::new( Terminal::new(command) .await - .context("Failed to init terminal")? - .into(), - ) { + .context("Failed to init terminal")?, + ); + if let Some(prev) = terminals.insert(name, terminal.clone()) { prev.cancel(); } - Ok(()) + Ok(terminal) } pub async fn delete_terminal(name: &str) { diff --git a/client/periphery/rs/Cargo.toml b/client/periphery/rs/Cargo.toml index 2179d796b..e2063b1a2 100644 --- a/client/periphery/rs/Cargo.toml +++ b/client/periphery/rs/Cargo.toml @@ -18,9 +18,7 @@ cache.workspace = true resolver_api.workspace = true serror.workspace = true # external -tokio-tungstenite.workspace = true serde_json.workspace = true -serde_qs.workspace = true reqwest.workspace = true tracing.workspace = true anyhow.workspace = true diff --git a/client/periphery/rs/src/api/terminal.rs b/client/periphery/rs/src/api/terminal.rs index 525a235db..a360ab42c 100644 --- a/client/periphery/rs/src/api/terminal.rs +++ b/client/periphery/rs/src/api/terminal.rs @@ -44,12 +44,26 @@ pub struct ConnectTerminal { // +#[derive(Serialize, Deserialize, Debug, Clone, Resolve)] +#[response(Uuid)] +#[error(serror::Error)] +pub struct ConnectContainerExec { + /// The name of the container to connect to. + pub container: String, + /// The shell to start inside container. + /// Default: `sh` + #[serde(default = "default_container_shell")] + pub shell: String, +} + +// + #[derive(Serialize, Deserialize, Debug, Clone, Resolve)] #[response(NoData)] #[error(serror::Error)] pub struct DisconnectTerminal { - /// The connection of the terminal to disconnect from - pub uuid: Uuid, + /// The connection id of the terminal to disconnect from + pub id: Uuid, } // diff --git a/client/periphery/rs/src/connection.rs b/client/periphery/rs/src/connection.rs index 04f209971..a1a0c46fa 100644 --- a/client/periphery/rs/src/connection.rs +++ b/client/periphery/rs/src/connection.rs @@ -6,24 +6,17 @@ use std::{ use bytes::Bytes; use cache::CloneCache; use komodo_client::entities::server::Server; -use tokio::sync::mpsc::{self, Sender}; +use tokio::sync::mpsc::{Sender, error::SendError}; use tracing::warn; use transport::{ TransportHandler, bytes::id_from_transport_bytes, channel::{BufferedReceiver, buffered_channel}, - client::{ClientConnection, fix_ws_address}, + client::ClientConnection, }; use uuid::Uuid; -// Server id => Channel sender map -pub type ResponseChannels = - CloneCache>>>; -pub fn periphery_response_channels() -> &'static ResponseChannels { - static RESPONSE_CHANNELS: OnceLock = - OnceLock::new(); - RESPONSE_CHANNELS.get_or_init(Default::default) -} +use crate::periphery_response_channels; pub struct CoreTransportHandler { response_channels: Arc>>, @@ -61,8 +54,25 @@ impl TransportHandler for CoreTransportHandler { } } +/// - Fixes server addresses: +/// - `server.domain` => `wss://server.domain` +/// - `http://server.domain` => `ws://server.domain` +/// - `https://server.domain` => `wss://server.domain` +fn fix_ws_address(address: &str) -> String { + if address.starts_with("ws://") || address.starts_with("wss://") { + return address.to_string(); + } + if address.starts_with("http://") { + return address.replace("http://", "ws://"); + } + if address.starts_with("https://") { + return address.replace("https://", "wss://"); + } + format!("wss://{address}") +} + /// Managed connections to exactly those specified by specs (ServerId -> Address) -pub async fn manage_outbound_connections(servers: &[Server]) { +pub async fn manage_client_connections(servers: &[Server]) { let periphery_connections = periphery_connections(); let periphery_response_channels = periphery_response_channels(); @@ -77,7 +87,7 @@ pub async fn manage_outbound_connections(servers: &[Server]) { periphery_connections.get_entries().await { if !specs.contains_key(&server_id) { - connection.client.cancel(); + connection.connection.cancel(); periphery_connections.remove(&server_id).await; periphery_response_channels.remove(&server_id).await; } @@ -92,7 +102,7 @@ pub async fn manage_outbound_connections(servers: &[Server]) { // All other cases re-spawn connection _ => { if let Err(e) = - spawn_outbound_connection(server_id.clone(), address).await + spawn_client_connection(server_id.clone(), address).await { warn!( "Failed to spawn new connnection for {server_id} | {e:#}" @@ -104,7 +114,7 @@ pub async fn manage_outbound_connections(servers: &[Server]) { } // Assumes address already wss formatted -async fn spawn_outbound_connection( +async fn spawn_client_connection( server_id: String, address: String, ) -> anyhow::Result<()> { @@ -112,17 +122,18 @@ async fn spawn_outbound_connection( let (connection, mut request_receiver) = PeripheryConnection::new(address.clone()); + if let Some(existing_connection) = periphery_connections() .insert(server_id, connection.clone()) .await { - existing_connection.client.cancel(); + existing_connection.connection.cancel(); } tokio::spawn(async move { - transport::client::handle_reconnecting_websocket( + transport::client::handle_client_connection( &address, - &connection.client, + &connection.connection, &transport, &mut request_receiver, ) @@ -135,6 +146,7 @@ async fn spawn_outbound_connection( /// server id => connection pub type PeripheryConnections = CloneCache>; + pub fn periphery_connections() -> &'static PeripheryConnections { static CONNECTIONS: OnceLock = OnceLock::new(); @@ -143,9 +155,9 @@ pub fn periphery_connections() -> &'static PeripheryConnections { #[derive(Debug)] pub struct PeripheryConnection { - address: String, - pub request_sender: mpsc::Sender, - pub client: ClientConnection, + pub address: String, + pub request_sender: Sender, + pub connection: ClientConnection, } impl PeripheryConnection { @@ -157,7 +169,7 @@ impl PeripheryConnection { PeripheryConnection { address, request_sender, - client: ClientConnection::new().into(), + connection: ClientConnection::new().into(), } .into(), request_receiver, @@ -167,23 +179,23 @@ impl PeripheryConnection { pub async fn send( &self, value: Bytes, - ) -> Result<(), mpsc::error::SendError> { + ) -> Result<(), SendError> { self.request_sender.send(value).await } pub fn connected(&self) -> bool { - self.client.connected() + self.connection.connected() } pub async fn error(&self) -> Option { - self.client.error().await + self.connection.error().await } pub async fn set_error(&self, e: anyhow::Error) { - self.client.set_error(e).await + self.connection.set_error(e).await } pub async fn clear_error(&self) { - self.client.clear_error().await + self.connection.clear_error().await } } diff --git a/client/periphery/rs/src/lib.rs b/client/periphery/rs/src/lib.rs index 16006e104..cce8c6076 100644 --- a/client/periphery/rs/src/lib.rs +++ b/client/periphery/rs/src/lib.rs @@ -1,5 +1,7 @@ -use std::sync::OnceLock; +use std::sync::{Arc, OnceLock}; +use bytes::Bytes; +use cache::CloneCache; use resolver_api::HasResponse; use serde::{Serialize, de::DeserializeOwned}; @@ -10,6 +12,18 @@ mod request; mod terminal; pub use request::request; +use tokio::sync::mpsc::Sender; +use uuid::Uuid; + +// Server id => Channel sender map +pub type ResponseChannels = + CloneCache>>>; + +pub fn periphery_response_channels() -> &'static ResponseChannels { + static RESPONSE_CHANNELS: OnceLock = + OnceLock::new(); + RESPONSE_CHANNELS.get_or_init(Default::default) +} fn periphery_http_client() -> &'static reqwest::Client { static PERIPHERY_HTTP_CLIENT: OnceLock = diff --git a/client/periphery/rs/src/request.rs b/client/periphery/rs/src/request.rs index f39d174ed..eb6259d5a 100644 --- a/client/periphery/rs/src/request.rs +++ b/client/periphery/rs/src/request.rs @@ -11,8 +11,8 @@ use transport::{ }; use uuid::Uuid; -use crate::connection::{ - periphery_connections, periphery_response_channels, +use crate::{ + connection::periphery_connections, periphery_response_channels, }; #[tracing::instrument(name = "PeripheryRequest", level = "debug")] diff --git a/client/periphery/rs/src/terminal.rs b/client/periphery/rs/src/terminal.rs index c8c22b8db..e0f64dcd5 100644 --- a/client/periphery/rs/src/terminal.rs +++ b/client/periphery/rs/src/terminal.rs @@ -2,22 +2,15 @@ use anyhow::Context; use bytes::Bytes; use komodo_client::terminal::TerminalStreamResponse; use reqwest::RequestBuilder; -use tokio::{ - net::TcpStream, - sync::mpsc::{Receiver, Sender, channel}, -}; -use tokio_tungstenite::{MaybeTlsStream, WebSocketStream}; +use tokio::sync::mpsc::{Receiver, Sender, channel}; use uuid::Uuid; use crate::{ - PeripheryClient, - api::terminal::*, - connection::{periphery_connections, periphery_response_channels}, + PeripheryClient, api::terminal::*, + connection::periphery_connections, periphery_response_channels, }; impl PeripheryClient { - /// Handles ws connect and login. - /// Does not handle reconnect. pub async fn connect_terminal( &self, terminal: String, @@ -34,7 +27,35 @@ impl PeripheryClient { let id = self .request(ConnectTerminal { terminal }) .await - .context("Failed to create terminal connectionn")?; + .context("Failed to create terminal connection")?; + + let response_channels = periphery_response_channels() + .get_or_insert_default(&self.id) + .await; + let (response_sender, response_receiever) = channel(1000); + response_channels.insert(id, response_sender).await; + + Ok((id, connection.request_sender.clone(), response_receiever)) + } + + pub async fn connect_container_exec( + &self, + container: String, + shell: String, + ) -> anyhow::Result<(Uuid, Sender, Receiver)> { + tracing::trace!( + "request | type: ConnectContainerExec | container name: {container} | shell: {shell}", + ); + + let connection = + periphery_connections().get(&self.id).await.with_context( + || format!("No connection found for server {}", self.id), + )?; + + let id = self + .request(ConnectContainerExec { container, shell }) + .await + .context("Failed to create conntainer exec connection")?; let response_channels = periphery_response_channels() .get_or_insert_default(&self.id) @@ -74,37 +95,6 @@ impl PeripheryClient { terminal_stream_response(req).await } - /// Handles ws connect and login. - /// Does not handle reconnect. - pub async fn connect_container_exec( - &self, - container: String, - shell: String, - ) -> anyhow::Result>> { - tracing::trace!( - "request | type: ConnectContainerExec | container name: {container} | shell: {shell}", - ); - - let token = self - .request(CreateTerminalAuthToken {}) - .await - .context("Failed to create terminal auth token")?; - - let query_str = serde_qs::to_string(&ConnectContainerExecQuery { - token: token.token, - container, - shell, - }) - .context("Failed to serialize query string")?; - - let url = format!( - "{}/terminal/container?{query_str}", - self.address.replacen("http", "ws", 1) - ); - - transport::client::connect_websocket(&url).await - } - /// Executes command on specified container, /// and streams the response ending in [KOMODO_EXIT_CODE][komodo_client::entities::KOMODO_EXIT_CODE] /// sentinal value as the expected final line of the stream. diff --git a/lib/transport/src/client.rs b/lib/transport/src/client.rs index cc5497022..8ea949568 100644 --- a/lib/transport/src/client.rs +++ b/lib/transport/src/client.rs @@ -19,24 +19,8 @@ use tracing::{info, warn}; use crate::{TransportHandler, channel::BufferedReceiver}; -/// Fixes server addresses: -/// server.domain => wss://server.domain -/// http://server.domain => ws://server.domain -/// https://server.domain => wss://server.domain -pub fn fix_ws_address(address: &str) -> String { - if address.starts_with("ws://") || address.starts_with("wss://") { - return address.to_string(); - } - if address.starts_with("http://") { - return address.replace("http://", "ws://"); - } - if address.starts_with("https://") { - return address.replace("https://", "wss://"); - } - format!("wss://{address}") -} - -pub async fn handle_reconnecting_websocket< +/// Handles client side / outbound connection +pub async fn handle_client_connection< T: TransportHandler + Send + Sync + 'static, >( address: &str, diff --git a/lib/transport/src/lib.rs b/lib/transport/src/lib.rs index 51715ad5a..426f6d5ba 100644 --- a/lib/transport/src/lib.rs +++ b/lib/transport/src/lib.rs @@ -20,3 +20,5 @@ pub trait TransportHandler { bytes: Bytes, ) -> impl Future + Send; } + + diff --git a/lib/transport/src/server.rs b/lib/transport/src/server.rs index d07903dda..b00ba9f59 100644 --- a/lib/transport/src/server.rs +++ b/lib/transport/src/server.rs @@ -11,7 +11,8 @@ use tracing::{error, warn}; use crate::{TransportHandler, channel::BufferedReceiver}; -pub fn inbound_connection< +/// Handles server side / inbound connection +pub fn handle_server_connection< T: TransportHandler + Send + Sync + 'static, >( ws: WebSocketUpgrade,