fix passkey support

This commit is contained in:
mbecker20
2025-10-06 22:12:15 -07:00
parent e9d13449bf
commit c3ea0239d6
19 changed files with 412 additions and 385 deletions

1
Cargo.lock generated
View File

@@ -5686,6 +5686,7 @@ dependencies = [
"sha2",
"tokio",
"tokio-tungstenite",
"tokio-util",
"tracing",
"url",
"uuid",

View File

@@ -439,7 +439,6 @@ async fn get_on_host_periphery(
&config,
),
config.insecure_tls,
&config.passkey,
)
.await?;
// Poll for connection to be estalished

View File

@@ -24,7 +24,6 @@ impl PeripheryConnectionArgs<'_> {
self,
id: String,
insecure: bool,
passkey: String,
) -> anyhow::Result<Arc<ConnectionChannels>> {
let Some(address) = self.address else {
return Err(anyhow!(
@@ -71,9 +70,8 @@ impl PeripheryConnectionArgs<'_> {
core_connection_query().as_bytes(),
);
if let Err(e) = connection
.client_login(&mut socket, identifiers, &passkey)
.await
if let Err(e) =
connection.client_login(&mut socket, identifiers).await
{
connection.set_error(e).await;
tokio::time::sleep(Duration::from_secs(
@@ -98,13 +96,11 @@ impl PeripheryConnection {
&self,
socket: &mut TungsteniteWebsocket,
identifiers: ConnectionIdentifiers<'_>,
// for legacy auth
passkey: &str,
) -> anyhow::Result<()> {
// Get the required auth type
let bytes = socket
.recv_result()
.with_timeout(Duration::from_secs(2))
.with_timeout(AUTH_TIMEOUT)
.await
.context("Failed to receive login type indicator")?;
@@ -116,7 +112,10 @@ impl PeripheryConnection {
.await
}
// Passkey auth
&[1] => handle_passkey_login(socket, passkey).await,
&[1] => {
handle_passkey_login(socket, self.args.passkey.as_deref())
.await
}
other => Err(anyhow!(
"Receieved invalid login type pattern: {other:?}"
)),
@@ -127,18 +126,18 @@ impl PeripheryConnection {
async fn handle_passkey_login(
socket: &mut TungsteniteWebsocket,
// for legacy auth
passkey: &str,
passkey: Option<&str>,
) -> anyhow::Result<()> {
let res = async {
let passkey = if passkey.is_empty() {
let passkey = if let Some(passkey) = passkey {
passkey.as_bytes().to_vec()
} else {
core_config()
.passkey
.as_deref()
.context("Periphery requires passkey auth")?
.as_bytes()
.to_vec()
} else {
passkey.as_bytes().to_vec()
};
socket

View File

@@ -93,6 +93,9 @@ pub struct PeripheryConnectionArgs<'a> {
pub id: &'a str,
pub address: Option<&'a str>,
periphery_public_key: Option<&'a str>,
/// V1 legacy support.
/// Only possible for Core -> Periphery.
passkey: Option<&'a str>,
}
impl PublicKeyValidator for PeripheryConnectionArgs<'_> {
@@ -145,6 +148,7 @@ impl<'a> PeripheryConnectionArgs<'a> {
id: &server.id,
address: optional_str(&server.config.address),
periphery_public_key: optional_str(&server.info.public_key),
passkey: optional_str(&server.config.passkey),
}
}
@@ -158,6 +162,7 @@ impl<'a> PeripheryConnectionArgs<'a> {
periphery_public_key: optional_str(
&config.periphery_public_key,
),
passkey: optional_str(&config.passkey),
}
}
@@ -172,6 +177,7 @@ impl<'a> PeripheryConnectionArgs<'a> {
periphery_public_key: optional_str(
&config.periphery_public_key,
),
passkey: None,
}
}
@@ -182,6 +188,7 @@ impl<'a> PeripheryConnectionArgs<'a> {
periphery_public_key: self
.periphery_public_key
.map(str::to_string),
passkey: self.passkey.map(str::to_string),
}
}
@@ -204,6 +211,9 @@ pub struct OwnedPeripheryConnectionArgs {
/// If None, must have 'periphery_public_keys' set
/// in Core config, or will error
pub periphery_public_key: Option<String>,
/// V1 legacy support.
/// Only possible for Core -> Periphery connection.
pub passkey: Option<String>,
}
impl OwnedPeripheryConnectionArgs {
@@ -212,6 +222,7 @@ impl OwnedPeripheryConnectionArgs {
id: &self.id,
address: self.address.as_deref(),
periphery_public_key: self.periphery_public_key.as_deref(),
passkey: self.passkey.as_deref(),
}
}
}
@@ -319,19 +330,14 @@ impl PeripheryConnection {
let (mut ws_write, mut ws_read) = socket.split();
ws_read.set_cancel(cancel.clone());
receiver.set_cancel(cancel.clone());
let forward_writes = async {
loop {
let next = tokio::select! {
next = receiver.recv() => next,
_ = cancel.cancelled() => break,
let Ok(message) = receiver.recv().await else {
break;
};
let message = match next {
Some(request) => request.to_message(),
// Sender Dropped (shouldn't happen, a reference is held on 'connection').
None => break,
};
match ws_write.send(message).await {
Ok(_) => receiver.clear_buffer(),
Err(e) => {
@@ -347,12 +353,7 @@ impl PeripheryConnection {
let handle_reads = async {
loop {
let next = tokio::select! {
next = ws_read.recv() => next,
_ = cancel.cancelled() => break,
};
match next {
match ws_read.recv().await {
Ok(WebsocketMessage::Message(message)) => {
self.handle_incoming_message(message).await
}

View File

@@ -1,4 +1,4 @@
use std::{str::FromStr, time::Duration};
use std::str::FromStr;
use anyhow::{Context, anyhow};
use axum::{
@@ -21,8 +21,8 @@ use serror::{AddStatusCode, AddStatusCodeError};
use transport::{
PeripheryConnectionQuery,
auth::{
HeaderConnectionIdentifiers, LoginFlow, LoginFlowArgs,
PublicKeyValidator, ServerLoginFlow,
AUTH_TIMEOUT, HeaderConnectionIdentifiers, LoginFlow,
LoginFlowArgs, PublicKeyValidator, ServerLoginFlow,
},
message::MessageState,
websocket::{Websocket, axum::AxumWebsocket},
@@ -176,7 +176,7 @@ async fn onboard_server_handler(
// Post onboarding login 1: Receive public key
let public_key = socket
.recv_result()
.with_timeout(Duration::from_secs(2))
.with_timeout(AUTH_TIMEOUT)
.await
.and_then(|bytes| {
String::from_utf8(bytes.into())

View File

@@ -54,7 +54,6 @@ pub async fn get_builder_periphery(
&config,
),
config.insecure_tls,
&config.passkey,
)
.await?;
periphery
@@ -116,7 +115,6 @@ async fn get_aws_builder(
&config,
),
config.insecure_tls,
"",
)
.await?;

View File

@@ -195,7 +195,6 @@ pub async fn periphery_client(
PeripheryClient::new(
PeripheryConnectionArgs::from_server(server),
server.config.insecure_tls,
&server.config.passkey,
)
.await
}

View File

@@ -33,8 +33,6 @@ impl PeripheryClient {
pub async fn new(
args: PeripheryConnectionArgs<'_>,
insecure_tls: bool,
// deprecated.
passkey: &str,
) -> anyhow::Result<PeripheryClient> {
let connections = periphery_connections();
@@ -49,7 +47,6 @@ impl PeripheryClient {
.spawn_client_connection(
id.clone(),
insecure_tls,
passkey.to_string(),
)
.await?;
return Ok(PeripheryClient { id, channels });
@@ -83,7 +80,6 @@ impl PeripheryClient {
.spawn_client_connection(
id.clone(),
insecure_tls,
passkey.to_string(),
)
.await?;
Ok(PeripheryClient { id, channels })

View File

@@ -191,6 +191,8 @@ async fn forward_ws_channel(
let (mut core_send, mut core_receive) = client_socket.split();
let cancel = CancellationToken::new();
periphery_receiver.set_cancel(cancel.clone());
trace!("starting ws exchange");
let core_to_periphery = async {
@@ -253,15 +255,13 @@ async fn forward_ws_channel(
let periphery_to_core = async {
loop {
let res = tokio::select! {
res = periphery_receiver.recv() => res.map(TransportMessage::into_data),
_ = cancel.cancelled() => {
trace!("periphery to core read: cancelled from inside");
break;
}
};
// Already adheres to cancellation token
let res = periphery_receiver
.recv()
.await
.and_then(TransportMessage::into_data);
match res {
Some(Ok(bytes)) => {
Ok(bytes) => {
if let Err(e) = core_send.send(Message::Binary(bytes)).await
{
debug!("{e:?}");
@@ -269,9 +269,7 @@ async fn forward_ws_channel(
break;
};
}
// No data, ignore
Some(Err(_e)) => {}
None => {
Err(_) => {
let _ = core_send.send(Message::text("STREAM EOF")).await;
cancel.cancel();
break;

View File

@@ -8,7 +8,6 @@ use bytes::Bytes;
use cache::CloneCache;
use resolver_api::Resolve;
use response::JsonBytes;
use serror::serialize_error_bytes;
use transport::{
auth::{
ConnectionIdentifiers, LoginFlow, LoginFlowArgs,
@@ -16,10 +15,7 @@ use transport::{
},
channel::{BufferedChannel, BufferedReceiver, Sender},
message::{Message, MessageState},
websocket::{
Websocket, WebsocketMessage, WebsocketReceiver,
WebsocketSender as _,
},
websocket::{Websocket, WebsocketReceiver, WebsocketSender as _},
};
use uuid::Uuid;
@@ -99,15 +95,15 @@ async fn handle_socket<W: Websocket>(
let forward_writes = async {
loop {
let msg = match receiver.recv().await {
// Sender Dropped (shouldn't happen, it is static).
None => break,
// This has to copy the bytes to follow ownership rules.
Some(msg) => msg,
let message = match receiver.recv().await {
Ok(message) => message,
Err(e) => {
warn!("{e:#}");
break;
}
};
match ws_write.send(msg.to_message()).await {
match ws_write.send(message).await {
// Clears the stored message from receiver buffer.
// TODO: Move after response ack.
Ok(_) => receiver.clear_buffer(),
Err(e) => {
warn!("Failed to send response | {e:?}");
@@ -120,23 +116,23 @@ async fn handle_socket<W: Websocket>(
let handle_reads = async {
loop {
match ws_read.recv().await {
Ok(WebsocketMessage::Message(message)) => {
handle_incoming_message(args, sender, message).await
}
Ok(WebsocketMessage::Close(frame)) => {
warn!("Connection closed with frame: {frame:?}");
break;
}
Ok(WebsocketMessage::Closed) => {
warn!("Connection already closed");
break;
}
let (data, channel, state) = match ws_read.recv_parts().await {
Ok(res) => res,
Err(e) => {
warn!("Failed to read websocket message | {e:?}");
warn!("{e:#}");
break;
}
};
match state {
MessageState::Request => {
handle_request(args.clone(), sender.clone(), channel, data)
}
MessageState::Terminal => {
crate::terminal::handle_message(channel, data).await
}
// Rest shouldn't be received by Periphery
_ => {}
}
}
};
@@ -146,32 +142,6 @@ async fn handle_socket<W: Websocket>(
}
}
async fn handle_incoming_message(
args: &Arc<Args>,
sender: &Sender,
message: Message,
) {
let (data, channel, state) = match message.into_parts() {
Ok(res) => res,
Err(e) => {
warn!("Failed to parse transport bytes | {e:#}");
return;
}
};
match state {
MessageState::Request => {
handle_request(args.clone(), sender.clone(), channel, data)
}
MessageState::Terminal => {
crate::terminal::handle_message(channel, data).await
}
// Shouldn't be received by Periphery
MessageState::InProgress => {}
MessageState::Successful => {}
MessageState::Failed => {}
}
}
fn handle_request(
args: Arc<Args>,
sender: Sender,
@@ -190,20 +160,14 @@ fn handle_request(
};
let resolve_response = async {
let (state, data) = match request.resolve(&args).await {
Ok(JsonBytes::Ok(res)) => (MessageState::Successful, res),
Ok(JsonBytes::Err(e)) => (
MessageState::Failed,
serialize_error_bytes(
&anyhow::Error::new(e)
.context("Failed to serialize response body"),
),
),
Err(e) => {
(MessageState::Failed, serialize_error_bytes(&e.error))
let message: Message = match request.resolve(&args).await {
Ok(JsonBytes::Ok(res)) => {
(res, req_id, MessageState::Successful).into()
}
Ok(JsonBytes::Err(e)) => (&e.into(), req_id).into(),
Err(e) => (&e.error, req_id).into(),
};
if let Err(e) = sender.send((data, req_id, state)).await {
if let Err(e) = sender.send(message).await {
error!("Failed to send response over channel | {e:?}");
}
};

View File

@@ -18,9 +18,7 @@ use axum::{
routing::get,
};
use axum_server::tls_rustls::RustlsConfig;
use serror::{
AddStatusCode, AddStatusCodeError, serialize_error_bytes,
};
use serror::{AddStatusCode, AddStatusCodeError};
use transport::{
CoreConnectionQuery,
auth::{
@@ -106,6 +104,28 @@ async fn handler(
Ok(ws.on_upgrade(|socket| async move {
let mut socket = AxumWebsocket(socket);
// Make sure receiver locked over the login.
let mut receiver = match channel.receiver() {
Ok(receiver) => receiver,
Err(e) => {
warn!("Failed to forward connection | {e:#}");
if let Err(e) = socket
.send(&e)
.await
.context("Failed to send forward failed to client")
{
// Log additional error
warn!("{e:#}");
}
// Close socket
let _ = socket.close(None).await;
return;
}
};
let query = format!("core={}", urlencoding::encode(&args.core));
if let Err(e) =
@@ -124,30 +144,6 @@ async fn handler(
already_logged_login_error()
.store(false, atomic::Ordering::Relaxed);
let mut receiver = match channel.receiver() {
Ok(receiver) => receiver,
Err(e) => {
warn!("Failed to forward connection | {e:#}");
let mut bytes = serialize_error_bytes(&e);
bytes.push(MessageState::Failed.as_byte());
if let Err(e) = socket
.send(&e)
.await
.context("Failed to send forward failed to client")
{
// Log additional error
warn!("{e:#}");
}
// Close socket
let _ = socket.close(None).await;
return;
}
};
super::handle_socket(
socket,
&args,
@@ -184,9 +180,11 @@ async fn handle_passkey_login(
socket: &mut AxumWebsocket,
passkeys: &[String],
) -> anyhow::Result<()> {
warn!(
"Authenticating using Passkeys. Set 'core_public_key' (PERIPHERY_CORE_PUBLIC_KEY) instead to enhance security."
);
if !already_logged_login_error().load(atomic::Ordering::Relaxed) {
warn!(
"Authenticating using Passkeys. Set 'core_public_key' (PERIPHERY_CORE_PUBLIC_KEY) instead to enhance security."
);
};
let res = async {
// Send login type
socket

View File

@@ -15,6 +15,7 @@ serror.workspace = true
tokio-tungstenite.workspace = true
pin-project-lite.workspace = true
futures-util.workspace = true
tokio-util.workspace = true
tracing.workspace = true
anyhow.workspace = true
base64.workspace = true

View File

@@ -1,9 +1,14 @@
use anyhow::Context;
use anyhow::{Context, anyhow};
use bytes::Bytes;
use futures_util::FutureExt;
use tokio::sync::{Mutex, MutexGuard, mpsc};
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use crate::message::{BorrowedMessage, Message, MessageState};
use crate::{
message::{Message, MessageState},
timeout::MaybeWithTimeout,
};
const RESPONSE_BUFFER_MAX_LEN: usize = 1_024;
@@ -37,7 +42,13 @@ impl BufferedChannel {
/// Create a channel
pub fn channel() -> (Sender, Receiver) {
let (sender, receiver) = mpsc::channel(RESPONSE_BUFFER_MAX_LEN);
(Sender(sender), Receiver(receiver))
(
Sender(sender),
Receiver {
receiver,
cancel: None,
},
)
}
/// Create a buffered channel
@@ -59,25 +70,60 @@ impl Sender {
}
#[derive(Debug)]
pub struct Receiver(mpsc::Receiver<Message>);
pub struct Receiver {
receiver: mpsc::Receiver<Message>,
cancel: Option<CancellationToken>,
}
impl Receiver {
pub async fn recv(&mut self) -> Option<Message> {
self.0.recv().await
}
pub async fn recv_parts(
&mut self,
) -> anyhow::Result<(Bytes, Uuid, MessageState)> {
let message = self.recv().await.context("Channel is broken")?;
message.into_parts()
pub fn set_cancel(&mut self, cancel: CancellationToken) {
self.cancel = Some(cancel);
}
pub fn poll_recv(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Message>> {
self.0.poll_recv(cx)
if let Some(cancel) = &self.cancel
&& cancel.is_cancelled()
{
return std::task::Poll::Ready(None);
}
self.receiver.poll_recv(cx)
}
pub fn recv(
&mut self,
) -> MaybeWithTimeout<
impl Future<Output = anyhow::Result<Message>> + Send,
> {
MaybeWithTimeout::new(async {
let recv = self
.receiver
.recv()
.map(|res| res.context("Channel is permanently closed"));
if let Some(cancel) = &self.cancel {
tokio::select! {
message = recv => message,
_ = cancel.cancelled() => Err(anyhow!("Stream cancelled"))
}
} else {
recv.await
}
})
}
pub fn recv_parts(
&mut self,
) -> MaybeWithTimeout<
impl Future<Output = anyhow::Result<(Bytes, Uuid, MessageState)>>
+ Send,
> {
MaybeWithTimeout::new(self.recv().map(|res| {
res
.context("Channel is permanently closed.")
.and_then(Message::into_parts)
}))
}
}
@@ -96,17 +142,41 @@ impl BufferedReceiver {
}
}
pub fn set_cancel(&mut self, cancel: CancellationToken) {
self.receiver.set_cancel(cancel);
}
/// - If 'buffer: Some(bytes)':
/// - Immediately returns borrow of buffer.
/// - Else:
/// - Wait for next item.
/// - store in buffer.
/// - return borrow of buffer.
pub async fn recv(&mut self) -> Option<BorrowedMessage<'_>> {
if self.buffer.is_none() {
self.buffer = Some(self.receiver.recv().await?);
}
self.buffer.as_ref().map(Message::borrow)
pub fn recv(
&mut self,
) -> MaybeWithTimeout<
impl Future<Output = anyhow::Result<Message>> + Send,
> {
MaybeWithTimeout::new(async {
if let Some(buffer) = self.buffer.clone() {
Ok(buffer)
} else {
let message = self.receiver.recv().await?;
self.buffer = Some(message.clone());
Ok(message)
}
})
}
pub fn recv_parts(
&mut self,
) -> MaybeWithTimeout<
impl Future<Output = anyhow::Result<(Bytes, Uuid, MessageState)>>
+ Send,
> {
MaybeWithTimeout::new(
self.recv().map(|res| res.and_then(Message::into_parts)),
)
}
/// Clears buffer.

View File

@@ -4,6 +4,7 @@ use serde::Deserialize;
pub mod auth;
pub mod channel;
pub mod message;
pub mod timeout;
pub mod websocket;
pub trait TransportHandler {

View File

@@ -81,7 +81,7 @@ impl From<&anyhow::Error> for Message {
}
}
#[derive(Default, Debug)]
#[derive(Default, Clone, Debug)]
pub struct Message(Bytes);
impl Message {
@@ -101,71 +101,11 @@ impl Message {
self.0
}
pub fn borrow(&self) -> BorrowedMessage<'_> {
BorrowedMessage(&self.0)
}
pub fn is_empty(&self) -> bool {
self.borrow().is_empty()
}
pub fn state(&self) -> anyhow::Result<MessageState> {
self.borrow().state()
}
pub fn channel(&self) -> anyhow::Result<Uuid> {
self.borrow().channel()
}
pub fn channel_and_state(
&self,
) -> anyhow::Result<(Uuid, MessageState)> {
self.borrow().channel_and_state()
}
pub fn data(&self) -> anyhow::Result<&[u8]> {
// Does not work with .borrow() due to lifetime issue
let len = self.0.len();
if len < 17 {
return Err(anyhow!(
"Transport bytes too short to include uuid + state + data"
));
}
Ok(&self.0[..(len - 17)])
}
pub fn into_data(self) -> anyhow::Result<Bytes> {
let len = self.0.len();
if len < 17 {
return Err(anyhow!(
"Transport bytes too short to include uuid + state + data"
));
}
let mut res: Vec<u8> = self.0.into();
res.drain((len - 17)..);
Ok(res.into())
}
pub fn into_parts(
self,
) -> anyhow::Result<(Bytes, Uuid, MessageState)> {
let (channel, state) = self.channel_and_state()?;
let data = self.into_data()?;
Ok((data, channel, state))
}
}
// Borrowed Message
#[derive(Debug, Clone, Copy)]
pub struct BorrowedMessage<'a>(&'a [u8]);
impl BorrowedMessage<'_> {
pub fn is_empty(self) -> bool {
self.0.is_empty()
}
pub fn state(self) -> anyhow::Result<MessageState> {
pub fn state(&self) -> anyhow::Result<MessageState> {
self
.0
.last()
@@ -200,6 +140,7 @@ impl BorrowedMessage<'_> {
}
pub fn data(&self) -> anyhow::Result<&[u8]> {
// Does not work with .borrow() due to lifetime issue
let len = self.0.len();
if len < 17 {
return Err(anyhow!(
@@ -209,20 +150,27 @@ impl BorrowedMessage<'_> {
Ok(&self.0[..(len - 17)])
}
/// This will clone the bytes
pub fn to_message(self) -> Message {
Message(Bytes::copy_from_slice(self.0))
pub fn into_data(self) -> anyhow::Result<Bytes> {
let len = self.0.len();
if len < 17 {
return Err(anyhow!(
"Transport bytes too short to include uuid + state + data"
));
}
let mut res: Vec<u8> = self.0.into();
res.drain((len - 17)..);
Ok(res.into())
}
pub fn into_parts(
self,
) -> anyhow::Result<(Bytes, Uuid, MessageState)> {
let (channel, state) = self.channel_and_state()?;
let data = self.into_data()?;
Ok((data, channel, state))
}
}
impl<'a> BorrowedMessage<'a> {
pub fn into_inner(self) -> &'a [u8] {
self.0
}
}
// Message State
#[derive(Debug, Clone, Copy)]
pub enum MessageState {
Successful = 0,

View File

@@ -0,0 +1,48 @@
use std::time::Duration;
use anyhow::Context as _;
use futures_util::FutureExt as _;
use pin_project_lite::pin_project;
pin_project! {
pub struct MaybeWithTimeout<F> {
#[pin]
inner: F,
}
}
impl<F> MaybeWithTimeout<F> {
pub fn new(inner: F) -> MaybeWithTimeout<F> {
MaybeWithTimeout { inner }
}
}
impl<F: Future> Future for MaybeWithTimeout<F> {
type Output = F::Output;
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
let mut inner = self.project().inner;
inner.as_mut().poll(cx)
}
}
impl<
O,
E: Into<anyhow::Error>,
F: Future<Output = Result<O, E>> + Send,
> MaybeWithTimeout<F>
{
pub fn with_timeout(
self,
timeout: Duration,
) -> impl Future<Output = anyhow::Result<O>> + Send {
tokio::time::timeout(timeout, self.inner).map(|res| {
res
.context("Timed out waiting for message.")
.map(|inner| inner.map_err(Into::into))
.flatten()
})
}
}

View File

@@ -1,75 +1,98 @@
use anyhow::{Context, anyhow};
use axum::extract::ws::CloseFrame;
use futures_util::{
SinkExt, Stream, StreamExt, TryStreamExt,
stream::{SplitSink, SplitStream},
};
use tokio_util::sync::CancellationToken;
use crate::message::Message;
use crate::{message::Message, timeout::MaybeWithTimeout};
use super::{
MaybeWithTimeout, Websocket, WebsocketMessage, WebsocketReceiver,
WebsocketSender,
Websocket, WebsocketMessage, WebsocketReceiver, WebsocketSender,
};
pub struct AxumWebsocket(pub axum::extract::ws::WebSocket);
impl Websocket for AxumWebsocket {
type CloseFrame = CloseFrame;
type Error = axum::Error;
fn split(self) -> (impl WebsocketSender, impl WebsocketReceiver) {
let (tx, rx) = self.0.split();
(AxumWebsocketSender(tx), AxumWebsocketReceiver(rx))
(AxumWebsocketSender(tx), AxumWebsocketReceiver::new(rx))
}
fn recv(
&mut self,
) -> MaybeWithTimeout<
impl Future<
Output = Result<
WebsocketMessage<Self::CloseFrame>,
Self::Error,
>,
Output = anyhow::Result<WebsocketMessage<Self::CloseFrame>>,
>,
> {
MaybeWithTimeout {
inner: try_next(&mut self.0),
}
MaybeWithTimeout::new(try_next(&mut self.0))
}
async fn send(
&mut self,
message: impl Into<Message>,
) -> Result<(), Self::Error> {
) -> anyhow::Result<()> {
self
.0
.send(axum::extract::ws::Message::Binary(
message.into().into_inner(),
))
.await
.context("Failed to send message bytes over websocket")
}
async fn close(
&mut self,
frame: Option<Self::CloseFrame>,
) -> Result<(), Self::Error> {
self.0.send(axum::extract::ws::Message::Close(frame)).await
) -> anyhow::Result<()> {
self
.0
.send(axum::extract::ws::Message::Close(frame))
.await
.context("Failed to send websocket close frame")
}
}
pub type InnerWebsocketReceiver =
SplitStream<axum::extract::ws::WebSocket>;
pub struct AxumWebsocketReceiver(pub InnerWebsocketReceiver);
pub struct AxumWebsocketReceiver {
receiver: InnerWebsocketReceiver,
cancel: Option<CancellationToken>,
}
impl AxumWebsocketReceiver {
pub fn new(receiver: InnerWebsocketReceiver) -> Self {
Self {
receiver,
cancel: None,
}
}
}
impl WebsocketReceiver for AxumWebsocketReceiver {
type CloseFrame = CloseFrame;
type Error = axum::Error;
fn set_cancel(&mut self, cancel: CancellationToken) {
self.cancel = Some(cancel);
}
async fn recv(
&mut self,
) -> Result<WebsocketMessage<Self::CloseFrame>, Self::Error> {
try_next(&mut self.0).await
) -> anyhow::Result<WebsocketMessage<Self::CloseFrame>> {
let fut = try_next(&mut self.receiver);
if let Some(cancel) = &self.cancel {
tokio::select! {
res = fut => res,
_ = cancel.cancelled() => Err(anyhow!("Cancelled before receive"))
}
} else {
fut.await
}
}
}
@@ -80,29 +103,30 @@ pub struct AxumWebsocketSender(pub InnerWebsocketSender);
impl WebsocketSender for AxumWebsocketSender {
type CloseFrame = CloseFrame;
type Error = axum::Error;
async fn send(
&mut self,
message: Message,
) -> Result<(), Self::Error> {
async fn send(&mut self, message: Message) -> anyhow::Result<()> {
self
.0
.send(axum::extract::ws::Message::Binary(message.into_inner()))
.await
.context("Failed to send message over websocket")
}
async fn close(
&mut self,
frame: Option<Self::CloseFrame>,
) -> Result<(), Self::Error> {
self.0.send(axum::extract::ws::Message::Close(frame)).await
) -> anyhow::Result<()> {
self
.0
.send(axum::extract::ws::Message::Close(frame))
.await
.context("Failed to send websocket close frame")
}
}
async fn try_next<S>(
stream: &mut S,
) -> Result<WebsocketMessage<CloseFrame>, axum::Error>
) -> anyhow::Result<WebsocketMessage<CloseFrame>>
where
S: Stream<Item = Result<axum::extract::ws::Message, axum::Error>>
+ Unpin,

View File

@@ -1,16 +1,17 @@
//! Wrappers to normalize behavior of websockets between Tungstenite and Axum,
//! as well as streamline process of handling socket messages.
use std::time::Duration;
use anyhow::{Context, anyhow};
use bytes::Bytes;
use futures_util::FutureExt;
use pin_project_lite::pin_project;
use serror::deserialize_error_bytes;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use crate::message::{Message, MessageState};
use crate::{
message::{Message, MessageState},
timeout::MaybeWithTimeout,
};
pub mod axum;
pub mod tungstenite;
@@ -29,7 +30,6 @@ pub enum WebsocketMessage<CloseFrame> {
/// Standard traits for websocket
pub trait Websocket: Send {
type CloseFrame: std::fmt::Debug + Send + Sync + 'static;
type Error: std::error::Error + Send + Sync + 'static;
/// Abstraction over websocket splitting
fn split(self) -> (impl WebsocketSender, impl WebsocketReceiver);
@@ -40,23 +40,20 @@ pub trait Websocket: Send {
&mut self,
) -> MaybeWithTimeout<
impl Future<
Output = Result<
WebsocketMessage<Self::CloseFrame>,
Self::Error,
>,
Output = anyhow::Result<WebsocketMessage<Self::CloseFrame>>,
> + Send,
>;
fn send(
&mut self,
message: impl Into<Message>,
) -> impl Future<Output = Result<(), Self::Error>>;
) -> impl Future<Output = anyhow::Result<()>>;
/// Send close message
fn close(
&mut self,
frame: Option<Self::CloseFrame>,
) -> impl Future<Output = Result<(), Self::Error>>;
) -> impl Future<Output = anyhow::Result<()>>;
/// Looping receiver for websocket messages which only returns on messages.
fn recv_message(
@@ -64,19 +61,17 @@ pub trait Websocket: Send {
) -> MaybeWithTimeout<
impl Future<Output = anyhow::Result<Message>> + Send,
> {
MaybeWithTimeout {
inner: async {
match self.recv().await? {
WebsocketMessage::Message(message) => Ok(message),
WebsocketMessage::Close(frame) => {
Err(anyhow!("Connection closed with framed: {frame:?}"))
}
WebsocketMessage::Closed => {
Err(anyhow!("Connection already closed"))
}
MaybeWithTimeout::new(async {
match self.recv().await? {
WebsocketMessage::Message(message) => Ok(message),
WebsocketMessage::Close(frame) => {
Err(anyhow!("Connection closed with framed: {frame:?}"))
}
},
}
WebsocketMessage::Closed => {
Err(anyhow!("Connection already closed"))
}
}
})
}
/// Receive message + message.into_parts
@@ -86,11 +81,11 @@ pub trait Websocket: Send {
impl Future<Output = anyhow::Result<(Bytes, Uuid, MessageState)>>
+ Send,
> {
MaybeWithTimeout {
inner: self
MaybeWithTimeout::new(
self
.recv_message()
.map(|res| res.map(|message| message.into_parts()).flatten()),
}
)
}
/// Auto deserializes non-successful message errors.
@@ -100,30 +95,30 @@ pub trait Websocket: Send {
) -> MaybeWithTimeout<
impl Future<Output = anyhow::Result<Bytes>> + Send,
> {
MaybeWithTimeout {
inner: self.recv_parts().map(|res| {
res
.map(|(data, _, state)| match state {
MessageState::Successful => Ok(data),
_ => Err(deserialize_error_bytes(&data)),
})
.flatten()
}),
}
MaybeWithTimeout::new(self.recv_parts().map(|res| {
res
.map(|(data, _, state)| match state {
MessageState::Successful => Ok(data),
_ => Err(deserialize_error_bytes(&data)),
})
.flatten()
}))
}
}
/// Traits for split websocket receiver
pub trait WebsocketReceiver: Send {
type CloseFrame: std::fmt::Debug + Send + Sync + 'static;
type Error: std::error::Error + Send + Sync + 'static;
/// Cancellation sensitive receive.
fn set_cancel(&mut self, _cancel: CancellationToken);
/// Looping receiver for websocket messages which only returns
/// on significant messages.
/// on significant messages. Must implement cancel support.
fn recv(
&mut self,
) -> impl Future<
Output = Result<WebsocketMessage<Self::CloseFrame>, Self::Error>,
Output = anyhow::Result<WebsocketMessage<Self::CloseFrame>>,
> + Send;
/// Looping receiver for websocket messages which only returns on messages.
@@ -132,19 +127,21 @@ pub trait WebsocketReceiver: Send {
) -> MaybeWithTimeout<
impl Future<Output = anyhow::Result<Message>> + Send,
> {
MaybeWithTimeout {
inner: async {
match self.recv().await? {
WebsocketMessage::Message(message) => Ok(message),
WebsocketMessage::Close(frame) => {
Err(anyhow!("Connection closed with framed: {frame:?}"))
}
WebsocketMessage::Closed => {
Err(anyhow!("Connection already closed"))
}
MaybeWithTimeout::new(async {
match self
.recv()
.await
.context("Failed to read websocket message")?
{
WebsocketMessage::Message(message) => Ok(message),
WebsocketMessage::Close(frame) => {
Err(anyhow!("Connection closed with framed: {frame:?}"))
}
},
}
WebsocketMessage::Closed => {
Err(anyhow!("Connection already closed"))
}
}
})
}
/// Receive message + message.into_parts
@@ -154,65 +151,27 @@ pub trait WebsocketReceiver: Send {
impl Future<Output = anyhow::Result<(Bytes, Uuid, MessageState)>>
+ Send,
> {
MaybeWithTimeout {
inner: self
MaybeWithTimeout::new(
self
.recv_message()
.map(|res| res.map(|message| message.into_parts()).flatten()),
}
)
}
}
/// Traits for split websocket receiver
pub trait WebsocketSender {
type CloseFrame: std::fmt::Debug + Send + Sync + 'static;
type Error: std::error::Error + Send + Sync + 'static;
/// Streamlined sending on bytes
fn send(
&mut self,
message: Message,
) -> impl Future<Output = Result<(), Self::Error>> + Send;
) -> impl Future<Output = anyhow::Result<()>> + Send;
/// Send close message
fn close(
&mut self,
frame: Option<Self::CloseFrame>,
) -> impl Future<Output = Result<(), Self::Error>> + Send;
}
pin_project! {
pub struct MaybeWithTimeout<F> {
#[pin]
inner: F,
}
}
impl<F: Future> Future for MaybeWithTimeout<F> {
type Output = F::Output;
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
let mut inner = self.project().inner;
inner.as_mut().poll(cx)
}
}
impl<
O,
E: Into<anyhow::Error>,
F: Future<Output = Result<O, E>> + Send,
> MaybeWithTimeout<F>
{
pub fn with_timeout(
self,
timeout: Duration,
) -> impl Future<Output = anyhow::Result<O>> + Send {
tokio::time::timeout(timeout, self.inner).map(|res| {
res
.context("Timed out waiting for message.")
.map(|inner| inner.map_err(Into::into))
.flatten()
})
}
) -> impl Future<Output = anyhow::Result<()>> + Send;
}

View File

@@ -1,6 +1,6 @@
use std::sync::Arc;
use anyhow::Context;
use anyhow::{Context, anyhow};
use axum::http::HeaderValue;
use futures_util::{
SinkExt, Stream, StreamExt, TryStreamExt,
@@ -15,12 +15,12 @@ use tokio_tungstenite::{
self, handshake::client::Response, protocol::CloseFrame,
},
};
use tokio_util::sync::CancellationToken;
use crate::message::Message;
use crate::{message::Message, timeout::MaybeWithTimeout};
use super::{
MaybeWithTimeout, Websocket, WebsocketMessage, WebsocketReceiver,
WebsocketSender,
Websocket, WebsocketMessage, WebsocketReceiver, WebsocketSender,
};
pub type InnerWebsocket = WebSocketStream<MaybeTlsStream<TcpStream>>;
@@ -29,13 +29,12 @@ pub struct TungsteniteWebsocket(pub InnerWebsocket);
impl Websocket for TungsteniteWebsocket {
type CloseFrame = CloseFrame;
type Error = tungstenite::Error;
fn split(self) -> (impl WebsocketSender, impl WebsocketReceiver) {
let (tx, rx) = self.0.split();
(
TungsteniteWebsocketSender(tx),
TungsteniteWebsocketReceiver(rx),
TungsteniteWebsocketReceiver::new(rx),
)
}
@@ -43,48 +42,71 @@ impl Websocket for TungsteniteWebsocket {
&mut self,
) -> MaybeWithTimeout<
impl Future<
Output = Result<
WebsocketMessage<Self::CloseFrame>,
Self::Error,
>,
Output = anyhow::Result<WebsocketMessage<Self::CloseFrame>>,
>,
> {
MaybeWithTimeout {
inner: try_next(&mut self.0),
}
MaybeWithTimeout::new(try_next(&mut self.0))
}
async fn send(
&mut self,
message: impl Into<Message>,
) -> Result<(), Self::Error> {
) -> anyhow::Result<()> {
self
.0
.send(tungstenite::Message::Binary(message.into().into_inner()))
.await
.context("Failed to send message over websocket")
}
async fn close(
&mut self,
frame: Option<Self::CloseFrame>,
) -> Result<(), Self::Error> {
self.0.close(frame).await
) -> anyhow::Result<()> {
self
.0
.close(frame)
.await
.context("Failed to send websocket close frame")
}
}
pub type InnerWebsocketReceiver =
SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
pub struct TungsteniteWebsocketReceiver(pub InnerWebsocketReceiver);
pub struct TungsteniteWebsocketReceiver {
receiver: InnerWebsocketReceiver,
cancel: Option<CancellationToken>,
}
impl TungsteniteWebsocketReceiver {
pub fn new(receiver: InnerWebsocketReceiver) -> Self {
Self {
receiver,
cancel: None,
}
}
}
impl WebsocketReceiver for TungsteniteWebsocketReceiver {
type CloseFrame = CloseFrame;
type Error = tungstenite::Error;
fn set_cancel(&mut self, cancel: CancellationToken) {
self.cancel = Some(cancel);
}
async fn recv(
&mut self,
) -> Result<WebsocketMessage<Self::CloseFrame>, Self::Error> {
try_next(&mut self.0).await
) -> anyhow::Result<WebsocketMessage<Self::CloseFrame>> {
let fut = try_next(&mut self.receiver);
if let Some(cancel) = &self.cancel {
tokio::select! {
res = fut => res,
_ = cancel.cancelled() => Err(anyhow!("Cancelled before receive"))
}
} else {
fut.await
}
}
}
@@ -97,29 +119,30 @@ pub struct TungsteniteWebsocketSender(pub InnerWebsocketSender);
impl WebsocketSender for TungsteniteWebsocketSender {
type CloseFrame = CloseFrame;
type Error = tungstenite::Error;
async fn send(
&mut self,
message: Message,
) -> Result<(), Self::Error> {
async fn send(&mut self, message: Message) -> anyhow::Result<()> {
self
.0
.send(tungstenite::Message::Binary(message.into_inner()))
.await
.context("Failed to send message over websocket")
}
async fn close(
&mut self,
frame: Option<Self::CloseFrame>,
) -> Result<(), Self::Error> {
self.0.send(tungstenite::Message::Close(frame)).await
) -> anyhow::Result<()> {
self
.0
.send(tungstenite::Message::Close(frame))
.await
.context("Failed to send websocket close frame")
}
}
async fn try_next<S>(
stream: &mut S,
) -> Result<WebsocketMessage<CloseFrame>, tungstenite::Error>
) -> anyhow::Result<WebsocketMessage<CloseFrame>>
where
S: Stream<Item = Result<tungstenite::Message, tungstenite::Error>>
+ Unpin,