mirror of
https://github.com/moghtech/komodo.git
synced 2026-03-11 17:44:19 -05:00
353 lines
9.1 KiB
Rust
353 lines
9.1 KiB
Rust
use std::{
|
|
sync::{
|
|
Arc,
|
|
atomic::{self, AtomicBool},
|
|
},
|
|
time::Duration,
|
|
};
|
|
|
|
use anyhow::{Context, anyhow};
|
|
use bytes::Bytes;
|
|
use cache::CloneCache;
|
|
use serror::serror_into_anyhow_error;
|
|
use tokio::sync::{
|
|
RwLock,
|
|
mpsc::{Sender, error::SendError},
|
|
};
|
|
use tokio_util::sync::CancellationToken;
|
|
use transport::{
|
|
auth::{ConnectionIdentifiers, LoginFlow, PublicKeyValidator},
|
|
bytes::id_from_transport_bytes,
|
|
channel::{BufferedReceiver, buffered_channel},
|
|
websocket::{
|
|
Websocket, WebsocketMessage, WebsocketReceiver as _,
|
|
WebsocketSender as _,
|
|
},
|
|
};
|
|
|
|
use crate::{config::core_config, periphery::ConnectionChannels};
|
|
|
|
pub mod client;
|
|
pub mod server;
|
|
|
|
pub struct WebsocketHandler<'a, W> {
|
|
pub socket: W,
|
|
pub connection_identifiers: ConnectionIdentifiers<'a>,
|
|
pub write_receiver: &'a mut BufferedReceiver<Bytes>,
|
|
pub connection: &'a PeripheryConnection,
|
|
}
|
|
|
|
impl<W: Websocket> WebsocketHandler<'_, W> {
|
|
async fn handle<L: LoginFlow>(self) -> anyhow::Result<()> {
|
|
let WebsocketHandler {
|
|
mut socket,
|
|
connection_identifiers,
|
|
write_receiver,
|
|
connection,
|
|
} = self;
|
|
|
|
let private_key = if connection.private_key.is_empty() {
|
|
&core_config().private_key
|
|
} else {
|
|
&connection.private_key
|
|
};
|
|
|
|
let expected_public_key = if !connection
|
|
.expected_public_key
|
|
.is_empty()
|
|
{
|
|
Some(connection.expected_public_key.as_str())
|
|
} else if connection.address.is_empty() {
|
|
// Only force periphery public key for Periphery -> Core connections
|
|
Some(
|
|
core_config()
|
|
.periphery_public_key
|
|
.as_deref()
|
|
.context("Must either configure Server 'Periphery Public Key' or set KOMODO_PERIPHERY_PUBLIC_KEY")?
|
|
)
|
|
} else {
|
|
None
|
|
};
|
|
|
|
L::login(
|
|
&mut socket,
|
|
connection_identifiers,
|
|
private_key,
|
|
&PeripheryPublicKeyValidator {
|
|
expected: expected_public_key,
|
|
},
|
|
)
|
|
.await?;
|
|
|
|
let handler_cancel = CancellationToken::new();
|
|
|
|
connection.set_connected(true);
|
|
connection.clear_error().await;
|
|
|
|
let (mut ws_write, mut ws_read) = socket.split();
|
|
|
|
let forward_writes = async {
|
|
loop {
|
|
let next = tokio::select! {
|
|
next = write_receiver.recv() => next,
|
|
_ = connection.cancel.cancelled() => break,
|
|
_ = handler_cancel.cancelled() => break,
|
|
};
|
|
|
|
let message = match next {
|
|
Some(request) => Bytes::copy_from_slice(request),
|
|
// Sender Dropped (shouldn't happen, a reference is held on 'connection').
|
|
None => break,
|
|
};
|
|
|
|
match ws_write.send(message).await {
|
|
Ok(_) => write_receiver.clear_buffer(),
|
|
Err(e) => {
|
|
connection.set_error(e.into()).await;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
// Cancel again if not already
|
|
let _ = ws_write.close(None).await;
|
|
handler_cancel.cancel();
|
|
};
|
|
|
|
let handle_reads = async {
|
|
loop {
|
|
let next = tokio::select! {
|
|
next = ws_read.recv() => next,
|
|
_ = connection.cancel.cancelled() => break,
|
|
_ = handler_cancel.cancelled() => break,
|
|
};
|
|
|
|
match next {
|
|
Ok(WebsocketMessage::Binary(bytes)) => {
|
|
connection.handle_incoming_bytes(bytes).await
|
|
}
|
|
Ok(WebsocketMessage::Close(_))
|
|
| Ok(WebsocketMessage::Closed) => {
|
|
connection.set_error(anyhow!("Connection closed")).await;
|
|
break;
|
|
}
|
|
Err(e) => {
|
|
connection.set_error(e.into()).await;
|
|
}
|
|
};
|
|
}
|
|
// Cancel again if not already
|
|
handler_cancel.cancel();
|
|
};
|
|
|
|
tokio::join!(forward_writes, handle_reads);
|
|
|
|
connection.set_connected(false);
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
pub struct PeripheryPublicKeyValidator<'a> {
|
|
/// If None, ignore public key.
|
|
pub expected: Option<&'a str>,
|
|
}
|
|
impl PublicKeyValidator for PeripheryPublicKeyValidator<'_> {
|
|
fn validate(&self, public_key: String) -> anyhow::Result<()> {
|
|
if let Some(expected) = self.expected
|
|
&& public_key != expected
|
|
{
|
|
Err(
|
|
anyhow!("Got invalid public key: {public_key}")
|
|
.context("Ensure public key matches configured Periphery Public Key")
|
|
.context("Core failed to validate Periphery public key"),
|
|
)
|
|
} else {
|
|
Ok(())
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Default)]
|
|
pub struct PeripheryConnections(
|
|
CloneCache<String, Arc<PeripheryConnection>>,
|
|
);
|
|
|
|
impl PeripheryConnections {
|
|
pub async fn insert(
|
|
&self,
|
|
server_id: String,
|
|
args: PeripheryConnectionArgs<'_>,
|
|
) -> (Arc<PeripheryConnection>, BufferedReceiver<Bytes>) {
|
|
let channels = if let Some(existing_connection) =
|
|
self.0.remove(&server_id).await
|
|
{
|
|
existing_connection.cancel();
|
|
// Keep the same channels so requests
|
|
// can handle disconnects while processing.
|
|
existing_connection.channels.clone()
|
|
} else {
|
|
Default::default()
|
|
};
|
|
|
|
let (connection, receiver) =
|
|
PeripheryConnection::new(args, channels);
|
|
|
|
self.0.insert(server_id, connection.clone()).await;
|
|
|
|
(connection, receiver)
|
|
}
|
|
|
|
pub async fn get(
|
|
&self,
|
|
server_id: &String,
|
|
) -> Option<Arc<PeripheryConnection>> {
|
|
self.0.get(server_id).await
|
|
}
|
|
|
|
/// Remove and cancel connection
|
|
pub async fn remove(
|
|
&self,
|
|
server_id: &String,
|
|
) -> Option<Arc<PeripheryConnection>> {
|
|
self
|
|
.0
|
|
.remove(server_id)
|
|
.await
|
|
.inspect(|connection| connection.cancel())
|
|
}
|
|
}
|
|
|
|
/// The configurable args of a connection
|
|
#[derive(Clone, Copy)]
|
|
pub struct PeripheryConnectionArgs<'a> {
|
|
pub address: &'a str,
|
|
pub private_key: &'a str,
|
|
pub expected_public_key: &'a str,
|
|
}
|
|
|
|
impl PeripheryConnectionArgs<'_> {
|
|
pub fn matches(&self, connection: &PeripheryConnection) -> bool {
|
|
self.address == connection.address
|
|
&& self.private_key == connection.private_key
|
|
&& self.expected_public_key == connection.expected_public_key
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct PeripheryConnection {
|
|
/// Specify outbound connection address.
|
|
/// Inbound connections have this as empty string
|
|
pub address: String,
|
|
/// The private key to use, or empty for core private key
|
|
pub private_key: String,
|
|
/// The public key to expect Periphery to have.
|
|
/// Required non-empty for inbound connection.
|
|
pub expected_public_key: String,
|
|
/// Whether Periphery is currently connected.
|
|
pub connected: AtomicBool,
|
|
/// Stores latest connection error
|
|
pub error: RwLock<Option<serror::Serror>>,
|
|
/// Cancel the connection
|
|
pub cancel: CancellationToken,
|
|
/// Send bytes to Periphery
|
|
pub sender: Sender<Bytes>,
|
|
/// Send bytes from Periphery to channel handlers.
|
|
/// Must be maintained if new connection replaces old
|
|
/// at the same server id.
|
|
pub channels: Arc<ConnectionChannels>,
|
|
}
|
|
|
|
impl PeripheryConnection {
|
|
pub fn new(
|
|
args: PeripheryConnectionArgs<'_>,
|
|
channels: Arc<ConnectionChannels>,
|
|
) -> (Arc<PeripheryConnection>, BufferedReceiver<Bytes>) {
|
|
let (sender, receiver) = buffered_channel();
|
|
(
|
|
PeripheryConnection {
|
|
address: args.address.to_string(),
|
|
private_key: args.private_key.to_string(),
|
|
expected_public_key: args.expected_public_key.to_string(),
|
|
sender,
|
|
channels,
|
|
connected: AtomicBool::new(false),
|
|
error: RwLock::new(None),
|
|
cancel: CancellationToken::new(),
|
|
}
|
|
.into(),
|
|
receiver,
|
|
)
|
|
}
|
|
|
|
pub async fn handle_incoming_bytes(&self, bytes: Bytes) {
|
|
let id = match id_from_transport_bytes(&bytes) {
|
|
Ok(res) => res,
|
|
Err(e) => {
|
|
// TODO: handle better
|
|
warn!("Failed to read id | {e:#}");
|
|
return;
|
|
}
|
|
};
|
|
let Some(channel) = self.channels.get(&id).await else {
|
|
// TODO: handle better
|
|
warn!("Failed to send response | No response channel found");
|
|
return;
|
|
};
|
|
if let Err(e) = channel.send(bytes).await {
|
|
// TODO: handle better
|
|
warn!("Failed to send response | Channel failure | {e:#}");
|
|
}
|
|
}
|
|
|
|
pub async fn send(
|
|
&self,
|
|
value: Bytes,
|
|
) -> Result<(), SendError<Bytes>> {
|
|
self.sender.send(value).await
|
|
}
|
|
|
|
pub fn set_connected(&self, connected: bool) {
|
|
self.connected.store(connected, atomic::Ordering::Relaxed);
|
|
}
|
|
|
|
pub fn connected(&self) -> bool {
|
|
self.connected.load(atomic::Ordering::Relaxed)
|
|
}
|
|
|
|
/// Polls connected 3 times (500ms in between) before bailing.
|
|
pub async fn bail_if_not_connected(&self) -> anyhow::Result<()> {
|
|
const POLL_TIMES: usize = 3;
|
|
for i in 0..POLL_TIMES {
|
|
if self.connected() {
|
|
return Ok(());
|
|
}
|
|
if i < POLL_TIMES - 1 {
|
|
tokio::time::sleep(Duration::from_millis(500)).await;
|
|
}
|
|
}
|
|
if let Some(e) = self.error().await {
|
|
Err(serror_into_anyhow_error(e))
|
|
} else {
|
|
Err(anyhow!("Server is not currently connected"))
|
|
}
|
|
}
|
|
|
|
pub async fn error(&self) -> Option<serror::Serror> {
|
|
self.error.read().await.clone()
|
|
}
|
|
|
|
pub async fn set_error(&self, e: anyhow::Error) {
|
|
let mut error = self.error.write().await;
|
|
*error = Some(e.into());
|
|
}
|
|
|
|
pub async fn clear_error(&self) {
|
|
let mut error = self.error.write().await;
|
|
*error = None;
|
|
}
|
|
|
|
pub fn cancel(&self) {
|
|
self.cancel.cancel();
|
|
}
|
|
}
|