mirror of
https://github.com/moghtech/komodo.git
synced 2026-04-28 11:49:39 -05:00
add login draft for transport
This commit is contained in:
@@ -1,21 +1,14 @@
|
||||
#[macro_use]
|
||||
extern crate tracing;
|
||||
|
||||
//
|
||||
use std::{net::SocketAddr, str::FromStr};
|
||||
|
||||
use anyhow::Context;
|
||||
use axum_server::tls_rustls::RustlsConfig;
|
||||
use config::periphery_config;
|
||||
|
||||
mod api;
|
||||
mod config;
|
||||
mod connection;
|
||||
mod docker;
|
||||
mod helpers;
|
||||
mod router;
|
||||
mod server;
|
||||
mod stats;
|
||||
mod terminal;
|
||||
mod connection;
|
||||
|
||||
async fn app() -> anyhow::Result<()> {
|
||||
dotenvy::dotenv().ok();
|
||||
@@ -24,7 +17,7 @@ async fn app() -> anyhow::Result<()> {
|
||||
|
||||
info!("Komodo Periphery version: v{}", env!("CARGO_PKG_VERSION"));
|
||||
|
||||
if periphery_config().pretty_startup_config {
|
||||
if config.pretty_startup_config {
|
||||
info!("{:#?}", config.sanitized());
|
||||
} else {
|
||||
info!("{:?}", config.sanitized());
|
||||
@@ -34,41 +27,7 @@ async fn app() -> anyhow::Result<()> {
|
||||
docker::stats::spawn_polling_thread();
|
||||
connection::init_response_channel();
|
||||
|
||||
let addr = format!(
|
||||
"{}:{}",
|
||||
config::periphery_config().bind_ip,
|
||||
config::periphery_config().port
|
||||
);
|
||||
|
||||
let socket_addr = SocketAddr::from_str(&addr)
|
||||
.context("failed to parse listen address")?;
|
||||
|
||||
let app = router::router()
|
||||
.into_make_service_with_connect_info::<SocketAddr>();
|
||||
|
||||
if config.ssl_enabled {
|
||||
info!("🔒 Periphery SSL Enabled");
|
||||
rustls::crypto::ring::default_provider()
|
||||
.install_default()
|
||||
.expect("failed to install default rustls CryptoProvider");
|
||||
helpers::ensure_ssl_certs().await;
|
||||
info!("Komodo Periphery starting on https://{}", socket_addr);
|
||||
let ssl_config = RustlsConfig::from_pem_file(
|
||||
config.ssl_cert_file(),
|
||||
config.ssl_key_file(),
|
||||
)
|
||||
.await
|
||||
.context("Invalid ssl cert / key")?;
|
||||
axum_server::bind_rustls(socket_addr, ssl_config)
|
||||
.serve(app)
|
||||
.await?
|
||||
} else {
|
||||
info!("🔓 Periphery SSL Disabled");
|
||||
info!("Komodo Periphery starting on http://{}", socket_addr);
|
||||
axum_server::bind(socket_addr).serve(app).await?
|
||||
}
|
||||
|
||||
Ok(())
|
||||
server::run_connection_server().await
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
|
||||
@@ -8,15 +8,51 @@ use axum::{
|
||||
response::Response,
|
||||
routing::get,
|
||||
};
|
||||
use axum_server::tls_rustls::RustlsConfig;
|
||||
use serror::{AddStatusCode, AddStatusCodeError};
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::{
|
||||
net::{IpAddr, SocketAddr},
|
||||
str::FromStr,
|
||||
};
|
||||
|
||||
use crate::config::periphery_config;
|
||||
|
||||
pub fn router() -> Router {
|
||||
Router::new()
|
||||
pub async fn run_connection_server() -> anyhow::Result<()> {
|
||||
let config = periphery_config();
|
||||
|
||||
let addr = format!("{}:{}", config.bind_ip, config.port);
|
||||
|
||||
let socket_addr = SocketAddr::from_str(&addr)
|
||||
.context("failed to parse listen address")?;
|
||||
|
||||
let app = Router::new()
|
||||
.route("/", get(crate::connection::inbound_connection))
|
||||
.layer(middleware::from_fn(guard_request_by_ip))
|
||||
.into_make_service_with_connect_info::<SocketAddr>();
|
||||
|
||||
if config.ssl_enabled {
|
||||
info!("🔒 Periphery SSL Enabled");
|
||||
rustls::crypto::ring::default_provider()
|
||||
.install_default()
|
||||
.expect("failed to install default rustls CryptoProvider");
|
||||
crate::helpers::ensure_ssl_certs().await;
|
||||
info!("Komodo Periphery starting on https://{}", socket_addr);
|
||||
let ssl_config = RustlsConfig::from_pem_file(
|
||||
config.ssl_cert_file(),
|
||||
config.ssl_key_file(),
|
||||
)
|
||||
.await
|
||||
.context("Invalid ssl cert / key")?;
|
||||
axum_server::bind_rustls(socket_addr, ssl_config)
|
||||
.serve(app)
|
||||
.await?
|
||||
} else {
|
||||
info!("🔓 Periphery SSL Disabled");
|
||||
info!("Komodo Periphery starting on http://{}", socket_addr);
|
||||
axum_server::bind(socket_addr).serve(app).await?
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn guard_request_by_ip(
|
||||
@@ -6,9 +6,9 @@ use std::{
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::{Context, anyhow};
|
||||
use bytes::Bytes;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use futures_util::{SinkExt, StreamExt, TryStreamExt};
|
||||
use rustls::{ClientConfig, client::danger::ServerCertVerifier};
|
||||
use tokio::{net::TcpStream, sync::RwLock};
|
||||
use tokio_tungstenite::{
|
||||
@@ -17,7 +17,9 @@ use tokio_tungstenite::{
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::{TransportHandler, channel::BufferedReceiver};
|
||||
use crate::{
|
||||
MessageState, TransportHandler, channel::BufferedReceiver,
|
||||
};
|
||||
|
||||
/// Handles client side / outbound connection
|
||||
pub async fn handle_client_connection<
|
||||
@@ -29,7 +31,7 @@ pub async fn handle_client_connection<
|
||||
write_receiver: &mut BufferedReceiver<Bytes>,
|
||||
) {
|
||||
loop {
|
||||
let socket = match connect_websocket(address).await {
|
||||
let mut socket = match connect_websocket(address).await {
|
||||
Ok(socket) => socket,
|
||||
Err(e) => {
|
||||
connection.set_error(e).await;
|
||||
@@ -39,6 +41,15 @@ pub async fn handle_client_connection<
|
||||
};
|
||||
|
||||
info!("Connected to {address}");
|
||||
|
||||
if let Err(e) = handle_login(&mut socket, Bytes::new()).await {
|
||||
connection.set_error(e).await;
|
||||
tokio::time::sleep(Duration::from_secs(5)).await;
|
||||
continue;
|
||||
};
|
||||
|
||||
info!("Logged into {address}");
|
||||
|
||||
connection.connected.store(true, atomic::Ordering::Relaxed);
|
||||
connection.clear_error().await;
|
||||
|
||||
@@ -116,6 +127,44 @@ pub async fn handle_client_connection<
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_login(
|
||||
socket: &mut WebSocketStream<MaybeTlsStream<TcpStream>>,
|
||||
credentials: Bytes,
|
||||
) -> anyhow::Result<()> {
|
||||
socket
|
||||
.send(Message::Binary(credentials))
|
||||
.await
|
||||
.context("Failed to send login credentials")?;
|
||||
|
||||
loop {
|
||||
let response = socket
|
||||
.try_next()
|
||||
.await
|
||||
.context("Failed to receive login response")?
|
||||
.context("Stream broken before login response received")?;
|
||||
let bytes = match &response {
|
||||
Message::Text(text) => text.as_bytes(),
|
||||
Message::Binary(bytes) => &bytes,
|
||||
Message::Close(frame) => {
|
||||
return Err(anyhow!(
|
||||
"Websocket close frame received during login | frame: {frame:?}"
|
||||
));
|
||||
}
|
||||
// Ignore others
|
||||
_ => continue,
|
||||
};
|
||||
let state = bytes
|
||||
.first()
|
||||
.map(|b| MessageState::from_byte(*b))
|
||||
.context("Login response is empty")?;
|
||||
if matches!(state, MessageState::Successful) {
|
||||
return Ok(());
|
||||
} else {
|
||||
return Err(anyhow!("Failed to login | Invalid credentails"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ClientConnection {
|
||||
connected: AtomicBool,
|
||||
|
||||
@@ -14,11 +14,15 @@ pub enum MessageState {
|
||||
Terminal,
|
||||
}
|
||||
|
||||
impl From<MessageState> for Bytes {
|
||||
fn from(value: MessageState) -> Self {
|
||||
Bytes::from_owner([value.as_byte()])
|
||||
}
|
||||
}
|
||||
|
||||
pub trait TransportHandler {
|
||||
fn handle_incoming_bytes(
|
||||
&self,
|
||||
bytes: Bytes,
|
||||
) -> impl Future<Output = ()> + Send;
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,15 +1,21 @@
|
||||
use anyhow::{Context, anyhow};
|
||||
use axum::{
|
||||
extract::{WebSocketUpgrade, ws::Message},
|
||||
extract::{
|
||||
WebSocketUpgrade,
|
||||
ws::{Message, WebSocket},
|
||||
},
|
||||
http::StatusCode,
|
||||
response::Response,
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use futures_util::{SinkExt, StreamExt, TryStreamExt};
|
||||
use serror::AddStatusCode;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{error, warn};
|
||||
|
||||
use crate::{TransportHandler, channel::BufferedReceiver};
|
||||
use crate::{
|
||||
MessageState, TransportHandler, channel::BufferedReceiver,
|
||||
};
|
||||
|
||||
/// Handles server side / inbound connection
|
||||
pub fn handle_server_connection<
|
||||
@@ -24,8 +30,11 @@ pub fn handle_server_connection<
|
||||
.try_lock()
|
||||
.status_code(StatusCode::FORBIDDEN)?;
|
||||
|
||||
Ok(ws.on_upgrade(|socket| async move {
|
||||
// TODO: Handle authentication exchange.
|
||||
Ok(ws.on_upgrade(|mut socket| async move {
|
||||
if let Err(e) = handle_login(&mut socket, |b| true).await {
|
||||
warn!("Client failed to login | {e:#}");
|
||||
return;
|
||||
};
|
||||
|
||||
let (mut ws_write, mut ws_read) = socket.split();
|
||||
|
||||
@@ -81,3 +90,44 @@ pub fn handle_server_connection<
|
||||
};
|
||||
}))
|
||||
}
|
||||
|
||||
async fn handle_login(
|
||||
socket: &mut WebSocket,
|
||||
validate: impl Fn(&[u8]) -> bool,
|
||||
) -> anyhow::Result<()> {
|
||||
loop {
|
||||
// Poll for next message
|
||||
let msg = socket
|
||||
.try_next()
|
||||
.await
|
||||
.context("Failed to receive login credentials")?
|
||||
.context("Stream broken before login credentials received")?;
|
||||
// Treat first message as credentials
|
||||
let credentials = match &msg {
|
||||
Message::Text(text) => text.as_bytes(),
|
||||
Message::Binary(bytes) => &bytes,
|
||||
Message::Close(frame) => {
|
||||
return Err(anyhow!(
|
||||
"Websocket close frame received during login | frame: {frame:?}"
|
||||
));
|
||||
}
|
||||
// Ignore others
|
||||
_ => continue,
|
||||
};
|
||||
// Validate
|
||||
if validate(credentials) {
|
||||
// Send login confirmation
|
||||
socket
|
||||
.send(Message::Binary(MessageState::Successful.into()))
|
||||
.await?;
|
||||
return Ok(());
|
||||
} else {
|
||||
// Send login failure
|
||||
socket
|
||||
.send(Message::Binary(MessageState::Failed.into()))
|
||||
.await?;
|
||||
let _ = socket.close().await;
|
||||
return Err(anyhow!("Received invalid credentials"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user