use Ping timeout to handle reconnect if for some reason network cuts but ws doesn't receive Close

This commit is contained in:
mbecker20
2025-10-23 00:55:51 -07:00
parent 11fb67a35b
commit 5d271d5547
5 changed files with 124 additions and 81 deletions

View File

@@ -31,7 +31,7 @@ use transport::{
},
channel::{BufferedReceiver, Sender, buffered_channel},
websocket::{
Websocket, WebsocketMessage, WebsocketReceiver as _,
Websocket, WebsocketReceiver as _, WebsocketReceiverExt,
WebsocketSender as _,
},
};
@@ -367,8 +367,22 @@ impl PeripheryConnection {
let forward_writes = async {
loop {
let Ok(message) = receiver.recv().await else {
break;
let message = match tokio::time::timeout(
Duration::from_secs(5),
receiver.recv(),
)
.await
{
Ok(Ok(message)) => message,
Ok(Err(_)) => break,
// Handle sending Ping
Err(_) => {
if let Err(e) = ws_write.ping().await {
self.set_error(e).await;
break;
}
continue;
}
};
match ws_write.send(message.into_bytes()).await {
Ok(_) => receiver.clear_buffer(),
@@ -385,19 +399,13 @@ impl PeripheryConnection {
let handle_reads = async {
loop {
match ws_read.recv().await {
Ok(WebsocketMessage::Message(message)) => {
self.handle_incoming_message(message).await
}
Ok(WebsocketMessage::Close(_))
| Ok(WebsocketMessage::Closed) => {
self.set_error(anyhow!("Connection closed")).await;
break;
}
match ws_read.recv_message().await {
Ok(message) => self.handle_incoming_message(message).await,
Err(e) => {
self.set_error(e).await;
break;
}
};
}
}
// Cancel again if not already
cancel.cancel();
@@ -410,15 +418,8 @@ impl PeripheryConnection {
pub async fn handle_incoming_message(
&self,
message: EncodedTransportMessage,
message: TransportMessage,
) {
let message: TransportMessage = match message.decode() {
Ok(res) => res,
Err(e) => {
warn!("Failed to parse Message bytes | {e:#}");
return;
}
};
match message {
TransportMessage::Response(data) => {
match data.decode().map(ResponseMessage::into_inner) {

View File

@@ -81,11 +81,21 @@ async fn handle_socket<W: Websocket>(
let forward_writes = async {
loop {
let message = match receiver.recv().await {
Ok(message) => message,
Err(e) => {
warn!("{e:#}");
break;
let message = match tokio::time::timeout(
Duration::from_secs(5),
receiver.recv(),
)
.await
{
Ok(Ok(message)) => message,
Ok(Err(_)) => break,
// Handle sending Ping
Err(_) => {
if let Err(e) = ws_write.ping().await {
warn!("Failed to send ping | {e:?}");
break;
}
continue;
}
};
match ws_write.send(message.into_bytes()).await {
@@ -93,11 +103,11 @@ async fn handle_socket<W: Websocket>(
Ok(_) => receiver.clear_buffer(),
Err(e) => {
warn!("Failed to send response | {e:?}");
let _ = ws_write.close().await;
break;
}
}
}
let _ = ws_write.close().await;
};
let handle_reads = async {

View File

@@ -18,8 +18,6 @@ use super::{
pub struct AxumWebsocket(pub axum::extract::ws::WebSocket);
impl Websocket for AxumWebsocket {
type CloseFrame = CloseFrame;
fn split(self) -> (impl WebsocketSender, impl WebsocketReceiver) {
let (tx, rx) = self.0.split();
(AxumWebsocketSender(tx), AxumWebsocketReceiver::new(rx))
@@ -44,9 +42,7 @@ impl Websocket for AxumWebsocket {
fn recv_inner(
&mut self,
) -> MaybeWithTimeout<
impl Future<
Output = anyhow::Result<WebsocketMessage<Self::CloseFrame>>,
>,
impl Future<Output = anyhow::Result<WebsocketMessage>>,
> {
MaybeWithTimeout::new(try_next(&mut self.0))
}
@@ -58,6 +54,14 @@ pub type InnerWebsocketSender =
pub struct AxumWebsocketSender(pub InnerWebsocketSender);
impl WebsocketSender for AxumWebsocketSender {
async fn ping(&mut self) -> anyhow::Result<()> {
self
.0
.send(axum::extract::ws::Message::Ping(Bytes::new()))
.await
.context("Failed to send ping over websocket")
}
async fn send(&mut self, bytes: Bytes) -> anyhow::Result<()> {
self
.0
@@ -77,7 +81,7 @@ impl WebsocketSender for AxumWebsocketSender {
async fn try_next<S>(
stream: &mut S,
) -> anyhow::Result<WebsocketMessage<CloseFrame>>
) -> anyhow::Result<WebsocketMessage>
where
S: Stream<Item = Result<axum::extract::ws::Message, axum::Error>>
+ Unpin,
@@ -95,13 +99,15 @@ where
EncodedTransportMessage::from_vec(bytes.into()),
));
}
Some(axum::extract::ws::Message::Close(frame)) => {
return Ok(WebsocketMessage::Close(frame));
Some(axum::extract::ws::Message::Ping(_)) => {
return Ok(WebsocketMessage::Ping);
}
Some(axum::extract::ws::Message::Close(_)) => {
return Ok(WebsocketMessage::Close);
}
None => return Ok(WebsocketMessage::Closed),
// Ignored messages
Some(axum::extract::ws::Message::Ping(_))
| Some(axum::extract::ws::Message::Pong(_)) => continue,
// Ignored
Some(axum::extract::ws::Message::Pong(_)) => continue,
}
}
}
@@ -130,9 +136,7 @@ impl WebsocketReceiver for AxumWebsocketReceiver {
self.cancel = Some(cancel);
}
async fn recv(
&mut self,
) -> anyhow::Result<WebsocketMessage<Self::CloseFrame>> {
async fn recv(&mut self) -> anyhow::Result<WebsocketMessage> {
let fut = try_next(&mut self.receiver);
if let Some(cancel) = &self.cancel {
tokio::select! {

View File

@@ -1,5 +1,7 @@
//! Wrappers to normalize behavior of websockets between Tungstenite and Axum
use std::time::Duration;
use anyhow::{Context, anyhow};
use bytes::Bytes;
use encoding::{
@@ -22,19 +24,20 @@ pub mod tungstenite;
/// Flattened websocket message possibilites
/// for easier handling.
pub enum WebsocketMessage<CloseFrame> {
pub enum WebsocketMessage {
/// Standard message
Message(EncodedTransportMessage),
/// Core / Periphery must receive every 10s
/// or reconnect triggered.
Ping,
/// Graceful close message
Close(Option<CloseFrame>),
Close,
/// Stream closed
Closed,
}
/// Standard traits for websocket
pub trait Websocket: Send {
type CloseFrame: std::fmt::Debug + Send + Sync + 'static;
/// Abstraction over websocket splitting
fn split(self) -> (impl WebsocketSender, impl WebsocketReceiver);
@@ -53,9 +56,7 @@ pub trait Websocket: Send {
fn recv_inner(
&mut self,
) -> MaybeWithTimeout<
impl Future<
Output = anyhow::Result<WebsocketMessage<Self::CloseFrame>>,
> + Send,
impl Future<Output = anyhow::Result<WebsocketMessage>> + Send,
>;
}
@@ -68,19 +69,31 @@ pub trait WebsocketExt: Websocket {
}
/// Looping receiver for websocket messages which only returns on TransportMessage.
/// Also ensures either Messages or Pings are received at least every 10s.
fn recv_message(
&mut self,
) -> MaybeWithTimeout<
impl Future<Output = anyhow::Result<TransportMessage>> + Send,
> {
MaybeWithTimeout::new(async {
match self.recv_inner().await? {
WebsocketMessage::Message(message) => message.decode(),
WebsocketMessage::Close(frame) => {
Err(anyhow!("Connection closed with framed: {frame:?}"))
}
WebsocketMessage::Closed => {
Err(anyhow!("Connection already closed"))
loop {
match tokio::time::timeout(
Duration::from_secs(10),
self.recv_inner(),
)
.await
.context("Timed out waiting for Ping")??
{
WebsocketMessage::Message(message) => {
return message.decode();
}
WebsocketMessage::Ping => continue,
WebsocketMessage::Close => {
return Err(anyhow!("Connection closed"));
}
WebsocketMessage::Closed => {
return Err(anyhow!("Connection already closed"));
}
}
}
})
@@ -91,6 +104,11 @@ impl<W: Websocket> WebsocketExt for W {}
/// Traits for split websocket receiver
pub trait WebsocketSender {
/// Streamlined pinging
fn ping(
&mut self,
) -> impl Future<Output = anyhow::Result<()>> + Send;
/// Streamlined sending on bytes
fn send(
&mut self,
@@ -165,30 +183,36 @@ pub trait WebsocketReceiver: Send {
/// on significant messages. Must implement cancel support.
fn recv(
&mut self,
) -> impl Future<
Output = anyhow::Result<WebsocketMessage<Self::CloseFrame>>,
> + Send;
) -> impl Future<Output = anyhow::Result<WebsocketMessage>> + Send;
}
pub trait WebsocketReceiverExt: WebsocketReceiver {
/// Looping receiver for websocket messages which only returns on TransportMessage.
/// Also ensures either Messages or Pings are received at least every 10s.
fn recv_message(
&mut self,
) -> MaybeWithTimeout<
impl Future<Output = anyhow::Result<TransportMessage>> + Send,
> {
MaybeWithTimeout::new(async {
match self
.recv()
loop {
match tokio::time::timeout(
Duration::from_secs(10),
self.recv(),
)
.await
.context("Failed to read websocket message")?
{
WebsocketMessage::Message(message) => message.decode(),
WebsocketMessage::Close(frame) => {
Err(anyhow!("Connection closed with framed: {frame:?}"))
}
WebsocketMessage::Closed => {
Err(anyhow!("Connection already closed"))
.context("Timed out waiting for Ping")??
{
WebsocketMessage::Message(message) => {
return message.decode();
}
WebsocketMessage::Ping => continue,
WebsocketMessage::Close => {
return Err(anyhow!("Connection closed"));
}
WebsocketMessage::Closed => {
return Err(anyhow!("Connection already closed"));
}
}
}
})

View File

@@ -31,8 +31,6 @@ pub type InnerWebsocket = WebSocketStream<MaybeTlsStream<TcpStream>>;
pub struct TungsteniteWebsocket(pub InnerWebsocket);
impl Websocket for TungsteniteWebsocket {
type CloseFrame = CloseFrame;
fn split(self) -> (impl WebsocketSender, impl WebsocketReceiver) {
let (tx, rx) = self.0.split();
(
@@ -44,9 +42,7 @@ impl Websocket for TungsteniteWebsocket {
fn recv_inner(
&mut self,
) -> MaybeWithTimeout<
impl Future<
Output = anyhow::Result<WebsocketMessage<Self::CloseFrame>>,
>,
impl Future<Output = anyhow::Result<WebsocketMessage>>,
> {
MaybeWithTimeout::new(try_next(&mut self.0))
}
@@ -76,6 +72,14 @@ pub type InnerWebsocketSender = SplitSink<
pub struct TungsteniteWebsocketSender(pub InnerWebsocketSender);
impl WebsocketSender for TungsteniteWebsocketSender {
async fn ping(&mut self) -> anyhow::Result<()> {
self
.0
.send(tungstenite::Message::Ping(Bytes::new()))
.await
.context("Failed to send ping over websocket")
}
async fn send(&mut self, bytes: Bytes) -> anyhow::Result<()> {
self
.0
@@ -95,7 +99,7 @@ impl WebsocketSender for TungsteniteWebsocketSender {
async fn try_next<S>(
stream: &mut S,
) -> anyhow::Result<WebsocketMessage<CloseFrame>>
) -> anyhow::Result<WebsocketMessage>
where
S: Stream<Item = Result<tungstenite::Message, tungstenite::Error>>
+ Unpin,
@@ -113,13 +117,15 @@ where
EncodedTransportMessage::from_vec(bytes.into()),
));
}
Some(tungstenite::Message::Close(frame)) => {
return Ok(WebsocketMessage::Close(frame));
Some(tungstenite::Message::Ping(_)) => {
return Ok(WebsocketMessage::Ping);
}
Some(tungstenite::Message::Close(_)) => {
return Ok(WebsocketMessage::Close);
}
None => return Ok(WebsocketMessage::Closed),
// Ignored messages
Some(tungstenite::Message::Ping(_))
| Some(tungstenite::Message::Pong(_))
Some(tungstenite::Message::Pong(_))
| Some(tungstenite::Message::Frame(_)) => continue,
}
}
@@ -149,9 +155,7 @@ impl WebsocketReceiver for TungsteniteWebsocketReceiver {
self.cancel = Some(cancel);
}
async fn recv(
&mut self,
) -> anyhow::Result<WebsocketMessage<Self::CloseFrame>> {
async fn recv(&mut self) -> anyhow::Result<WebsocketMessage> {
let fut = try_next(&mut self.receiver);
if let Some(cancel) = &self.cancel {
tokio::select! {