improve ws trait ergonomics

This commit is contained in:
mbecker20
2025-10-06 20:02:06 -07:00
parent 2daa92a639
commit e9d13449bf
8 changed files with 132 additions and 132 deletions

View File

@@ -8,7 +8,6 @@ use transport::{
ConnectionIdentifiers,
},
fix_ws_address,
message::Message,
websocket::{Websocket, tungstenite::TungsteniteWebsocket},
};
@@ -107,9 +106,6 @@ impl PeripheryConnection {
.recv_result()
.with_timeout(Duration::from_secs(2))
.await
.flatten()
.flatten()
.and_then(Message::into_data)
.context("Failed to receive login type indicator")?;
match bytes.iter().as_slice() {
@@ -155,8 +151,6 @@ async fn handle_passkey_login(
.recv_result()
.with_timeout(AUTH_TIMEOUT)
.await
.flatten()
.flatten()
.context("Failed to receive authentication state message")?;
Ok(())

View File

@@ -173,19 +173,16 @@ async fn onboard_server_handler(
}
};
let res = socket
// Post onboarding login 1: Receive public key
let public_key = socket
.recv_result()
.with_timeout(Duration::from_secs(2))
.await
.flatten()
.flatten()
.and_then(|message| {
String::from_utf8(message.into_data()?.into())
.and_then(|bytes| {
String::from_utf8(bytes.into())
.context("Public key bytes are not valid utf8")
});
// Post onboarding login 1: Receive public key
let public_key = match res
let public_key = match public_key
{
Ok(public_key) => public_key,
Err(e) => {

View File

@@ -5,11 +5,10 @@ use axum::http::{HeaderValue, StatusCode};
use periphery_client::CONNECTION_RETRY_SECONDS;
use transport::{
auth::{
AddressConnectionIdentifiers, ClientLoginFlow,
AUTH_TIMEOUT, AddressConnectionIdentifiers, ClientLoginFlow,
ConnectionIdentifiers, LoginFlow, LoginFlowArgs,
},
fix_ws_address,
message::Message,
websocket::{Websocket, tungstenite::TungsteniteWebsocket},
};
@@ -68,14 +67,11 @@ pub async fn handler(address: &str) -> anyhow::Result<()> {
let flow_bytes = match socket
.recv_result()
.with_timeout(Duration::from_secs(2))
.with_timeout(AUTH_TIMEOUT)
.await
.flatten()
.flatten()
.and_then(Message::into_data)
.context("Failed to receive login flow indicator")
{
Ok(flow_message) => flow_message,
Ok(flow_bytes) => flow_bytes,
Err(e) => {
if !already_logged_connection_error {
warn!("{e:#}");
@@ -186,10 +182,8 @@ async fn handle_onboarding(
socket
.recv_result()
.with_timeout(Duration::from_secs(2))
.with_timeout(AUTH_TIMEOUT)
.await
.flatten()
.flatten()
.context("Failed to receive Server creation result")?;
info!(

View File

@@ -4,6 +4,7 @@ use std::{
};
use anyhow::anyhow;
use bytes::Bytes;
use cache::CloneCache;
use resolver_api::Resolve;
use response::JsonBytes;
@@ -120,8 +121,8 @@ async fn handle_socket<W: Websocket>(
let handle_reads = async {
loop {
match ws_read.recv().await {
Ok(WebsocketMessage::Message(bytes)) => {
handle_incoming_message(args, sender, bytes).await
Ok(WebsocketMessage::Message(message)) => {
handle_incoming_message(args, sender, message).await
}
Ok(WebsocketMessage::Close(frame)) => {
warn!("Connection closed with frame: {frame:?}");
@@ -150,7 +151,7 @@ async fn handle_incoming_message(
sender: &Sender,
message: Message,
) {
let (channel, state) = match message.channel_and_state() {
let (data, channel, state) = match message.into_parts() {
Ok(res) => res,
Err(e) => {
warn!("Failed to parse transport bytes | {e:#}");
@@ -159,10 +160,10 @@ async fn handle_incoming_message(
};
match state {
MessageState::Request => {
handle_request(args.clone(), sender.clone(), channel, message)
handle_request(args.clone(), sender.clone(), channel, data)
}
MessageState::Terminal => {
crate::terminal::handle_incoming_message(channel, message).await
crate::terminal::handle_message(channel, data).await
}
// Shouldn't be received by Periphery
MessageState::InProgress => {}
@@ -175,16 +176,9 @@ fn handle_request(
args: Arc<Args>,
sender: Sender,
req_id: Uuid,
message: Message,
request: Bytes,
) {
tokio::spawn(async move {
let request = match message.into_data() {
Ok(req) if !req.is_empty() => req,
_ => {
return;
}
};
let request =
match serde_json::from_slice::<PeripheryRequest>(&request) {
Ok(req) => req,

View File

@@ -27,7 +27,7 @@ use transport::{
AUTH_TIMEOUT, ConnectionIdentifiers, HeaderConnectionIdentifiers,
ServerLoginFlow,
},
message::{Message, MessageState},
message::MessageState,
websocket::{Websocket, axum::AxumWebsocket},
};
@@ -200,9 +200,6 @@ async fn handle_passkey_login(
.recv_result()
.with_timeout(AUTH_TIMEOUT)
.await
.flatten()
.flatten()
.and_then(Message::into_data)
.context("Failed to receive passkey from Core")?;
if passkeys

View File

@@ -13,7 +13,6 @@ use komodo_client::{
use portable_pty::{CommandBuilder, PtySize, native_pty_system};
use tokio::sync::{broadcast, mpsc};
use tokio_util::sync::CancellationToken;
use transport::message::Message;
use uuid::Uuid;
pub type TerminalChannels =
@@ -25,13 +24,10 @@ pub fn terminal_channels() -> &'static TerminalChannels {
TERMINAL_CHANNELS.get_or_init(Default::default)
}
pub async fn handle_incoming_message(id: Uuid, message: Message) {
let Some((channel, _)) = terminal_channels().get(&id).await else {
warn!("No terminal channel for {id}");
return;
};
let Ok(data) = message.into_data() else {
warn!("Got terminal message with no data for {id}");
pub async fn handle_message(channel_id: Uuid, data: Bytes) {
let Some((channel, _)) = terminal_channels().get(&channel_id).await
else {
warn!("No terminal channel for {channel_id}");
return;
};
let msg = match data.first() {
@@ -52,7 +48,7 @@ pub async fn handle_incoming_message(id: Uuid, message: Message) {
None => return,
};
if let Err(e) = channel.send(msg).await {
warn!("No receiver for {id} | {e:?}");
warn!("No receiver for {channel_id} | {e:?}");
};
}

View File

@@ -1,10 +1,5 @@
//! Implementes both sides of Noise handshake
//! using asymmetric private-public key authentication.
//!
//! TODO: Revisit
//! Note. Relies on Server being behind trusted TLS connection.
//! This is trivial for Periphery -> Core connection, but presents a challenge
//! for Core -> Periphery, where untrusted TLS certs are being used.
use std::time::Duration;
@@ -15,7 +10,6 @@ use noise::{NoiseHandshake, key::SpkiPublicKey};
use rand::RngCore;
use sha2::{Digest, Sha256};
use tracing::warn;
use uuid::Uuid;
use crate::{message::MessageState, websocket::Websocket};
@@ -55,11 +49,10 @@ impl LoginFlow for ServerLoginFlow {
) -> anyhow::Result<V::ValidationResult> {
// Server generates random nonce / uuid and sends to client
let nonce = nonce();
let channel = Uuid::new_v4();
let res = async {
socket
.send((nonce.to_vec(), channel, MessageState::Successful))
.send(nonce)
.await
.context("Failed to send connection nonce")?;
@@ -67,7 +60,7 @@ impl LoginFlow for ServerLoginFlow {
private_key,
// Builds the handshake using the connection-unique prologue hash.
// The prologue must be the same on both sides of connection.
&identifiers.hash(&nonce, channel.as_bytes()),
&identifiers.hash(&nonce),
)
.context("Failed to inialize handshake")?;
@@ -76,11 +69,9 @@ impl LoginFlow for ServerLoginFlow {
.recv_result()
.with_timeout(AUTH_TIMEOUT)
.await
.flatten()
.flatten()
.context("Failed to get handshake_m1")?;
handshake
.read_message(handshake_m1.data()?)
.read_message(&handshake_m1)
.context("Failed to read handshake_m1")?;
// Send handshake_m2
@@ -88,7 +79,7 @@ impl LoginFlow for ServerLoginFlow {
.next_message()
.context("Failed to write handshake_m2")?;
socket
.send((handshake_m2, channel, MessageState::Successful))
.send(handshake_m2)
.await
.context("Failed to send handshake_m2")?;
@@ -97,11 +88,9 @@ impl LoginFlow for ServerLoginFlow {
.recv_result()
.with_timeout(AUTH_TIMEOUT)
.await
.flatten()
.flatten()
.context("Failed to get handshake_m3")?;
handshake
.read_message(handshake_m3.data()?)
.read_message(&handshake_m3)
.context("Failed to read handshake_m3")?;
// Server now has client public key
@@ -117,14 +106,14 @@ impl LoginFlow for ServerLoginFlow {
match res {
Ok(res) => {
socket
.send((channel, MessageState::Successful))
.send(MessageState::Successful)
.await
.context("Failed to send login successful to client")?;
Ok(res)
}
Err(e) => {
if let Err(e) = socket
.send((&e, channel))
.send(&e)
.await
.context("Failed to send login failed to client")
{
@@ -151,34 +140,19 @@ impl LoginFlow for ClientLoginFlow {
socket,
}: LoginFlowArgs<'a, 's, V, W>,
) -> anyhow::Result<V::ValidationResult> {
// Receive nonce and channel from server
let (channel, nonce) = match socket
.recv_result()
.with_timeout(AUTH_TIMEOUT)
.await
.flatten()
.flatten()
.and_then(|message| {
Ok((message.channel()?, message.into_data()?))
})
.context("Failed to receive connection nonce")
{
Ok(message) => message,
Err(e) => {
let _ = socket.close(None).await;
warn!(
"Could not get login channel and nonce, closing connection | {e:#}"
);
return Err(e);
}
};
let res = async {
// Receive nonce and channel from server
let nonce = socket
.recv_result()
.with_timeout(AUTH_TIMEOUT)
.await
.context("Failed to receive connection nonce")?;
let mut handshake = NoiseHandshake::new_initiator(
private_key,
// Builds the handshake using the connection-unique prologue hash.
// The prologue must be the same on both sides of connection.
&identifiers.hash(&nonce, channel.as_bytes()),
&identifiers.hash(&nonce),
)
.context("Failed to inialize handshake")?;
@@ -187,7 +161,7 @@ impl LoginFlow for ClientLoginFlow {
.next_message()
.context("Failed to write handshake m1")?;
socket
.send((handshake_m1, channel, MessageState::Successful))
.send(handshake_m1)
.await
.context("Failed to send handshake_m1")?;
@@ -196,11 +170,9 @@ impl LoginFlow for ClientLoginFlow {
.recv_result()
.with_timeout(AUTH_TIMEOUT)
.await
.flatten()
.flatten()
.context("Failed to get handshake_m2")?;
handshake
.read_message(handshake_m2.data()?)
.read_message(&handshake_m2)
.context("Failed to read handshake_m2")?;
// Client now has server public key.
@@ -217,7 +189,7 @@ impl LoginFlow for ClientLoginFlow {
.next_message()
.context("Failed to write handshake_m3")?;
socket
.send((handshake_m3, channel, MessageState::Successful))
.send(handshake_m3)
.await
.context("Failed to send handshake_m3")?;
@@ -226,8 +198,6 @@ impl LoginFlow for ClientLoginFlow {
.recv_result()
.with_timeout(AUTH_TIMEOUT)
.await
.flatten()
.flatten()
.context("Failed to receive authentication state message")?;
anyhow::Ok(validation_result)
@@ -238,7 +208,7 @@ impl LoginFlow for ClientLoginFlow {
Ok(res) => Ok(res),
Err(e) => {
if let Err(e) = socket
.send((&e, channel))
.send(&e)
.await
.context("Failed to send login failed to client")
{
@@ -272,7 +242,7 @@ pub struct ConnectionIdentifiers<'a> {
impl ConnectionIdentifiers<'_> {
/// nonce: Server computed random connection nonce, sent to client before auth handshake
pub fn hash(&self, nonce: &[u8], channel: &[u8]) -> [u8; 32] {
pub fn hash(&self, nonce: &[u8]) -> [u8; 32] {
let mut hash = Sha256::new();
hash.update(b"noise-wss-v1|");
hash.update(self.host);
@@ -282,8 +252,6 @@ impl ConnectionIdentifiers<'_> {
hash.update(self.accept);
hash.update(b"|");
hash.update(nonce);
hash.update(b"|");
hash.update(channel);
hash.finalize().into()
}
}

View File

@@ -3,10 +3,12 @@
use std::time::Duration;
use anyhow::anyhow;
use anyhow::{Context, anyhow};
use bytes::Bytes;
use futures_util::FutureExt;
use pin_project_lite::pin_project;
use serror::deserialize_error_bytes;
use uuid::Uuid;
use crate::message::{Message, MessageState};
@@ -45,11 +47,22 @@ pub trait Websocket: Send {
> + Send,
>;
fn send(
&mut self,
message: impl Into<Message>,
) -> impl Future<Output = Result<(), Self::Error>>;
/// Send close message
fn close(
&mut self,
frame: Option<Self::CloseFrame>,
) -> impl Future<Output = Result<(), Self::Error>>;
/// Looping receiver for websocket messages which only returns on messages.
fn recv_message(
&mut self,
) -> MaybeWithTimeout<
impl Future<Output = Result<Message, anyhow::Error>> + Send,
impl Future<Output = anyhow::Result<Message>> + Send,
> {
MaybeWithTimeout {
inner: async {
@@ -66,38 +79,42 @@ pub trait Websocket: Send {
}
}
/// Auto deserializes non-successful message errors
fn recv_result(
/// Receive message + message.into_parts
fn recv_parts(
&mut self,
) -> MaybeWithTimeout<
impl Future<Output = Result<anyhow::Result<Message>, anyhow::Error>>
impl Future<Output = anyhow::Result<(Bytes, Uuid, MessageState)>>
+ Send,
> {
MaybeWithTimeout {
inner: self.recv_message().map(|res| {
res.map(|message| match message.state()? {
MessageState::Successful => Ok(message),
_ => Err(deserialize_error_bytes(message.data()?)),
})
}),
inner: self
.recv_message()
.map(|res| res.map(|message| message.into_parts()).flatten()),
}
}
/// Streamlined sending on bytes
fn send(
/// Auto deserializes non-successful message errors.
/// Discards the channels.
fn recv_result(
&mut self,
message: impl Into<Message>,
) -> impl Future<Output = Result<(), Self::Error>>;
/// Send close message
fn close(
&mut self,
frame: Option<Self::CloseFrame>,
) -> impl Future<Output = Result<(), Self::Error>>;
) -> MaybeWithTimeout<
impl Future<Output = anyhow::Result<Bytes>> + Send,
> {
MaybeWithTimeout {
inner: self.recv_parts().map(|res| {
res
.map(|(data, _, state)| match state {
MessageState::Successful => Ok(data),
_ => Err(deserialize_error_bytes(&data)),
})
.flatten()
}),
}
}
}
/// Traits for split websocket receiver
pub trait WebsocketReceiver {
pub trait WebsocketReceiver: Send {
type CloseFrame: std::fmt::Debug + Send + Sync + 'static;
type Error: std::error::Error + Send + Sync + 'static;
@@ -107,8 +124,42 @@ pub trait WebsocketReceiver {
&mut self,
) -> impl Future<
Output = Result<WebsocketMessage<Self::CloseFrame>, Self::Error>,
> + Send
+ Sync;
> + Send;
/// Looping receiver for websocket messages which only returns on messages.
fn recv_message(
&mut self,
) -> MaybeWithTimeout<
impl Future<Output = anyhow::Result<Message>> + Send,
> {
MaybeWithTimeout {
inner: async {
match self.recv().await? {
WebsocketMessage::Message(message) => Ok(message),
WebsocketMessage::Close(frame) => {
Err(anyhow!("Connection closed with framed: {frame:?}"))
}
WebsocketMessage::Closed => {
Err(anyhow!("Connection already closed"))
}
}
},
}
}
/// Receive message + message.into_parts
fn recv_parts(
&mut self,
) -> MaybeWithTimeout<
impl Future<Output = anyhow::Result<(Bytes, Uuid, MessageState)>>
+ Send,
> {
MaybeWithTimeout {
inner: self
.recv_message()
.map(|res| res.map(|message| message.into_parts()).flatten()),
}
}
}
/// Traits for split websocket receiver
@@ -120,13 +171,13 @@ pub trait WebsocketSender {
fn send(
&mut self,
message: Message,
) -> impl Future<Output = Result<(), Self::Error>> + Send + Sync;
) -> impl Future<Output = Result<(), Self::Error>> + Send;
/// Send close message
fn close(
&mut self,
frame: Option<Self::CloseFrame>,
) -> impl Future<Output = Result<(), Self::Error>> + Send + Sync;
) -> impl Future<Output = Result<(), Self::Error>> + Send;
}
pin_project! {
@@ -147,12 +198,21 @@ impl<F: Future> Future for MaybeWithTimeout<F> {
}
}
impl<F: Future + Send> MaybeWithTimeout<F> {
impl<
O,
E: Into<anyhow::Error>,
F: Future<Output = Result<O, E>> + Send,
> MaybeWithTimeout<F>
{
pub fn with_timeout(
self,
timeout: Duration,
) -> impl Future<Output = anyhow::Result<F::Output>> + Send {
tokio::time::timeout(timeout, self.inner)
.map(|res| res.map_err(|_| anyhow!("Timed out")))
) -> impl Future<Output = anyhow::Result<O>> + Send {
tokio::time::timeout(timeout, self.inner).map(|res| {
res
.context("Timed out waiting for message.")
.map(|inner| inner.map_err(Into::into))
.flatten()
})
}
}