forked from github-starred/komodo
fix passkey support
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -5686,6 +5686,7 @@ dependencies = [
|
||||
"sha2",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"url",
|
||||
"uuid",
|
||||
|
||||
@@ -439,7 +439,6 @@ async fn get_on_host_periphery(
|
||||
&config,
|
||||
),
|
||||
config.insecure_tls,
|
||||
&config.passkey,
|
||||
)
|
||||
.await?;
|
||||
// Poll for connection to be estalished
|
||||
|
||||
@@ -24,7 +24,6 @@ impl PeripheryConnectionArgs<'_> {
|
||||
self,
|
||||
id: String,
|
||||
insecure: bool,
|
||||
passkey: String,
|
||||
) -> anyhow::Result<Arc<ConnectionChannels>> {
|
||||
let Some(address) = self.address else {
|
||||
return Err(anyhow!(
|
||||
@@ -71,9 +70,8 @@ impl PeripheryConnectionArgs<'_> {
|
||||
core_connection_query().as_bytes(),
|
||||
);
|
||||
|
||||
if let Err(e) = connection
|
||||
.client_login(&mut socket, identifiers, &passkey)
|
||||
.await
|
||||
if let Err(e) =
|
||||
connection.client_login(&mut socket, identifiers).await
|
||||
{
|
||||
connection.set_error(e).await;
|
||||
tokio::time::sleep(Duration::from_secs(
|
||||
@@ -98,13 +96,11 @@ impl PeripheryConnection {
|
||||
&self,
|
||||
socket: &mut TungsteniteWebsocket,
|
||||
identifiers: ConnectionIdentifiers<'_>,
|
||||
// for legacy auth
|
||||
passkey: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
// Get the required auth type
|
||||
let bytes = socket
|
||||
.recv_result()
|
||||
.with_timeout(Duration::from_secs(2))
|
||||
.with_timeout(AUTH_TIMEOUT)
|
||||
.await
|
||||
.context("Failed to receive login type indicator")?;
|
||||
|
||||
@@ -116,7 +112,10 @@ impl PeripheryConnection {
|
||||
.await
|
||||
}
|
||||
// Passkey auth
|
||||
&[1] => handle_passkey_login(socket, passkey).await,
|
||||
&[1] => {
|
||||
handle_passkey_login(socket, self.args.passkey.as_deref())
|
||||
.await
|
||||
}
|
||||
other => Err(anyhow!(
|
||||
"Receieved invalid login type pattern: {other:?}"
|
||||
)),
|
||||
@@ -127,18 +126,18 @@ impl PeripheryConnection {
|
||||
async fn handle_passkey_login(
|
||||
socket: &mut TungsteniteWebsocket,
|
||||
// for legacy auth
|
||||
passkey: &str,
|
||||
passkey: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
let res = async {
|
||||
let passkey = if passkey.is_empty() {
|
||||
let passkey = if let Some(passkey) = passkey {
|
||||
passkey.as_bytes().to_vec()
|
||||
} else {
|
||||
core_config()
|
||||
.passkey
|
||||
.as_deref()
|
||||
.context("Periphery requires passkey auth")?
|
||||
.as_bytes()
|
||||
.to_vec()
|
||||
} else {
|
||||
passkey.as_bytes().to_vec()
|
||||
};
|
||||
|
||||
socket
|
||||
|
||||
@@ -93,6 +93,9 @@ pub struct PeripheryConnectionArgs<'a> {
|
||||
pub id: &'a str,
|
||||
pub address: Option<&'a str>,
|
||||
periphery_public_key: Option<&'a str>,
|
||||
/// V1 legacy support.
|
||||
/// Only possible for Core -> Periphery.
|
||||
passkey: Option<&'a str>,
|
||||
}
|
||||
|
||||
impl PublicKeyValidator for PeripheryConnectionArgs<'_> {
|
||||
@@ -145,6 +148,7 @@ impl<'a> PeripheryConnectionArgs<'a> {
|
||||
id: &server.id,
|
||||
address: optional_str(&server.config.address),
|
||||
periphery_public_key: optional_str(&server.info.public_key),
|
||||
passkey: optional_str(&server.config.passkey),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -158,6 +162,7 @@ impl<'a> PeripheryConnectionArgs<'a> {
|
||||
periphery_public_key: optional_str(
|
||||
&config.periphery_public_key,
|
||||
),
|
||||
passkey: optional_str(&config.passkey),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,6 +177,7 @@ impl<'a> PeripheryConnectionArgs<'a> {
|
||||
periphery_public_key: optional_str(
|
||||
&config.periphery_public_key,
|
||||
),
|
||||
passkey: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -182,6 +188,7 @@ impl<'a> PeripheryConnectionArgs<'a> {
|
||||
periphery_public_key: self
|
||||
.periphery_public_key
|
||||
.map(str::to_string),
|
||||
passkey: self.passkey.map(str::to_string),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -204,6 +211,9 @@ pub struct OwnedPeripheryConnectionArgs {
|
||||
/// If None, must have 'periphery_public_keys' set
|
||||
/// in Core config, or will error
|
||||
pub periphery_public_key: Option<String>,
|
||||
/// V1 legacy support.
|
||||
/// Only possible for Core -> Periphery connection.
|
||||
pub passkey: Option<String>,
|
||||
}
|
||||
|
||||
impl OwnedPeripheryConnectionArgs {
|
||||
@@ -212,6 +222,7 @@ impl OwnedPeripheryConnectionArgs {
|
||||
id: &self.id,
|
||||
address: self.address.as_deref(),
|
||||
periphery_public_key: self.periphery_public_key.as_deref(),
|
||||
passkey: self.passkey.as_deref(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -319,19 +330,14 @@ impl PeripheryConnection {
|
||||
|
||||
let (mut ws_write, mut ws_read) = socket.split();
|
||||
|
||||
ws_read.set_cancel(cancel.clone());
|
||||
receiver.set_cancel(cancel.clone());
|
||||
|
||||
let forward_writes = async {
|
||||
loop {
|
||||
let next = tokio::select! {
|
||||
next = receiver.recv() => next,
|
||||
_ = cancel.cancelled() => break,
|
||||
let Ok(message) = receiver.recv().await else {
|
||||
break;
|
||||
};
|
||||
|
||||
let message = match next {
|
||||
Some(request) => request.to_message(),
|
||||
// Sender Dropped (shouldn't happen, a reference is held on 'connection').
|
||||
None => break,
|
||||
};
|
||||
|
||||
match ws_write.send(message).await {
|
||||
Ok(_) => receiver.clear_buffer(),
|
||||
Err(e) => {
|
||||
@@ -347,12 +353,7 @@ impl PeripheryConnection {
|
||||
|
||||
let handle_reads = async {
|
||||
loop {
|
||||
let next = tokio::select! {
|
||||
next = ws_read.recv() => next,
|
||||
_ = cancel.cancelled() => break,
|
||||
};
|
||||
|
||||
match next {
|
||||
match ws_read.recv().await {
|
||||
Ok(WebsocketMessage::Message(message)) => {
|
||||
self.handle_incoming_message(message).await
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::{str::FromStr, time::Duration};
|
||||
use std::str::FromStr;
|
||||
|
||||
use anyhow::{Context, anyhow};
|
||||
use axum::{
|
||||
@@ -21,8 +21,8 @@ use serror::{AddStatusCode, AddStatusCodeError};
|
||||
use transport::{
|
||||
PeripheryConnectionQuery,
|
||||
auth::{
|
||||
HeaderConnectionIdentifiers, LoginFlow, LoginFlowArgs,
|
||||
PublicKeyValidator, ServerLoginFlow,
|
||||
AUTH_TIMEOUT, HeaderConnectionIdentifiers, LoginFlow,
|
||||
LoginFlowArgs, PublicKeyValidator, ServerLoginFlow,
|
||||
},
|
||||
message::MessageState,
|
||||
websocket::{Websocket, axum::AxumWebsocket},
|
||||
@@ -176,7 +176,7 @@ async fn onboard_server_handler(
|
||||
// Post onboarding login 1: Receive public key
|
||||
let public_key = socket
|
||||
.recv_result()
|
||||
.with_timeout(Duration::from_secs(2))
|
||||
.with_timeout(AUTH_TIMEOUT)
|
||||
.await
|
||||
.and_then(|bytes| {
|
||||
String::from_utf8(bytes.into())
|
||||
|
||||
@@ -54,7 +54,6 @@ pub async fn get_builder_periphery(
|
||||
&config,
|
||||
),
|
||||
config.insecure_tls,
|
||||
&config.passkey,
|
||||
)
|
||||
.await?;
|
||||
periphery
|
||||
@@ -116,7 +115,6 @@ async fn get_aws_builder(
|
||||
&config,
|
||||
),
|
||||
config.insecure_tls,
|
||||
"",
|
||||
)
|
||||
.await?;
|
||||
|
||||
|
||||
@@ -195,7 +195,6 @@ pub async fn periphery_client(
|
||||
PeripheryClient::new(
|
||||
PeripheryConnectionArgs::from_server(server),
|
||||
server.config.insecure_tls,
|
||||
&server.config.passkey,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -33,8 +33,6 @@ impl PeripheryClient {
|
||||
pub async fn new(
|
||||
args: PeripheryConnectionArgs<'_>,
|
||||
insecure_tls: bool,
|
||||
// deprecated.
|
||||
passkey: &str,
|
||||
) -> anyhow::Result<PeripheryClient> {
|
||||
let connections = periphery_connections();
|
||||
|
||||
@@ -49,7 +47,6 @@ impl PeripheryClient {
|
||||
.spawn_client_connection(
|
||||
id.clone(),
|
||||
insecure_tls,
|
||||
passkey.to_string(),
|
||||
)
|
||||
.await?;
|
||||
return Ok(PeripheryClient { id, channels });
|
||||
@@ -83,7 +80,6 @@ impl PeripheryClient {
|
||||
.spawn_client_connection(
|
||||
id.clone(),
|
||||
insecure_tls,
|
||||
passkey.to_string(),
|
||||
)
|
||||
.await?;
|
||||
Ok(PeripheryClient { id, channels })
|
||||
|
||||
@@ -191,6 +191,8 @@ async fn forward_ws_channel(
|
||||
let (mut core_send, mut core_receive) = client_socket.split();
|
||||
let cancel = CancellationToken::new();
|
||||
|
||||
periphery_receiver.set_cancel(cancel.clone());
|
||||
|
||||
trace!("starting ws exchange");
|
||||
|
||||
let core_to_periphery = async {
|
||||
@@ -253,15 +255,13 @@ async fn forward_ws_channel(
|
||||
|
||||
let periphery_to_core = async {
|
||||
loop {
|
||||
let res = tokio::select! {
|
||||
res = periphery_receiver.recv() => res.map(TransportMessage::into_data),
|
||||
_ = cancel.cancelled() => {
|
||||
trace!("periphery to core read: cancelled from inside");
|
||||
break;
|
||||
}
|
||||
};
|
||||
// Already adheres to cancellation token
|
||||
let res = periphery_receiver
|
||||
.recv()
|
||||
.await
|
||||
.and_then(TransportMessage::into_data);
|
||||
match res {
|
||||
Some(Ok(bytes)) => {
|
||||
Ok(bytes) => {
|
||||
if let Err(e) = core_send.send(Message::Binary(bytes)).await
|
||||
{
|
||||
debug!("{e:?}");
|
||||
@@ -269,9 +269,7 @@ async fn forward_ws_channel(
|
||||
break;
|
||||
};
|
||||
}
|
||||
// No data, ignore
|
||||
Some(Err(_e)) => {}
|
||||
None => {
|
||||
Err(_) => {
|
||||
let _ = core_send.send(Message::text("STREAM EOF")).await;
|
||||
cancel.cancel();
|
||||
break;
|
||||
|
||||
@@ -8,7 +8,6 @@ use bytes::Bytes;
|
||||
use cache::CloneCache;
|
||||
use resolver_api::Resolve;
|
||||
use response::JsonBytes;
|
||||
use serror::serialize_error_bytes;
|
||||
use transport::{
|
||||
auth::{
|
||||
ConnectionIdentifiers, LoginFlow, LoginFlowArgs,
|
||||
@@ -16,10 +15,7 @@ use transport::{
|
||||
},
|
||||
channel::{BufferedChannel, BufferedReceiver, Sender},
|
||||
message::{Message, MessageState},
|
||||
websocket::{
|
||||
Websocket, WebsocketMessage, WebsocketReceiver,
|
||||
WebsocketSender as _,
|
||||
},
|
||||
websocket::{Websocket, WebsocketReceiver, WebsocketSender as _},
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -99,15 +95,15 @@ async fn handle_socket<W: Websocket>(
|
||||
|
||||
let forward_writes = async {
|
||||
loop {
|
||||
let msg = match receiver.recv().await {
|
||||
// Sender Dropped (shouldn't happen, it is static).
|
||||
None => break,
|
||||
// This has to copy the bytes to follow ownership rules.
|
||||
Some(msg) => msg,
|
||||
let message = match receiver.recv().await {
|
||||
Ok(message) => message,
|
||||
Err(e) => {
|
||||
warn!("{e:#}");
|
||||
break;
|
||||
}
|
||||
};
|
||||
match ws_write.send(msg.to_message()).await {
|
||||
match ws_write.send(message).await {
|
||||
// Clears the stored message from receiver buffer.
|
||||
// TODO: Move after response ack.
|
||||
Ok(_) => receiver.clear_buffer(),
|
||||
Err(e) => {
|
||||
warn!("Failed to send response | {e:?}");
|
||||
@@ -120,23 +116,23 @@ async fn handle_socket<W: Websocket>(
|
||||
|
||||
let handle_reads = async {
|
||||
loop {
|
||||
match ws_read.recv().await {
|
||||
Ok(WebsocketMessage::Message(message)) => {
|
||||
handle_incoming_message(args, sender, message).await
|
||||
}
|
||||
Ok(WebsocketMessage::Close(frame)) => {
|
||||
warn!("Connection closed with frame: {frame:?}");
|
||||
break;
|
||||
}
|
||||
Ok(WebsocketMessage::Closed) => {
|
||||
warn!("Connection already closed");
|
||||
break;
|
||||
}
|
||||
let (data, channel, state) = match ws_read.recv_parts().await {
|
||||
Ok(res) => res,
|
||||
Err(e) => {
|
||||
warn!("Failed to read websocket message | {e:?}");
|
||||
warn!("{e:#}");
|
||||
break;
|
||||
}
|
||||
};
|
||||
match state {
|
||||
MessageState::Request => {
|
||||
handle_request(args.clone(), sender.clone(), channel, data)
|
||||
}
|
||||
MessageState::Terminal => {
|
||||
crate::terminal::handle_message(channel, data).await
|
||||
}
|
||||
// Rest shouldn't be received by Periphery
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -146,32 +142,6 @@ async fn handle_socket<W: Websocket>(
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_incoming_message(
|
||||
args: &Arc<Args>,
|
||||
sender: &Sender,
|
||||
message: Message,
|
||||
) {
|
||||
let (data, channel, state) = match message.into_parts() {
|
||||
Ok(res) => res,
|
||||
Err(e) => {
|
||||
warn!("Failed to parse transport bytes | {e:#}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
match state {
|
||||
MessageState::Request => {
|
||||
handle_request(args.clone(), sender.clone(), channel, data)
|
||||
}
|
||||
MessageState::Terminal => {
|
||||
crate::terminal::handle_message(channel, data).await
|
||||
}
|
||||
// Shouldn't be received by Periphery
|
||||
MessageState::InProgress => {}
|
||||
MessageState::Successful => {}
|
||||
MessageState::Failed => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_request(
|
||||
args: Arc<Args>,
|
||||
sender: Sender,
|
||||
@@ -190,20 +160,14 @@ fn handle_request(
|
||||
};
|
||||
|
||||
let resolve_response = async {
|
||||
let (state, data) = match request.resolve(&args).await {
|
||||
Ok(JsonBytes::Ok(res)) => (MessageState::Successful, res),
|
||||
Ok(JsonBytes::Err(e)) => (
|
||||
MessageState::Failed,
|
||||
serialize_error_bytes(
|
||||
&anyhow::Error::new(e)
|
||||
.context("Failed to serialize response body"),
|
||||
),
|
||||
),
|
||||
Err(e) => {
|
||||
(MessageState::Failed, serialize_error_bytes(&e.error))
|
||||
let message: Message = match request.resolve(&args).await {
|
||||
Ok(JsonBytes::Ok(res)) => {
|
||||
(res, req_id, MessageState::Successful).into()
|
||||
}
|
||||
Ok(JsonBytes::Err(e)) => (&e.into(), req_id).into(),
|
||||
Err(e) => (&e.error, req_id).into(),
|
||||
};
|
||||
if let Err(e) = sender.send((data, req_id, state)).await {
|
||||
if let Err(e) = sender.send(message).await {
|
||||
error!("Failed to send response over channel | {e:?}");
|
||||
}
|
||||
};
|
||||
|
||||
@@ -18,9 +18,7 @@ use axum::{
|
||||
routing::get,
|
||||
};
|
||||
use axum_server::tls_rustls::RustlsConfig;
|
||||
use serror::{
|
||||
AddStatusCode, AddStatusCodeError, serialize_error_bytes,
|
||||
};
|
||||
use serror::{AddStatusCode, AddStatusCodeError};
|
||||
use transport::{
|
||||
CoreConnectionQuery,
|
||||
auth::{
|
||||
@@ -106,6 +104,28 @@ async fn handler(
|
||||
Ok(ws.on_upgrade(|socket| async move {
|
||||
let mut socket = AxumWebsocket(socket);
|
||||
|
||||
// Make sure receiver locked over the login.
|
||||
let mut receiver = match channel.receiver() {
|
||||
Ok(receiver) => receiver,
|
||||
Err(e) => {
|
||||
warn!("Failed to forward connection | {e:#}");
|
||||
|
||||
if let Err(e) = socket
|
||||
.send(&e)
|
||||
.await
|
||||
.context("Failed to send forward failed to client")
|
||||
{
|
||||
// Log additional error
|
||||
warn!("{e:#}");
|
||||
}
|
||||
|
||||
// Close socket
|
||||
let _ = socket.close(None).await;
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let query = format!("core={}", urlencoding::encode(&args.core));
|
||||
|
||||
if let Err(e) =
|
||||
@@ -124,30 +144,6 @@ async fn handler(
|
||||
already_logged_login_error()
|
||||
.store(false, atomic::Ordering::Relaxed);
|
||||
|
||||
let mut receiver = match channel.receiver() {
|
||||
Ok(receiver) => receiver,
|
||||
Err(e) => {
|
||||
warn!("Failed to forward connection | {e:#}");
|
||||
|
||||
let mut bytes = serialize_error_bytes(&e);
|
||||
bytes.push(MessageState::Failed.as_byte());
|
||||
|
||||
if let Err(e) = socket
|
||||
.send(&e)
|
||||
.await
|
||||
.context("Failed to send forward failed to client")
|
||||
{
|
||||
// Log additional error
|
||||
warn!("{e:#}");
|
||||
}
|
||||
|
||||
// Close socket
|
||||
let _ = socket.close(None).await;
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
super::handle_socket(
|
||||
socket,
|
||||
&args,
|
||||
@@ -184,9 +180,11 @@ async fn handle_passkey_login(
|
||||
socket: &mut AxumWebsocket,
|
||||
passkeys: &[String],
|
||||
) -> anyhow::Result<()> {
|
||||
warn!(
|
||||
"Authenticating using Passkeys. Set 'core_public_key' (PERIPHERY_CORE_PUBLIC_KEY) instead to enhance security."
|
||||
);
|
||||
if !already_logged_login_error().load(atomic::Ordering::Relaxed) {
|
||||
warn!(
|
||||
"Authenticating using Passkeys. Set 'core_public_key' (PERIPHERY_CORE_PUBLIC_KEY) instead to enhance security."
|
||||
);
|
||||
};
|
||||
let res = async {
|
||||
// Send login type
|
||||
socket
|
||||
|
||||
@@ -15,6 +15,7 @@ serror.workspace = true
|
||||
tokio-tungstenite.workspace = true
|
||||
pin-project-lite.workspace = true
|
||||
futures-util.workspace = true
|
||||
tokio-util.workspace = true
|
||||
tracing.workspace = true
|
||||
anyhow.workspace = true
|
||||
base64.workspace = true
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
use anyhow::Context;
|
||||
use anyhow::{Context, anyhow};
|
||||
use bytes::Bytes;
|
||||
use futures_util::FutureExt;
|
||||
use tokio::sync::{Mutex, MutexGuard, mpsc};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::message::{BorrowedMessage, Message, MessageState};
|
||||
use crate::{
|
||||
message::{Message, MessageState},
|
||||
timeout::MaybeWithTimeout,
|
||||
};
|
||||
|
||||
const RESPONSE_BUFFER_MAX_LEN: usize = 1_024;
|
||||
|
||||
@@ -37,7 +42,13 @@ impl BufferedChannel {
|
||||
/// Create a channel
|
||||
pub fn channel() -> (Sender, Receiver) {
|
||||
let (sender, receiver) = mpsc::channel(RESPONSE_BUFFER_MAX_LEN);
|
||||
(Sender(sender), Receiver(receiver))
|
||||
(
|
||||
Sender(sender),
|
||||
Receiver {
|
||||
receiver,
|
||||
cancel: None,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a buffered channel
|
||||
@@ -59,25 +70,60 @@ impl Sender {
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Receiver(mpsc::Receiver<Message>);
|
||||
pub struct Receiver {
|
||||
receiver: mpsc::Receiver<Message>,
|
||||
cancel: Option<CancellationToken>,
|
||||
}
|
||||
|
||||
impl Receiver {
|
||||
pub async fn recv(&mut self) -> Option<Message> {
|
||||
self.0.recv().await
|
||||
}
|
||||
|
||||
pub async fn recv_parts(
|
||||
&mut self,
|
||||
) -> anyhow::Result<(Bytes, Uuid, MessageState)> {
|
||||
let message = self.recv().await.context("Channel is broken")?;
|
||||
message.into_parts()
|
||||
pub fn set_cancel(&mut self, cancel: CancellationToken) {
|
||||
self.cancel = Some(cancel);
|
||||
}
|
||||
|
||||
pub fn poll_recv(
|
||||
&mut self,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Option<Message>> {
|
||||
self.0.poll_recv(cx)
|
||||
if let Some(cancel) = &self.cancel
|
||||
&& cancel.is_cancelled()
|
||||
{
|
||||
return std::task::Poll::Ready(None);
|
||||
}
|
||||
self.receiver.poll_recv(cx)
|
||||
}
|
||||
|
||||
pub fn recv(
|
||||
&mut self,
|
||||
) -> MaybeWithTimeout<
|
||||
impl Future<Output = anyhow::Result<Message>> + Send,
|
||||
> {
|
||||
MaybeWithTimeout::new(async {
|
||||
let recv = self
|
||||
.receiver
|
||||
.recv()
|
||||
.map(|res| res.context("Channel is permanently closed"));
|
||||
if let Some(cancel) = &self.cancel {
|
||||
tokio::select! {
|
||||
message = recv => message,
|
||||
_ = cancel.cancelled() => Err(anyhow!("Stream cancelled"))
|
||||
}
|
||||
} else {
|
||||
recv.await
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn recv_parts(
|
||||
&mut self,
|
||||
) -> MaybeWithTimeout<
|
||||
impl Future<Output = anyhow::Result<(Bytes, Uuid, MessageState)>>
|
||||
+ Send,
|
||||
> {
|
||||
MaybeWithTimeout::new(self.recv().map(|res| {
|
||||
res
|
||||
.context("Channel is permanently closed.")
|
||||
.and_then(Message::into_parts)
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,17 +142,41 @@ impl BufferedReceiver {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_cancel(&mut self, cancel: CancellationToken) {
|
||||
self.receiver.set_cancel(cancel);
|
||||
}
|
||||
|
||||
/// - If 'buffer: Some(bytes)':
|
||||
/// - Immediately returns borrow of buffer.
|
||||
/// - Else:
|
||||
/// - Wait for next item.
|
||||
/// - store in buffer.
|
||||
/// - return borrow of buffer.
|
||||
pub async fn recv(&mut self) -> Option<BorrowedMessage<'_>> {
|
||||
if self.buffer.is_none() {
|
||||
self.buffer = Some(self.receiver.recv().await?);
|
||||
}
|
||||
self.buffer.as_ref().map(Message::borrow)
|
||||
pub fn recv(
|
||||
&mut self,
|
||||
) -> MaybeWithTimeout<
|
||||
impl Future<Output = anyhow::Result<Message>> + Send,
|
||||
> {
|
||||
MaybeWithTimeout::new(async {
|
||||
if let Some(buffer) = self.buffer.clone() {
|
||||
Ok(buffer)
|
||||
} else {
|
||||
let message = self.receiver.recv().await?;
|
||||
self.buffer = Some(message.clone());
|
||||
Ok(message)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn recv_parts(
|
||||
&mut self,
|
||||
) -> MaybeWithTimeout<
|
||||
impl Future<Output = anyhow::Result<(Bytes, Uuid, MessageState)>>
|
||||
+ Send,
|
||||
> {
|
||||
MaybeWithTimeout::new(
|
||||
self.recv().map(|res| res.and_then(Message::into_parts)),
|
||||
)
|
||||
}
|
||||
|
||||
/// Clears buffer.
|
||||
|
||||
@@ -4,6 +4,7 @@ use serde::Deserialize;
|
||||
pub mod auth;
|
||||
pub mod channel;
|
||||
pub mod message;
|
||||
pub mod timeout;
|
||||
pub mod websocket;
|
||||
|
||||
pub trait TransportHandler {
|
||||
|
||||
@@ -81,7 +81,7 @@ impl From<&anyhow::Error> for Message {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
#[derive(Default, Clone, Debug)]
|
||||
pub struct Message(Bytes);
|
||||
|
||||
impl Message {
|
||||
@@ -101,71 +101,11 @@ impl Message {
|
||||
self.0
|
||||
}
|
||||
|
||||
pub fn borrow(&self) -> BorrowedMessage<'_> {
|
||||
BorrowedMessage(&self.0)
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.borrow().is_empty()
|
||||
}
|
||||
|
||||
pub fn state(&self) -> anyhow::Result<MessageState> {
|
||||
self.borrow().state()
|
||||
}
|
||||
|
||||
pub fn channel(&self) -> anyhow::Result<Uuid> {
|
||||
self.borrow().channel()
|
||||
}
|
||||
|
||||
pub fn channel_and_state(
|
||||
&self,
|
||||
) -> anyhow::Result<(Uuid, MessageState)> {
|
||||
self.borrow().channel_and_state()
|
||||
}
|
||||
|
||||
pub fn data(&self) -> anyhow::Result<&[u8]> {
|
||||
// Does not work with .borrow() due to lifetime issue
|
||||
let len = self.0.len();
|
||||
if len < 17 {
|
||||
return Err(anyhow!(
|
||||
"Transport bytes too short to include uuid + state + data"
|
||||
));
|
||||
}
|
||||
Ok(&self.0[..(len - 17)])
|
||||
}
|
||||
|
||||
pub fn into_data(self) -> anyhow::Result<Bytes> {
|
||||
let len = self.0.len();
|
||||
if len < 17 {
|
||||
return Err(anyhow!(
|
||||
"Transport bytes too short to include uuid + state + data"
|
||||
));
|
||||
}
|
||||
let mut res: Vec<u8> = self.0.into();
|
||||
res.drain((len - 17)..);
|
||||
Ok(res.into())
|
||||
}
|
||||
|
||||
pub fn into_parts(
|
||||
self,
|
||||
) -> anyhow::Result<(Bytes, Uuid, MessageState)> {
|
||||
let (channel, state) = self.channel_and_state()?;
|
||||
let data = self.into_data()?;
|
||||
Ok((data, channel, state))
|
||||
}
|
||||
}
|
||||
|
||||
// Borrowed Message
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct BorrowedMessage<'a>(&'a [u8]);
|
||||
|
||||
impl BorrowedMessage<'_> {
|
||||
pub fn is_empty(self) -> bool {
|
||||
self.0.is_empty()
|
||||
}
|
||||
|
||||
pub fn state(self) -> anyhow::Result<MessageState> {
|
||||
pub fn state(&self) -> anyhow::Result<MessageState> {
|
||||
self
|
||||
.0
|
||||
.last()
|
||||
@@ -200,6 +140,7 @@ impl BorrowedMessage<'_> {
|
||||
}
|
||||
|
||||
pub fn data(&self) -> anyhow::Result<&[u8]> {
|
||||
// Does not work with .borrow() due to lifetime issue
|
||||
let len = self.0.len();
|
||||
if len < 17 {
|
||||
return Err(anyhow!(
|
||||
@@ -209,20 +150,27 @@ impl BorrowedMessage<'_> {
|
||||
Ok(&self.0[..(len - 17)])
|
||||
}
|
||||
|
||||
/// This will clone the bytes
|
||||
pub fn to_message(self) -> Message {
|
||||
Message(Bytes::copy_from_slice(self.0))
|
||||
pub fn into_data(self) -> anyhow::Result<Bytes> {
|
||||
let len = self.0.len();
|
||||
if len < 17 {
|
||||
return Err(anyhow!(
|
||||
"Transport bytes too short to include uuid + state + data"
|
||||
));
|
||||
}
|
||||
let mut res: Vec<u8> = self.0.into();
|
||||
res.drain((len - 17)..);
|
||||
Ok(res.into())
|
||||
}
|
||||
|
||||
pub fn into_parts(
|
||||
self,
|
||||
) -> anyhow::Result<(Bytes, Uuid, MessageState)> {
|
||||
let (channel, state) = self.channel_and_state()?;
|
||||
let data = self.into_data()?;
|
||||
Ok((data, channel, state))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> BorrowedMessage<'a> {
|
||||
pub fn into_inner(self) -> &'a [u8] {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
// Message State
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum MessageState {
|
||||
Successful = 0,
|
||||
|
||||
48
lib/transport/src/timeout.rs
Normal file
48
lib/transport/src/timeout.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context as _;
|
||||
use futures_util::FutureExt as _;
|
||||
use pin_project_lite::pin_project;
|
||||
|
||||
pin_project! {
|
||||
pub struct MaybeWithTimeout<F> {
|
||||
#[pin]
|
||||
inner: F,
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> MaybeWithTimeout<F> {
|
||||
pub fn new(inner: F) -> MaybeWithTimeout<F> {
|
||||
MaybeWithTimeout { inner }
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Future> Future for MaybeWithTimeout<F> {
|
||||
type Output = F::Output;
|
||||
fn poll(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Self::Output> {
|
||||
let mut inner = self.project().inner;
|
||||
inner.as_mut().poll(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
O,
|
||||
E: Into<anyhow::Error>,
|
||||
F: Future<Output = Result<O, E>> + Send,
|
||||
> MaybeWithTimeout<F>
|
||||
{
|
||||
pub fn with_timeout(
|
||||
self,
|
||||
timeout: Duration,
|
||||
) -> impl Future<Output = anyhow::Result<O>> + Send {
|
||||
tokio::time::timeout(timeout, self.inner).map(|res| {
|
||||
res
|
||||
.context("Timed out waiting for message.")
|
||||
.map(|inner| inner.map_err(Into::into))
|
||||
.flatten()
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,75 +1,98 @@
|
||||
use anyhow::{Context, anyhow};
|
||||
use axum::extract::ws::CloseFrame;
|
||||
use futures_util::{
|
||||
SinkExt, Stream, StreamExt, TryStreamExt,
|
||||
stream::{SplitSink, SplitStream},
|
||||
};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::message::Message;
|
||||
use crate::{message::Message, timeout::MaybeWithTimeout};
|
||||
|
||||
use super::{
|
||||
MaybeWithTimeout, Websocket, WebsocketMessage, WebsocketReceiver,
|
||||
WebsocketSender,
|
||||
Websocket, WebsocketMessage, WebsocketReceiver, WebsocketSender,
|
||||
};
|
||||
|
||||
pub struct AxumWebsocket(pub axum::extract::ws::WebSocket);
|
||||
|
||||
impl Websocket for AxumWebsocket {
|
||||
type CloseFrame = CloseFrame;
|
||||
type Error = axum::Error;
|
||||
|
||||
fn split(self) -> (impl WebsocketSender, impl WebsocketReceiver) {
|
||||
let (tx, rx) = self.0.split();
|
||||
(AxumWebsocketSender(tx), AxumWebsocketReceiver(rx))
|
||||
(AxumWebsocketSender(tx), AxumWebsocketReceiver::new(rx))
|
||||
}
|
||||
|
||||
fn recv(
|
||||
&mut self,
|
||||
) -> MaybeWithTimeout<
|
||||
impl Future<
|
||||
Output = Result<
|
||||
WebsocketMessage<Self::CloseFrame>,
|
||||
Self::Error,
|
||||
>,
|
||||
Output = anyhow::Result<WebsocketMessage<Self::CloseFrame>>,
|
||||
>,
|
||||
> {
|
||||
MaybeWithTimeout {
|
||||
inner: try_next(&mut self.0),
|
||||
}
|
||||
MaybeWithTimeout::new(try_next(&mut self.0))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&mut self,
|
||||
message: impl Into<Message>,
|
||||
) -> Result<(), Self::Error> {
|
||||
) -> anyhow::Result<()> {
|
||||
self
|
||||
.0
|
||||
.send(axum::extract::ws::Message::Binary(
|
||||
message.into().into_inner(),
|
||||
))
|
||||
.await
|
||||
.context("Failed to send message bytes over websocket")
|
||||
}
|
||||
|
||||
async fn close(
|
||||
&mut self,
|
||||
frame: Option<Self::CloseFrame>,
|
||||
) -> Result<(), Self::Error> {
|
||||
self.0.send(axum::extract::ws::Message::Close(frame)).await
|
||||
) -> anyhow::Result<()> {
|
||||
self
|
||||
.0
|
||||
.send(axum::extract::ws::Message::Close(frame))
|
||||
.await
|
||||
.context("Failed to send websocket close frame")
|
||||
}
|
||||
}
|
||||
|
||||
pub type InnerWebsocketReceiver =
|
||||
SplitStream<axum::extract::ws::WebSocket>;
|
||||
|
||||
pub struct AxumWebsocketReceiver(pub InnerWebsocketReceiver);
|
||||
pub struct AxumWebsocketReceiver {
|
||||
receiver: InnerWebsocketReceiver,
|
||||
cancel: Option<CancellationToken>,
|
||||
}
|
||||
|
||||
impl AxumWebsocketReceiver {
|
||||
pub fn new(receiver: InnerWebsocketReceiver) -> Self {
|
||||
Self {
|
||||
receiver,
|
||||
cancel: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl WebsocketReceiver for AxumWebsocketReceiver {
|
||||
type CloseFrame = CloseFrame;
|
||||
type Error = axum::Error;
|
||||
|
||||
fn set_cancel(&mut self, cancel: CancellationToken) {
|
||||
self.cancel = Some(cancel);
|
||||
}
|
||||
|
||||
async fn recv(
|
||||
&mut self,
|
||||
) -> Result<WebsocketMessage<Self::CloseFrame>, Self::Error> {
|
||||
try_next(&mut self.0).await
|
||||
) -> anyhow::Result<WebsocketMessage<Self::CloseFrame>> {
|
||||
let fut = try_next(&mut self.receiver);
|
||||
if let Some(cancel) = &self.cancel {
|
||||
tokio::select! {
|
||||
res = fut => res,
|
||||
_ = cancel.cancelled() => Err(anyhow!("Cancelled before receive"))
|
||||
}
|
||||
} else {
|
||||
fut.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,29 +103,30 @@ pub struct AxumWebsocketSender(pub InnerWebsocketSender);
|
||||
|
||||
impl WebsocketSender for AxumWebsocketSender {
|
||||
type CloseFrame = CloseFrame;
|
||||
type Error = axum::Error;
|
||||
|
||||
async fn send(
|
||||
&mut self,
|
||||
message: Message,
|
||||
) -> Result<(), Self::Error> {
|
||||
async fn send(&mut self, message: Message) -> anyhow::Result<()> {
|
||||
self
|
||||
.0
|
||||
.send(axum::extract::ws::Message::Binary(message.into_inner()))
|
||||
.await
|
||||
.context("Failed to send message over websocket")
|
||||
}
|
||||
|
||||
async fn close(
|
||||
&mut self,
|
||||
frame: Option<Self::CloseFrame>,
|
||||
) -> Result<(), Self::Error> {
|
||||
self.0.send(axum::extract::ws::Message::Close(frame)).await
|
||||
) -> anyhow::Result<()> {
|
||||
self
|
||||
.0
|
||||
.send(axum::extract::ws::Message::Close(frame))
|
||||
.await
|
||||
.context("Failed to send websocket close frame")
|
||||
}
|
||||
}
|
||||
|
||||
async fn try_next<S>(
|
||||
stream: &mut S,
|
||||
) -> Result<WebsocketMessage<CloseFrame>, axum::Error>
|
||||
) -> anyhow::Result<WebsocketMessage<CloseFrame>>
|
||||
where
|
||||
S: Stream<Item = Result<axum::extract::ws::Message, axum::Error>>
|
||||
+ Unpin,
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
//! Wrappers to normalize behavior of websockets between Tungstenite and Axum,
|
||||
//! as well as streamline process of handling socket messages.
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{Context, anyhow};
|
||||
use bytes::Bytes;
|
||||
use futures_util::FutureExt;
|
||||
use pin_project_lite::pin_project;
|
||||
use serror::deserialize_error_bytes;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::message::{Message, MessageState};
|
||||
use crate::{
|
||||
message::{Message, MessageState},
|
||||
timeout::MaybeWithTimeout,
|
||||
};
|
||||
|
||||
pub mod axum;
|
||||
pub mod tungstenite;
|
||||
@@ -29,7 +30,6 @@ pub enum WebsocketMessage<CloseFrame> {
|
||||
/// Standard traits for websocket
|
||||
pub trait Websocket: Send {
|
||||
type CloseFrame: std::fmt::Debug + Send + Sync + 'static;
|
||||
type Error: std::error::Error + Send + Sync + 'static;
|
||||
|
||||
/// Abstraction over websocket splitting
|
||||
fn split(self) -> (impl WebsocketSender, impl WebsocketReceiver);
|
||||
@@ -40,23 +40,20 @@ pub trait Websocket: Send {
|
||||
&mut self,
|
||||
) -> MaybeWithTimeout<
|
||||
impl Future<
|
||||
Output = Result<
|
||||
WebsocketMessage<Self::CloseFrame>,
|
||||
Self::Error,
|
||||
>,
|
||||
Output = anyhow::Result<WebsocketMessage<Self::CloseFrame>>,
|
||||
> + Send,
|
||||
>;
|
||||
|
||||
fn send(
|
||||
&mut self,
|
||||
message: impl Into<Message>,
|
||||
) -> impl Future<Output = Result<(), Self::Error>>;
|
||||
) -> impl Future<Output = anyhow::Result<()>>;
|
||||
|
||||
/// Send close message
|
||||
fn close(
|
||||
&mut self,
|
||||
frame: Option<Self::CloseFrame>,
|
||||
) -> impl Future<Output = Result<(), Self::Error>>;
|
||||
) -> impl Future<Output = anyhow::Result<()>>;
|
||||
|
||||
/// Looping receiver for websocket messages which only returns on messages.
|
||||
fn recv_message(
|
||||
@@ -64,19 +61,17 @@ pub trait Websocket: Send {
|
||||
) -> MaybeWithTimeout<
|
||||
impl Future<Output = anyhow::Result<Message>> + Send,
|
||||
> {
|
||||
MaybeWithTimeout {
|
||||
inner: async {
|
||||
match self.recv().await? {
|
||||
WebsocketMessage::Message(message) => Ok(message),
|
||||
WebsocketMessage::Close(frame) => {
|
||||
Err(anyhow!("Connection closed with framed: {frame:?}"))
|
||||
}
|
||||
WebsocketMessage::Closed => {
|
||||
Err(anyhow!("Connection already closed"))
|
||||
}
|
||||
MaybeWithTimeout::new(async {
|
||||
match self.recv().await? {
|
||||
WebsocketMessage::Message(message) => Ok(message),
|
||||
WebsocketMessage::Close(frame) => {
|
||||
Err(anyhow!("Connection closed with framed: {frame:?}"))
|
||||
}
|
||||
},
|
||||
}
|
||||
WebsocketMessage::Closed => {
|
||||
Err(anyhow!("Connection already closed"))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Receive message + message.into_parts
|
||||
@@ -86,11 +81,11 @@ pub trait Websocket: Send {
|
||||
impl Future<Output = anyhow::Result<(Bytes, Uuid, MessageState)>>
|
||||
+ Send,
|
||||
> {
|
||||
MaybeWithTimeout {
|
||||
inner: self
|
||||
MaybeWithTimeout::new(
|
||||
self
|
||||
.recv_message()
|
||||
.map(|res| res.map(|message| message.into_parts()).flatten()),
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
/// Auto deserializes non-successful message errors.
|
||||
@@ -100,30 +95,30 @@ pub trait Websocket: Send {
|
||||
) -> MaybeWithTimeout<
|
||||
impl Future<Output = anyhow::Result<Bytes>> + Send,
|
||||
> {
|
||||
MaybeWithTimeout {
|
||||
inner: self.recv_parts().map(|res| {
|
||||
res
|
||||
.map(|(data, _, state)| match state {
|
||||
MessageState::Successful => Ok(data),
|
||||
_ => Err(deserialize_error_bytes(&data)),
|
||||
})
|
||||
.flatten()
|
||||
}),
|
||||
}
|
||||
MaybeWithTimeout::new(self.recv_parts().map(|res| {
|
||||
res
|
||||
.map(|(data, _, state)| match state {
|
||||
MessageState::Successful => Ok(data),
|
||||
_ => Err(deserialize_error_bytes(&data)),
|
||||
})
|
||||
.flatten()
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
/// Traits for split websocket receiver
|
||||
pub trait WebsocketReceiver: Send {
|
||||
type CloseFrame: std::fmt::Debug + Send + Sync + 'static;
|
||||
type Error: std::error::Error + Send + Sync + 'static;
|
||||
|
||||
/// Cancellation sensitive receive.
|
||||
fn set_cancel(&mut self, _cancel: CancellationToken);
|
||||
|
||||
/// Looping receiver for websocket messages which only returns
|
||||
/// on significant messages.
|
||||
/// on significant messages. Must implement cancel support.
|
||||
fn recv(
|
||||
&mut self,
|
||||
) -> impl Future<
|
||||
Output = Result<WebsocketMessage<Self::CloseFrame>, Self::Error>,
|
||||
Output = anyhow::Result<WebsocketMessage<Self::CloseFrame>>,
|
||||
> + Send;
|
||||
|
||||
/// Looping receiver for websocket messages which only returns on messages.
|
||||
@@ -132,19 +127,21 @@ pub trait WebsocketReceiver: Send {
|
||||
) -> MaybeWithTimeout<
|
||||
impl Future<Output = anyhow::Result<Message>> + Send,
|
||||
> {
|
||||
MaybeWithTimeout {
|
||||
inner: async {
|
||||
match self.recv().await? {
|
||||
WebsocketMessage::Message(message) => Ok(message),
|
||||
WebsocketMessage::Close(frame) => {
|
||||
Err(anyhow!("Connection closed with framed: {frame:?}"))
|
||||
}
|
||||
WebsocketMessage::Closed => {
|
||||
Err(anyhow!("Connection already closed"))
|
||||
}
|
||||
MaybeWithTimeout::new(async {
|
||||
match self
|
||||
.recv()
|
||||
.await
|
||||
.context("Failed to read websocket message")?
|
||||
{
|
||||
WebsocketMessage::Message(message) => Ok(message),
|
||||
WebsocketMessage::Close(frame) => {
|
||||
Err(anyhow!("Connection closed with framed: {frame:?}"))
|
||||
}
|
||||
},
|
||||
}
|
||||
WebsocketMessage::Closed => {
|
||||
Err(anyhow!("Connection already closed"))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Receive message + message.into_parts
|
||||
@@ -154,65 +151,27 @@ pub trait WebsocketReceiver: Send {
|
||||
impl Future<Output = anyhow::Result<(Bytes, Uuid, MessageState)>>
|
||||
+ Send,
|
||||
> {
|
||||
MaybeWithTimeout {
|
||||
inner: self
|
||||
MaybeWithTimeout::new(
|
||||
self
|
||||
.recv_message()
|
||||
.map(|res| res.map(|message| message.into_parts()).flatten()),
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Traits for split websocket receiver
|
||||
pub trait WebsocketSender {
|
||||
type CloseFrame: std::fmt::Debug + Send + Sync + 'static;
|
||||
type Error: std::error::Error + Send + Sync + 'static;
|
||||
|
||||
/// Streamlined sending on bytes
|
||||
fn send(
|
||||
&mut self,
|
||||
message: Message,
|
||||
) -> impl Future<Output = Result<(), Self::Error>> + Send;
|
||||
) -> impl Future<Output = anyhow::Result<()>> + Send;
|
||||
|
||||
/// Send close message
|
||||
fn close(
|
||||
&mut self,
|
||||
frame: Option<Self::CloseFrame>,
|
||||
) -> impl Future<Output = Result<(), Self::Error>> + Send;
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
pub struct MaybeWithTimeout<F> {
|
||||
#[pin]
|
||||
inner: F,
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Future> Future for MaybeWithTimeout<F> {
|
||||
type Output = F::Output;
|
||||
fn poll(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Self::Output> {
|
||||
let mut inner = self.project().inner;
|
||||
inner.as_mut().poll(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
O,
|
||||
E: Into<anyhow::Error>,
|
||||
F: Future<Output = Result<O, E>> + Send,
|
||||
> MaybeWithTimeout<F>
|
||||
{
|
||||
pub fn with_timeout(
|
||||
self,
|
||||
timeout: Duration,
|
||||
) -> impl Future<Output = anyhow::Result<O>> + Send {
|
||||
tokio::time::timeout(timeout, self.inner).map(|res| {
|
||||
res
|
||||
.context("Timed out waiting for message.")
|
||||
.map(|inner| inner.map_err(Into::into))
|
||||
.flatten()
|
||||
})
|
||||
}
|
||||
) -> impl Future<Output = anyhow::Result<()>> + Send;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::{Context, anyhow};
|
||||
use axum::http::HeaderValue;
|
||||
use futures_util::{
|
||||
SinkExt, Stream, StreamExt, TryStreamExt,
|
||||
@@ -15,12 +15,12 @@ use tokio_tungstenite::{
|
||||
self, handshake::client::Response, protocol::CloseFrame,
|
||||
},
|
||||
};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::message::Message;
|
||||
use crate::{message::Message, timeout::MaybeWithTimeout};
|
||||
|
||||
use super::{
|
||||
MaybeWithTimeout, Websocket, WebsocketMessage, WebsocketReceiver,
|
||||
WebsocketSender,
|
||||
Websocket, WebsocketMessage, WebsocketReceiver, WebsocketSender,
|
||||
};
|
||||
|
||||
pub type InnerWebsocket = WebSocketStream<MaybeTlsStream<TcpStream>>;
|
||||
@@ -29,13 +29,12 @@ pub struct TungsteniteWebsocket(pub InnerWebsocket);
|
||||
|
||||
impl Websocket for TungsteniteWebsocket {
|
||||
type CloseFrame = CloseFrame;
|
||||
type Error = tungstenite::Error;
|
||||
|
||||
fn split(self) -> (impl WebsocketSender, impl WebsocketReceiver) {
|
||||
let (tx, rx) = self.0.split();
|
||||
(
|
||||
TungsteniteWebsocketSender(tx),
|
||||
TungsteniteWebsocketReceiver(rx),
|
||||
TungsteniteWebsocketReceiver::new(rx),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -43,48 +42,71 @@ impl Websocket for TungsteniteWebsocket {
|
||||
&mut self,
|
||||
) -> MaybeWithTimeout<
|
||||
impl Future<
|
||||
Output = Result<
|
||||
WebsocketMessage<Self::CloseFrame>,
|
||||
Self::Error,
|
||||
>,
|
||||
Output = anyhow::Result<WebsocketMessage<Self::CloseFrame>>,
|
||||
>,
|
||||
> {
|
||||
MaybeWithTimeout {
|
||||
inner: try_next(&mut self.0),
|
||||
}
|
||||
MaybeWithTimeout::new(try_next(&mut self.0))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&mut self,
|
||||
message: impl Into<Message>,
|
||||
) -> Result<(), Self::Error> {
|
||||
) -> anyhow::Result<()> {
|
||||
self
|
||||
.0
|
||||
.send(tungstenite::Message::Binary(message.into().into_inner()))
|
||||
.await
|
||||
.context("Failed to send message over websocket")
|
||||
}
|
||||
|
||||
async fn close(
|
||||
&mut self,
|
||||
frame: Option<Self::CloseFrame>,
|
||||
) -> Result<(), Self::Error> {
|
||||
self.0.close(frame).await
|
||||
) -> anyhow::Result<()> {
|
||||
self
|
||||
.0
|
||||
.close(frame)
|
||||
.await
|
||||
.context("Failed to send websocket close frame")
|
||||
}
|
||||
}
|
||||
|
||||
pub type InnerWebsocketReceiver =
|
||||
SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
|
||||
|
||||
pub struct TungsteniteWebsocketReceiver(pub InnerWebsocketReceiver);
|
||||
pub struct TungsteniteWebsocketReceiver {
|
||||
receiver: InnerWebsocketReceiver,
|
||||
cancel: Option<CancellationToken>,
|
||||
}
|
||||
|
||||
impl TungsteniteWebsocketReceiver {
|
||||
pub fn new(receiver: InnerWebsocketReceiver) -> Self {
|
||||
Self {
|
||||
receiver,
|
||||
cancel: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl WebsocketReceiver for TungsteniteWebsocketReceiver {
|
||||
type CloseFrame = CloseFrame;
|
||||
type Error = tungstenite::Error;
|
||||
|
||||
fn set_cancel(&mut self, cancel: CancellationToken) {
|
||||
self.cancel = Some(cancel);
|
||||
}
|
||||
|
||||
async fn recv(
|
||||
&mut self,
|
||||
) -> Result<WebsocketMessage<Self::CloseFrame>, Self::Error> {
|
||||
try_next(&mut self.0).await
|
||||
) -> anyhow::Result<WebsocketMessage<Self::CloseFrame>> {
|
||||
let fut = try_next(&mut self.receiver);
|
||||
if let Some(cancel) = &self.cancel {
|
||||
tokio::select! {
|
||||
res = fut => res,
|
||||
_ = cancel.cancelled() => Err(anyhow!("Cancelled before receive"))
|
||||
}
|
||||
} else {
|
||||
fut.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,29 +119,30 @@ pub struct TungsteniteWebsocketSender(pub InnerWebsocketSender);
|
||||
|
||||
impl WebsocketSender for TungsteniteWebsocketSender {
|
||||
type CloseFrame = CloseFrame;
|
||||
type Error = tungstenite::Error;
|
||||
|
||||
async fn send(
|
||||
&mut self,
|
||||
message: Message,
|
||||
) -> Result<(), Self::Error> {
|
||||
async fn send(&mut self, message: Message) -> anyhow::Result<()> {
|
||||
self
|
||||
.0
|
||||
.send(tungstenite::Message::Binary(message.into_inner()))
|
||||
.await
|
||||
.context("Failed to send message over websocket")
|
||||
}
|
||||
|
||||
async fn close(
|
||||
&mut self,
|
||||
frame: Option<Self::CloseFrame>,
|
||||
) -> Result<(), Self::Error> {
|
||||
self.0.send(tungstenite::Message::Close(frame)).await
|
||||
) -> anyhow::Result<()> {
|
||||
self
|
||||
.0
|
||||
.send(tungstenite::Message::Close(frame))
|
||||
.await
|
||||
.context("Failed to send websocket close frame")
|
||||
}
|
||||
}
|
||||
|
||||
async fn try_next<S>(
|
||||
stream: &mut S,
|
||||
) -> Result<WebsocketMessage<CloseFrame>, tungstenite::Error>
|
||||
) -> anyhow::Result<WebsocketMessage<CloseFrame>>
|
||||
where
|
||||
S: Stream<Item = Result<tungstenite::Message, tungstenite::Error>>
|
||||
+ Unpin,
|
||||
|
||||
Reference in New Issue
Block a user