forked from github-starred/komodo
working with safer transport message api
This commit is contained in:
@@ -2,14 +2,13 @@ use std::{sync::Arc, time::Duration};
|
||||
|
||||
use anyhow::{Context, anyhow};
|
||||
use periphery_client::CONNECTION_RETRY_SECONDS;
|
||||
use serror::{deserialize_error_bytes, serialize_error_bytes};
|
||||
use transport::{
|
||||
MessageState,
|
||||
auth::{
|
||||
AddressConnectionIdentifiers, ClientLoginFlow,
|
||||
AUTH_TIMEOUT, AddressConnectionIdentifiers, ClientLoginFlow,
|
||||
ConnectionIdentifiers,
|
||||
},
|
||||
fix_ws_address,
|
||||
message::Message,
|
||||
websocket::{Websocket, tungstenite::TungsteniteWebsocket},
|
||||
};
|
||||
|
||||
@@ -105,9 +104,12 @@ impl PeripheryConnection {
|
||||
) -> anyhow::Result<()> {
|
||||
// Get the required auth type
|
||||
let bytes = socket
|
||||
.recv_bytes()
|
||||
.recv_result()
|
||||
.with_timeout(Duration::from_secs(2))
|
||||
.await?
|
||||
.await
|
||||
.flatten()
|
||||
.flatten()
|
||||
.and_then(Message::into_data)
|
||||
.context("Failed to receive login type indicator")?;
|
||||
|
||||
match bytes.iter().as_slice() {
|
||||
@@ -132,7 +134,7 @@ async fn handle_passkey_login(
|
||||
passkey: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let res = async {
|
||||
let mut passkey = if passkey.is_empty() {
|
||||
let passkey = if passkey.is_empty() {
|
||||
core_config()
|
||||
.passkey
|
||||
.as_deref()
|
||||
@@ -142,34 +144,27 @@ async fn handle_passkey_login(
|
||||
} else {
|
||||
passkey.as_bytes().to_vec()
|
||||
};
|
||||
passkey.push(MessageState::Successful.as_byte());
|
||||
|
||||
socket
|
||||
.send(passkey.into())
|
||||
.send(passkey)
|
||||
.await
|
||||
.context("Failed to send passkey")?;
|
||||
|
||||
// Receive login state message and return based on value
|
||||
let state_msg = socket
|
||||
.recv_bytes()
|
||||
socket
|
||||
.recv_result()
|
||||
.with_timeout(AUTH_TIMEOUT)
|
||||
.await
|
||||
.flatten()
|
||||
.flatten()
|
||||
.context("Failed to receive authentication state message")?;
|
||||
let state = state_msg.last().context(
|
||||
"Authentication state message did not contain state byte",
|
||||
)?;
|
||||
match MessageState::from_byte(*state) {
|
||||
MessageState::Successful => anyhow::Ok(()),
|
||||
_ => Err(deserialize_error_bytes(
|
||||
&state_msg[..(state_msg.len() - 1)],
|
||||
)),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
.await;
|
||||
if let Err(e) = res {
|
||||
let mut bytes = serialize_error_bytes(&e);
|
||||
bytes.push(MessageState::Failed.as_byte());
|
||||
if let Err(e) = socket
|
||||
.send(bytes.into())
|
||||
.send(&e)
|
||||
.await
|
||||
.context("Failed to send login failed to client")
|
||||
{
|
||||
|
||||
@@ -7,7 +7,6 @@ use std::{
|
||||
};
|
||||
|
||||
use anyhow::anyhow;
|
||||
use bytes::Bytes;
|
||||
use cache::CloneCache;
|
||||
use database::mungos::{by_id::update_one_by_id, mongodb::bson::doc};
|
||||
use komodo_client::entities::{
|
||||
@@ -16,18 +15,15 @@ use komodo_client::entities::{
|
||||
server::Server,
|
||||
};
|
||||
use serror::serror_into_anyhow_error;
|
||||
use tokio::sync::{
|
||||
RwLock,
|
||||
mpsc::{Sender, error::SendError},
|
||||
};
|
||||
use tokio::sync::{RwLock, mpsc::error::SendError};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use transport::{
|
||||
auth::{
|
||||
ConnectionIdentifiers, LoginFlow, LoginFlowArgs,
|
||||
PublicKeyValidator,
|
||||
},
|
||||
bytes::id_from_transport_bytes,
|
||||
channel::{BufferedReceiver, buffered_channel},
|
||||
channel::{BufferedReceiver, Sender, buffered_channel},
|
||||
message::Message,
|
||||
websocket::{
|
||||
Websocket, WebsocketMessage, WebsocketReceiver as _,
|
||||
WebsocketSender as _,
|
||||
@@ -56,7 +52,7 @@ impl PeripheryConnections {
|
||||
&self,
|
||||
server_id: String,
|
||||
args: PeripheryConnectionArgs<'_>,
|
||||
) -> (Arc<PeripheryConnection>, BufferedReceiver<Bytes>) {
|
||||
) -> (Arc<PeripheryConnection>, BufferedReceiver) {
|
||||
let (connection, receiver) = if let Some(existing_connection) =
|
||||
self.0.remove(&server_id).await
|
||||
{
|
||||
@@ -110,12 +106,11 @@ impl PublicKeyValidator for PeripheryConnectionArgs<'_> {
|
||||
self.id.to_string(),
|
||||
Some(public_key.clone()),
|
||||
);
|
||||
let e = anyhow!("{public_key} is invalid")
|
||||
anyhow!("{public_key} is invalid")
|
||||
.context(
|
||||
"Ensure public key matches configured Periphery Public Key",
|
||||
)
|
||||
.context("Core failed to validate Periphery public key");
|
||||
e
|
||||
.context("Core failed to validate Periphery public key")
|
||||
};
|
||||
let core_to_periphery = self.address.is_some();
|
||||
match (self.periphery_public_key, core_to_periphery) {
|
||||
@@ -242,7 +237,7 @@ pub struct PeripheryConnection {
|
||||
/// The connection args
|
||||
pub args: OwnedPeripheryConnectionArgs,
|
||||
/// Send and receive bytes over the connection socket.
|
||||
pub sender: Sender<Bytes>,
|
||||
pub sender: Sender,
|
||||
/// Cancel the connection
|
||||
pub cancel: CancellationToken,
|
||||
/// Whether Periphery is currently connected.
|
||||
@@ -258,7 +253,7 @@ pub struct PeripheryConnection {
|
||||
impl PeripheryConnection {
|
||||
pub fn new(
|
||||
args: impl Into<OwnedPeripheryConnectionArgs>,
|
||||
) -> (Arc<PeripheryConnection>, BufferedReceiver<Bytes>) {
|
||||
) -> (Arc<PeripheryConnection>, BufferedReceiver) {
|
||||
let (sender, receiever) = buffered_channel();
|
||||
(
|
||||
PeripheryConnection {
|
||||
@@ -277,7 +272,7 @@ impl PeripheryConnection {
|
||||
pub fn with_new_args(
|
||||
&self,
|
||||
args: impl Into<OwnedPeripheryConnectionArgs>,
|
||||
) -> (Arc<PeripheryConnection>, BufferedReceiver<Bytes>) {
|
||||
) -> (Arc<PeripheryConnection>, BufferedReceiver) {
|
||||
// Ensure this connection is cancelled.
|
||||
self.cancel();
|
||||
let (sender, receiever) = buffered_channel();
|
||||
@@ -315,7 +310,7 @@ impl PeripheryConnection {
|
||||
pub async fn handle_socket<W: Websocket>(
|
||||
&self,
|
||||
socket: W,
|
||||
receiver: &mut BufferedReceiver<Bytes>,
|
||||
receiver: &mut BufferedReceiver,
|
||||
) {
|
||||
let cancel = self.cancel.child_token();
|
||||
|
||||
@@ -332,7 +327,7 @@ impl PeripheryConnection {
|
||||
};
|
||||
|
||||
let message = match next {
|
||||
Some(request) => Bytes::copy_from_slice(request),
|
||||
Some(request) => request.to_message(),
|
||||
// Sender Dropped (shouldn't happen, a reference is held on 'connection').
|
||||
None => break,
|
||||
};
|
||||
@@ -358,8 +353,8 @@ impl PeripheryConnection {
|
||||
};
|
||||
|
||||
match next {
|
||||
Ok(WebsocketMessage::Binary(bytes)) => {
|
||||
self.handle_incoming_bytes(bytes).await
|
||||
Ok(WebsocketMessage::Message(message)) => {
|
||||
self.handle_incoming_message(message).await
|
||||
}
|
||||
Ok(WebsocketMessage::Close(_))
|
||||
| Ok(WebsocketMessage::Closed) => {
|
||||
@@ -380,21 +375,21 @@ impl PeripheryConnection {
|
||||
self.set_connected(false);
|
||||
}
|
||||
|
||||
pub async fn handle_incoming_bytes(&self, bytes: Bytes) {
|
||||
let id = match id_from_transport_bytes(&bytes) {
|
||||
pub async fn handle_incoming_message(&self, message: Message) {
|
||||
let channel = match message.channel() {
|
||||
Ok(res) => res,
|
||||
Err(e) => {
|
||||
// TODO: handle better
|
||||
warn!("Failed to read id | {e:#}");
|
||||
warn!("Failed to read channel | {e:#}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
let Some(channel) = self.channels.get(&id).await else {
|
||||
let Some(channel) = self.channels.get(&channel).await else {
|
||||
// TODO: handle better
|
||||
debug!("Failed to send response | No response channel found");
|
||||
return;
|
||||
};
|
||||
if let Err(e) = channel.send(bytes).await {
|
||||
if let Err(e) = channel.send(message).await {
|
||||
// TODO: handle better
|
||||
warn!("Failed to send response | Channel failure | {e:#}");
|
||||
}
|
||||
@@ -402,9 +397,9 @@ impl PeripheryConnection {
|
||||
|
||||
pub async fn send(
|
||||
&self,
|
||||
value: Bytes,
|
||||
) -> Result<(), SendError<Bytes>> {
|
||||
self.sender.send(value).await
|
||||
message: impl Into<Message>,
|
||||
) -> Result<(), SendError<Message>> {
|
||||
self.sender.send(message).await
|
||||
}
|
||||
|
||||
pub fn set_connected(&self, connected: bool) {
|
||||
|
||||
@@ -6,7 +6,6 @@ use axum::{
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::Response,
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use database::mungos::mongodb::bson::{doc, oid::ObjectId};
|
||||
use komodo_client::{
|
||||
api::write::{CreateBuilder, CreateServer, UpdateResourceMeta},
|
||||
@@ -20,11 +19,12 @@ use komodo_client::{
|
||||
use resolver_api::Resolve;
|
||||
use serror::{AddStatusCode, AddStatusCodeError};
|
||||
use transport::{
|
||||
MessageState, PeripheryConnectionQuery,
|
||||
PeripheryConnectionQuery,
|
||||
auth::{
|
||||
HeaderConnectionIdentifiers, LoginFlow, LoginFlowArgs,
|
||||
PublicKeyValidator, ServerLoginFlow,
|
||||
},
|
||||
message::MessageState,
|
||||
websocket::{Websocket, axum::AxumWebsocket},
|
||||
};
|
||||
|
||||
@@ -119,7 +119,7 @@ async fn existing_server_handler(
|
||||
format!("server={}", urlencoding::encode(&server_query));
|
||||
let mut socket = AxumWebsocket(socket);
|
||||
|
||||
if let Err(e) = socket.send(Bytes::from_owner([0])).await.context(
|
||||
if let Err(e) = socket.send([0]).await.context(
|
||||
"Failed to send the login flow indicator over connnection",
|
||||
) {
|
||||
connection.set_error(e).await;
|
||||
@@ -151,7 +151,7 @@ async fn onboard_server_handler(
|
||||
format!("server={}", urlencoding::encode(&server_query));
|
||||
let mut socket = AxumWebsocket(socket);
|
||||
|
||||
if let Err(e) = socket.send(Bytes::from_owner([1])).await.context(
|
||||
if let Err(e) = socket.send([1]).await.context(
|
||||
"Failed to send the login flow indicator over connnection",
|
||||
).context("Server onboarding error") {
|
||||
warn!("{e:#}");
|
||||
@@ -174,14 +174,14 @@ async fn onboard_server_handler(
|
||||
};
|
||||
|
||||
let res = socket
|
||||
.recv_bytes()
|
||||
.recv_result()
|
||||
.with_timeout(Duration::from_secs(2))
|
||||
.await
|
||||
.and_then(|res| {
|
||||
res.and_then(|public_key_bytes| {
|
||||
String::from_utf8(public_key_bytes.into())
|
||||
.context("Public key bytes are not valid utf8")
|
||||
})
|
||||
.flatten()
|
||||
.flatten()
|
||||
.and_then(|message| {
|
||||
String::from_utf8(message.into_data()?.into())
|
||||
.context("Public key bytes are not valid utf8")
|
||||
});
|
||||
|
||||
// Post onboarding login 1: Receive public key
|
||||
@@ -205,7 +205,7 @@ async fn onboard_server_handler(
|
||||
Err(e) => {
|
||||
warn!("{e:#}");
|
||||
if let Err(e) = socket
|
||||
.send_error(&e)
|
||||
.send(&e)
|
||||
.await
|
||||
.context("Failed to send Server creation failed to client")
|
||||
{
|
||||
@@ -217,7 +217,7 @@ async fn onboard_server_handler(
|
||||
};
|
||||
|
||||
if let Err(e) = socket
|
||||
.send(MessageState::Successful.into())
|
||||
.send(MessageState::Successful)
|
||||
.await
|
||||
.context("Failed to send Server creation successful to client")
|
||||
{
|
||||
|
||||
@@ -1,18 +1,16 @@
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use anyhow::{Context, anyhow};
|
||||
use bytes::Bytes;
|
||||
use cache::CloneCache;
|
||||
use periphery_client::api;
|
||||
use resolver_api::HasResponse;
|
||||
use serde::{Serialize, de::DeserializeOwned};
|
||||
use serde_json::json;
|
||||
use serror::deserialize_error_bytes;
|
||||
use tokio::sync::mpsc::{self, Sender};
|
||||
use tracing::warn;
|
||||
use transport::{
|
||||
MessageState,
|
||||
bytes::{from_transport_bytes, to_transport_bytes},
|
||||
channel::{Sender, channel},
|
||||
message::MessageState,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -23,7 +21,7 @@ use crate::{
|
||||
|
||||
pub mod terminal;
|
||||
|
||||
pub type ConnectionChannels = CloneCache<Uuid, Sender<Bytes>>;
|
||||
pub type ConnectionChannels = CloneCache<Uuid, Sender>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PeripheryClient {
|
||||
@@ -124,8 +122,7 @@ impl PeripheryClient {
|
||||
connection.bail_if_not_connected().await?;
|
||||
|
||||
let id = Uuid::new_v4();
|
||||
let (response_sender, mut response_receiever) =
|
||||
mpsc::channel(1000);
|
||||
let (response_sender, mut response_receiever) = channel();
|
||||
self.channels.insert(id, response_sender).await;
|
||||
|
||||
let req_type = T::req_type();
|
||||
@@ -136,7 +133,7 @@ impl PeripheryClient {
|
||||
.context("Failed to serialize request to bytes")?;
|
||||
|
||||
if let Err(e) = connection
|
||||
.send(to_transport_bytes(data, id, MessageState::Request))
|
||||
.send((data, id, MessageState::Request))
|
||||
.await
|
||||
.context("Failed to send request over channel")
|
||||
{
|
||||
@@ -147,35 +144,13 @@ impl PeripheryClient {
|
||||
|
||||
// Poll for the associated response
|
||||
loop {
|
||||
let next = tokio::select! {
|
||||
msg = response_receiever.recv() => msg,
|
||||
let (data, _, state) = tokio::select! {
|
||||
msg = response_receiever.recv_parts() => msg?,
|
||||
// Periphery will send InProgress every 5s to avoid timeout
|
||||
_ = tokio::time::sleep(Duration::from_secs(10)) => {
|
||||
return Err(anyhow!("Response timed out"));
|
||||
}
|
||||
};
|
||||
|
||||
let bytes = match next {
|
||||
Some(bytes) => bytes,
|
||||
None => {
|
||||
return Err(anyhow!(
|
||||
"Sender dropped before response was recieved"
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let (state, data) = match from_transport_bytes(bytes) {
|
||||
Ok((data, _, state)) if !data.is_empty() => (state, data),
|
||||
// Ignore no data cases
|
||||
Ok(_) => continue,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Server {} | Received invalid message | {e:#}",
|
||||
self.id
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
match state {
|
||||
// TODO: improve the allocation in .to_vec
|
||||
MessageState::Successful => {
|
||||
|
||||
@@ -12,8 +12,10 @@ use periphery_client::api::terminal::{
|
||||
ConnectContainerExec, ConnectTerminal, END_OF_OUTPUT,
|
||||
ExecuteContainerExec, ExecuteTerminal,
|
||||
};
|
||||
use tokio::sync::mpsc::{Receiver, Sender, channel};
|
||||
use transport::bytes::data_from_transport_bytes;
|
||||
use transport::{
|
||||
channel::{Receiver, Sender, channel},
|
||||
message::Message,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
@@ -24,7 +26,7 @@ impl PeripheryClient {
|
||||
pub async fn connect_terminal(
|
||||
&self,
|
||||
terminal: String,
|
||||
) -> anyhow::Result<(Uuid, Sender<Bytes>, Receiver<Bytes>)> {
|
||||
) -> anyhow::Result<(Uuid, Sender, Receiver)> {
|
||||
tracing::trace!(
|
||||
"request | type: ConnectTerminal | terminal name: {terminal}",
|
||||
);
|
||||
@@ -39,7 +41,7 @@ impl PeripheryClient {
|
||||
.await
|
||||
.context("Failed to create terminal connection")?;
|
||||
|
||||
let (sender, receiever) = channel(1024);
|
||||
let (sender, receiever) = channel();
|
||||
connection.channels.insert(id, sender).await;
|
||||
|
||||
Ok((id, connection.sender.clone(), receiever))
|
||||
@@ -49,7 +51,7 @@ impl PeripheryClient {
|
||||
&self,
|
||||
container: String,
|
||||
shell: String,
|
||||
) -> anyhow::Result<(Uuid, Sender<Bytes>, Receiver<Bytes>)> {
|
||||
) -> anyhow::Result<(Uuid, Sender, Receiver)> {
|
||||
tracing::trace!(
|
||||
"request | type: ConnectContainerExec | container name: {container} | shell: {shell}",
|
||||
);
|
||||
@@ -64,7 +66,7 @@ impl PeripheryClient {
|
||||
.await
|
||||
.context("Failed to create container exec connection")?;
|
||||
|
||||
let (sender, receiever) = channel(1000);
|
||||
let (sender, receiever) = channel();
|
||||
connection.channels.insert(id, sender).await;
|
||||
|
||||
Ok((id, connection.sender.clone(), receiever))
|
||||
@@ -105,7 +107,7 @@ impl PeripheryClient {
|
||||
.await
|
||||
.context("Failed to create execute terminal connection")?;
|
||||
|
||||
let (sender, receiver) = channel(1000);
|
||||
let (sender, receiver) = channel();
|
||||
|
||||
connection.channels.insert(id, sender).await;
|
||||
|
||||
@@ -154,7 +156,7 @@ impl PeripheryClient {
|
||||
.await
|
||||
.context("Failed to create execute terminal connection")?;
|
||||
|
||||
let (sender, receiver) = channel(1000);
|
||||
let (sender, receiver) = channel();
|
||||
|
||||
connection.channels.insert(id, sender).await;
|
||||
|
||||
@@ -168,8 +170,8 @@ impl PeripheryClient {
|
||||
|
||||
pub struct ReceiverStream {
|
||||
id: Uuid,
|
||||
channels: Arc<CloneCache<Uuid, Sender<Bytes>>>,
|
||||
receiver: Receiver<Bytes>,
|
||||
channels: Arc<CloneCache<Uuid, Sender>>,
|
||||
receiver: Receiver,
|
||||
}
|
||||
|
||||
impl Stream for ReceiverStream {
|
||||
@@ -181,9 +183,11 @@ impl Stream for ReceiverStream {
|
||||
match self
|
||||
.receiver
|
||||
.poll_recv(cx)
|
||||
.map(|bytes| bytes.map(data_from_transport_bytes))
|
||||
.map(|message| message.map(Message::into_data))
|
||||
{
|
||||
Poll::Ready(Some(Ok(bytes))) if bytes == END_OF_OUTPUT => {
|
||||
Poll::Ready(Some(Ok(bytes)))
|
||||
if bytes == END_OF_OUTPUT.as_bytes() =>
|
||||
{
|
||||
self.cleanup();
|
||||
Poll::Ready(None)
|
||||
}
|
||||
|
||||
@@ -281,7 +281,7 @@ pub async fn update_server_public_key(
|
||||
pub async fn rotate_server_keys(
|
||||
server: &Server,
|
||||
) -> anyhow::Result<()> {
|
||||
let periphery = periphery_client(&server).await?;
|
||||
let periphery = periphery_client(server).await?;
|
||||
let public_key = periphery
|
||||
.request(api::keys::RotatePrivateKey {})
|
||||
.await
|
||||
|
||||
@@ -17,11 +17,10 @@ use komodo_client::{
|
||||
ws::WsLoginMessage,
|
||||
};
|
||||
use periphery_client::api::terminal::DisconnectTerminal;
|
||||
use tokio::sync::mpsc::{Receiver, Sender};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use transport::{
|
||||
MessageState,
|
||||
bytes::{data_from_transport_bytes, to_transport_bytes},
|
||||
channel::{Receiver, Sender},
|
||||
message::{Message as TransportMessage, MessageState},
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -186,8 +185,8 @@ async fn forward_ws_channel(
|
||||
periphery: PeripheryClient,
|
||||
client_socket: axum::extract::ws::WebSocket,
|
||||
periphery_connection_id: Uuid,
|
||||
periphery_sender: Sender<Bytes>,
|
||||
mut periphery_receiver: Receiver<Bytes>,
|
||||
periphery_sender: Sender,
|
||||
mut periphery_receiver: Receiver,
|
||||
) {
|
||||
let (mut core_send, mut core_receive) = client_socket.split();
|
||||
let cancel = CancellationToken::new();
|
||||
@@ -204,10 +203,10 @@ async fn forward_ws_channel(
|
||||
}
|
||||
};
|
||||
match res {
|
||||
Some(Ok(Message::Binary(data))) => {
|
||||
Some(Ok(Message::Binary(bytes))) => {
|
||||
if let Err(e) = periphery_sender
|
||||
.send(to_transport_bytes(
|
||||
data.into(),
|
||||
.send((
|
||||
bytes,
|
||||
periphery_connection_id,
|
||||
MessageState::Terminal,
|
||||
))
|
||||
@@ -218,11 +217,11 @@ async fn forward_ws_channel(
|
||||
break;
|
||||
};
|
||||
}
|
||||
Some(Ok(Message::Text(data))) => {
|
||||
let data: Bytes = data.into();
|
||||
Some(Ok(Message::Text(text))) => {
|
||||
let bytes: Bytes = text.into();
|
||||
if let Err(e) = periphery_sender
|
||||
.send(to_transport_bytes(
|
||||
data.into(),
|
||||
.send((
|
||||
bytes,
|
||||
periphery_connection_id,
|
||||
MessageState::Terminal,
|
||||
))
|
||||
@@ -255,7 +254,7 @@ async fn forward_ws_channel(
|
||||
let periphery_to_core = async {
|
||||
loop {
|
||||
let res = tokio::select! {
|
||||
res = periphery_receiver.recv() => res.map(data_from_transport_bytes),
|
||||
res = periphery_receiver.recv() => res.map(TransportMessage::into_data),
|
||||
_ = cancel.cancelled() => {
|
||||
trace!("periphery to core read: cancelled from inside");
|
||||
break;
|
||||
|
||||
@@ -11,9 +11,8 @@ use komodo_client::{
|
||||
use periphery_client::api::terminal::*;
|
||||
use resolver_api::Resolve;
|
||||
use serror::AddStatusCodeError;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
use tokio_util::{codec::LinesCodecError, sync::CancellationToken};
|
||||
use transport::{MessageState, bytes::to_transport_bytes};
|
||||
use transport::{channel::Sender, message::MessageState};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
@@ -92,13 +91,18 @@ impl Resolve<super::Args> for ConnectTerminal {
|
||||
|
||||
let terminal = get_terminal(&self.terminal).await?;
|
||||
|
||||
let id = Uuid::new_v4();
|
||||
let channel_id = Uuid::new_v4();
|
||||
|
||||
tokio::spawn(async move {
|
||||
handle_terminal_forwarding(&channel.sender, id, terminal).await
|
||||
handle_terminal_forwarding(
|
||||
&channel.sender,
|
||||
channel_id,
|
||||
terminal,
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
Ok(id)
|
||||
Ok(channel_id)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -139,13 +143,18 @@ impl Resolve<super::Args> for ConnectContainerExec {
|
||||
.await
|
||||
.context("Failed to create terminal for container exec")?;
|
||||
|
||||
let id = Uuid::new_v4();
|
||||
let channel_id = Uuid::new_v4();
|
||||
|
||||
tokio::spawn(async move {
|
||||
handle_terminal_forwarding(&channel.sender, id, terminal).await
|
||||
handle_terminal_forwarding(
|
||||
&channel.sender,
|
||||
channel_id,
|
||||
terminal,
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
Ok(id)
|
||||
Ok(channel_id)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -186,18 +195,18 @@ impl Resolve<super::Args> for ExecuteTerminal {
|
||||
setup_execute_command_on_terminal(&terminal, &self.command)
|
||||
.await?;
|
||||
|
||||
let id = Uuid::new_v4();
|
||||
let channel_id = Uuid::new_v4();
|
||||
|
||||
tokio::spawn(async move {
|
||||
forward_execute_command_on_terminal_response(
|
||||
&channel.sender,
|
||||
id,
|
||||
channel_id,
|
||||
stdout,
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
Ok(id)
|
||||
Ok(channel_id)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -243,7 +252,7 @@ impl Resolve<super::Args> for ExecuteContainerExec {
|
||||
let stdout =
|
||||
setup_execute_command_on_terminal(&terminal, &command).await?;
|
||||
|
||||
let id = Uuid::new_v4();
|
||||
let channel_id = Uuid::new_v4();
|
||||
|
||||
let channel =
|
||||
core_channels().get(&args.core).await.with_context(|| {
|
||||
@@ -253,46 +262,38 @@ impl Resolve<super::Args> for ExecuteContainerExec {
|
||||
tokio::spawn(async move {
|
||||
forward_execute_command_on_terminal_response(
|
||||
&channel.sender,
|
||||
id,
|
||||
channel_id,
|
||||
stdout,
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
Ok(id)
|
||||
Ok(channel_id)
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_terminal_forwarding(
|
||||
sender: &Sender<Bytes>,
|
||||
id: Uuid,
|
||||
sender: &Sender,
|
||||
channel: Uuid,
|
||||
terminal: Arc<Terminal>,
|
||||
) {
|
||||
let cancel = CancellationToken::new();
|
||||
|
||||
terminal_channels()
|
||||
.insert(id, (terminal.stdin.clone(), cancel.clone()))
|
||||
.insert(channel, (terminal.stdin.clone(), cancel.clone()))
|
||||
.await;
|
||||
|
||||
let init_res = async {
|
||||
let (a, b) = terminal.history.bytes_parts();
|
||||
if !a.is_empty() {
|
||||
sender
|
||||
.send(to_transport_bytes(
|
||||
a.into(),
|
||||
id,
|
||||
MessageState::Terminal,
|
||||
))
|
||||
.send((a, channel, MessageState::Terminal))
|
||||
.await
|
||||
.context("Failed to send history part a")?;
|
||||
}
|
||||
if !b.is_empty() {
|
||||
sender
|
||||
.send(to_transport_bytes(
|
||||
b.into(),
|
||||
id,
|
||||
MessageState::Terminal,
|
||||
))
|
||||
.send((b, channel, MessageState::Terminal))
|
||||
.await
|
||||
.context("Failed to send history part b")?;
|
||||
}
|
||||
@@ -303,7 +304,7 @@ async fn handle_terminal_forwarding(
|
||||
if let Err(e) = init_res {
|
||||
// TODO: Handle error
|
||||
warn!("Failed to init terminal | {e:#}");
|
||||
terminal_channels().remove(&id).await;
|
||||
terminal_channels().remove(&channel).await;
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -330,13 +331,8 @@ async fn handle_terminal_forwarding(
|
||||
};
|
||||
match res {
|
||||
Ok(bytes) => {
|
||||
if let Err(e) = sender
|
||||
.send(to_transport_bytes(
|
||||
bytes.into(),
|
||||
id,
|
||||
MessageState::Terminal,
|
||||
))
|
||||
.await
|
||||
if let Err(e) =
|
||||
sender.send((bytes, channel, MessageState::Terminal)).await
|
||||
{
|
||||
debug!("Failed to send to WS: {e:?}");
|
||||
cancel.cancel();
|
||||
@@ -346,9 +342,9 @@ async fn handle_terminal_forwarding(
|
||||
Err(e) => {
|
||||
debug!("PTY -> WS channel read error: {e:?}");
|
||||
let _ = sender
|
||||
.send(to_transport_bytes(
|
||||
format!("ERROR: {e:#}").into(),
|
||||
id,
|
||||
.send((
|
||||
format!("ERROR: {e:#}"),
|
||||
channel,
|
||||
MessageState::Terminal,
|
||||
))
|
||||
.await;
|
||||
@@ -359,8 +355,10 @@ async fn handle_terminal_forwarding(
|
||||
}
|
||||
|
||||
// Clean up
|
||||
if let Some((_, cancel)) = terminal_channels().remove(&id).await {
|
||||
trace!("Cancel called for {id}");
|
||||
if let Some((_, cancel)) =
|
||||
terminal_channels().remove(&channel).await
|
||||
{
|
||||
trace!("Cancel called for {channel}");
|
||||
cancel.cancel();
|
||||
}
|
||||
clean_up_terminals().await;
|
||||
@@ -420,33 +418,23 @@ async fn setup_execute_command_on_terminal(
|
||||
}
|
||||
|
||||
async fn forward_execute_command_on_terminal_response(
|
||||
sender: &Sender<Bytes>,
|
||||
sender: &Sender,
|
||||
id: Uuid,
|
||||
mut stdout: impl Stream<Item = Result<String, LinesCodecError>> + Unpin,
|
||||
) {
|
||||
loop {
|
||||
match stdout.next().await {
|
||||
Some(Ok(line)) if line.as_str() == END_OF_OUTPUT => {
|
||||
if let Err(e) = sender
|
||||
.send(to_transport_bytes(
|
||||
line.into(),
|
||||
id,
|
||||
MessageState::Terminal,
|
||||
))
|
||||
.await
|
||||
if let Err(e) =
|
||||
sender.send((line, id, MessageState::Terminal)).await
|
||||
{
|
||||
warn!("Got ws_sender send error on END_OF_OUTPUT | {e:?}");
|
||||
}
|
||||
break;
|
||||
}
|
||||
Some(Ok(line)) => {
|
||||
if let Err(e) = sender
|
||||
.send(to_transport_bytes(
|
||||
(line + "\n").into(),
|
||||
id,
|
||||
MessageState::Terminal,
|
||||
))
|
||||
.await
|
||||
if let Err(e) =
|
||||
sender.send((line + "\n", id, MessageState::Terminal)).await
|
||||
{
|
||||
warn!("Got ws_sender send error | {e:?}");
|
||||
break;
|
||||
|
||||
@@ -2,16 +2,14 @@ use std::{sync::Arc, time::Duration};
|
||||
|
||||
use anyhow::{Context, anyhow};
|
||||
use axum::http::{HeaderValue, StatusCode};
|
||||
use bytes::Bytes;
|
||||
use periphery_client::CONNECTION_RETRY_SECONDS;
|
||||
use serror::deserialize_error_bytes;
|
||||
use transport::{
|
||||
MessageState,
|
||||
auth::{
|
||||
AddressConnectionIdentifiers, ClientLoginFlow,
|
||||
ConnectionIdentifiers, LoginFlow, LoginFlowArgs,
|
||||
},
|
||||
fix_ws_address,
|
||||
message::Message,
|
||||
websocket::{Websocket, tungstenite::TungsteniteWebsocket},
|
||||
};
|
||||
|
||||
@@ -69,12 +67,15 @@ pub async fn handler(address: &str) -> anyhow::Result<()> {
|
||||
// Receive whether to use Server connection flow vs Server onboarding flow.
|
||||
|
||||
let flow_bytes = match socket
|
||||
.recv_bytes()
|
||||
.recv_result()
|
||||
.with_timeout(Duration::from_secs(2))
|
||||
.await?
|
||||
.await
|
||||
.flatten()
|
||||
.flatten()
|
||||
.and_then(Message::into_data)
|
||||
.context("Failed to receive login flow indicator")
|
||||
{
|
||||
Ok(flow_bytes) => flow_bytes,
|
||||
Ok(flow_message) => flow_message,
|
||||
Err(e) => {
|
||||
if !already_logged_connection_error {
|
||||
warn!("{e:#}");
|
||||
@@ -179,33 +180,24 @@ async fn handle_onboarding(
|
||||
|
||||
// Post onboarding login 1: Send public key
|
||||
socket
|
||||
.send(Bytes::copy_from_slice(
|
||||
periphery_public_key().load().as_bytes(),
|
||||
))
|
||||
.send(periphery_public_key().load().as_bytes())
|
||||
.await
|
||||
.context("Failed to send public key bytes")?;
|
||||
|
||||
let res = socket
|
||||
.recv_bytes()
|
||||
socket
|
||||
.recv_result()
|
||||
.with_timeout(Duration::from_secs(2))
|
||||
.await?
|
||||
.await
|
||||
.flatten()
|
||||
.flatten()
|
||||
.context("Failed to receive Server creation result")?;
|
||||
|
||||
match res.last().map(|byte| MessageState::from_byte(*byte)) {
|
||||
Some(MessageState::Successful) => {
|
||||
info!(
|
||||
"Server onboarding flow for '{}' successful ✅",
|
||||
config.connect_as
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
Some(MessageState::Failed) => {
|
||||
Err(deserialize_error_bytes(&res[..(res.len() - 1)]))
|
||||
}
|
||||
other => Err(anyhow!(
|
||||
"Got unrecognized onboarding flow response: {other:?}"
|
||||
)),
|
||||
}
|
||||
info!(
|
||||
"Server onboarding flow for '{}' successful ✅",
|
||||
config.connect_as
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn connect_websocket(
|
||||
|
||||
@@ -4,23 +4,17 @@ use std::{
|
||||
};
|
||||
|
||||
use anyhow::anyhow;
|
||||
use bytes::Bytes;
|
||||
use cache::CloneCache;
|
||||
use resolver_api::Resolve;
|
||||
use response::JsonBytes;
|
||||
use serror::serialize_error_bytes;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
use transport::{
|
||||
MessageState,
|
||||
auth::{
|
||||
ConnectionIdentifiers, LoginFlow, LoginFlowArgs,
|
||||
PublicKeyValidator,
|
||||
},
|
||||
bytes::{
|
||||
data_from_transport_bytes, id_state_from_transport_bytes,
|
||||
to_transport_bytes,
|
||||
},
|
||||
channel::{BufferedChannel, BufferedReceiver},
|
||||
channel::{BufferedChannel, BufferedReceiver, Sender},
|
||||
message::{Message, MessageState},
|
||||
websocket::{
|
||||
Websocket, WebsocketMessage, WebsocketReceiver,
|
||||
WebsocketSender as _,
|
||||
@@ -39,8 +33,7 @@ pub mod client;
|
||||
pub mod server;
|
||||
|
||||
// Core Address / Host -> Channel
|
||||
pub type CoreChannels =
|
||||
CloneCache<String, Arc<BufferedChannel<Bytes>>>;
|
||||
pub type CoreChannels = CloneCache<String, Arc<BufferedChannel>>;
|
||||
|
||||
pub fn core_channels() -> &'static CoreChannels {
|
||||
static CORE_CHANNELS: OnceLock<CoreChannels> = OnceLock::new();
|
||||
@@ -85,8 +78,8 @@ async fn handle_login<W: Websocket, L: LoginFlow>(
|
||||
async fn handle_socket<W: Websocket>(
|
||||
socket: W,
|
||||
args: &Arc<Args>,
|
||||
sender: &Sender<Bytes>,
|
||||
receiver: &mut BufferedReceiver<Bytes>,
|
||||
sender: &Sender,
|
||||
receiver: &mut BufferedReceiver,
|
||||
) {
|
||||
let config = periphery_config();
|
||||
info!(
|
||||
@@ -109,9 +102,9 @@ async fn handle_socket<W: Websocket>(
|
||||
// Sender Dropped (shouldn't happen, it is static).
|
||||
None => break,
|
||||
// This has to copy the bytes to follow ownership rules.
|
||||
Some(msg) => Bytes::copy_from_slice(msg),
|
||||
Some(msg) => msg,
|
||||
};
|
||||
match ws_write.send(msg).await {
|
||||
match ws_write.send(msg.to_message()).await {
|
||||
// Clears the stored message from receiver buffer.
|
||||
// TODO: Move after response ack.
|
||||
Ok(_) => receiver.clear_buffer(),
|
||||
@@ -127,8 +120,8 @@ async fn handle_socket<W: Websocket>(
|
||||
let handle_reads = async {
|
||||
loop {
|
||||
match ws_read.recv().await {
|
||||
Ok(WebsocketMessage::Binary(bytes)) => {
|
||||
handle_incoming_bytes(args, sender, bytes).await
|
||||
Ok(WebsocketMessage::Message(bytes)) => {
|
||||
handle_incoming_message(args, sender, bytes).await
|
||||
}
|
||||
Ok(WebsocketMessage::Close(frame)) => {
|
||||
warn!("Connection closed with frame: {frame:?}");
|
||||
@@ -152,12 +145,12 @@ async fn handle_socket<W: Websocket>(
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_incoming_bytes(
|
||||
async fn handle_incoming_message(
|
||||
args: &Arc<Args>,
|
||||
sender: &Sender<Bytes>,
|
||||
bytes: Bytes,
|
||||
sender: &Sender,
|
||||
message: Message,
|
||||
) {
|
||||
let (id, state) = match id_state_from_transport_bytes(&bytes) {
|
||||
let (channel, state) = match message.channel_and_state() {
|
||||
Ok(res) => res,
|
||||
Err(e) => {
|
||||
warn!("Failed to parse transport bytes | {e:#}");
|
||||
@@ -166,10 +159,10 @@ async fn handle_incoming_bytes(
|
||||
};
|
||||
match state {
|
||||
MessageState::Request => {
|
||||
handle_request(args.clone(), sender.clone(), id, bytes)
|
||||
handle_request(args.clone(), sender.clone(), channel, message)
|
||||
}
|
||||
MessageState::Terminal => {
|
||||
crate::terminal::handle_incoming_message(id, bytes).await
|
||||
crate::terminal::handle_incoming_message(channel, message).await
|
||||
}
|
||||
// Shouldn't be received by Periphery
|
||||
MessageState::InProgress => {}
|
||||
@@ -180,12 +173,12 @@ async fn handle_incoming_bytes(
|
||||
|
||||
fn handle_request(
|
||||
args: Arc<Args>,
|
||||
sender: Sender<Bytes>,
|
||||
sender: Sender,
|
||||
req_id: Uuid,
|
||||
bytes: Bytes,
|
||||
message: Message,
|
||||
) {
|
||||
tokio::spawn(async move {
|
||||
let request = match data_from_transport_bytes(bytes) {
|
||||
let request = match message.into_data() {
|
||||
Ok(req) if !req.is_empty() => req,
|
||||
_ => {
|
||||
return;
|
||||
@@ -216,9 +209,7 @@ fn handle_request(
|
||||
(MessageState::Failed, serialize_error_bytes(&e.error))
|
||||
}
|
||||
};
|
||||
if let Err(e) =
|
||||
sender.send(to_transport_bytes(data, req_id, state)).await
|
||||
{
|
||||
if let Err(e) = sender.send((data, req_id, state)).await {
|
||||
error!("Failed to send response over channel | {e:?}");
|
||||
}
|
||||
};
|
||||
@@ -226,13 +217,8 @@ fn handle_request(
|
||||
let ping_in_progress = async {
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(5)).await;
|
||||
if let Err(e) = sender
|
||||
.send(to_transport_bytes(
|
||||
Vec::new(),
|
||||
req_id,
|
||||
MessageState::InProgress,
|
||||
))
|
||||
.await
|
||||
if let Err(e) =
|
||||
sender.send((req_id, MessageState::InProgress)).await
|
||||
{
|
||||
error!("Failed to ping in progress over channel | {e:?}");
|
||||
}
|
||||
|
||||
@@ -18,17 +18,16 @@ use axum::{
|
||||
routing::get,
|
||||
};
|
||||
use axum_server::tls_rustls::RustlsConfig;
|
||||
use bytes::Bytes;
|
||||
use serror::{
|
||||
AddStatusCode, AddStatusCodeError, deserialize_error_bytes,
|
||||
serialize_error_bytes,
|
||||
AddStatusCode, AddStatusCodeError, serialize_error_bytes,
|
||||
};
|
||||
use transport::{
|
||||
CoreConnectionQuery, MessageState,
|
||||
CoreConnectionQuery,
|
||||
auth::{
|
||||
ConnectionIdentifiers, HeaderConnectionIdentifiers,
|
||||
AUTH_TIMEOUT, ConnectionIdentifiers, HeaderConnectionIdentifiers,
|
||||
ServerLoginFlow,
|
||||
},
|
||||
message::{Message, MessageState},
|
||||
websocket::{Websocket, axum::AxumWebsocket},
|
||||
};
|
||||
|
||||
@@ -134,7 +133,7 @@ async fn handler(
|
||||
bytes.push(MessageState::Failed.as_byte());
|
||||
|
||||
if let Err(e) = socket
|
||||
.send(bytes.into())
|
||||
.send(&e)
|
||||
.await
|
||||
.context("Failed to send forward failed to client")
|
||||
{
|
||||
@@ -169,7 +168,7 @@ async fn handle_login(
|
||||
(Some(_), _) | (_, None) => {
|
||||
// Send login type [0] (Noise auth)
|
||||
socket
|
||||
.send(Bytes::from_owner([0]))
|
||||
.send([0])
|
||||
.await
|
||||
.context("Failed to send login type indicator")?;
|
||||
super::handle_login::<_, ServerLoginFlow>(socket, identifiers)
|
||||
@@ -192,43 +191,33 @@ async fn handle_passkey_login(
|
||||
// Send login type
|
||||
socket
|
||||
// Passkey auth: [1]
|
||||
.send(Bytes::from_owner([1]))
|
||||
.send([1])
|
||||
.await
|
||||
.context("Failed to send login type indicator")?;
|
||||
|
||||
// Receieve passkey
|
||||
let bytes = socket
|
||||
.recv_bytes()
|
||||
let passkey = socket
|
||||
.recv_result()
|
||||
.with_timeout(AUTH_TIMEOUT)
|
||||
.await
|
||||
.flatten()
|
||||
.flatten()
|
||||
.and_then(Message::into_data)
|
||||
.context("Failed to receive passkey from Core")?;
|
||||
let passkey = match MessageState::from_byte(
|
||||
*bytes.last().context("passkey message is empty")?,
|
||||
) {
|
||||
MessageState::Successful => &bytes[..(bytes.len() - 1)],
|
||||
_ => {
|
||||
return Err(deserialize_error_bytes(
|
||||
&bytes[..(bytes.len() - 1)],
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
if passkeys
|
||||
.iter()
|
||||
.any(|expected_passkey| expected_passkey.as_bytes() == passkey)
|
||||
{
|
||||
socket
|
||||
.send(MessageState::Successful.into())
|
||||
.send(MessageState::Successful)
|
||||
.await
|
||||
.context("Failed to send login type indicator")?;
|
||||
Ok(())
|
||||
} else {
|
||||
let e = anyhow!("Invalid passkey");
|
||||
let mut bytes = serialize_error_bytes(&e);
|
||||
bytes.push(MessageState::Failed.as_byte());
|
||||
if let Err(e) = socket
|
||||
.send(bytes.into())
|
||||
.await
|
||||
.context("Failed to send login failed")
|
||||
if let Err(e) =
|
||||
socket.send(&e).await.context("Failed to send login failed")
|
||||
{
|
||||
// Log additional error
|
||||
warn!("{e:#}");
|
||||
@@ -241,10 +230,8 @@ async fn handle_passkey_login(
|
||||
}
|
||||
.await;
|
||||
if let Err(e) = res {
|
||||
let mut bytes = serialize_error_bytes(&e);
|
||||
bytes.push(MessageState::Failed.as_byte());
|
||||
if let Err(e) = socket
|
||||
.send(bytes.into())
|
||||
.send(&e)
|
||||
.await
|
||||
.context("Failed to send login failed to client")
|
||||
{
|
||||
|
||||
@@ -13,7 +13,7 @@ use komodo_client::{
|
||||
use portable_pty::{CommandBuilder, PtySize, native_pty_system};
|
||||
use tokio::sync::{broadcast, mpsc};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use transport::bytes::data_from_transport_bytes;
|
||||
use transport::message::Message;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub type TerminalChannels =
|
||||
@@ -25,12 +25,12 @@ pub fn terminal_channels() -> &'static TerminalChannels {
|
||||
TERMINAL_CHANNELS.get_or_init(Default::default)
|
||||
}
|
||||
|
||||
pub async fn handle_incoming_message(id: Uuid, bytes: Bytes) {
|
||||
pub async fn handle_incoming_message(id: Uuid, message: Message) {
|
||||
let Some((channel, _)) = terminal_channels().get(&id).await else {
|
||||
warn!("No terminal channel for {id}");
|
||||
return;
|
||||
};
|
||||
let Ok(data) = data_from_transport_bytes(bytes) else {
|
||||
let Ok(data) = message.into_data() else {
|
||||
warn!("Got terminal message with no data for {id}");
|
||||
return;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user