forked from github-starred/komodo
multiplex requests + terminal over single WS
This commit is contained in:
@@ -8,8 +8,6 @@ repository.workspace = true
|
||||
homepage.workspace = true
|
||||
|
||||
[dependencies]
|
||||
cache.workspace = true
|
||||
#
|
||||
serror.workspace = true
|
||||
#
|
||||
tokio-tungstenite.workspace = true
|
||||
|
||||
@@ -4,65 +4,93 @@ use uuid::Uuid;
|
||||
|
||||
use crate::MessageState;
|
||||
|
||||
/// Serializes channel id + data to byte vec.
|
||||
/// The first 16 bytes are the Uuid, followed by the json serialized data bytes.
|
||||
/// Serializes data + channel id + state to byte vec.
|
||||
/// The last byte is the State, and the 16 before that is the Uuid.
|
||||
pub fn to_transport_bytes(
|
||||
mut data: Vec<u8>,
|
||||
id: Uuid,
|
||||
state: MessageState,
|
||||
data: &[u8],
|
||||
) -> Bytes {
|
||||
// Index 0..15
|
||||
let mut res = id.into_bytes().to_vec();
|
||||
// Index 16
|
||||
res.push(state.as_byte());
|
||||
// Index 17..end
|
||||
res.extend_from_slice(data);
|
||||
|
||||
Bytes::from(res)
|
||||
data.extend(id.into_bytes());
|
||||
data.push(state.as_byte());
|
||||
data.into()
|
||||
}
|
||||
|
||||
/// Deserializes channel id from
|
||||
/// incoming transport bytes.
|
||||
pub fn id_from_transport_bytes(bytes: &[u8]) -> anyhow::Result<Uuid> {
|
||||
if bytes.len() < 16 {
|
||||
return Err(anyhow!("Transport bytes too short to include uuid"));
|
||||
let len = bytes.len();
|
||||
if len < 17 {
|
||||
return Err(anyhow!(
|
||||
"Transport bytes too short to include uuid + state"
|
||||
));
|
||||
}
|
||||
Uuid::from_slice(&bytes[..16]).context("Invalid Uuid bytes")
|
||||
Uuid::from_slice(&bytes[(len - 17)..(len - 1)])
|
||||
.context("Invalid Uuid bytes")
|
||||
}
|
||||
|
||||
/// Deserializes channel id from
|
||||
/// incoming transport bytes.
|
||||
pub fn id_state_from_transport_bytes(
|
||||
bytes: &[u8],
|
||||
) -> anyhow::Result<(Uuid, MessageState)> {
|
||||
let len = bytes.len();
|
||||
if len < 17 {
|
||||
return Err(anyhow!(
|
||||
"Transport bytes too short to include uuid + state"
|
||||
));
|
||||
}
|
||||
let uuid = Uuid::from_slice(&bytes[(len - 17)..(len - 1)])
|
||||
.context("Invalid Uuid bytes")?;
|
||||
let state = MessageState::from_byte(bytes[len - 1]);
|
||||
Ok((uuid, state))
|
||||
}
|
||||
|
||||
/// extracts data from incoming transport bytes,
|
||||
/// consuming bytes in the process.
|
||||
pub fn data_from_transport_bytes(
|
||||
bytes: Bytes,
|
||||
) -> anyhow::Result<Bytes> {
|
||||
let len = bytes.len();
|
||||
if len < 17 {
|
||||
return Err(anyhow!(
|
||||
"Transport bytes too short to include uuid + state + data"
|
||||
));
|
||||
}
|
||||
let mut res: Vec<u8> = bytes.into();
|
||||
res.drain((len - 17)..);
|
||||
Ok(res.into())
|
||||
}
|
||||
|
||||
/// Deserializes channel id + data from
|
||||
/// incoming transport bytes.
|
||||
pub fn from_transport_bytes(
|
||||
bytes: &[u8],
|
||||
) -> anyhow::Result<(Uuid, MessageState, Option<&[u8]>)> {
|
||||
if bytes.len() < 17 {
|
||||
return Err(anyhow!(
|
||||
"Transport bytes too short to include uuid and state"
|
||||
));
|
||||
}
|
||||
let (id, state, data) = (&bytes[..16], bytes[16], bytes.get(17..));
|
||||
let id = Uuid::from_slice(id).context("Invalid Uuid bytes")?;
|
||||
let state = MessageState::from_byte(state);
|
||||
Ok((id, state, data))
|
||||
bytes: Bytes,
|
||||
) -> anyhow::Result<(Uuid, MessageState, Bytes)> {
|
||||
let (id, state) = id_state_from_transport_bytes(&bytes)?;
|
||||
let mut res: Vec<u8> = bytes.into();
|
||||
res.drain((res.len() - 17)..);
|
||||
Ok((id, state, res.into()))
|
||||
}
|
||||
|
||||
impl MessageState {
|
||||
/// - 0 => Successful
|
||||
/// - 1 => Failed
|
||||
/// - other => InProgress
|
||||
pub fn from_byte(byte: u8) -> MessageState {
|
||||
match byte {
|
||||
0 => MessageState::Successful,
|
||||
1 => MessageState::Failed,
|
||||
0 => MessageState::Request,
|
||||
1 => MessageState::Terminal,
|
||||
2 => MessageState::Successful,
|
||||
3 => MessageState::Failed,
|
||||
_ => MessageState::InProgress,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_byte(&self) -> u8 {
|
||||
match self {
|
||||
MessageState::Successful => 0,
|
||||
MessageState::Failed => 1,
|
||||
MessageState::InProgress => 2,
|
||||
MessageState::Request => 0,
|
||||
MessageState::Terminal => 1,
|
||||
MessageState::Successful => 2,
|
||||
MessageState::Failed => 3,
|
||||
MessageState::InProgress => 4,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,136 +8,128 @@ use std::{
|
||||
|
||||
use anyhow::Context;
|
||||
use bytes::Bytes;
|
||||
use cache::CloneCache;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use rustls::{ClientConfig, client::danger::ServerCertVerifier};
|
||||
use tokio::{
|
||||
net::TcpStream,
|
||||
sync::{RwLock, mpsc::Sender},
|
||||
};
|
||||
use tokio::{net::TcpStream, sync::RwLock};
|
||||
use tokio_tungstenite::{
|
||||
Connector, MaybeTlsStream, WebSocketStream, tungstenite::Message,
|
||||
};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{info, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
bytes::id_from_transport_bytes, channel::BufferedReceiver,
|
||||
};
|
||||
use crate::{TransportHandler, channel::BufferedReceiver};
|
||||
|
||||
pub fn spawn_reconnecting_websocket(
|
||||
address: String,
|
||||
connection: Arc<ClientConnection>,
|
||||
mut request_receiver: BufferedReceiver<Bytes>,
|
||||
response_channels: Arc<CloneCache<Uuid, Sender<Bytes>>>,
|
||||
/// Fixes server addresses:
|
||||
/// server.domain => wss://server.domain
|
||||
/// http://server.domain => ws://server.domain
|
||||
/// https://server.domain => wss://server.domain
|
||||
pub fn fix_ws_address(address: &str) -> String {
|
||||
if address.starts_with("ws://") || address.starts_with("wss://") {
|
||||
return address.to_string();
|
||||
}
|
||||
if address.starts_with("http://") {
|
||||
return address.replace("http://", "ws://");
|
||||
}
|
||||
if address.starts_with("https://") {
|
||||
return address.replace("https://", "wss://");
|
||||
}
|
||||
format!("wss://{address}")
|
||||
}
|
||||
|
||||
pub async fn handle_reconnecting_websocket<
|
||||
T: TransportHandler + Send + Sync + 'static,
|
||||
>(
|
||||
address: &str,
|
||||
connection: &ClientConnection,
|
||||
transport: &T,
|
||||
write_receiver: &mut BufferedReceiver<Bytes>,
|
||||
) {
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let socket = match connect_websocket(&address).await {
|
||||
Ok(socket) => socket,
|
||||
Err(e) => {
|
||||
connection.set_error(e).await;
|
||||
tokio::time::sleep(Duration::from_secs(5)).await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
loop {
|
||||
let socket = match connect_websocket(address).await {
|
||||
Ok(socket) => socket,
|
||||
Err(e) => {
|
||||
connection.set_error(e).await;
|
||||
tokio::time::sleep(Duration::from_secs(5)).await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
info!("Connected to {address}");
|
||||
connection.connected.store(true, atomic::Ordering::Relaxed);
|
||||
connection.clear_error().await;
|
||||
info!("Connected to {address}");
|
||||
connection.connected.store(true, atomic::Ordering::Relaxed);
|
||||
connection.clear_error().await;
|
||||
|
||||
let (mut ws_write, mut ws_read) = socket.split();
|
||||
let (mut ws_write, mut ws_read) = socket.split();
|
||||
|
||||
let forward_requests = async {
|
||||
loop {
|
||||
let next = tokio::select! {
|
||||
next = request_receiver.recv() => next,
|
||||
_ = connection.cancel.cancelled() => {
|
||||
let _ = ws_write.close().await;
|
||||
break;
|
||||
}
|
||||
};
|
||||
let forward_writes = async {
|
||||
loop {
|
||||
let next = tokio::select! {
|
||||
next = write_receiver.recv() => next,
|
||||
_ = connection.cancel.cancelled() => break,
|
||||
};
|
||||
|
||||
let message = match next {
|
||||
None => {
|
||||
info!("Got None over request reciever for {address}");
|
||||
break;
|
||||
}
|
||||
Some(request) => {
|
||||
Message::Binary(Bytes::copy_from_slice(request))
|
||||
}
|
||||
};
|
||||
let message = match next {
|
||||
None => {
|
||||
info!("Got None over request reciever for {address}");
|
||||
break;
|
||||
}
|
||||
Some(request) => {
|
||||
Message::Binary(Bytes::copy_from_slice(request))
|
||||
}
|
||||
};
|
||||
|
||||
match ws_write.send(message).await {
|
||||
Ok(_) => request_receiver.clear_buffer(),
|
||||
Err(e) => {
|
||||
warn!("Failed to send request to {address} | {e:#}");
|
||||
break;
|
||||
}
|
||||
match ws_write.send(message).await {
|
||||
Ok(_) => write_receiver.clear_buffer(),
|
||||
Err(e) => {
|
||||
warn!("Failed to send request to {address} | {e:#}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
// Cancel again if not already
|
||||
let _ = ws_write.close().await;
|
||||
connection.cancel();
|
||||
};
|
||||
|
||||
let read_responses = async {
|
||||
loop {
|
||||
let next = tokio::select! {
|
||||
next = ws_read.next() => next,
|
||||
_ = connection.cancel.cancelled() => break,
|
||||
};
|
||||
let handle_reads = async {
|
||||
loop {
|
||||
let next = tokio::select! {
|
||||
next = ws_read.next() => next,
|
||||
_ = connection.cancel.cancelled() => break,
|
||||
};
|
||||
|
||||
let bytes = match next {
|
||||
Some(Ok(Message::Binary(bytes))) => bytes,
|
||||
Some(Ok(Message::Close(frame))) => {
|
||||
warn!(
|
||||
"Connection to {address} broken with frame: {frame:?}"
|
||||
);
|
||||
break;
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
warn!(
|
||||
"Connection to {address} broken with error: {e:?}"
|
||||
);
|
||||
break;
|
||||
}
|
||||
None => {
|
||||
warn!("Connection to {address} closed");
|
||||
break;
|
||||
}
|
||||
// Can ignore other message types
|
||||
Some(Ok(_)) => {
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let id = match id_from_transport_bytes(&bytes) {
|
||||
Ok(res) => res,
|
||||
Err(e) => {
|
||||
warn!("Failed to read id from {address} | {e:#}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let Some(channel) = response_channels.get(&id).await else {
|
||||
match next {
|
||||
Some(Ok(Message::Binary(bytes))) => {
|
||||
transport.handle_incoming_bytes(bytes).await
|
||||
}
|
||||
Some(Ok(Message::Close(frame))) => {
|
||||
warn!(
|
||||
"Failed to send response for {address} | No response channel found"
|
||||
"Connection to {address} broken with frame: {frame:?}"
|
||||
);
|
||||
break;
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
warn!("Connection to {address} broken with error: {e:?}");
|
||||
break;
|
||||
}
|
||||
None => {
|
||||
warn!("Connection to {address} closed");
|
||||
break;
|
||||
}
|
||||
// Can ignore other message types
|
||||
Some(Ok(_)) => {
|
||||
continue;
|
||||
};
|
||||
if let Err(e) = channel.send(bytes).await {
|
||||
warn!(
|
||||
"Failed to send response for {address} | Channel failure | {e:#}"
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
}
|
||||
// Cancel again if not already
|
||||
connection.cancel();
|
||||
};
|
||||
|
||||
tokio::join!(forward_requests, read_responses);
|
||||
tokio::join!(forward_writes, handle_reads);
|
||||
|
||||
warn!("Disconnnected from {address}");
|
||||
connection.connected.store(false, atomic::Ordering::Relaxed);
|
||||
}
|
||||
});
|
||||
warn!("Disconnnected from {address}");
|
||||
connection.connected.store(false, atomic::Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use ::bytes::Bytes;
|
||||
|
||||
pub mod bytes;
|
||||
pub mod channel;
|
||||
pub mod client;
|
||||
@@ -7,5 +9,14 @@ pub mod server;
|
||||
pub enum MessageState {
|
||||
Successful,
|
||||
Failed,
|
||||
Request,
|
||||
InProgress,
|
||||
Terminal,
|
||||
}
|
||||
|
||||
pub trait TransportHandler {
|
||||
fn handle_incoming_bytes(
|
||||
&self,
|
||||
bytes: Bytes,
|
||||
) -> impl Future<Output = ()> + Send;
|
||||
}
|
||||
|
||||
@@ -9,15 +9,17 @@ use serror::AddStatusCode;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{error, warn};
|
||||
|
||||
use crate::channel::BufferedReceiver;
|
||||
use crate::{TransportHandler, channel::BufferedReceiver};
|
||||
|
||||
pub fn inbound_connection(
|
||||
pub fn inbound_connection<
|
||||
T: TransportHandler + Send + Sync + 'static,
|
||||
>(
|
||||
ws: WebSocketUpgrade,
|
||||
handle_read: impl Fn(Bytes) + Send + Sync + 'static,
|
||||
transport: T,
|
||||
write_receiver: &'static Mutex<BufferedReceiver<Bytes>>,
|
||||
) -> serror::Result<Response> {
|
||||
// Limits to only one active websocket connection.
|
||||
let mut response_receiver = write_receiver
|
||||
let mut write_receiver = write_receiver
|
||||
.try_lock()
|
||||
.status_code(StatusCode::FORBIDDEN)?;
|
||||
|
||||
@@ -26,12 +28,34 @@ pub fn inbound_connection(
|
||||
|
||||
let (mut ws_write, mut ws_read) = socket.split();
|
||||
|
||||
// Handle incoming messages
|
||||
let ws_read = async {
|
||||
let forward_writes = async {
|
||||
loop {
|
||||
let msg = match write_receiver.recv().await {
|
||||
// Sender Dropped (shouldn't happen, it is static).
|
||||
None => break,
|
||||
// This has to copy the bytes to follow ownership rules.
|
||||
Some(msg) => Message::Binary(Bytes::copy_from_slice(msg)),
|
||||
};
|
||||
match ws_write.send(msg).await {
|
||||
// Clears the stored message from receiver buffer.
|
||||
// TODO: Move after response ack.
|
||||
Ok(_) => write_receiver.clear_buffer(),
|
||||
Err(e) => {
|
||||
warn!("Failed to send response | {e:?}");
|
||||
let _ = ws_write.close().await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let handle_reads = async {
|
||||
loop {
|
||||
match ws_read.next().await {
|
||||
// Incoming core msg
|
||||
Some(Ok(Message::Binary(msg))) => handle_read(msg),
|
||||
Some(Ok(Message::Binary(bytes))) => {
|
||||
transport.handle_incoming_bytes(bytes).await
|
||||
}
|
||||
// Disconnection cases.
|
||||
Some(Ok(Message::Close(frame))) => {
|
||||
warn!("Connection closed with frame: {frame:?}");
|
||||
@@ -50,31 +74,9 @@ pub fn inbound_connection(
|
||||
}
|
||||
};
|
||||
|
||||
// Handle outgoing messages
|
||||
let ws_write = async {
|
||||
loop {
|
||||
let msg = match response_receiver.recv().await {
|
||||
// Sender Dropped (shouldn't happen, it is static).
|
||||
None => break,
|
||||
// This has to copy the bytes to follow ownership rules.
|
||||
Some(msg) => Message::Binary(Bytes::copy_from_slice(msg)),
|
||||
};
|
||||
match ws_write.send(msg).await {
|
||||
// Clears the stored message from receiver buffer.
|
||||
// TODO: Move after response ack from Core.
|
||||
Ok(_) => response_receiver.clear_buffer(),
|
||||
Err(e) => {
|
||||
warn!("Failed to send response | {e:?}");
|
||||
let _ = ws_write.close().await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
tokio::select! {
|
||||
_ = ws_read => {},
|
||||
_ = ws_write => {}
|
||||
_ = forward_writes => {},
|
||||
_ = handle_reads => {},
|
||||
};
|
||||
}))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user