From 5d271d5547a11fe0a085b3d4948b79f4738de389 Mon Sep 17 00:00:00 2001 From: mbecker20 Date: Thu, 23 Oct 2025 00:55:51 -0700 Subject: [PATCH] use Ping timeout to handle reconnect if for some reason network cuts but ws doesn't receive Close --- bin/core/src/connection/mod.rs | 43 ++++++------ bin/periphery/src/connection/mod.rs | 22 ++++-- lib/transport/src/websocket/axum.rs | 32 +++++---- lib/transport/src/websocket/mod.rs | 78 ++++++++++++++-------- lib/transport/src/websocket/tungstenite.rs | 30 +++++---- 5 files changed, 124 insertions(+), 81 deletions(-) diff --git a/bin/core/src/connection/mod.rs b/bin/core/src/connection/mod.rs index 844205971..1e8a26795 100644 --- a/bin/core/src/connection/mod.rs +++ b/bin/core/src/connection/mod.rs @@ -31,7 +31,7 @@ use transport::{ }, channel::{BufferedReceiver, Sender, buffered_channel}, websocket::{ - Websocket, WebsocketMessage, WebsocketReceiver as _, + Websocket, WebsocketReceiver as _, WebsocketReceiverExt, WebsocketSender as _, }, }; @@ -367,8 +367,22 @@ impl PeripheryConnection { let forward_writes = async { loop { - let Ok(message) = receiver.recv().await else { - break; + let message = match tokio::time::timeout( + Duration::from_secs(5), + receiver.recv(), + ) + .await + { + Ok(Ok(message)) => message, + Ok(Err(_)) => break, + // Handle sending Ping + Err(_) => { + if let Err(e) = ws_write.ping().await { + self.set_error(e).await; + break; + } + continue; + } }; match ws_write.send(message.into_bytes()).await { Ok(_) => receiver.clear_buffer(), @@ -385,19 +399,13 @@ impl PeripheryConnection { let handle_reads = async { loop { - match ws_read.recv().await { - Ok(WebsocketMessage::Message(message)) => { - self.handle_incoming_message(message).await - } - Ok(WebsocketMessage::Close(_)) - | Ok(WebsocketMessage::Closed) => { - self.set_error(anyhow!("Connection closed")).await; - break; - } + match ws_read.recv_message().await { + Ok(message) => self.handle_incoming_message(message).await, Err(e) => { self.set_error(e).await; + break; } - }; + } } // Cancel again if not already cancel.cancel(); @@ -410,15 +418,8 @@ impl PeripheryConnection { pub async fn handle_incoming_message( &self, - message: EncodedTransportMessage, + message: TransportMessage, ) { - let message: TransportMessage = match message.decode() { - Ok(res) => res, - Err(e) => { - warn!("Failed to parse Message bytes | {e:#}"); - return; - } - }; match message { TransportMessage::Response(data) => { match data.decode().map(ResponseMessage::into_inner) { diff --git a/bin/periphery/src/connection/mod.rs b/bin/periphery/src/connection/mod.rs index 929bb217e..f1931ecc4 100644 --- a/bin/periphery/src/connection/mod.rs +++ b/bin/periphery/src/connection/mod.rs @@ -81,11 +81,21 @@ async fn handle_socket( let forward_writes = async { loop { - let message = match receiver.recv().await { - Ok(message) => message, - Err(e) => { - warn!("{e:#}"); - break; + let message = match tokio::time::timeout( + Duration::from_secs(5), + receiver.recv(), + ) + .await + { + Ok(Ok(message)) => message, + Ok(Err(_)) => break, + // Handle sending Ping + Err(_) => { + if let Err(e) = ws_write.ping().await { + warn!("Failed to send ping | {e:?}"); + break; + } + continue; } }; match ws_write.send(message.into_bytes()).await { @@ -93,11 +103,11 @@ async fn handle_socket( Ok(_) => receiver.clear_buffer(), Err(e) => { warn!("Failed to send response | {e:?}"); - let _ = ws_write.close().await; break; } } } + let _ = ws_write.close().await; }; let handle_reads = async { diff --git a/lib/transport/src/websocket/axum.rs b/lib/transport/src/websocket/axum.rs index 7fa3c43de..c52a057a8 100644 --- a/lib/transport/src/websocket/axum.rs +++ b/lib/transport/src/websocket/axum.rs @@ -18,8 +18,6 @@ use super::{ pub struct AxumWebsocket(pub axum::extract::ws::WebSocket); impl Websocket for AxumWebsocket { - type CloseFrame = CloseFrame; - fn split(self) -> (impl WebsocketSender, impl WebsocketReceiver) { let (tx, rx) = self.0.split(); (AxumWebsocketSender(tx), AxumWebsocketReceiver::new(rx)) @@ -44,9 +42,7 @@ impl Websocket for AxumWebsocket { fn recv_inner( &mut self, ) -> MaybeWithTimeout< - impl Future< - Output = anyhow::Result>, - >, + impl Future>, > { MaybeWithTimeout::new(try_next(&mut self.0)) } @@ -58,6 +54,14 @@ pub type InnerWebsocketSender = pub struct AxumWebsocketSender(pub InnerWebsocketSender); impl WebsocketSender for AxumWebsocketSender { + async fn ping(&mut self) -> anyhow::Result<()> { + self + .0 + .send(axum::extract::ws::Message::Ping(Bytes::new())) + .await + .context("Failed to send ping over websocket") + } + async fn send(&mut self, bytes: Bytes) -> anyhow::Result<()> { self .0 @@ -77,7 +81,7 @@ impl WebsocketSender for AxumWebsocketSender { async fn try_next( stream: &mut S, -) -> anyhow::Result> +) -> anyhow::Result where S: Stream> + Unpin, @@ -95,13 +99,15 @@ where EncodedTransportMessage::from_vec(bytes.into()), )); } - Some(axum::extract::ws::Message::Close(frame)) => { - return Ok(WebsocketMessage::Close(frame)); + Some(axum::extract::ws::Message::Ping(_)) => { + return Ok(WebsocketMessage::Ping); + } + Some(axum::extract::ws::Message::Close(_)) => { + return Ok(WebsocketMessage::Close); } None => return Ok(WebsocketMessage::Closed), - // Ignored messages - Some(axum::extract::ws::Message::Ping(_)) - | Some(axum::extract::ws::Message::Pong(_)) => continue, + // Ignored + Some(axum::extract::ws::Message::Pong(_)) => continue, } } } @@ -130,9 +136,7 @@ impl WebsocketReceiver for AxumWebsocketReceiver { self.cancel = Some(cancel); } - async fn recv( - &mut self, - ) -> anyhow::Result> { + async fn recv(&mut self) -> anyhow::Result { let fut = try_next(&mut self.receiver); if let Some(cancel) = &self.cancel { tokio::select! { diff --git a/lib/transport/src/websocket/mod.rs b/lib/transport/src/websocket/mod.rs index 79c7ea7ff..6a46dffb9 100644 --- a/lib/transport/src/websocket/mod.rs +++ b/lib/transport/src/websocket/mod.rs @@ -1,5 +1,7 @@ //! Wrappers to normalize behavior of websockets between Tungstenite and Axum +use std::time::Duration; + use anyhow::{Context, anyhow}; use bytes::Bytes; use encoding::{ @@ -22,19 +24,20 @@ pub mod tungstenite; /// Flattened websocket message possibilites /// for easier handling. -pub enum WebsocketMessage { +pub enum WebsocketMessage { /// Standard message Message(EncodedTransportMessage), + /// Core / Periphery must receive every 10s + /// or reconnect triggered. + Ping, /// Graceful close message - Close(Option), + Close, /// Stream closed Closed, } /// Standard traits for websocket pub trait Websocket: Send { - type CloseFrame: std::fmt::Debug + Send + Sync + 'static; - /// Abstraction over websocket splitting fn split(self) -> (impl WebsocketSender, impl WebsocketReceiver); @@ -53,9 +56,7 @@ pub trait Websocket: Send { fn recv_inner( &mut self, ) -> MaybeWithTimeout< - impl Future< - Output = anyhow::Result>, - > + Send, + impl Future> + Send, >; } @@ -68,19 +69,31 @@ pub trait WebsocketExt: Websocket { } /// Looping receiver for websocket messages which only returns on TransportMessage. + /// Also ensures either Messages or Pings are received at least every 10s. fn recv_message( &mut self, ) -> MaybeWithTimeout< impl Future> + Send, > { MaybeWithTimeout::new(async { - match self.recv_inner().await? { - WebsocketMessage::Message(message) => message.decode(), - WebsocketMessage::Close(frame) => { - Err(anyhow!("Connection closed with framed: {frame:?}")) - } - WebsocketMessage::Closed => { - Err(anyhow!("Connection already closed")) + loop { + match tokio::time::timeout( + Duration::from_secs(10), + self.recv_inner(), + ) + .await + .context("Timed out waiting for Ping")?? + { + WebsocketMessage::Message(message) => { + return message.decode(); + } + WebsocketMessage::Ping => continue, + WebsocketMessage::Close => { + return Err(anyhow!("Connection closed")); + } + WebsocketMessage::Closed => { + return Err(anyhow!("Connection already closed")); + } } } }) @@ -91,6 +104,11 @@ impl WebsocketExt for W {} /// Traits for split websocket receiver pub trait WebsocketSender { + /// Streamlined pinging + fn ping( + &mut self, + ) -> impl Future> + Send; + /// Streamlined sending on bytes fn send( &mut self, @@ -165,30 +183,36 @@ pub trait WebsocketReceiver: Send { /// on significant messages. Must implement cancel support. fn recv( &mut self, - ) -> impl Future< - Output = anyhow::Result>, - > + Send; + ) -> impl Future> + Send; } pub trait WebsocketReceiverExt: WebsocketReceiver { /// Looping receiver for websocket messages which only returns on TransportMessage. + /// Also ensures either Messages or Pings are received at least every 10s. fn recv_message( &mut self, ) -> MaybeWithTimeout< impl Future> + Send, > { MaybeWithTimeout::new(async { - match self - .recv() + loop { + match tokio::time::timeout( + Duration::from_secs(10), + self.recv(), + ) .await - .context("Failed to read websocket message")? - { - WebsocketMessage::Message(message) => message.decode(), - WebsocketMessage::Close(frame) => { - Err(anyhow!("Connection closed with framed: {frame:?}")) - } - WebsocketMessage::Closed => { - Err(anyhow!("Connection already closed")) + .context("Timed out waiting for Ping")?? + { + WebsocketMessage::Message(message) => { + return message.decode(); + } + WebsocketMessage::Ping => continue, + WebsocketMessage::Close => { + return Err(anyhow!("Connection closed")); + } + WebsocketMessage::Closed => { + return Err(anyhow!("Connection already closed")); + } } } }) diff --git a/lib/transport/src/websocket/tungstenite.rs b/lib/transport/src/websocket/tungstenite.rs index caace9daf..0cb4372f7 100644 --- a/lib/transport/src/websocket/tungstenite.rs +++ b/lib/transport/src/websocket/tungstenite.rs @@ -31,8 +31,6 @@ pub type InnerWebsocket = WebSocketStream>; pub struct TungsteniteWebsocket(pub InnerWebsocket); impl Websocket for TungsteniteWebsocket { - type CloseFrame = CloseFrame; - fn split(self) -> (impl WebsocketSender, impl WebsocketReceiver) { let (tx, rx) = self.0.split(); ( @@ -44,9 +42,7 @@ impl Websocket for TungsteniteWebsocket { fn recv_inner( &mut self, ) -> MaybeWithTimeout< - impl Future< - Output = anyhow::Result>, - >, + impl Future>, > { MaybeWithTimeout::new(try_next(&mut self.0)) } @@ -76,6 +72,14 @@ pub type InnerWebsocketSender = SplitSink< pub struct TungsteniteWebsocketSender(pub InnerWebsocketSender); impl WebsocketSender for TungsteniteWebsocketSender { + async fn ping(&mut self) -> anyhow::Result<()> { + self + .0 + .send(tungstenite::Message::Ping(Bytes::new())) + .await + .context("Failed to send ping over websocket") + } + async fn send(&mut self, bytes: Bytes) -> anyhow::Result<()> { self .0 @@ -95,7 +99,7 @@ impl WebsocketSender for TungsteniteWebsocketSender { async fn try_next( stream: &mut S, -) -> anyhow::Result> +) -> anyhow::Result where S: Stream> + Unpin, @@ -113,13 +117,15 @@ where EncodedTransportMessage::from_vec(bytes.into()), )); } - Some(tungstenite::Message::Close(frame)) => { - return Ok(WebsocketMessage::Close(frame)); + Some(tungstenite::Message::Ping(_)) => { + return Ok(WebsocketMessage::Ping); + } + Some(tungstenite::Message::Close(_)) => { + return Ok(WebsocketMessage::Close); } None => return Ok(WebsocketMessage::Closed), // Ignored messages - Some(tungstenite::Message::Ping(_)) - | Some(tungstenite::Message::Pong(_)) + Some(tungstenite::Message::Pong(_)) | Some(tungstenite::Message::Frame(_)) => continue, } } @@ -149,9 +155,7 @@ impl WebsocketReceiver for TungsteniteWebsocketReceiver { self.cancel = Some(cancel); } - async fn recv( - &mut self, - ) -> anyhow::Result> { + async fn recv(&mut self) -> anyhow::Result { let fut = try_next(&mut self.receiver); if let Some(cancel) = &self.cancel { tokio::select! {