multiplex requests + terminal over single WS

This commit is contained in:
mbecker20
2025-09-16 17:47:58 -07:00
parent 39f900d651
commit 673c7f3a6b
23 changed files with 809 additions and 344 deletions

View File

@@ -8,8 +8,6 @@ repository.workspace = true
homepage.workspace = true
[dependencies]
cache.workspace = true
#
serror.workspace = true
#
tokio-tungstenite.workspace = true

View File

@@ -4,65 +4,93 @@ use uuid::Uuid;
use crate::MessageState;
/// Serializes channel id + data to byte vec.
/// The first 16 bytes are the Uuid, followed by the json serialized data bytes.
/// Serializes data + channel id + state to byte vec.
/// The last byte is the State, and the 16 before that is the Uuid.
pub fn to_transport_bytes(
mut data: Vec<u8>,
id: Uuid,
state: MessageState,
data: &[u8],
) -> Bytes {
// Index 0..15
let mut res = id.into_bytes().to_vec();
// Index 16
res.push(state.as_byte());
// Index 17..end
res.extend_from_slice(data);
Bytes::from(res)
data.extend(id.into_bytes());
data.push(state.as_byte());
data.into()
}
/// Deserializes channel id from
/// incoming transport bytes.
pub fn id_from_transport_bytes(bytes: &[u8]) -> anyhow::Result<Uuid> {
if bytes.len() < 16 {
return Err(anyhow!("Transport bytes too short to include uuid"));
let len = bytes.len();
if len < 17 {
return Err(anyhow!(
"Transport bytes too short to include uuid + state"
));
}
Uuid::from_slice(&bytes[..16]).context("Invalid Uuid bytes")
Uuid::from_slice(&bytes[(len - 17)..(len - 1)])
.context("Invalid Uuid bytes")
}
/// Deserializes channel id from
/// incoming transport bytes.
pub fn id_state_from_transport_bytes(
bytes: &[u8],
) -> anyhow::Result<(Uuid, MessageState)> {
let len = bytes.len();
if len < 17 {
return Err(anyhow!(
"Transport bytes too short to include uuid + state"
));
}
let uuid = Uuid::from_slice(&bytes[(len - 17)..(len - 1)])
.context("Invalid Uuid bytes")?;
let state = MessageState::from_byte(bytes[len - 1]);
Ok((uuid, state))
}
/// extracts data from incoming transport bytes,
/// consuming bytes in the process.
pub fn data_from_transport_bytes(
bytes: Bytes,
) -> anyhow::Result<Bytes> {
let len = bytes.len();
if len < 17 {
return Err(anyhow!(
"Transport bytes too short to include uuid + state + data"
));
}
let mut res: Vec<u8> = bytes.into();
res.drain((len - 17)..);
Ok(res.into())
}
/// Deserializes channel id + data from
/// incoming transport bytes.
pub fn from_transport_bytes(
bytes: &[u8],
) -> anyhow::Result<(Uuid, MessageState, Option<&[u8]>)> {
if bytes.len() < 17 {
return Err(anyhow!(
"Transport bytes too short to include uuid and state"
));
}
let (id, state, data) = (&bytes[..16], bytes[16], bytes.get(17..));
let id = Uuid::from_slice(id).context("Invalid Uuid bytes")?;
let state = MessageState::from_byte(state);
Ok((id, state, data))
bytes: Bytes,
) -> anyhow::Result<(Uuid, MessageState, Bytes)> {
let (id, state) = id_state_from_transport_bytes(&bytes)?;
let mut res: Vec<u8> = bytes.into();
res.drain((res.len() - 17)..);
Ok((id, state, res.into()))
}
impl MessageState {
/// - 0 => Successful
/// - 1 => Failed
/// - other => InProgress
pub fn from_byte(byte: u8) -> MessageState {
match byte {
0 => MessageState::Successful,
1 => MessageState::Failed,
0 => MessageState::Request,
1 => MessageState::Terminal,
2 => MessageState::Successful,
3 => MessageState::Failed,
_ => MessageState::InProgress,
}
}
pub fn as_byte(&self) -> u8 {
match self {
MessageState::Successful => 0,
MessageState::Failed => 1,
MessageState::InProgress => 2,
MessageState::Request => 0,
MessageState::Terminal => 1,
MessageState::Successful => 2,
MessageState::Failed => 3,
MessageState::InProgress => 4,
}
}
}

View File

@@ -8,136 +8,128 @@ use std::{
use anyhow::Context;
use bytes::Bytes;
use cache::CloneCache;
use futures_util::{SinkExt, StreamExt};
use rustls::{ClientConfig, client::danger::ServerCertVerifier};
use tokio::{
net::TcpStream,
sync::{RwLock, mpsc::Sender},
};
use tokio::{net::TcpStream, sync::RwLock};
use tokio_tungstenite::{
Connector, MaybeTlsStream, WebSocketStream, tungstenite::Message,
};
use tokio_util::sync::CancellationToken;
use tracing::{info, warn};
use uuid::Uuid;
use crate::{
bytes::id_from_transport_bytes, channel::BufferedReceiver,
};
use crate::{TransportHandler, channel::BufferedReceiver};
pub fn spawn_reconnecting_websocket(
address: String,
connection: Arc<ClientConnection>,
mut request_receiver: BufferedReceiver<Bytes>,
response_channels: Arc<CloneCache<Uuid, Sender<Bytes>>>,
/// Fixes server addresses:
/// server.domain => wss://server.domain
/// http://server.domain => ws://server.domain
/// https://server.domain => wss://server.domain
pub fn fix_ws_address(address: &str) -> String {
if address.starts_with("ws://") || address.starts_with("wss://") {
return address.to_string();
}
if address.starts_with("http://") {
return address.replace("http://", "ws://");
}
if address.starts_with("https://") {
return address.replace("https://", "wss://");
}
format!("wss://{address}")
}
pub async fn handle_reconnecting_websocket<
T: TransportHandler + Send + Sync + 'static,
>(
address: &str,
connection: &ClientConnection,
transport: &T,
write_receiver: &mut BufferedReceiver<Bytes>,
) {
tokio::spawn(async move {
loop {
let socket = match connect_websocket(&address).await {
Ok(socket) => socket,
Err(e) => {
connection.set_error(e).await;
tokio::time::sleep(Duration::from_secs(5)).await;
continue;
}
};
loop {
let socket = match connect_websocket(address).await {
Ok(socket) => socket,
Err(e) => {
connection.set_error(e).await;
tokio::time::sleep(Duration::from_secs(5)).await;
continue;
}
};
info!("Connected to {address}");
connection.connected.store(true, atomic::Ordering::Relaxed);
connection.clear_error().await;
info!("Connected to {address}");
connection.connected.store(true, atomic::Ordering::Relaxed);
connection.clear_error().await;
let (mut ws_write, mut ws_read) = socket.split();
let (mut ws_write, mut ws_read) = socket.split();
let forward_requests = async {
loop {
let next = tokio::select! {
next = request_receiver.recv() => next,
_ = connection.cancel.cancelled() => {
let _ = ws_write.close().await;
break;
}
};
let forward_writes = async {
loop {
let next = tokio::select! {
next = write_receiver.recv() => next,
_ = connection.cancel.cancelled() => break,
};
let message = match next {
None => {
info!("Got None over request reciever for {address}");
break;
}
Some(request) => {
Message::Binary(Bytes::copy_from_slice(request))
}
};
let message = match next {
None => {
info!("Got None over request reciever for {address}");
break;
}
Some(request) => {
Message::Binary(Bytes::copy_from_slice(request))
}
};
match ws_write.send(message).await {
Ok(_) => request_receiver.clear_buffer(),
Err(e) => {
warn!("Failed to send request to {address} | {e:#}");
break;
}
match ws_write.send(message).await {
Ok(_) => write_receiver.clear_buffer(),
Err(e) => {
warn!("Failed to send request to {address} | {e:#}");
break;
}
}
};
}
// Cancel again if not already
let _ = ws_write.close().await;
connection.cancel();
};
let read_responses = async {
loop {
let next = tokio::select! {
next = ws_read.next() => next,
_ = connection.cancel.cancelled() => break,
};
let handle_reads = async {
loop {
let next = tokio::select! {
next = ws_read.next() => next,
_ = connection.cancel.cancelled() => break,
};
let bytes = match next {
Some(Ok(Message::Binary(bytes))) => bytes,
Some(Ok(Message::Close(frame))) => {
warn!(
"Connection to {address} broken with frame: {frame:?}"
);
break;
}
Some(Err(e)) => {
warn!(
"Connection to {address} broken with error: {e:?}"
);
break;
}
None => {
warn!("Connection to {address} closed");
break;
}
// Can ignore other message types
Some(Ok(_)) => {
continue;
}
};
let id = match id_from_transport_bytes(&bytes) {
Ok(res) => res,
Err(e) => {
warn!("Failed to read id from {address} | {e:#}");
continue;
}
};
let Some(channel) = response_channels.get(&id).await else {
match next {
Some(Ok(Message::Binary(bytes))) => {
transport.handle_incoming_bytes(bytes).await
}
Some(Ok(Message::Close(frame))) => {
warn!(
"Failed to send response for {address} | No response channel found"
"Connection to {address} broken with frame: {frame:?}"
);
break;
}
Some(Err(e)) => {
warn!("Connection to {address} broken with error: {e:?}");
break;
}
None => {
warn!("Connection to {address} closed");
break;
}
// Can ignore other message types
Some(Ok(_)) => {
continue;
};
if let Err(e) = channel.send(bytes).await {
warn!(
"Failed to send response for {address} | Channel failure | {e:#}"
);
}
}
};
};
}
// Cancel again if not already
connection.cancel();
};
tokio::join!(forward_requests, read_responses);
tokio::join!(forward_writes, handle_reads);
warn!("Disconnnected from {address}");
connection.connected.store(false, atomic::Ordering::Relaxed);
}
});
warn!("Disconnnected from {address}");
connection.connected.store(false, atomic::Ordering::Relaxed);
}
}
#[derive(Debug)]

View File

@@ -1,3 +1,5 @@
use ::bytes::Bytes;
pub mod bytes;
pub mod channel;
pub mod client;
@@ -7,5 +9,14 @@ pub mod server;
pub enum MessageState {
Successful,
Failed,
Request,
InProgress,
Terminal,
}
pub trait TransportHandler {
fn handle_incoming_bytes(
&self,
bytes: Bytes,
) -> impl Future<Output = ()> + Send;
}

View File

@@ -9,15 +9,17 @@ use serror::AddStatusCode;
use tokio::sync::Mutex;
use tracing::{error, warn};
use crate::channel::BufferedReceiver;
use crate::{TransportHandler, channel::BufferedReceiver};
pub fn inbound_connection(
pub fn inbound_connection<
T: TransportHandler + Send + Sync + 'static,
>(
ws: WebSocketUpgrade,
handle_read: impl Fn(Bytes) + Send + Sync + 'static,
transport: T,
write_receiver: &'static Mutex<BufferedReceiver<Bytes>>,
) -> serror::Result<Response> {
// Limits to only one active websocket connection.
let mut response_receiver = write_receiver
let mut write_receiver = write_receiver
.try_lock()
.status_code(StatusCode::FORBIDDEN)?;
@@ -26,12 +28,34 @@ pub fn inbound_connection(
let (mut ws_write, mut ws_read) = socket.split();
// Handle incoming messages
let ws_read = async {
let forward_writes = async {
loop {
let msg = match write_receiver.recv().await {
// Sender Dropped (shouldn't happen, it is static).
None => break,
// This has to copy the bytes to follow ownership rules.
Some(msg) => Message::Binary(Bytes::copy_from_slice(msg)),
};
match ws_write.send(msg).await {
// Clears the stored message from receiver buffer.
// TODO: Move after response ack.
Ok(_) => write_receiver.clear_buffer(),
Err(e) => {
warn!("Failed to send response | {e:?}");
let _ = ws_write.close().await;
break;
}
}
}
};
let handle_reads = async {
loop {
match ws_read.next().await {
// Incoming core msg
Some(Ok(Message::Binary(msg))) => handle_read(msg),
Some(Ok(Message::Binary(bytes))) => {
transport.handle_incoming_bytes(bytes).await
}
// Disconnection cases.
Some(Ok(Message::Close(frame))) => {
warn!("Connection closed with frame: {frame:?}");
@@ -50,31 +74,9 @@ pub fn inbound_connection(
}
};
// Handle outgoing messages
let ws_write = async {
loop {
let msg = match response_receiver.recv().await {
// Sender Dropped (shouldn't happen, it is static).
None => break,
// This has to copy the bytes to follow ownership rules.
Some(msg) => Message::Binary(Bytes::copy_from_slice(msg)),
};
match ws_write.send(msg).await {
// Clears the stored message from receiver buffer.
// TODO: Move after response ack from Core.
Ok(_) => response_receiver.clear_buffer(),
Err(e) => {
warn!("Failed to send response | {e:?}");
let _ = ws_write.close().await;
break;
}
}
}
};
tokio::select! {
_ = ws_read => {},
_ = ws_write => {}
_ = forward_writes => {},
_ = handle_reads => {},
};
}))
}