add login draft for transport

This commit is contained in:
mbecker20
2025-09-19 01:48:09 -07:00
parent 230f357b5a
commit 58c1afb8ef
5 changed files with 157 additions and 59 deletions

View File

@@ -1,21 +1,14 @@
#[macro_use]
extern crate tracing;
//
use std::{net::SocketAddr, str::FromStr};
use anyhow::Context;
use axum_server::tls_rustls::RustlsConfig;
use config::periphery_config;
mod api;
mod config;
mod connection;
mod docker;
mod helpers;
mod router;
mod server;
mod stats;
mod terminal;
mod connection;
async fn app() -> anyhow::Result<()> {
dotenvy::dotenv().ok();
@@ -24,7 +17,7 @@ async fn app() -> anyhow::Result<()> {
info!("Komodo Periphery version: v{}", env!("CARGO_PKG_VERSION"));
if periphery_config().pretty_startup_config {
if config.pretty_startup_config {
info!("{:#?}", config.sanitized());
} else {
info!("{:?}", config.sanitized());
@@ -34,41 +27,7 @@ async fn app() -> anyhow::Result<()> {
docker::stats::spawn_polling_thread();
connection::init_response_channel();
let addr = format!(
"{}:{}",
config::periphery_config().bind_ip,
config::periphery_config().port
);
let socket_addr = SocketAddr::from_str(&addr)
.context("failed to parse listen address")?;
let app = router::router()
.into_make_service_with_connect_info::<SocketAddr>();
if config.ssl_enabled {
info!("🔒 Periphery SSL Enabled");
rustls::crypto::ring::default_provider()
.install_default()
.expect("failed to install default rustls CryptoProvider");
helpers::ensure_ssl_certs().await;
info!("Komodo Periphery starting on https://{}", socket_addr);
let ssl_config = RustlsConfig::from_pem_file(
config.ssl_cert_file(),
config.ssl_key_file(),
)
.await
.context("Invalid ssl cert / key")?;
axum_server::bind_rustls(socket_addr, ssl_config)
.serve(app)
.await?
} else {
info!("🔓 Periphery SSL Disabled");
info!("Komodo Periphery starting on http://{}", socket_addr);
axum_server::bind(socket_addr).serve(app).await?
}
Ok(())
server::run_connection_server().await
}
#[tokio::main]

View File

@@ -8,15 +8,51 @@ use axum::{
response::Response,
routing::get,
};
use axum_server::tls_rustls::RustlsConfig;
use serror::{AddStatusCode, AddStatusCodeError};
use std::net::{IpAddr, SocketAddr};
use std::{
net::{IpAddr, SocketAddr},
str::FromStr,
};
use crate::config::periphery_config;
pub fn router() -> Router {
Router::new()
pub async fn run_connection_server() -> anyhow::Result<()> {
let config = periphery_config();
let addr = format!("{}:{}", config.bind_ip, config.port);
let socket_addr = SocketAddr::from_str(&addr)
.context("failed to parse listen address")?;
let app = Router::new()
.route("/", get(crate::connection::inbound_connection))
.layer(middleware::from_fn(guard_request_by_ip))
.into_make_service_with_connect_info::<SocketAddr>();
if config.ssl_enabled {
info!("🔒 Periphery SSL Enabled");
rustls::crypto::ring::default_provider()
.install_default()
.expect("failed to install default rustls CryptoProvider");
crate::helpers::ensure_ssl_certs().await;
info!("Komodo Periphery starting on https://{}", socket_addr);
let ssl_config = RustlsConfig::from_pem_file(
config.ssl_cert_file(),
config.ssl_key_file(),
)
.await
.context("Invalid ssl cert / key")?;
axum_server::bind_rustls(socket_addr, ssl_config)
.serve(app)
.await?
} else {
info!("🔓 Periphery SSL Disabled");
info!("Komodo Periphery starting on http://{}", socket_addr);
axum_server::bind(socket_addr).serve(app).await?
}
Ok(())
}
async fn guard_request_by_ip(

View File

@@ -6,9 +6,9 @@ use std::{
time::Duration,
};
use anyhow::Context;
use anyhow::{Context, anyhow};
use bytes::Bytes;
use futures_util::{SinkExt, StreamExt};
use futures_util::{SinkExt, StreamExt, TryStreamExt};
use rustls::{ClientConfig, client::danger::ServerCertVerifier};
use tokio::{net::TcpStream, sync::RwLock};
use tokio_tungstenite::{
@@ -17,7 +17,9 @@ use tokio_tungstenite::{
use tokio_util::sync::CancellationToken;
use tracing::{info, warn};
use crate::{TransportHandler, channel::BufferedReceiver};
use crate::{
MessageState, TransportHandler, channel::BufferedReceiver,
};
/// Handles client side / outbound connection
pub async fn handle_client_connection<
@@ -29,7 +31,7 @@ pub async fn handle_client_connection<
write_receiver: &mut BufferedReceiver<Bytes>,
) {
loop {
let socket = match connect_websocket(address).await {
let mut socket = match connect_websocket(address).await {
Ok(socket) => socket,
Err(e) => {
connection.set_error(e).await;
@@ -39,6 +41,15 @@ pub async fn handle_client_connection<
};
info!("Connected to {address}");
if let Err(e) = handle_login(&mut socket, Bytes::new()).await {
connection.set_error(e).await;
tokio::time::sleep(Duration::from_secs(5)).await;
continue;
};
info!("Logged into {address}");
connection.connected.store(true, atomic::Ordering::Relaxed);
connection.clear_error().await;
@@ -116,6 +127,44 @@ pub async fn handle_client_connection<
}
}
async fn handle_login(
socket: &mut WebSocketStream<MaybeTlsStream<TcpStream>>,
credentials: Bytes,
) -> anyhow::Result<()> {
socket
.send(Message::Binary(credentials))
.await
.context("Failed to send login credentials")?;
loop {
let response = socket
.try_next()
.await
.context("Failed to receive login response")?
.context("Stream broken before login response received")?;
let bytes = match &response {
Message::Text(text) => text.as_bytes(),
Message::Binary(bytes) => &bytes,
Message::Close(frame) => {
return Err(anyhow!(
"Websocket close frame received during login | frame: {frame:?}"
));
}
// Ignore others
_ => continue,
};
let state = bytes
.first()
.map(|b| MessageState::from_byte(*b))
.context("Login response is empty")?;
if matches!(state, MessageState::Successful) {
return Ok(());
} else {
return Err(anyhow!("Failed to login | Invalid credentails"));
}
}
}
#[derive(Debug)]
pub struct ClientConnection {
connected: AtomicBool,

View File

@@ -14,11 +14,15 @@ pub enum MessageState {
Terminal,
}
impl From<MessageState> for Bytes {
fn from(value: MessageState) -> Self {
Bytes::from_owner([value.as_byte()])
}
}
pub trait TransportHandler {
fn handle_incoming_bytes(
&self,
bytes: Bytes,
) -> impl Future<Output = ()> + Send;
}

View File

@@ -1,15 +1,21 @@
use anyhow::{Context, anyhow};
use axum::{
extract::{WebSocketUpgrade, ws::Message},
extract::{
WebSocketUpgrade,
ws::{Message, WebSocket},
},
http::StatusCode,
response::Response,
};
use bytes::Bytes;
use futures_util::{SinkExt, StreamExt};
use futures_util::{SinkExt, StreamExt, TryStreamExt};
use serror::AddStatusCode;
use tokio::sync::Mutex;
use tracing::{error, warn};
use crate::{TransportHandler, channel::BufferedReceiver};
use crate::{
MessageState, TransportHandler, channel::BufferedReceiver,
};
/// Handles server side / inbound connection
pub fn handle_server_connection<
@@ -24,8 +30,11 @@ pub fn handle_server_connection<
.try_lock()
.status_code(StatusCode::FORBIDDEN)?;
Ok(ws.on_upgrade(|socket| async move {
// TODO: Handle authentication exchange.
Ok(ws.on_upgrade(|mut socket| async move {
if let Err(e) = handle_login(&mut socket, |b| true).await {
warn!("Client failed to login | {e:#}");
return;
};
let (mut ws_write, mut ws_read) = socket.split();
@@ -81,3 +90,44 @@ pub fn handle_server_connection<
};
}))
}
async fn handle_login(
socket: &mut WebSocket,
validate: impl Fn(&[u8]) -> bool,
) -> anyhow::Result<()> {
loop {
// Poll for next message
let msg = socket
.try_next()
.await
.context("Failed to receive login credentials")?
.context("Stream broken before login credentials received")?;
// Treat first message as credentials
let credentials = match &msg {
Message::Text(text) => text.as_bytes(),
Message::Binary(bytes) => &bytes,
Message::Close(frame) => {
return Err(anyhow!(
"Websocket close frame received during login | frame: {frame:?}"
));
}
// Ignore others
_ => continue,
};
// Validate
if validate(credentials) {
// Send login confirmation
socket
.send(Message::Binary(MessageState::Successful.into()))
.await?;
return Ok(());
} else {
// Send login failure
socket
.send(Message::Binary(MessageState::Failed.into()))
.await?;
let _ = socket.close().await;
return Err(anyhow!("Received invalid credentials"));
}
}
}