outbound connection mode working

This commit is contained in:
mbecker20
2025-09-19 20:09:39 -07:00
parent 58c1afb8ef
commit d5de338561
27 changed files with 1231 additions and 925 deletions

View File

@@ -8,15 +8,12 @@ repository.workspace = true
homepage.workspace = true
[dependencies]
serror.workspace = true
#
tokio-tungstenite.workspace = true
futures-util.workspace = true
tokio-util.workspace = true
tracing.workspace = true
anyhow.workspace = true
rustls.workspace = true
bytes.workspace = true
tokio.workspace = true
serde.workspace = true
axum.workspace = true
uuid.workspace = true

96
lib/transport/src/auth.rs Normal file
View File

@@ -0,0 +1,96 @@
use anyhow::{Context, anyhow};
use bytes::Bytes;
use futures_util::{SinkExt, TryStreamExt};
use tokio::net::TcpStream;
use tokio_tungstenite::{
MaybeTlsStream, WebSocketStream, tungstenite,
};
use tracing::{info, warn};
use crate::MessageState;
pub async fn handle_client_side_login(
socket: &mut WebSocketStream<MaybeTlsStream<TcpStream>>,
credentials: Bytes,
) -> anyhow::Result<()> {
socket
.send(tungstenite::Message::Binary(credentials))
.await
.context("Failed to send login credentials")?;
loop {
let response = socket
.try_next()
.await
.context("Failed to receive login response")?
.context("Stream broken before login response received")?;
let bytes = match &response {
tungstenite::Message::Text(text) => text.as_bytes(),
tungstenite::Message::Binary(bytes) => &bytes,
tungstenite::Message::Close(frame) => {
return Err(anyhow!(
"Websocket close frame received during login | frame: {frame:?}"
));
}
// Ignore others
_ => continue,
};
let state = bytes
.first()
.map(|b| MessageState::from_byte(*b))
.context("Login response is empty")?;
if matches!(state, MessageState::Successful) {
return Ok(());
} else {
return Err(anyhow!("Failed to login | Invalid credentails"));
}
}
}
pub async fn handle_server_side_login(
socket: &mut axum::extract::ws::WebSocket,
validate_credentials: impl Fn(&[u8]) -> bool,
) -> anyhow::Result<()> {
loop {
// Poll for next message
let msg = socket
.try_next()
.await
.context("Failed to receive login credentials")?
.context("Stream broken before login credentials received")?;
// Treat first message as credentials
let credentials = match &msg {
axum::extract::ws::Message::Text(text) => text.as_bytes(),
axum::extract::ws::Message::Binary(bytes) => &bytes,
axum::extract::ws::Message::Close(frame) => {
return Err(anyhow!(
"Websocket close frame received during login | frame: {frame:?}"
));
}
// Ignore others
_ => continue,
};
// Validate
if validate_credentials(credentials) {
// Send login confirmation
// TODO: remove / edit logs
info!("Client logged in");
socket
.send(axum::extract::ws::Message::Binary(
MessageState::Successful.into(),
))
.await?;
return Ok(());
} else {
// Send login failure
warn!("Client failed to log in");
socket
.send(axum::extract::ws::Message::Binary(
MessageState::Failed.into(),
))
.await?;
let _ = socket.close().await;
return Err(anyhow!("Received invalid credentials"));
}
}
}

View File

@@ -1,294 +0,0 @@
use std::{
sync::{
Arc,
atomic::{self, AtomicBool},
},
time::Duration,
};
use anyhow::{Context, anyhow};
use bytes::Bytes;
use futures_util::{SinkExt, StreamExt, TryStreamExt};
use rustls::{ClientConfig, client::danger::ServerCertVerifier};
use tokio::{net::TcpStream, sync::RwLock};
use tokio_tungstenite::{
Connector, MaybeTlsStream, WebSocketStream, tungstenite::Message,
};
use tokio_util::sync::CancellationToken;
use tracing::{info, warn};
use crate::{
MessageState, TransportHandler, channel::BufferedReceiver,
};
/// Handles client side / outbound connection
pub async fn handle_client_connection<
T: TransportHandler + Send + Sync + 'static,
>(
address: &str,
connection: &ClientConnection,
transport: &T,
write_receiver: &mut BufferedReceiver<Bytes>,
) {
loop {
let mut socket = match connect_websocket(address).await {
Ok(socket) => socket,
Err(e) => {
connection.set_error(e).await;
tokio::time::sleep(Duration::from_secs(5)).await;
continue;
}
};
info!("Connected to {address}");
if let Err(e) = handle_login(&mut socket, Bytes::new()).await {
connection.set_error(e).await;
tokio::time::sleep(Duration::from_secs(5)).await;
continue;
};
info!("Logged into {address}");
connection.connected.store(true, atomic::Ordering::Relaxed);
connection.clear_error().await;
let (mut ws_write, mut ws_read) = socket.split();
let forward_writes = async {
loop {
let next = tokio::select! {
next = write_receiver.recv() => next,
_ = connection.cancel.cancelled() => break,
};
let message = match next {
None => {
info!("Got None over request reciever for {address}");
break;
}
Some(request) => {
Message::Binary(Bytes::copy_from_slice(request))
}
};
match ws_write.send(message).await {
Ok(_) => write_receiver.clear_buffer(),
Err(e) => {
warn!("Failed to send request to {address} | {e:#}");
break;
}
}
}
// Cancel again if not already
let _ = ws_write.close().await;
connection.cancel();
};
let handle_reads = async {
loop {
let next = tokio::select! {
next = ws_read.next() => next,
_ = connection.cancel.cancelled() => break,
};
match next {
Some(Ok(Message::Binary(bytes))) => {
transport.handle_incoming_bytes(bytes).await
}
Some(Ok(Message::Close(frame))) => {
warn!(
"Connection to {address} broken with frame: {frame:?}"
);
break;
}
Some(Err(e)) => {
warn!("Connection to {address} broken with error: {e:?}");
break;
}
None => {
warn!("Connection to {address} closed");
break;
}
// Can ignore other message types
Some(Ok(_)) => {
continue;
}
};
}
// Cancel again if not already
connection.cancel();
};
tokio::join!(forward_writes, handle_reads);
warn!("Disconnnected from {address}");
connection.connected.store(false, atomic::Ordering::Relaxed);
}
}
async fn handle_login(
socket: &mut WebSocketStream<MaybeTlsStream<TcpStream>>,
credentials: Bytes,
) -> anyhow::Result<()> {
socket
.send(Message::Binary(credentials))
.await
.context("Failed to send login credentials")?;
loop {
let response = socket
.try_next()
.await
.context("Failed to receive login response")?
.context("Stream broken before login response received")?;
let bytes = match &response {
Message::Text(text) => text.as_bytes(),
Message::Binary(bytes) => &bytes,
Message::Close(frame) => {
return Err(anyhow!(
"Websocket close frame received during login | frame: {frame:?}"
));
}
// Ignore others
_ => continue,
};
let state = bytes
.first()
.map(|b| MessageState::from_byte(*b))
.context("Login response is empty")?;
if matches!(state, MessageState::Successful) {
return Ok(());
} else {
return Err(anyhow!("Failed to login | Invalid credentails"));
}
}
}
#[derive(Debug)]
pub struct ClientConnection {
connected: AtomicBool,
error: RwLock<Option<serror::Serror>>,
cancel: CancellationToken,
}
impl ClientConnection {
pub fn new() -> ClientConnection {
ClientConnection {
connected: AtomicBool::new(false),
error: RwLock::new(None),
cancel: CancellationToken::new(),
}
}
pub fn connected(&self) -> bool {
self.connected.load(atomic::Ordering::Relaxed)
}
pub async fn error(&self) -> Option<serror::Serror> {
self.error.read().await.clone()
}
pub async fn set_error(&self, e: anyhow::Error) {
let mut error = self.error.write().await;
*error = Some(e.into());
}
pub async fn clear_error(&self) {
let mut error = self.error.write().await;
*error = None;
}
pub fn cancel(&self) {
self.cancel.cancel();
}
}
pub async fn connect_websocket(
url: &str,
) -> anyhow::Result<WebSocketStream<MaybeTlsStream<TcpStream>>> {
let (stream, _) = if url.starts_with("wss") {
tokio_tungstenite::connect_async_tls_with_config(
url,
None,
false,
Some(Connector::Rustls(Arc::new(
ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(
InsecureVerifier,
))
.with_no_client_auth(),
))),
)
.await
.with_context(|| {
format!("failed to connect to websocket | url: {url}")
})?
} else {
tokio_tungstenite::connect_async(url).await.with_context(
|| format!("failed to connect to websocket | url: {url}"),
)?
};
Ok(stream)
}
#[derive(Debug)]
struct InsecureVerifier;
impl ServerCertVerifier for InsecureVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error>
{
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<
rustls::client::danger::HandshakeSignatureValid,
rustls::Error,
> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<
rustls::client::danger::HandshakeSignatureValid,
rustls::Error,
> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA1,
rustls::SignatureScheme::ECDSA_SHA1_Legacy,
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::RSA_PKCS1_SHA384,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::RSA_PKCS1_SHA512,
rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
rustls::SignatureScheme::RSA_PSS_SHA256,
rustls::SignatureScheme::RSA_PSS_SHA384,
rustls::SignatureScheme::RSA_PSS_SHA512,
rustls::SignatureScheme::ED25519,
rustls::SignatureScheme::ED448,
]
}
}

View File

@@ -1,9 +1,9 @@
use ::bytes::Bytes;
use serde::{Deserialize, Serialize};
pub mod auth;
pub mod bytes;
pub mod channel;
pub mod client;
pub mod server;
#[derive(Debug, Clone, Copy)]
pub enum MessageState {
@@ -26,3 +26,26 @@ pub trait TransportHandler {
bytes: Bytes,
) -> impl Future<Output = ()> + Send;
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct PeripheryConnectionQuery {
/// Server Id or name
pub server: String,
}
/// - Fixes ws addresses:
/// - `server.domain` => `wss://server.domain`
/// - `http://server.domain` => `ws://server.domain`
/// - `https://server.domain` => `wss://server.domain`
pub fn fix_ws_address(address: &str) -> String {
if address.starts_with("ws://") || address.starts_with("wss://") {
return address.to_string();
}
if address.starts_with("http://") {
return address.replace("http://", "ws://");
}
if address.starts_with("https://") {
return address.replace("https://", "wss://");
}
format!("wss://{address}")
}

View File

@@ -1,133 +0,0 @@
use anyhow::{Context, anyhow};
use axum::{
extract::{
WebSocketUpgrade,
ws::{Message, WebSocket},
},
http::StatusCode,
response::Response,
};
use bytes::Bytes;
use futures_util::{SinkExt, StreamExt, TryStreamExt};
use serror::AddStatusCode;
use tokio::sync::Mutex;
use tracing::{error, warn};
use crate::{
MessageState, TransportHandler, channel::BufferedReceiver,
};
/// Handles server side / inbound connection
pub fn handle_server_connection<
T: TransportHandler + Send + Sync + 'static,
>(
ws: WebSocketUpgrade,
transport: T,
write_receiver: &'static Mutex<BufferedReceiver<Bytes>>,
) -> serror::Result<Response> {
// Limits to only one active websocket connection.
let mut write_receiver = write_receiver
.try_lock()
.status_code(StatusCode::FORBIDDEN)?;
Ok(ws.on_upgrade(|mut socket| async move {
if let Err(e) = handle_login(&mut socket, |b| true).await {
warn!("Client failed to login | {e:#}");
return;
};
let (mut ws_write, mut ws_read) = socket.split();
let forward_writes = async {
loop {
let msg = match write_receiver.recv().await {
// Sender Dropped (shouldn't happen, it is static).
None => break,
// This has to copy the bytes to follow ownership rules.
Some(msg) => Message::Binary(Bytes::copy_from_slice(msg)),
};
match ws_write.send(msg).await {
// Clears the stored message from receiver buffer.
// TODO: Move after response ack.
Ok(_) => write_receiver.clear_buffer(),
Err(e) => {
warn!("Failed to send response | {e:?}");
let _ = ws_write.close().await;
break;
}
}
}
};
let handle_reads = async {
loop {
match ws_read.next().await {
// Incoming core msg
Some(Ok(Message::Binary(bytes))) => {
transport.handle_incoming_bytes(bytes).await
}
// Disconnection cases.
Some(Ok(Message::Close(frame))) => {
warn!("Connection closed with frame: {frame:?}");
break;
}
None => break,
Some(Err(e)) => {
error!("Failed to read websocket message | {e:?}");
break;
}
// Can ignore the rest
_ => {
continue;
}
};
}
};
tokio::select! {
_ = forward_writes => {},
_ = handle_reads => {},
};
}))
}
async fn handle_login(
socket: &mut WebSocket,
validate: impl Fn(&[u8]) -> bool,
) -> anyhow::Result<()> {
loop {
// Poll for next message
let msg = socket
.try_next()
.await
.context("Failed to receive login credentials")?
.context("Stream broken before login credentials received")?;
// Treat first message as credentials
let credentials = match &msg {
Message::Text(text) => text.as_bytes(),
Message::Binary(bytes) => &bytes,
Message::Close(frame) => {
return Err(anyhow!(
"Websocket close frame received during login | frame: {frame:?}"
));
}
// Ignore others
_ => continue,
};
// Validate
if validate(credentials) {
// Send login confirmation
socket
.send(Message::Binary(MessageState::Successful.into()))
.await?;
return Ok(());
} else {
// Send login failure
socket
.send(Message::Binary(MessageState::Failed.into()))
.await?;
let _ = socket.close().await;
return Err(anyhow!("Received invalid credentials"));
}
}
}