working with safer transport message api

This commit is contained in:
mbecker20
2025-10-06 19:16:03 -07:00
parent 6473080078
commit 2daa92a639
21 changed files with 622 additions and 496 deletions

View File

@@ -2,14 +2,13 @@ use std::{sync::Arc, time::Duration};
use anyhow::{Context, anyhow};
use periphery_client::CONNECTION_RETRY_SECONDS;
use serror::{deserialize_error_bytes, serialize_error_bytes};
use transport::{
MessageState,
auth::{
AddressConnectionIdentifiers, ClientLoginFlow,
AUTH_TIMEOUT, AddressConnectionIdentifiers, ClientLoginFlow,
ConnectionIdentifiers,
},
fix_ws_address,
message::Message,
websocket::{Websocket, tungstenite::TungsteniteWebsocket},
};
@@ -105,9 +104,12 @@ impl PeripheryConnection {
) -> anyhow::Result<()> {
// Get the required auth type
let bytes = socket
.recv_bytes()
.recv_result()
.with_timeout(Duration::from_secs(2))
.await?
.await
.flatten()
.flatten()
.and_then(Message::into_data)
.context("Failed to receive login type indicator")?;
match bytes.iter().as_slice() {
@@ -132,7 +134,7 @@ async fn handle_passkey_login(
passkey: &str,
) -> anyhow::Result<()> {
let res = async {
let mut passkey = if passkey.is_empty() {
let passkey = if passkey.is_empty() {
core_config()
.passkey
.as_deref()
@@ -142,34 +144,27 @@ async fn handle_passkey_login(
} else {
passkey.as_bytes().to_vec()
};
passkey.push(MessageState::Successful.as_byte());
socket
.send(passkey.into())
.send(passkey)
.await
.context("Failed to send passkey")?;
// Receive login state message and return based on value
let state_msg = socket
.recv_bytes()
socket
.recv_result()
.with_timeout(AUTH_TIMEOUT)
.await
.flatten()
.flatten()
.context("Failed to receive authentication state message")?;
let state = state_msg.last().context(
"Authentication state message did not contain state byte",
)?;
match MessageState::from_byte(*state) {
MessageState::Successful => anyhow::Ok(()),
_ => Err(deserialize_error_bytes(
&state_msg[..(state_msg.len() - 1)],
)),
}
Ok(())
}
.await;
if let Err(e) = res {
let mut bytes = serialize_error_bytes(&e);
bytes.push(MessageState::Failed.as_byte());
if let Err(e) = socket
.send(bytes.into())
.send(&e)
.await
.context("Failed to send login failed to client")
{

View File

@@ -7,7 +7,6 @@ use std::{
};
use anyhow::anyhow;
use bytes::Bytes;
use cache::CloneCache;
use database::mungos::{by_id::update_one_by_id, mongodb::bson::doc};
use komodo_client::entities::{
@@ -16,18 +15,15 @@ use komodo_client::entities::{
server::Server,
};
use serror::serror_into_anyhow_error;
use tokio::sync::{
RwLock,
mpsc::{Sender, error::SendError},
};
use tokio::sync::{RwLock, mpsc::error::SendError};
use tokio_util::sync::CancellationToken;
use transport::{
auth::{
ConnectionIdentifiers, LoginFlow, LoginFlowArgs,
PublicKeyValidator,
},
bytes::id_from_transport_bytes,
channel::{BufferedReceiver, buffered_channel},
channel::{BufferedReceiver, Sender, buffered_channel},
message::Message,
websocket::{
Websocket, WebsocketMessage, WebsocketReceiver as _,
WebsocketSender as _,
@@ -56,7 +52,7 @@ impl PeripheryConnections {
&self,
server_id: String,
args: PeripheryConnectionArgs<'_>,
) -> (Arc<PeripheryConnection>, BufferedReceiver<Bytes>) {
) -> (Arc<PeripheryConnection>, BufferedReceiver) {
let (connection, receiver) = if let Some(existing_connection) =
self.0.remove(&server_id).await
{
@@ -110,12 +106,11 @@ impl PublicKeyValidator for PeripheryConnectionArgs<'_> {
self.id.to_string(),
Some(public_key.clone()),
);
let e = anyhow!("{public_key} is invalid")
anyhow!("{public_key} is invalid")
.context(
"Ensure public key matches configured Periphery Public Key",
)
.context("Core failed to validate Periphery public key");
e
.context("Core failed to validate Periphery public key")
};
let core_to_periphery = self.address.is_some();
match (self.periphery_public_key, core_to_periphery) {
@@ -242,7 +237,7 @@ pub struct PeripheryConnection {
/// The connection args
pub args: OwnedPeripheryConnectionArgs,
/// Send and receive bytes over the connection socket.
pub sender: Sender<Bytes>,
pub sender: Sender,
/// Cancel the connection
pub cancel: CancellationToken,
/// Whether Periphery is currently connected.
@@ -258,7 +253,7 @@ pub struct PeripheryConnection {
impl PeripheryConnection {
pub fn new(
args: impl Into<OwnedPeripheryConnectionArgs>,
) -> (Arc<PeripheryConnection>, BufferedReceiver<Bytes>) {
) -> (Arc<PeripheryConnection>, BufferedReceiver) {
let (sender, receiever) = buffered_channel();
(
PeripheryConnection {
@@ -277,7 +272,7 @@ impl PeripheryConnection {
pub fn with_new_args(
&self,
args: impl Into<OwnedPeripheryConnectionArgs>,
) -> (Arc<PeripheryConnection>, BufferedReceiver<Bytes>) {
) -> (Arc<PeripheryConnection>, BufferedReceiver) {
// Ensure this connection is cancelled.
self.cancel();
let (sender, receiever) = buffered_channel();
@@ -315,7 +310,7 @@ impl PeripheryConnection {
pub async fn handle_socket<W: Websocket>(
&self,
socket: W,
receiver: &mut BufferedReceiver<Bytes>,
receiver: &mut BufferedReceiver,
) {
let cancel = self.cancel.child_token();
@@ -332,7 +327,7 @@ impl PeripheryConnection {
};
let message = match next {
Some(request) => Bytes::copy_from_slice(request),
Some(request) => request.to_message(),
// Sender Dropped (shouldn't happen, a reference is held on 'connection').
None => break,
};
@@ -358,8 +353,8 @@ impl PeripheryConnection {
};
match next {
Ok(WebsocketMessage::Binary(bytes)) => {
self.handle_incoming_bytes(bytes).await
Ok(WebsocketMessage::Message(message)) => {
self.handle_incoming_message(message).await
}
Ok(WebsocketMessage::Close(_))
| Ok(WebsocketMessage::Closed) => {
@@ -380,21 +375,21 @@ impl PeripheryConnection {
self.set_connected(false);
}
pub async fn handle_incoming_bytes(&self, bytes: Bytes) {
let id = match id_from_transport_bytes(&bytes) {
pub async fn handle_incoming_message(&self, message: Message) {
let channel = match message.channel() {
Ok(res) => res,
Err(e) => {
// TODO: handle better
warn!("Failed to read id | {e:#}");
warn!("Failed to read channel | {e:#}");
return;
}
};
let Some(channel) = self.channels.get(&id).await else {
let Some(channel) = self.channels.get(&channel).await else {
// TODO: handle better
debug!("Failed to send response | No response channel found");
return;
};
if let Err(e) = channel.send(bytes).await {
if let Err(e) = channel.send(message).await {
// TODO: handle better
warn!("Failed to send response | Channel failure | {e:#}");
}
@@ -402,9 +397,9 @@ impl PeripheryConnection {
pub async fn send(
&self,
value: Bytes,
) -> Result<(), SendError<Bytes>> {
self.sender.send(value).await
message: impl Into<Message>,
) -> Result<(), SendError<Message>> {
self.sender.send(message).await
}
pub fn set_connected(&self, connected: bool) {

View File

@@ -6,7 +6,6 @@ use axum::{
http::{HeaderMap, StatusCode},
response::Response,
};
use bytes::Bytes;
use database::mungos::mongodb::bson::{doc, oid::ObjectId};
use komodo_client::{
api::write::{CreateBuilder, CreateServer, UpdateResourceMeta},
@@ -20,11 +19,12 @@ use komodo_client::{
use resolver_api::Resolve;
use serror::{AddStatusCode, AddStatusCodeError};
use transport::{
MessageState, PeripheryConnectionQuery,
PeripheryConnectionQuery,
auth::{
HeaderConnectionIdentifiers, LoginFlow, LoginFlowArgs,
PublicKeyValidator, ServerLoginFlow,
},
message::MessageState,
websocket::{Websocket, axum::AxumWebsocket},
};
@@ -119,7 +119,7 @@ async fn existing_server_handler(
format!("server={}", urlencoding::encode(&server_query));
let mut socket = AxumWebsocket(socket);
if let Err(e) = socket.send(Bytes::from_owner([0])).await.context(
if let Err(e) = socket.send([0]).await.context(
"Failed to send the login flow indicator over connnection",
) {
connection.set_error(e).await;
@@ -151,7 +151,7 @@ async fn onboard_server_handler(
format!("server={}", urlencoding::encode(&server_query));
let mut socket = AxumWebsocket(socket);
if let Err(e) = socket.send(Bytes::from_owner([1])).await.context(
if let Err(e) = socket.send([1]).await.context(
"Failed to send the login flow indicator over connnection",
).context("Server onboarding error") {
warn!("{e:#}");
@@ -174,14 +174,14 @@ async fn onboard_server_handler(
};
let res = socket
.recv_bytes()
.recv_result()
.with_timeout(Duration::from_secs(2))
.await
.and_then(|res| {
res.and_then(|public_key_bytes| {
String::from_utf8(public_key_bytes.into())
.context("Public key bytes are not valid utf8")
})
.flatten()
.flatten()
.and_then(|message| {
String::from_utf8(message.into_data()?.into())
.context("Public key bytes are not valid utf8")
});
// Post onboarding login 1: Receive public key
@@ -205,7 +205,7 @@ async fn onboard_server_handler(
Err(e) => {
warn!("{e:#}");
if let Err(e) = socket
.send_error(&e)
.send(&e)
.await
.context("Failed to send Server creation failed to client")
{
@@ -217,7 +217,7 @@ async fn onboard_server_handler(
};
if let Err(e) = socket
.send(MessageState::Successful.into())
.send(MessageState::Successful)
.await
.context("Failed to send Server creation successful to client")
{

View File

@@ -1,18 +1,16 @@
use std::{sync::Arc, time::Duration};
use anyhow::{Context, anyhow};
use bytes::Bytes;
use cache::CloneCache;
use periphery_client::api;
use resolver_api::HasResponse;
use serde::{Serialize, de::DeserializeOwned};
use serde_json::json;
use serror::deserialize_error_bytes;
use tokio::sync::mpsc::{self, Sender};
use tracing::warn;
use transport::{
MessageState,
bytes::{from_transport_bytes, to_transport_bytes},
channel::{Sender, channel},
message::MessageState,
};
use uuid::Uuid;
@@ -23,7 +21,7 @@ use crate::{
pub mod terminal;
pub type ConnectionChannels = CloneCache<Uuid, Sender<Bytes>>;
pub type ConnectionChannels = CloneCache<Uuid, Sender>;
#[derive(Debug)]
pub struct PeripheryClient {
@@ -124,8 +122,7 @@ impl PeripheryClient {
connection.bail_if_not_connected().await?;
let id = Uuid::new_v4();
let (response_sender, mut response_receiever) =
mpsc::channel(1000);
let (response_sender, mut response_receiever) = channel();
self.channels.insert(id, response_sender).await;
let req_type = T::req_type();
@@ -136,7 +133,7 @@ impl PeripheryClient {
.context("Failed to serialize request to bytes")?;
if let Err(e) = connection
.send(to_transport_bytes(data, id, MessageState::Request))
.send((data, id, MessageState::Request))
.await
.context("Failed to send request over channel")
{
@@ -147,35 +144,13 @@ impl PeripheryClient {
// Poll for the associated response
loop {
let next = tokio::select! {
msg = response_receiever.recv() => msg,
let (data, _, state) = tokio::select! {
msg = response_receiever.recv_parts() => msg?,
// Periphery will send InProgress every 5s to avoid timeout
_ = tokio::time::sleep(Duration::from_secs(10)) => {
return Err(anyhow!("Response timed out"));
}
};
let bytes = match next {
Some(bytes) => bytes,
None => {
return Err(anyhow!(
"Sender dropped before response was recieved"
));
}
};
let (state, data) = match from_transport_bytes(bytes) {
Ok((data, _, state)) if !data.is_empty() => (state, data),
// Ignore no data cases
Ok(_) => continue,
Err(e) => {
warn!(
"Server {} | Received invalid message | {e:#}",
self.id
);
continue;
}
};
match state {
// TODO: improve the allocation in .to_vec
MessageState::Successful => {

View File

@@ -12,8 +12,10 @@ use periphery_client::api::terminal::{
ConnectContainerExec, ConnectTerminal, END_OF_OUTPUT,
ExecuteContainerExec, ExecuteTerminal,
};
use tokio::sync::mpsc::{Receiver, Sender, channel};
use transport::bytes::data_from_transport_bytes;
use transport::{
channel::{Receiver, Sender, channel},
message::Message,
};
use uuid::Uuid;
use crate::{
@@ -24,7 +26,7 @@ impl PeripheryClient {
pub async fn connect_terminal(
&self,
terminal: String,
) -> anyhow::Result<(Uuid, Sender<Bytes>, Receiver<Bytes>)> {
) -> anyhow::Result<(Uuid, Sender, Receiver)> {
tracing::trace!(
"request | type: ConnectTerminal | terminal name: {terminal}",
);
@@ -39,7 +41,7 @@ impl PeripheryClient {
.await
.context("Failed to create terminal connection")?;
let (sender, receiever) = channel(1024);
let (sender, receiever) = channel();
connection.channels.insert(id, sender).await;
Ok((id, connection.sender.clone(), receiever))
@@ -49,7 +51,7 @@ impl PeripheryClient {
&self,
container: String,
shell: String,
) -> anyhow::Result<(Uuid, Sender<Bytes>, Receiver<Bytes>)> {
) -> anyhow::Result<(Uuid, Sender, Receiver)> {
tracing::trace!(
"request | type: ConnectContainerExec | container name: {container} | shell: {shell}",
);
@@ -64,7 +66,7 @@ impl PeripheryClient {
.await
.context("Failed to create container exec connection")?;
let (sender, receiever) = channel(1000);
let (sender, receiever) = channel();
connection.channels.insert(id, sender).await;
Ok((id, connection.sender.clone(), receiever))
@@ -105,7 +107,7 @@ impl PeripheryClient {
.await
.context("Failed to create execute terminal connection")?;
let (sender, receiver) = channel(1000);
let (sender, receiver) = channel();
connection.channels.insert(id, sender).await;
@@ -154,7 +156,7 @@ impl PeripheryClient {
.await
.context("Failed to create execute terminal connection")?;
let (sender, receiver) = channel(1000);
let (sender, receiver) = channel();
connection.channels.insert(id, sender).await;
@@ -168,8 +170,8 @@ impl PeripheryClient {
pub struct ReceiverStream {
id: Uuid,
channels: Arc<CloneCache<Uuid, Sender<Bytes>>>,
receiver: Receiver<Bytes>,
channels: Arc<CloneCache<Uuid, Sender>>,
receiver: Receiver,
}
impl Stream for ReceiverStream {
@@ -181,9 +183,11 @@ impl Stream for ReceiverStream {
match self
.receiver
.poll_recv(cx)
.map(|bytes| bytes.map(data_from_transport_bytes))
.map(|message| message.map(Message::into_data))
{
Poll::Ready(Some(Ok(bytes))) if bytes == END_OF_OUTPUT => {
Poll::Ready(Some(Ok(bytes)))
if bytes == END_OF_OUTPUT.as_bytes() =>
{
self.cleanup();
Poll::Ready(None)
}

View File

@@ -281,7 +281,7 @@ pub async fn update_server_public_key(
pub async fn rotate_server_keys(
server: &Server,
) -> anyhow::Result<()> {
let periphery = periphery_client(&server).await?;
let periphery = periphery_client(server).await?;
let public_key = periphery
.request(api::keys::RotatePrivateKey {})
.await

View File

@@ -17,11 +17,10 @@ use komodo_client::{
ws::WsLoginMessage,
};
use periphery_client::api::terminal::DisconnectTerminal;
use tokio::sync::mpsc::{Receiver, Sender};
use tokio_util::sync::CancellationToken;
use transport::{
MessageState,
bytes::{data_from_transport_bytes, to_transport_bytes},
channel::{Receiver, Sender},
message::{Message as TransportMessage, MessageState},
};
use uuid::Uuid;
@@ -186,8 +185,8 @@ async fn forward_ws_channel(
periphery: PeripheryClient,
client_socket: axum::extract::ws::WebSocket,
periphery_connection_id: Uuid,
periphery_sender: Sender<Bytes>,
mut periphery_receiver: Receiver<Bytes>,
periphery_sender: Sender,
mut periphery_receiver: Receiver,
) {
let (mut core_send, mut core_receive) = client_socket.split();
let cancel = CancellationToken::new();
@@ -204,10 +203,10 @@ async fn forward_ws_channel(
}
};
match res {
Some(Ok(Message::Binary(data))) => {
Some(Ok(Message::Binary(bytes))) => {
if let Err(e) = periphery_sender
.send(to_transport_bytes(
data.into(),
.send((
bytes,
periphery_connection_id,
MessageState::Terminal,
))
@@ -218,11 +217,11 @@ async fn forward_ws_channel(
break;
};
}
Some(Ok(Message::Text(data))) => {
let data: Bytes = data.into();
Some(Ok(Message::Text(text))) => {
let bytes: Bytes = text.into();
if let Err(e) = periphery_sender
.send(to_transport_bytes(
data.into(),
.send((
bytes,
periphery_connection_id,
MessageState::Terminal,
))
@@ -255,7 +254,7 @@ async fn forward_ws_channel(
let periphery_to_core = async {
loop {
let res = tokio::select! {
res = periphery_receiver.recv() => res.map(data_from_transport_bytes),
res = periphery_receiver.recv() => res.map(TransportMessage::into_data),
_ = cancel.cancelled() => {
trace!("periphery to core read: cancelled from inside");
break;

View File

@@ -11,9 +11,8 @@ use komodo_client::{
use periphery_client::api::terminal::*;
use resolver_api::Resolve;
use serror::AddStatusCodeError;
use tokio::sync::mpsc::Sender;
use tokio_util::{codec::LinesCodecError, sync::CancellationToken};
use transport::{MessageState, bytes::to_transport_bytes};
use transport::{channel::Sender, message::MessageState};
use uuid::Uuid;
use crate::{
@@ -92,13 +91,18 @@ impl Resolve<super::Args> for ConnectTerminal {
let terminal = get_terminal(&self.terminal).await?;
let id = Uuid::new_v4();
let channel_id = Uuid::new_v4();
tokio::spawn(async move {
handle_terminal_forwarding(&channel.sender, id, terminal).await
handle_terminal_forwarding(
&channel.sender,
channel_id,
terminal,
)
.await
});
Ok(id)
Ok(channel_id)
}
}
@@ -139,13 +143,18 @@ impl Resolve<super::Args> for ConnectContainerExec {
.await
.context("Failed to create terminal for container exec")?;
let id = Uuid::new_v4();
let channel_id = Uuid::new_v4();
tokio::spawn(async move {
handle_terminal_forwarding(&channel.sender, id, terminal).await
handle_terminal_forwarding(
&channel.sender,
channel_id,
terminal,
)
.await
});
Ok(id)
Ok(channel_id)
}
}
@@ -186,18 +195,18 @@ impl Resolve<super::Args> for ExecuteTerminal {
setup_execute_command_on_terminal(&terminal, &self.command)
.await?;
let id = Uuid::new_v4();
let channel_id = Uuid::new_v4();
tokio::spawn(async move {
forward_execute_command_on_terminal_response(
&channel.sender,
id,
channel_id,
stdout,
)
.await
});
Ok(id)
Ok(channel_id)
}
}
@@ -243,7 +252,7 @@ impl Resolve<super::Args> for ExecuteContainerExec {
let stdout =
setup_execute_command_on_terminal(&terminal, &command).await?;
let id = Uuid::new_v4();
let channel_id = Uuid::new_v4();
let channel =
core_channels().get(&args.core).await.with_context(|| {
@@ -253,46 +262,38 @@ impl Resolve<super::Args> for ExecuteContainerExec {
tokio::spawn(async move {
forward_execute_command_on_terminal_response(
&channel.sender,
id,
channel_id,
stdout,
)
.await
});
Ok(id)
Ok(channel_id)
}
}
async fn handle_terminal_forwarding(
sender: &Sender<Bytes>,
id: Uuid,
sender: &Sender,
channel: Uuid,
terminal: Arc<Terminal>,
) {
let cancel = CancellationToken::new();
terminal_channels()
.insert(id, (terminal.stdin.clone(), cancel.clone()))
.insert(channel, (terminal.stdin.clone(), cancel.clone()))
.await;
let init_res = async {
let (a, b) = terminal.history.bytes_parts();
if !a.is_empty() {
sender
.send(to_transport_bytes(
a.into(),
id,
MessageState::Terminal,
))
.send((a, channel, MessageState::Terminal))
.await
.context("Failed to send history part a")?;
}
if !b.is_empty() {
sender
.send(to_transport_bytes(
b.into(),
id,
MessageState::Terminal,
))
.send((b, channel, MessageState::Terminal))
.await
.context("Failed to send history part b")?;
}
@@ -303,7 +304,7 @@ async fn handle_terminal_forwarding(
if let Err(e) = init_res {
// TODO: Handle error
warn!("Failed to init terminal | {e:#}");
terminal_channels().remove(&id).await;
terminal_channels().remove(&channel).await;
return;
}
@@ -330,13 +331,8 @@ async fn handle_terminal_forwarding(
};
match res {
Ok(bytes) => {
if let Err(e) = sender
.send(to_transport_bytes(
bytes.into(),
id,
MessageState::Terminal,
))
.await
if let Err(e) =
sender.send((bytes, channel, MessageState::Terminal)).await
{
debug!("Failed to send to WS: {e:?}");
cancel.cancel();
@@ -346,9 +342,9 @@ async fn handle_terminal_forwarding(
Err(e) => {
debug!("PTY -> WS channel read error: {e:?}");
let _ = sender
.send(to_transport_bytes(
format!("ERROR: {e:#}").into(),
id,
.send((
format!("ERROR: {e:#}"),
channel,
MessageState::Terminal,
))
.await;
@@ -359,8 +355,10 @@ async fn handle_terminal_forwarding(
}
// Clean up
if let Some((_, cancel)) = terminal_channels().remove(&id).await {
trace!("Cancel called for {id}");
if let Some((_, cancel)) =
terminal_channels().remove(&channel).await
{
trace!("Cancel called for {channel}");
cancel.cancel();
}
clean_up_terminals().await;
@@ -420,33 +418,23 @@ async fn setup_execute_command_on_terminal(
}
async fn forward_execute_command_on_terminal_response(
sender: &Sender<Bytes>,
sender: &Sender,
id: Uuid,
mut stdout: impl Stream<Item = Result<String, LinesCodecError>> + Unpin,
) {
loop {
match stdout.next().await {
Some(Ok(line)) if line.as_str() == END_OF_OUTPUT => {
if let Err(e) = sender
.send(to_transport_bytes(
line.into(),
id,
MessageState::Terminal,
))
.await
if let Err(e) =
sender.send((line, id, MessageState::Terminal)).await
{
warn!("Got ws_sender send error on END_OF_OUTPUT | {e:?}");
}
break;
}
Some(Ok(line)) => {
if let Err(e) = sender
.send(to_transport_bytes(
(line + "\n").into(),
id,
MessageState::Terminal,
))
.await
if let Err(e) =
sender.send((line + "\n", id, MessageState::Terminal)).await
{
warn!("Got ws_sender send error | {e:?}");
break;

View File

@@ -2,16 +2,14 @@ use std::{sync::Arc, time::Duration};
use anyhow::{Context, anyhow};
use axum::http::{HeaderValue, StatusCode};
use bytes::Bytes;
use periphery_client::CONNECTION_RETRY_SECONDS;
use serror::deserialize_error_bytes;
use transport::{
MessageState,
auth::{
AddressConnectionIdentifiers, ClientLoginFlow,
ConnectionIdentifiers, LoginFlow, LoginFlowArgs,
},
fix_ws_address,
message::Message,
websocket::{Websocket, tungstenite::TungsteniteWebsocket},
};
@@ -69,12 +67,15 @@ pub async fn handler(address: &str) -> anyhow::Result<()> {
// Receive whether to use Server connection flow vs Server onboarding flow.
let flow_bytes = match socket
.recv_bytes()
.recv_result()
.with_timeout(Duration::from_secs(2))
.await?
.await
.flatten()
.flatten()
.and_then(Message::into_data)
.context("Failed to receive login flow indicator")
{
Ok(flow_bytes) => flow_bytes,
Ok(flow_message) => flow_message,
Err(e) => {
if !already_logged_connection_error {
warn!("{e:#}");
@@ -179,33 +180,24 @@ async fn handle_onboarding(
// Post onboarding login 1: Send public key
socket
.send(Bytes::copy_from_slice(
periphery_public_key().load().as_bytes(),
))
.send(periphery_public_key().load().as_bytes())
.await
.context("Failed to send public key bytes")?;
let res = socket
.recv_bytes()
socket
.recv_result()
.with_timeout(Duration::from_secs(2))
.await?
.await
.flatten()
.flatten()
.context("Failed to receive Server creation result")?;
match res.last().map(|byte| MessageState::from_byte(*byte)) {
Some(MessageState::Successful) => {
info!(
"Server onboarding flow for '{}' successful ✅",
config.connect_as
);
Ok(())
}
Some(MessageState::Failed) => {
Err(deserialize_error_bytes(&res[..(res.len() - 1)]))
}
other => Err(anyhow!(
"Got unrecognized onboarding flow response: {other:?}"
)),
}
info!(
"Server onboarding flow for '{}' successful ✅",
config.connect_as
);
Ok(())
}
async fn connect_websocket(

View File

@@ -4,23 +4,17 @@ use std::{
};
use anyhow::anyhow;
use bytes::Bytes;
use cache::CloneCache;
use resolver_api::Resolve;
use response::JsonBytes;
use serror::serialize_error_bytes;
use tokio::sync::mpsc::Sender;
use transport::{
MessageState,
auth::{
ConnectionIdentifiers, LoginFlow, LoginFlowArgs,
PublicKeyValidator,
},
bytes::{
data_from_transport_bytes, id_state_from_transport_bytes,
to_transport_bytes,
},
channel::{BufferedChannel, BufferedReceiver},
channel::{BufferedChannel, BufferedReceiver, Sender},
message::{Message, MessageState},
websocket::{
Websocket, WebsocketMessage, WebsocketReceiver,
WebsocketSender as _,
@@ -39,8 +33,7 @@ pub mod client;
pub mod server;
// Core Address / Host -> Channel
pub type CoreChannels =
CloneCache<String, Arc<BufferedChannel<Bytes>>>;
pub type CoreChannels = CloneCache<String, Arc<BufferedChannel>>;
pub fn core_channels() -> &'static CoreChannels {
static CORE_CHANNELS: OnceLock<CoreChannels> = OnceLock::new();
@@ -85,8 +78,8 @@ async fn handle_login<W: Websocket, L: LoginFlow>(
async fn handle_socket<W: Websocket>(
socket: W,
args: &Arc<Args>,
sender: &Sender<Bytes>,
receiver: &mut BufferedReceiver<Bytes>,
sender: &Sender,
receiver: &mut BufferedReceiver,
) {
let config = periphery_config();
info!(
@@ -109,9 +102,9 @@ async fn handle_socket<W: Websocket>(
// Sender Dropped (shouldn't happen, it is static).
None => break,
// This has to copy the bytes to follow ownership rules.
Some(msg) => Bytes::copy_from_slice(msg),
Some(msg) => msg,
};
match ws_write.send(msg).await {
match ws_write.send(msg.to_message()).await {
// Clears the stored message from receiver buffer.
// TODO: Move after response ack.
Ok(_) => receiver.clear_buffer(),
@@ -127,8 +120,8 @@ async fn handle_socket<W: Websocket>(
let handle_reads = async {
loop {
match ws_read.recv().await {
Ok(WebsocketMessage::Binary(bytes)) => {
handle_incoming_bytes(args, sender, bytes).await
Ok(WebsocketMessage::Message(bytes)) => {
handle_incoming_message(args, sender, bytes).await
}
Ok(WebsocketMessage::Close(frame)) => {
warn!("Connection closed with frame: {frame:?}");
@@ -152,12 +145,12 @@ async fn handle_socket<W: Websocket>(
}
}
async fn handle_incoming_bytes(
async fn handle_incoming_message(
args: &Arc<Args>,
sender: &Sender<Bytes>,
bytes: Bytes,
sender: &Sender,
message: Message,
) {
let (id, state) = match id_state_from_transport_bytes(&bytes) {
let (channel, state) = match message.channel_and_state() {
Ok(res) => res,
Err(e) => {
warn!("Failed to parse transport bytes | {e:#}");
@@ -166,10 +159,10 @@ async fn handle_incoming_bytes(
};
match state {
MessageState::Request => {
handle_request(args.clone(), sender.clone(), id, bytes)
handle_request(args.clone(), sender.clone(), channel, message)
}
MessageState::Terminal => {
crate::terminal::handle_incoming_message(id, bytes).await
crate::terminal::handle_incoming_message(channel, message).await
}
// Shouldn't be received by Periphery
MessageState::InProgress => {}
@@ -180,12 +173,12 @@ async fn handle_incoming_bytes(
fn handle_request(
args: Arc<Args>,
sender: Sender<Bytes>,
sender: Sender,
req_id: Uuid,
bytes: Bytes,
message: Message,
) {
tokio::spawn(async move {
let request = match data_from_transport_bytes(bytes) {
let request = match message.into_data() {
Ok(req) if !req.is_empty() => req,
_ => {
return;
@@ -216,9 +209,7 @@ fn handle_request(
(MessageState::Failed, serialize_error_bytes(&e.error))
}
};
if let Err(e) =
sender.send(to_transport_bytes(data, req_id, state)).await
{
if let Err(e) = sender.send((data, req_id, state)).await {
error!("Failed to send response over channel | {e:?}");
}
};
@@ -226,13 +217,8 @@ fn handle_request(
let ping_in_progress = async {
loop {
tokio::time::sleep(Duration::from_secs(5)).await;
if let Err(e) = sender
.send(to_transport_bytes(
Vec::new(),
req_id,
MessageState::InProgress,
))
.await
if let Err(e) =
sender.send((req_id, MessageState::InProgress)).await
{
error!("Failed to ping in progress over channel | {e:?}");
}

View File

@@ -18,17 +18,16 @@ use axum::{
routing::get,
};
use axum_server::tls_rustls::RustlsConfig;
use bytes::Bytes;
use serror::{
AddStatusCode, AddStatusCodeError, deserialize_error_bytes,
serialize_error_bytes,
AddStatusCode, AddStatusCodeError, serialize_error_bytes,
};
use transport::{
CoreConnectionQuery, MessageState,
CoreConnectionQuery,
auth::{
ConnectionIdentifiers, HeaderConnectionIdentifiers,
AUTH_TIMEOUT, ConnectionIdentifiers, HeaderConnectionIdentifiers,
ServerLoginFlow,
},
message::{Message, MessageState},
websocket::{Websocket, axum::AxumWebsocket},
};
@@ -134,7 +133,7 @@ async fn handler(
bytes.push(MessageState::Failed.as_byte());
if let Err(e) = socket
.send(bytes.into())
.send(&e)
.await
.context("Failed to send forward failed to client")
{
@@ -169,7 +168,7 @@ async fn handle_login(
(Some(_), _) | (_, None) => {
// Send login type [0] (Noise auth)
socket
.send(Bytes::from_owner([0]))
.send([0])
.await
.context("Failed to send login type indicator")?;
super::handle_login::<_, ServerLoginFlow>(socket, identifiers)
@@ -192,43 +191,33 @@ async fn handle_passkey_login(
// Send login type
socket
// Passkey auth: [1]
.send(Bytes::from_owner([1]))
.send([1])
.await
.context("Failed to send login type indicator")?;
// Receieve passkey
let bytes = socket
.recv_bytes()
let passkey = socket
.recv_result()
.with_timeout(AUTH_TIMEOUT)
.await
.flatten()
.flatten()
.and_then(Message::into_data)
.context("Failed to receive passkey from Core")?;
let passkey = match MessageState::from_byte(
*bytes.last().context("passkey message is empty")?,
) {
MessageState::Successful => &bytes[..(bytes.len() - 1)],
_ => {
return Err(deserialize_error_bytes(
&bytes[..(bytes.len() - 1)],
));
}
};
if passkeys
.iter()
.any(|expected_passkey| expected_passkey.as_bytes() == passkey)
{
socket
.send(MessageState::Successful.into())
.send(MessageState::Successful)
.await
.context("Failed to send login type indicator")?;
Ok(())
} else {
let e = anyhow!("Invalid passkey");
let mut bytes = serialize_error_bytes(&e);
bytes.push(MessageState::Failed.as_byte());
if let Err(e) = socket
.send(bytes.into())
.await
.context("Failed to send login failed")
if let Err(e) =
socket.send(&e).await.context("Failed to send login failed")
{
// Log additional error
warn!("{e:#}");
@@ -241,10 +230,8 @@ async fn handle_passkey_login(
}
.await;
if let Err(e) = res {
let mut bytes = serialize_error_bytes(&e);
bytes.push(MessageState::Failed.as_byte());
if let Err(e) = socket
.send(bytes.into())
.send(&e)
.await
.context("Failed to send login failed to client")
{

View File

@@ -13,7 +13,7 @@ use komodo_client::{
use portable_pty::{CommandBuilder, PtySize, native_pty_system};
use tokio::sync::{broadcast, mpsc};
use tokio_util::sync::CancellationToken;
use transport::bytes::data_from_transport_bytes;
use transport::message::Message;
use uuid::Uuid;
pub type TerminalChannels =
@@ -25,12 +25,12 @@ pub fn terminal_channels() -> &'static TerminalChannels {
TERMINAL_CHANNELS.get_or_init(Default::default)
}
pub async fn handle_incoming_message(id: Uuid, bytes: Bytes) {
pub async fn handle_incoming_message(id: Uuid, message: Message) {
let Some((channel, _)) = terminal_channels().get(&id).await else {
warn!("No terminal channel for {id}");
return;
};
let Ok(data) = data_from_transport_bytes(bytes) else {
let Ok(data) = message.into_data() else {
warn!("Got terminal message with no data for {id}");
return;
};