slowly better ergonomics

This commit is contained in:
mbecker20
2025-10-09 17:29:05 -07:00
parent dd8ac67c72
commit deaa8754f3
24 changed files with 384 additions and 281 deletions

View File

@@ -9,7 +9,7 @@ use transport::{
},
fix_ws_address,
message::{
Encode, Message,
Encode, TransportMessage,
login::{LoginMessage, LoginWebsocketExt},
},
websocket::{
@@ -158,7 +158,7 @@ async fn handle_passkey_login(
.await;
if let Err(e) = res {
if let Err(e) = socket
.send(Message::Login((&e).encode()))
.send(TransportMessage::Login((&e).encode()))
.await
.context("Failed to send login failed to client")
{

View File

@@ -24,9 +24,9 @@ use transport::{
},
channel::{BufferedReceiver, Sender, buffered_channel},
message::{
CastBytes, Decode, Encode, Message, MessageBytes,
json::JsonMessageBytes,
wrappers::{OptionWrapper, ResultWrapper, WithChannel},
CastBytes, Decode, Encode, TransportMessage, EncodedTransportMessage,
json::EncodedJsonMessage,
wrappers::{EncodedOption, EncodedResult, WithChannel},
},
websocket::{
Websocket, WebsocketMessage, WebsocketReceiver as _,
@@ -56,7 +56,7 @@ impl PeripheryConnections {
&self,
server_id: String,
args: PeripheryConnectionArgs<'_>,
) -> (Arc<PeripheryConnection>, BufferedReceiver<MessageBytes>) {
) -> (Arc<PeripheryConnection>, BufferedReceiver<EncodedTransportMessage>) {
let (connection, receiver) = if let Some(existing_connection) =
self.0.remove(&server_id).await
{
@@ -250,7 +250,7 @@ impl<'a> From<&'a OwnedPeripheryConnectionArgs>
/// Sends None as InProgress ping.
pub type ResponseChannels = CloneCache<
Uuid,
Sender<OptionWrapper<ResultWrapper<JsonMessageBytes>>>,
Sender<EncodedOption<EncodedResult<EncodedJsonMessage>>>,
>;
pub type TerminalChannels = CloneCache<Uuid, Sender<Vec<u8>>>;
@@ -260,7 +260,7 @@ pub struct PeripheryConnection {
/// The connection args
pub args: OwnedPeripheryConnectionArgs,
/// Send and receive bytes over the connection socket.
pub sender: Sender<MessageBytes>,
pub sender: Sender<EncodedTransportMessage>,
/// Cancel the connection
pub cancel: CancellationToken,
/// Whether Periphery is currently connected.
@@ -278,7 +278,7 @@ pub struct PeripheryConnection {
impl PeripheryConnection {
pub fn new(
args: impl Into<OwnedPeripheryConnectionArgs>,
) -> (Arc<PeripheryConnection>, BufferedReceiver<MessageBytes>) {
) -> (Arc<PeripheryConnection>, BufferedReceiver<EncodedTransportMessage>) {
let (sender, receiever) = buffered_channel();
(
PeripheryConnection {
@@ -298,7 +298,7 @@ impl PeripheryConnection {
pub fn with_new_args(
&self,
args: impl Into<OwnedPeripheryConnectionArgs>,
) -> (Arc<PeripheryConnection>, BufferedReceiver<MessageBytes>) {
) -> (Arc<PeripheryConnection>, BufferedReceiver<EncodedTransportMessage>) {
// Ensure this connection is cancelled.
self.cancel();
let (sender, receiever) = buffered_channel();
@@ -337,7 +337,7 @@ impl PeripheryConnection {
pub async fn handle_socket<W: Websocket>(
&self,
socket: W,
receiver: &mut BufferedReceiver<MessageBytes>,
receiver: &mut BufferedReceiver<EncodedTransportMessage>,
) {
let cancel = self.cancel.child_token();
@@ -392,8 +392,8 @@ impl PeripheryConnection {
self.set_connected(false);
}
pub async fn handle_incoming_message(&self, message: MessageBytes) {
let message: Message = match message.decode() {
pub async fn handle_incoming_message(&self, message: EncodedTransportMessage) {
let message: TransportMessage = match message.decode() {
Ok(res) => res,
Err(e) => {
warn!("Failed to parse Message bytes | {e:#}");
@@ -401,7 +401,7 @@ impl PeripheryConnection {
}
};
match message {
Message::Response(data) => match data.decode() {
TransportMessage::Response(data) => match data.decode() {
Ok(WithChannel {
channel: channel_id,
data,
@@ -423,7 +423,7 @@ impl PeripheryConnection {
warn!("Failed to read Response message | {e:#}");
}
},
Message::Terminal(data) => match data.decode() {
TransportMessage::Terminal(data) => match data.decode() {
Ok(WithChannel {
channel: channel_id,
data,
@@ -454,7 +454,7 @@ impl PeripheryConnection {
pub async fn send(
&self,
message: impl Encode<MessageBytes>,
message: impl Encode<EncodedTransportMessage>,
) -> anyhow::Result<()> {
self.sender.send_message(message).await
}

View File

@@ -26,7 +26,7 @@ use transport::{
PublicKeyValidator, ServerLoginFlow,
},
message::{
Encode, Message,
Encode, TransportMessage,
login::{LoginMessage, LoginWebsocketExt},
},
websocket::{Websocket, WebsocketExt as _, axum::AxumWebsocket},
@@ -202,7 +202,7 @@ async fn onboard_server_handler(
Err(e) => {
warn!("{e:#}");
if let Err(e) = socket
.send(Message::Login((&e).encode()))
.send(TransportMessage::Login((&e).encode()))
.await
.context("Failed to send Server creation failed to client")
{

View File

@@ -8,7 +8,7 @@ use serde_json::json;
use transport::{
channel::channel,
message::{
Decode, Encode, Message, json::JsonMessage, wrappers::WithChannel,
Decode, Encode, TransportMessage, json::JsonMessage, wrappers::WithChannel,
},
};
use uuid::Uuid;
@@ -122,7 +122,7 @@ impl PeripheryClient {
.encode()?;
if let Err(e) = connection
.send(Message::Request(
.send(TransportMessage::Request(
WithChannel {
channel: channel_id,
data,

View File

@@ -14,7 +14,7 @@ use periphery_client::api::terminal::{
};
use transport::{
channel::{Receiver, Sender, channel},
message::MessageBytes,
message::EncodedTransportMessage,
};
use uuid::Uuid;
@@ -26,8 +26,11 @@ impl PeripheryClient {
pub async fn connect_terminal(
&self,
terminal: String,
) -> anyhow::Result<(Uuid, Sender<MessageBytes>, Receiver<Vec<u8>>)>
{
) -> anyhow::Result<(
Uuid,
Sender<EncodedTransportMessage>,
Receiver<Vec<u8>>,
)> {
tracing::trace!(
"request | type: ConnectTerminal | terminal name: {terminal}",
);
@@ -60,8 +63,11 @@ impl PeripheryClient {
&self,
container: String,
shell: String,
) -> anyhow::Result<(Uuid, Sender<MessageBytes>, Receiver<Vec<u8>>)>
{
) -> anyhow::Result<(
Uuid,
Sender<EncodedTransportMessage>,
Receiver<Vec<u8>>,
)> {
tracing::trace!(
"request | type: ConnectContainerExec | container name: {container} | shell: {shell}",
);

View File

@@ -20,7 +20,7 @@ use periphery_client::api::terminal::DisconnectTerminal;
use tokio_util::sync::CancellationToken;
use transport::{
channel::{Receiver, Sender},
message::MessageBytes,
message::EncodedTransportMessage,
};
use uuid::Uuid;
@@ -185,7 +185,7 @@ async fn forward_ws_channel(
periphery: PeripheryClient,
client_socket: axum::extract::ws::WebSocket,
periphery_connection_id: Uuid,
periphery_sender: Sender<MessageBytes>,
periphery_sender: Sender<EncodedTransportMessage>,
mut periphery_receiver: Receiver<Vec<u8>>,
) {
let (mut client_send, mut client_receive) = client_socket.split();

View File

@@ -14,7 +14,7 @@ use periphery_client::api::{
use resolver_api::Resolve;
use serde::{Deserialize, Serialize};
use transport::message::{
json::JsonMessageBytes, wrappers::ResultWrapper,
json::EncodedJsonMessage, wrappers::EncodedResult,
};
use crate::{
@@ -43,7 +43,7 @@ pub struct Args {
Serialize, Deserialize, Debug, Clone, Resolve, EnumVariants,
)]
#[args(Args)]
#[response(ResultWrapper<JsonMessageBytes>)]
#[response(EncodedResult<EncodedJsonMessage>)]
#[error(anyhow::Error)]
#[variant_derive(Debug)]
#[serde(tag = "type", content = "params")]

View File

@@ -9,7 +9,7 @@ use komodo_client::{
use periphery_client::api::terminal::*;
use resolver_api::Resolve;
use tokio_util::{codec::LinesCodecError, sync::CancellationToken};
use transport::{channel::Sender, message::MessageBytes};
use transport::{channel::Sender, message::EncodedTransportMessage};
use uuid::Uuid;
use crate::{
@@ -262,7 +262,7 @@ impl Resolve<super::Args> for ExecuteContainerExec {
}
async fn handle_terminal_forwarding(
sender: &Sender<MessageBytes>,
sender: &Sender<EncodedTransportMessage>,
channel: Uuid,
terminal: Arc<Terminal>,
) {
@@ -414,7 +414,7 @@ async fn setup_execute_command_on_terminal(
}
async fn forward_execute_command_on_terminal_response(
sender: &Sender<MessageBytes>,
sender: &Sender<EncodedTransportMessage>,
channel: Uuid,
mut stdout: impl Stream<Item = Result<String, LinesCodecError>> + Unpin,
) {

View File

@@ -15,9 +15,9 @@ use transport::{
},
channel::{BufferedChannel, BufferedReceiver, Sender},
message::{
CastBytes, Decode, Encode, Message, MessageBytes,
json::JsonMessageBytes,
wrappers::{ChannelWrapper, WithChannel},
CastBytes, Decode, Encode, EncodedTransportMessage,
TransportMessage, json::EncodedJsonMessage,
wrappers::EncodedChannel,
},
websocket::{
Websocket, WebsocketReceiverExt as _, WebsocketSender as _,
@@ -34,7 +34,7 @@ pub mod server;
// Core Address / Host -> Channel
pub type CoreChannels =
CloneCache<String, Arc<BufferedChannel<MessageBytes>>>;
CloneCache<String, Arc<BufferedChannel<EncodedTransportMessage>>>;
pub fn core_channels() -> &'static CoreChannels {
static CORE_CHANNELS: OnceLock<CoreChannels> = OnceLock::new();
@@ -130,8 +130,8 @@ async fn handle_login<W: Websocket, L: LoginFlow>(
async fn handle_socket<W: Websocket>(
socket: W,
args: &Arc<Args>,
sender: &Sender<MessageBytes>,
receiver: &mut BufferedReceiver<MessageBytes>,
sender: &Sender<EncodedTransportMessage>,
receiver: &mut BufferedReceiver<EncodedTransportMessage>,
) {
let config = periphery_config();
info!(
@@ -179,11 +179,11 @@ async fn handle_socket<W: Websocket>(
}
};
match message {
Message::Request(message) => {
handle_request(args.clone(), sender.clone(), message)
TransportMessage::Request(message) => {
handle_request(args.clone(), sender.clone(), message.0)
}
Message::Terminal(message) => {
crate::terminal::handle_message(message).await
TransportMessage::Terminal(message) => {
crate::terminal::handle_message(message.0).await
}
// Rest shouldn't be received by Periphery
_ => {}
@@ -199,8 +199,8 @@ async fn handle_socket<W: Websocket>(
fn handle_request(
args: Arc<Args>,
sender: Sender<MessageBytes>,
message: ChannelWrapper<JsonMessageBytes>,
sender: Sender<EncodedTransportMessage>,
message: EncodedChannel<EncodedJsonMessage>,
) {
tokio::spawn(async move {
let (channel, request): (_, PeripheryRequest) = match message
@@ -216,18 +216,11 @@ fn handle_request(
};
let resolve_response = async {
let data = match request.resolve(&args).await {
let response = match request.resolve(&args).await {
Ok(res) => res,
Err(e) => (&e).encode(),
};
let message = WithChannel {
channel,
data: Some(data).encode(),
}
.encode();
if let Err(e) =
sender.send_message(Message::Response(message)).await
{
if let Err(e) = sender.send_response(channel, response).await {
error!("Failed to send response over channel | {e:?}");
}
};

View File

@@ -25,10 +25,7 @@ use transport::{
ConnectionIdentifiers, HeaderConnectionIdentifiers,
ServerLoginFlow,
},
message::{
Encode, Message,
login::{LoginMessage, LoginWebsocketExt},
},
message::login::{LoginMessage, LoginWebsocketExt},
websocket::{Websocket, WebsocketExt, axum::AxumWebsocket},
};
@@ -210,7 +207,7 @@ async fn handle_passkey_login(
} else {
let e = anyhow!("Invalid passkey");
if let Err(e) = socket
.send(Message::Login((&e).encode()))
.send_login_error(&e)
.await
.context("Failed to send login failed")
{
@@ -226,7 +223,7 @@ async fn handle_passkey_login(
.await;
if let Err(e) = res {
if let Err(e) = socket
.send(Message::Login((&e).encode()))
.send_login_error(&e)
.await
.context("Failed to send login failed to client")
{

View File

@@ -15,7 +15,7 @@ use tokio::sync::{Mutex, broadcast, mpsc, oneshot};
use tokio_util::sync::CancellationToken;
use transport::message::{
Decode,
wrappers::{ChannelWrapper, WithChannel},
wrappers::{EncodedChannel, WithChannel},
};
use uuid::Uuid;
@@ -99,7 +99,7 @@ impl TerminalTriggers {
}
}
pub async fn handle_message(message: ChannelWrapper<Vec<u8>>) {
pub async fn handle_message(message: EncodedChannel<Vec<u8>>) {
let WithChannel {
channel: channel_id,
mut data,

View File

@@ -7,8 +7,10 @@ use uuid::Uuid;
use crate::{
message::{
Encode, Message, MessageBytes, json::JsonMessage,
wrappers::WithChannel,
DecodedTransportMessage, Encode, EncodedResponseMessage,
EncodedTransportMessage, TransportMessage,
json::{EncodedJsonMessage, JsonMessage},
wrappers::{EncodedResult, WithChannel},
},
timeout::MaybeWithTimeout,
};
@@ -70,10 +72,10 @@ impl<T> Sender<T> {
}
}
impl Sender<MessageBytes> {
impl Sender<EncodedTransportMessage> {
pub async fn send_message(
&self,
message: impl Encode<MessageBytes>,
message: impl Encode<EncodedTransportMessage>,
) -> anyhow::Result<()> {
self.send(message.encode()).await
}
@@ -88,7 +90,7 @@ impl Sender<MessageBytes> {
{
let data = JsonMessage(request).encode()?;
let message =
Message::Request(WithChannel { channel, data }.encode());
DecodedTransportMessage::Request(WithChannel { channel, data });
self.send_message(message).await
}
@@ -96,34 +98,25 @@ impl Sender<MessageBytes> {
&self,
channel: Uuid,
) -> anyhow::Result<()> {
let message = Message::Response(
WithChannel {
channel,
data: None.encode(),
}
.encode(),
);
let message = DecodedTransportMessage::Response(WithChannel {
channel,
data: None,
});
self.send_message(message).await
}
pub async fn send_response<'a, T: Serialize + Send>(
pub async fn send_response(
&self,
channel: Uuid,
response: anyhow::Result<&'a T>,
) -> anyhow::Result<()>
where
&'a T: Send,
{
let data = response
.and_then(|json| JsonMessage(json).encode())
.encode();
let message = Message::Response(
response: EncodedResult<EncodedJsonMessage>,
) -> anyhow::Result<()> {
let message = TransportMessage::Response(EncodedResponseMessage(
WithChannel {
channel,
data: Some(data).encode(),
data: Some(response).encode(),
}
.encode(),
);
));
self.send_message(message).await
}
@@ -132,13 +125,10 @@ impl Sender<MessageBytes> {
channel: Uuid,
data: impl Into<Vec<u8>>,
) -> anyhow::Result<()> {
let message = Message::Terminal(
WithChannel {
channel,
data: data.into(),
}
.encode(),
);
let message = DecodedTransportMessage::Terminal(WithChannel {
channel,
data: data.into(),
});
self.send_message(message).await
}
}

View File

@@ -2,7 +2,7 @@ use anyhow::Context;
use serde::{Serialize, de::DeserializeOwned};
use crate::message::{
CastBytes, Decode, Encode, wrappers::ResultWrapper,
CastBytes, Decode, Encode, wrappers::EncodedResult,
};
/// ```markdown
@@ -10,9 +10,9 @@ use crate::message::{
/// | <JSON BYTES> |
/// ```
#[derive(Clone, Debug)]
pub struct JsonMessageBytes(Vec<u8>);
pub struct EncodedJsonMessage(Vec<u8>);
impl CastBytes for JsonMessageBytes {
impl CastBytes for EncodedJsonMessage {
fn from_vec(vec: Vec<u8>) -> Self {
Self(vec)
}
@@ -23,19 +23,19 @@ impl CastBytes for JsonMessageBytes {
pub struct JsonMessage<'a, T>(pub &'a T);
impl<'a, T: Serialize + Send> Encode<anyhow::Result<JsonMessageBytes>>
for JsonMessage<'a, T>
impl<'a, T: Serialize + Send>
Encode<anyhow::Result<EncodedJsonMessage>> for JsonMessage<'a, T>
where
&'a T: Send,
{
fn encode(self) -> anyhow::Result<JsonMessageBytes> {
fn encode(self) -> anyhow::Result<EncodedJsonMessage> {
let bytes = serde_json::to_vec(self.0)
.context("Failed to serialize data to bytes")?;
Ok(JsonMessageBytes(bytes))
Ok(EncodedJsonMessage(bytes))
}
}
impl<T: DeserializeOwned> Decode<T> for JsonMessageBytes {
impl<T: DeserializeOwned> Decode<T> for EncodedJsonMessage {
fn decode(self) -> anyhow::Result<T> {
serde_json::from_slice(&self.0)
.context("Failed to parse JSON bytes")
@@ -43,11 +43,11 @@ impl<T: DeserializeOwned> Decode<T> for JsonMessageBytes {
}
impl<T: Serialize + Send> From<T>
for ResultWrapper<JsonMessageBytes>
for EncodedResult<EncodedJsonMessage>
{
fn from(value: T) -> Self {
serde_json::to_vec(&value)
.map(JsonMessageBytes::from_vec)
.map(EncodedJsonMessage::from_vec)
.context("Failed to serialize data to bytes")
.encode()
}

View File

@@ -4,7 +4,10 @@ use noise::key::SpkiPublicKey;
use crate::{
auth::AUTH_TIMEOUT,
message::{CastBytes, Decode, Encode, Message},
message::{
CastBytes, Decode, DecodedTransportMessage, Encode,
EncodedLoginMessage, TransportMessage,
},
websocket::{Websocket, WebsocketExt},
};
@@ -13,7 +16,8 @@ pub trait LoginWebsocketExt: WebsocketExt {
&mut self,
e: &anyhow::Error,
) -> impl Future<Output = anyhow::Result<()>> + Send {
let message = Message::Login(e.encode());
let message =
TransportMessage::Login(EncodedLoginMessage(e.encode()));
self.send(message)
}
@@ -21,14 +25,14 @@ pub trait LoginWebsocketExt: WebsocketExt {
&mut self,
) -> impl Future<Output = anyhow::Result<LoginMessage>> + Send {
async {
let Message::Login(message) =
let TransportMessage::Login(message) =
self.recv().with_timeout(AUTH_TIMEOUT).await?
else {
return Err(anyhow!(
"Expected Login message, got other message type"
));
};
message.decode_into()
message.0.decode_into()
}
}
@@ -139,9 +143,9 @@ pub trait LoginWebsocketExt: WebsocketExt {
impl<W: Websocket> LoginWebsocketExt for W {}
impl From<LoginMessage> for Message {
impl From<LoginMessage> for TransportMessage {
fn from(value: LoginMessage) -> Self {
Self::Login(Ok(value.encode()).encode())
DecodedTransportMessage::Login(Ok(value)).encode()
}
}
@@ -150,9 +154,9 @@ impl From<LoginMessage> for Message {
/// | <CONTENTS> | LoginMessageVariant |
/// ```
#[derive(Clone, Debug)]
pub struct LoginMessageBytes(Vec<u8>);
pub struct InnerEncodedLoginMessage(Vec<u8>);
impl CastBytes for LoginMessageBytes {
impl CastBytes for InnerEncodedLoginMessage {
fn from_vec(vec: Vec<u8>) -> Self {
Self(vec)
}
@@ -191,8 +195,8 @@ pub enum LoginMessage {
V1Passkey(Vec<u8>),
}
impl Encode<LoginMessageBytes> for LoginMessage {
fn encode(self) -> LoginMessageBytes {
impl Encode<InnerEncodedLoginMessage> for LoginMessage {
fn encode(self) -> InnerEncodedLoginMessage {
let variant_byte = self.extract_variant().as_byte();
let mut bytes = match self {
LoginMessage::Success => Vec::new(),
@@ -212,11 +216,11 @@ impl Encode<LoginMessageBytes> for LoginMessage {
LoginMessage::V1Passkey(bytes) => bytes,
};
bytes.push(variant_byte);
LoginMessageBytes(bytes)
InnerEncodedLoginMessage(bytes)
}
}
impl Decode<LoginMessage> for LoginMessageBytes {
impl Decode<LoginMessage> for InnerEncodedLoginMessage {
/// Parses login messages, performing various validations.
fn decode(self) -> anyhow::Result<LoginMessage> {
let mut bytes = self.0;

View File

@@ -4,8 +4,8 @@ pub mod json;
pub mod login;
pub mod wrappers;
mod root;
pub use root::*;
mod transport;
pub use transport::*;
pub trait Encode<Target>: Sized + Send {
fn encode(self) -> Target;

View File

@@ -1,92 +0,0 @@
use anyhow::{Context as _, anyhow};
use derive_variants::{EnumVariants, ExtractVariant as _};
use crate::message::{
CastBytes, Decode, Encode,
json::JsonMessageBytes,
login::LoginMessageBytes,
wrappers::{ChannelWrapper, OptionWrapper, ResultWrapper},
};
#[derive(Debug, Clone)]
pub struct MessageBytes(Vec<u8>);
impl CastBytes for MessageBytes {
fn from_vec(vec: Vec<u8>) -> Self {
Self(vec)
}
fn into_vec(self) -> Vec<u8> {
self.0
}
}
#[derive(Debug, EnumVariants)]
#[variant_derive(Debug, Clone, Copy)]
pub enum Message {
Login(ResultWrapper<LoginMessageBytes>),
Request(ChannelWrapper<JsonMessageBytes>),
Response(
ChannelWrapper<OptionWrapper<ResultWrapper<JsonMessageBytes>>>,
),
Terminal(ChannelWrapper<Vec<u8>>),
}
impl<T: Into<Message> + Send> Encode<MessageBytes> for T {
fn encode(self) -> MessageBytes {
let message = self.into();
let variant_byte = message.extract_variant().as_byte();
let mut bytes = match message {
Message::Login(data) => data.into_vec(),
Message::Request(data) => data.into_vec(),
Message::Response(data) => data.into_vec(),
Message::Terminal(data) => data.into_vec(),
};
bytes.push(variant_byte);
MessageBytes(bytes)
}
}
impl<T: From<Message>> Decode<T> for MessageBytes {
fn decode(self) -> anyhow::Result<T> {
let mut bytes = self.0;
let variant_byte = bytes
.pop()
.context("Failed to decode message | bytes are empty")?;
use MessageVariant::*;
let message = match MessageVariant::from_byte(variant_byte)? {
Login => Message::Login(ResultWrapper::from_vec(bytes)),
Request => Message::Request(ChannelWrapper::from_vec(bytes)),
Response => Message::Response(ChannelWrapper::from_vec(bytes)),
Terminal => Message::Terminal(ChannelWrapper::from_vec(bytes)),
};
Ok(message.into())
}
}
impl MessageVariant {
pub fn from_byte(byte: u8) -> anyhow::Result<Self> {
use MessageVariant::*;
let variant = match byte {
0 => Login,
1 => Request,
2 => Response,
3 => Terminal,
other => {
return Err(anyhow!(
"Got unrecognized MessageVariant byte: {other}"
));
}
};
Ok(variant)
}
pub fn as_byte(self) -> u8 {
use MessageVariant::*;
match self {
Login => 0,
Request => 1,
Response => 2,
Terminal => 3,
}
}
}

View File

@@ -0,0 +1,194 @@
use anyhow::{Context as _, anyhow};
use derive_variants::{EnumVariants, ExtractVariant as _};
use crate::message::{
CastBytes, Decode, Encode,
json::EncodedJsonMessage,
login::{InnerEncodedLoginMessage, LoginMessage},
wrappers::{
EncodedChannel, EncodedOption, EncodedResult, WithChannel,
},
};
#[derive(Debug, Clone)]
pub struct EncodedTransportMessage(Vec<u8>);
impl CastBytes for EncodedTransportMessage {
fn from_vec(vec: Vec<u8>) -> Self {
Self(vec)
}
fn into_vec(self) -> Vec<u8> {
self.0
}
}
#[derive(Debug)]
pub struct EncodedLoginMessage(
pub EncodedResult<InnerEncodedLoginMessage>,
);
#[derive(Debug)]
pub struct EncodedRequestMessage(
pub EncodedChannel<EncodedJsonMessage>,
);
#[derive(Debug)]
pub struct EncodedResponseMessage(
pub EncodedChannel<EncodedOption<EncodedResult<EncodedJsonMessage>>>,
);
#[derive(Debug)]
pub struct EncodedTerminalMessage(pub EncodedChannel<Vec<u8>>);
#[derive(Debug, EnumVariants)]
#[variant_derive(Debug, Clone, Copy)]
pub enum TransportMessage {
Login(EncodedLoginMessage),
Request(EncodedRequestMessage),
Response(EncodedResponseMessage),
Terminal(EncodedTerminalMessage),
}
impl<T: Into<TransportMessage> + Send> Encode<EncodedTransportMessage>
for T
{
fn encode(self) -> EncodedTransportMessage {
let message = self.into();
let variant_byte = message.extract_variant().as_byte();
let mut bytes = match message {
TransportMessage::Login(data) => data.0.into_vec(),
TransportMessage::Request(data) => data.0.into_vec(),
TransportMessage::Response(data) => data.0.into_vec(),
TransportMessage::Terminal(data) => data.0.into_vec(),
};
bytes.push(variant_byte);
EncodedTransportMessage(bytes)
}
}
impl<T: From<TransportMessage>> Decode<T>
for EncodedTransportMessage
{
fn decode(self) -> anyhow::Result<T> {
let mut bytes = self.0;
let variant_byte = bytes
.pop()
.context("Failed to decode message | bytes are empty")?;
use TransportMessageVariant::*;
let message =
match TransportMessageVariant::from_byte(variant_byte)? {
Login => TransportMessage::Login(EncodedLoginMessage(
EncodedResult::from_vec(bytes),
)),
Request => TransportMessage::Request(EncodedRequestMessage(
EncodedChannel::from_vec(bytes),
)),
Response => TransportMessage::Response(
EncodedResponseMessage(EncodedChannel::from_vec(bytes)),
),
Terminal => TransportMessage::Terminal(
EncodedTerminalMessage(EncodedChannel::from_vec(bytes)),
),
};
Ok(message.into())
}
}
pub enum DecodedTransportMessage {
Login(anyhow::Result<LoginMessage>),
Request(WithChannel<EncodedJsonMessage>),
Response(
WithChannel<Option<anyhow::Result<EncodedJsonMessage>>>, // EncodedChannel<EncodedOption<EncodedResult<EncodedJsonMessage>>>,
),
Terminal(WithChannel<Vec<u8>>),
}
impl<T: Into<DecodedTransportMessage> + Send> Encode<TransportMessage>
for T
{
fn encode(self) -> TransportMessage {
use DecodedTransportMessage::*;
match self.into() {
Login(res) => TransportMessage::Login(EncodedLoginMessage(
res.map(LoginMessage::encode).encode(),
)),
Request(data) => TransportMessage::Request(
EncodedRequestMessage(data.encode()),
),
Response(data) => {
TransportMessage::Response(EncodedResponseMessage(
data
.map(|data| data.map(|data| data.encode()).encode())
.encode(),
))
}
Terminal(data) => TransportMessage::Terminal(
EncodedTerminalMessage(data.encode()),
),
}
}
}
impl Into<TransportMessage> for DecodedTransportMessage {
fn into(self) -> TransportMessage {
self.encode()
}
}
impl<T: From<DecodedTransportMessage>> Decode<T>
for TransportMessage
{
fn decode(self) -> anyhow::Result<T> {
let message = match self {
TransportMessage::Login(encoded_result) => {
let res =
encoded_result.0.decode().and_then(|msg| msg.decode());
DecodedTransportMessage::Login(res)
}
TransportMessage::Request(encoded_channel) => {
DecodedTransportMessage::Request(encoded_channel.0.decode()?)
}
TransportMessage::Response(encoded_channel) => {
let WithChannel { channel, data } =
encoded_channel.0.decode()?;
let data = data.decode()?.map(|data| data.decode());
DecodedTransportMessage::Response(WithChannel {
channel,
data,
})
}
TransportMessage::Terminal(encoded_channel) => {
DecodedTransportMessage::Terminal(encoded_channel.0.decode()?)
}
};
Ok(message.into())
}
}
impl TransportMessageVariant {
pub fn from_byte(byte: u8) -> anyhow::Result<Self> {
use TransportMessageVariant::*;
let variant = match byte {
0 => Login,
1 => Request,
2 => Response,
3 => Terminal,
other => {
return Err(anyhow!(
"Got unrecognized MessageVariant byte: {other}"
));
}
};
Ok(variant)
}
pub fn as_byte(self) -> u8 {
use TransportMessageVariant::*;
match self {
Login => 0,
Request => 1,
Response => 2,
Terminal => 3,
}
}
}

View File

@@ -11,15 +11,15 @@ use crate::message::{CastBytes, Decode, Encode};
/// | <CONTENTS> | Channel Uuid |
/// ```
#[derive(Clone, Debug)]
pub struct ChannelWrapper<T>(T);
pub struct EncodedChannel<T>(T);
impl<T> From<T> for ChannelWrapper<T> {
impl<T> From<T> for EncodedChannel<T> {
fn from(value: T) -> Self {
Self(value)
}
}
impl<T: CastBytes> CastBytes for ChannelWrapper<T> {
impl<T: CastBytes> CastBytes for EncodedChannel<T> {
fn from_bytes(bytes: Bytes) -> Self {
Self(T::from_bytes(bytes))
}
@@ -39,17 +39,35 @@ pub struct WithChannel<T> {
pub data: T,
}
impl<T: CastBytes + Send> Encode<ChannelWrapper<T>>
for WithChannel<T>
{
fn encode(self) -> ChannelWrapper<T> {
let mut bytes = self.data.into_vec();
bytes.extend(self.channel.into_bytes());
ChannelWrapper(T::from_vec(bytes))
impl<T> WithChannel<T> {
pub fn map<R>(self, map: impl FnOnce(T) -> R) -> WithChannel<R> {
WithChannel {
channel: self.channel,
data: map(self.data),
}
}
}
impl<T: CastBytes> Decode<WithChannel<T>> for ChannelWrapper<T> {
impl<T, E: Encode<T>> Encode<WithChannel<T>> for WithChannel<E> {
fn encode(self) -> WithChannel<T> {
WithChannel {
channel: self.channel,
data: self.data.encode(),
}
}
}
impl<T: CastBytes + Send> Encode<EncodedChannel<T>>
for WithChannel<T>
{
fn encode(self) -> EncodedChannel<T> {
let mut bytes = self.data.into_vec();
bytes.extend(self.channel.into_bytes());
EncodedChannel(T::from_vec(bytes))
}
}
impl<T: CastBytes> Decode<WithChannel<T>> for EncodedChannel<T> {
fn decode(self) -> anyhow::Result<WithChannel<T>> {
let mut bytes = self.0.into_vec();
let len = bytes.len();

View File

@@ -2,6 +2,6 @@ mod channel;
mod option;
mod result;
pub use channel::{ChannelWrapper, WithChannel};
pub use option::OptionWrapper;
pub use result::ResultWrapper;
pub use channel::{EncodedChannel, WithChannel};
pub use option::EncodedOption;
pub use result::EncodedResult;

View File

@@ -10,15 +10,15 @@ use crate::message::{CastBytes, Decode, Encode};
/// | <CONTENTS> | 0: Ok or _: Err |
/// ```
#[derive(Clone, Debug)]
pub struct OptionWrapper<T>(T);
pub struct EncodedOption<T>(T);
impl<T> From<T> for OptionWrapper<T> {
impl<T> From<T> for EncodedOption<T> {
fn from(value: T) -> Self {
Self(value)
}
}
impl<T: CastBytes> CastBytes for OptionWrapper<T> {
impl<T: CastBytes> CastBytes for EncodedOption<T> {
fn from_bytes(bytes: Bytes) -> Self {
Self(T::from_bytes(bytes))
}
@@ -33,20 +33,20 @@ impl<T: CastBytes> CastBytes for OptionWrapper<T> {
}
}
impl<T: CastBytes + Send> Encode<OptionWrapper<T>> for Option<T> {
fn encode(self) -> OptionWrapper<T> {
impl<T: CastBytes + Send> Encode<EncodedOption<T>> for Option<T> {
fn encode(self) -> EncodedOption<T> {
match self {
Some(data) => {
let mut bytes = data.into_vec();
bytes.push(0);
OptionWrapper(T::from_vec(bytes))
EncodedOption(T::from_vec(bytes))
}
None => OptionWrapper(T::from_vec(vec![1])),
None => EncodedOption(T::from_vec(vec![1])),
}
}
}
impl<T: CastBytes> Decode<Option<T>> for OptionWrapper<T> {
impl<T: CastBytes> Decode<Option<T>> for EncodedOption<T> {
fn decode(self) -> anyhow::Result<Option<T>> {
let mut bytes = self.0.into_vec();
let option_byte =

View File

@@ -11,15 +11,15 @@ use crate::message::{CastBytes, Decode, Encode};
/// | <CONTENTS> | 0: Ok or _: Err |
/// ```
#[derive(Clone, Debug)]
pub struct ResultWrapper<T>(T);
pub struct EncodedResult<T>(T);
impl<T> From<T> for ResultWrapper<T> {
impl<T> From<T> for EncodedResult<T> {
fn from(value: T) -> Self {
Self(value)
}
}
impl<T: CastBytes> CastBytes for ResultWrapper<T> {
impl<T: CastBytes> CastBytes for EncodedResult<T> {
fn from_bytes(bytes: Bytes) -> Self {
Self(T::from_bytes(bytes))
}
@@ -34,10 +34,10 @@ impl<T: CastBytes> CastBytes for ResultWrapper<T> {
}
}
impl<T: CastBytes + Send> Encode<ResultWrapper<T>>
impl<T: CastBytes + Send> Encode<EncodedResult<T>>
for anyhow::Result<T>
{
fn encode(self) -> ResultWrapper<T> {
fn encode(self) -> EncodedResult<T> {
let bytes = match self {
Ok(data) => {
let mut bytes = data.into_vec();
@@ -50,21 +50,21 @@ impl<T: CastBytes + Send> Encode<ResultWrapper<T>>
bytes
}
};
ResultWrapper(T::from_vec(bytes))
EncodedResult(T::from_vec(bytes))
}
}
impl<T: CastBytes + Send> Encode<ResultWrapper<T>>
impl<T: CastBytes + Send> Encode<EncodedResult<T>>
for &anyhow::Error
{
fn encode(self) -> ResultWrapper<T> {
fn encode(self) -> EncodedResult<T> {
let mut bytes = serialize_error_bytes(self);
bytes.push(1);
ResultWrapper::from_vec(bytes)
EncodedResult::from_vec(bytes)
}
}
impl<T: CastBytes> Decode<T> for ResultWrapper<T> {
impl<T: CastBytes> Decode<T> for EncodedResult<T> {
fn decode(self) -> anyhow::Result<T> {
let mut bytes = self.0.into_vec();
let result_byte =

View File

@@ -8,7 +8,7 @@ use futures_util::{
use tokio_util::sync::CancellationToken;
use crate::{
message::{CastBytes, MessageBytes},
message::{CastBytes, EncodedTransportMessage},
timeout::MaybeWithTimeout,
};
@@ -87,13 +87,13 @@ where
match stream.try_next().await? {
Some(axum::extract::ws::Message::Binary(bytes)) => {
return Ok(WebsocketMessage::Message(
MessageBytes::from_vec(bytes.into()),
EncodedTransportMessage::from_vec(bytes.into()),
));
}
Some(axum::extract::ws::Message::Text(text)) => {
let bytes: Bytes = text.into();
return Ok(WebsocketMessage::Message(
MessageBytes::from_vec(bytes.into()),
EncodedTransportMessage::from_vec(bytes.into()),
));
}
Some(axum::extract::ws::Message::Close(frame)) => {

View File

@@ -8,8 +8,11 @@ use uuid::Uuid;
use crate::{
message::{
CastBytes, Decode, Encode, Message, MessageBytes,
json::JsonMessage, wrappers::WithChannel,
CastBytes, Decode, DecodedTransportMessage, Encode,
EncodedResponseMessage, EncodedTransportMessage,
TransportMessage,
json::{EncodedJsonMessage, JsonMessage},
wrappers::{EncodedResult, WithChannel},
},
timeout::MaybeWithTimeout,
};
@@ -21,7 +24,7 @@ pub mod tungstenite;
/// for easier handling.
pub enum WebsocketMessage<CloseFrame> {
/// Standard message
Message(MessageBytes),
Message(EncodedTransportMessage),
/// Graceful close message
Close(Option<CloseFrame>),
/// Stream closed
@@ -59,7 +62,7 @@ pub trait Websocket: Send {
pub trait WebsocketExt: Websocket {
fn send(
&mut self,
message: impl Encode<MessageBytes>,
message: impl Encode<EncodedTransportMessage>,
) -> impl Future<Output = anyhow::Result<()>> + Send {
self.send_inner(message.encode().into_vec().into())
}
@@ -68,7 +71,7 @@ pub trait WebsocketExt: Websocket {
fn recv(
&mut self,
) -> MaybeWithTimeout<
impl Future<Output = anyhow::Result<Message>> + Send,
impl Future<Output = anyhow::Result<TransportMessage>> + Send,
> {
MaybeWithTimeout::new(async {
match self.recv_inner().await? {
@@ -103,7 +106,7 @@ pub trait WebsocketSender {
pub trait WebsocketSenderExt: WebsocketSender + Send {
fn send(
&mut self,
message: impl Encode<MessageBytes>,
message: impl Encode<EncodedTransportMessage>,
) -> impl Future<Output = anyhow::Result<()>> + Send {
self.send_inner(message.encode().into_vec().into())
}
@@ -118,8 +121,10 @@ pub trait WebsocketSenderExt: WebsocketSender + Send {
{
async move {
let data = JsonMessage(request).encode()?;
let message =
Message::Request(WithChannel { channel, data }.encode());
let message = DecodedTransportMessage::Request(WithChannel {
channel,
data,
});
self.send(message).await
}
}
@@ -128,34 +133,25 @@ pub trait WebsocketSenderExt: WebsocketSender + Send {
&mut self,
channel: Uuid,
) -> impl Future<Output = anyhow::Result<()>> + Send {
let message = Message::Response(
WithChannel {
channel,
data: None.encode(),
}
.encode(),
);
let message = DecodedTransportMessage::Response(WithChannel {
channel,
data: None,
});
self.send(message)
}
fn send_response<'a, T: Serialize + Send>(
fn send_response(
&mut self,
channel: Uuid,
response: anyhow::Result<&'a T>,
) -> impl Future<Output = anyhow::Result<()>> + Send
where
&'a T: Send,
{
let data = response
.and_then(|json| JsonMessage(json).encode())
.encode();
let message = Message::Response(
response: EncodedResult<EncodedJsonMessage>,
) -> impl Future<Output = anyhow::Result<()>> + Send {
let message = TransportMessage::Response(EncodedResponseMessage(
WithChannel {
channel,
data: Some(data).encode(),
data: Some(response).encode(),
}
.encode(),
);
));
self.send(message)
}
@@ -164,13 +160,10 @@ pub trait WebsocketSenderExt: WebsocketSender + Send {
channel: Uuid,
data: impl Into<Vec<u8>>,
) -> impl Future<Output = anyhow::Result<()>> + Send {
let message = Message::Terminal(
WithChannel {
channel,
data: data.into(),
}
.encode(),
);
let message = DecodedTransportMessage::Terminal(WithChannel {
channel,
data: data.into(),
});
self.send(message)
}
}
@@ -198,7 +191,7 @@ pub trait WebsocketReceiverExt: WebsocketReceiver {
fn recv(
&mut self,
) -> MaybeWithTimeout<
impl Future<Output = anyhow::Result<Message>> + Send,
impl Future<Output = anyhow::Result<TransportMessage>> + Send,
> {
MaybeWithTimeout::new(async {
match self

View File

@@ -19,7 +19,7 @@ use tokio_tungstenite::{
use tokio_util::sync::CancellationToken;
use crate::{
message::{CastBytes, MessageBytes},
message::{CastBytes, EncodedTransportMessage},
timeout::MaybeWithTimeout,
};
@@ -105,13 +105,13 @@ where
match stream.try_next().await? {
Some(tungstenite::Message::Binary(bytes)) => {
return Ok(WebsocketMessage::Message(
MessageBytes::from_vec(bytes.into()),
EncodedTransportMessage::from_vec(bytes.into()),
));
}
Some(tungstenite::Message::Text(text)) => {
let bytes: Bytes = text.into();
return Ok(WebsocketMessage::Message(
MessageBytes::from_vec(bytes.into()),
EncodedTransportMessage::from_vec(bytes.into()),
));
}
Some(tungstenite::Message::Close(frame)) => {