store connection channels under the connection

This commit is contained in:
mbecker20
2025-09-23 01:10:03 -07:00
parent 6b26cd120c
commit 0da5718991
11 changed files with 177 additions and 175 deletions

View File

@@ -13,17 +13,11 @@ use transport::{
websocket::tungstenite::TungsteniteWebsocket,
};
use crate::{
config::core_config,
connection::MessageHandler,
periphery::PeripheryConnection,
state::{all_server_channels, periphery_connections},
};
use crate::{config::core_config, state::periphery_connections};
/// Managed connections to exactly those specified by specs (ServerId -> Address)
pub async fn manage_client_connections(servers: &[Server]) {
let periphery_connections = periphery_connections();
let periphery_channels = all_server_channels();
let specs = servers
.iter()
@@ -41,16 +35,12 @@ pub async fn manage_client_connections(servers: &[Server]) {
.collect::<HashMap<_, _>>();
// Clear non specced / enabled server connections
for (server_id, connection) in
periphery_connections.get_entries().await
{
for server_id in periphery_connections.get_keys().await {
if !specs.contains_key(&server_id) {
info!(
"Specs do not container {server_id}, cancelling connection"
"Specs do not contain {server_id}, cancelling connection"
);
connection.cancel();
periphery_connections.remove(&server_id).await;
periphery_channels.remove(&server_id).await;
}
}
@@ -122,17 +112,9 @@ pub async fn spawn_client_connection(
host.push_str(&port.to_string());
}
let handler = MessageHandler::new(&server_id).await;
let (connection, mut write_receiver) =
PeripheryConnection::new(address.clone().into());
if let Some(existing_connection) = periphery_connections()
.insert(server_id, connection.clone())
.await
{
existing_connection.cancel();
}
let (connection, mut write_receiver) = periphery_connections()
.insert(server_id, address.clone().into())
.await;
let config = core_config();
let private_key = if private_key.is_empty() {
@@ -175,7 +157,6 @@ pub async fn spawn_client_connection(
expected_public_key: expected_public_key.as_deref(),
write_receiver: &mut write_receiver,
connection: &connection,
handler: &handler,
};
if let Err(e) = handler.handle::<ClientLoginFlow>().await {

View File

@@ -1,25 +1,16 @@
use std::sync::Arc;
use anyhow::anyhow;
use bytes::Bytes;
use cache::CloneCache;
use tokio::sync::mpsc::Sender;
use tokio_util::sync::CancellationToken;
use tracing::warn;
use transport::{
auth::{ConnectionIdentifiers, LoginFlow, PublicKeyValidator},
bytes::id_from_transport_bytes,
channel::BufferedReceiver,
websocket::{
Websocket, WebsocketMessage, WebsocketReceiver as _,
WebsocketSender as _,
},
};
use uuid::Uuid;
use crate::{
periphery::PeripheryConnection, state::all_server_channels,
};
use crate::periphery::PeripheryConnection;
pub mod client;
pub mod server;
@@ -31,7 +22,7 @@ pub struct WebsocketHandler<'a, W> {
pub expected_public_key: Option<&'a str>,
pub write_receiver: &'a mut BufferedReceiver<Bytes>,
pub connection: &'a PeripheryConnection,
pub handler: &'a MessageHandler,
// pub handler: &'a MessageHandler,
}
impl<W: Websocket> WebsocketHandler<'_, W> {
@@ -43,7 +34,7 @@ impl<W: Websocket> WebsocketHandler<'_, W> {
expected_public_key,
write_receiver,
connection,
handler,
// handler,
} = self;
L::login(
@@ -100,7 +91,7 @@ impl<W: Websocket> WebsocketHandler<'_, W> {
match next {
Ok(WebsocketMessage::Binary(bytes)) => {
handler.handle_incoming_bytes(bytes).await
connection.handle_incoming_bytes(bytes).await
}
Ok(WebsocketMessage::Close(_))
| Ok(WebsocketMessage::Closed) => {
@@ -143,37 +134,3 @@ impl PublicKeyValidator for PeripheryPublicKeyValidator<'_> {
}
}
}
pub struct MessageHandler {
channels: Arc<CloneCache<Uuid, Sender<Bytes>>>,
}
impl MessageHandler {
pub async fn new(server_id: &String) -> MessageHandler {
MessageHandler {
channels: all_server_channels()
.get_or_insert_default(server_id)
.await,
}
}
async fn handle_incoming_bytes(&self, bytes: Bytes) {
let id = match id_from_transport_bytes(&bytes) {
Ok(res) => res,
Err(e) => {
// TODO: handle better
warn!("Failed to read id | {e:#}");
return;
}
};
let Some(channel) = self.channels.get(&id).await else {
// TODO: handle better
warn!("Failed to send response | No response channel found");
return;
};
if let Err(e) = channel.send(bytes).await {
// TODO: handle better
warn!("Failed to send response | Channel failure | {e:#}");
}
}
}

View File

@@ -12,11 +12,7 @@ use transport::{
websocket::axum::AxumWebsocket,
};
use crate::{
config::core_config,
connection::{MessageHandler, PeripheryConnection},
state::periphery_connections,
};
use crate::{config::core_config, state::periphery_connections};
pub async fn handler(
Query(PeripheryConnectionQuery { server: _server }): Query<
@@ -58,23 +54,17 @@ pub async fn handler(
}
let expected_public_key = if server.config.public_key.is_empty() {
core_config().periphery_public_key.clone().context("Must either configure Server 'Periphery Public Key' or set KOMODO_PERIPHERY_PUBLIC_KEY")?
core_config()
.periphery_public_key
.clone()
.context("Must either configure Server 'Periphery Public Key' or set KOMODO_PERIPHERY_PUBLIC_KEY")?
} else {
server.config.public_key
};
let handler = MessageHandler::new(&server.id).await;
let (connection, mut write_receiver) =
PeripheryConnection::new(None);
if let Some(existing_connection) = connections
.insert(server.id.clone(), connection.clone())
.await
{
// This case shouldn't be reached from above but doesn't hurt to handle
existing_connection.cancel();
}
let (connection, mut write_receiver) = periphery_connections()
.insert(server.id.clone(), None)
.await;
Ok(ws.on_upgrade(|socket| async move {
let query = format!("server={}", urlencoding::encode(&_server));
@@ -89,7 +79,6 @@ pub async fn handler(
expected_public_key: Some(&expected_public_key),
write_receiver: &mut write_receiver,
connection: &connection,
handler: &handler,
};
if let Err(e) = handler.handle::<ServerLoginFlow>().await {

View File

@@ -191,7 +191,7 @@ pub async fn periphery_client(
if !server.config.enabled {
return Err(anyhow!("server not enabled"));
}
Ok(PeripheryClient::new(server.id.clone()).await)
PeripheryClient::new(server.id.clone()).await
}
#[instrument]

View File

@@ -8,6 +8,7 @@ use std::{
use anyhow::{Context, anyhow};
use bytes::Bytes;
use cache::CloneCache;
use periphery_client::api;
use resolver_api::HasResponse;
use serde::{Serialize, de::DeserializeOwned};
@@ -21,31 +22,38 @@ use tokio_util::sync::CancellationToken;
use tracing::warn;
use transport::{
MessageState,
bytes::{from_transport_bytes, to_transport_bytes},
bytes::{
from_transport_bytes, id_from_transport_bytes, to_transport_bytes,
},
channel::{BufferedReceiver, buffered_channel},
fix_ws_address,
};
use uuid::Uuid;
use crate::state::{
ServerChannels, all_server_channels, periphery_connections,
};
use crate::state::periphery_connections;
pub mod terminal;
pub type ConnectionChannels = CloneCache<Uuid, Sender<Bytes>>;
pub struct PeripheryClient {
pub server_id: String,
channels: Arc<ServerChannels>,
channels: Arc<ConnectionChannels>,
}
impl PeripheryClient {
pub async fn new(server_id: String) -> PeripheryClient {
PeripheryClient {
channels: all_server_channels()
.get_or_insert_default(&server_id)
.await,
pub async fn new(
server_id: String,
) -> anyhow::Result<PeripheryClient> {
Ok(PeripheryClient {
channels: periphery_connections()
.get(&server_id)
.await
.context("Periphery not connected")?
.channels
.clone(),
server_id,
}
})
}
pub async fn new_with_spawned_client_connection<
@@ -59,9 +67,8 @@ impl PeripheryClient {
if address.is_empty() {
return Err(anyhow!("Server address cannot be empty"));
}
let periphery = PeripheryClient::new(server_id.clone()).await;
spawn(server_id, fix_ws_address(address)).await?;
Ok(periphery)
spawn(server_id.clone(), fix_ws_address(address)).await?;
PeripheryClient::new(server_id).await
}
#[tracing::instrument(level = "debug", skip(self))]
@@ -174,39 +181,124 @@ impl PeripheryClient {
}
}
#[derive(Default)]
pub struct PeripheryConnections(
CloneCache<String, Arc<PeripheryConnection>>,
);
impl PeripheryConnections {
pub async fn insert(
&self,
server_id: String,
address: Option<String>,
) -> (Arc<PeripheryConnection>, BufferedReceiver<Bytes>) {
let channels = if let Some(existing_connection) =
self.0.remove(&server_id).await
{
existing_connection.cancel();
// Keep the same channels so requests
// can handle disconnects while processing.
existing_connection.channels.clone()
} else {
Default::default()
};
let (connection, receiver) =
PeripheryConnection::new(address, channels);
self.0.insert(server_id, connection.clone()).await;
(connection, receiver)
}
pub async fn get(
&self,
server_id: &String,
) -> Option<Arc<PeripheryConnection>> {
self.0.get(server_id).await
}
/// Remove and cancel connection
pub async fn remove(
&self,
server_id: &String,
) -> Option<Arc<PeripheryConnection>> {
self
.0
.remove(server_id)
.await
.inspect(|connection| connection.cancel())
}
pub async fn get_keys(&self) -> Vec<String> {
self.0.get_keys().await
}
}
#[derive(Debug)]
pub struct PeripheryConnection {
// Inbound connections have this as None
/// Specify outbound connection address.
/// Inbound connections have this as None
pub address: Option<String>,
pub write_sender: Sender<Bytes>,
/// Whether Periphery is currently connected.
pub connected: AtomicBool,
/// Stores latest connection error
pub error: RwLock<Option<serror::Serror>>,
/// Cancel the connection
pub cancel: CancellationToken,
/// Send bytes to Periphery
pub sender: Sender<Bytes>,
/// Send bytes from Periphery to channel handlers.
/// Must be maintained if new connection replaces old
/// at the same server id.
pub channels: Arc<ConnectionChannels>,
}
impl PeripheryConnection {
pub fn new(
address: Option<String>,
channels: Arc<ConnectionChannels>,
) -> (Arc<PeripheryConnection>, BufferedReceiver<Bytes>) {
let (write_sender, write_receiver) = buffered_channel(1000);
let (sender, receiver) = buffered_channel();
(
PeripheryConnection {
address,
write_sender,
sender,
channels,
connected: AtomicBool::new(false),
error: RwLock::new(None),
cancel: CancellationToken::new(),
}
.into(),
write_receiver,
receiver,
)
}
pub async fn handle_incoming_bytes(&self, bytes: Bytes) {
let id = match id_from_transport_bytes(&bytes) {
Ok(res) => res,
Err(e) => {
// TODO: handle better
warn!("Failed to read id | {e:#}");
return;
}
};
let Some(channel) = self.channels.get(&id).await else {
// TODO: handle better
warn!("Failed to send response | No response channel found");
return;
};
if let Err(e) = channel.send(bytes).await {
// TODO: handle better
warn!("Failed to send response | Channel failure | {e:#}");
}
}
pub async fn send(
&self,
value: Bytes,
) -> Result<(), SendError<Bytes>> {
self.write_sender.send(value).await
self.sender.send(value).await
}
pub fn set_connected(&self, connected: bool) {

View File

@@ -17,8 +17,7 @@ use transport::bytes::data_from_transport_bytes;
use uuid::Uuid;
use crate::{
periphery::PeripheryClient,
state::{all_server_channels, periphery_connections},
periphery::PeripheryClient, state::periphery_connections,
};
impl PeripheryClient {
@@ -42,13 +41,10 @@ impl PeripheryClient {
.await
.context("Failed to create terminal connection")?;
let response_channels = all_server_channels()
.get_or_insert_default(&self.server_id)
.await;
let (response_sender, response_receiever) = channel(1000);
response_channels.insert(id, response_sender).await;
let (sender, receiever) = channel(1024);
connection.channels.insert(id, sender).await;
Ok((id, connection.write_sender.clone(), response_receiever))
Ok((id, connection.sender.clone(), receiever))
}
pub async fn connect_container_exec(
@@ -72,13 +68,10 @@ impl PeripheryClient {
.await
.context("Failed to create container exec connection")?;
let response_channels = all_server_channels()
.get_or_insert_default(&self.server_id)
.await;
let (response_sender, response_receiever) = channel(1000);
response_channels.insert(id, response_sender).await;
let (sender, receiever) = channel(1000);
connection.channels.insert(id, sender).await;
Ok((id, connection.write_sender.clone(), response_receiever))
Ok((id, connection.sender.clone(), receiever))
}
/// Executes command on specified terminal,
@@ -106,23 +99,26 @@ impl PeripheryClient {
"sending request | type: ExecuteTerminal | terminal name: {terminal} | command: {command}",
);
let connection = periphery_connections()
.get(&self.server_id)
.await
.with_context(|| {
format!("No connection found for server {}", self.server_id)
})?;
let id = self
.request(ExecuteTerminal { terminal, command })
.await
.context("Failed to create execute terminal connection")?;
let response_channels = all_server_channels()
.get_or_insert_default(&self.server_id)
.await;
let (sender, receiver) = channel(1000);
let (response_sender, response_receiever) = channel(1000);
response_channels.insert(id, response_sender).await;
connection.channels.insert(id, sender).await;
Ok(ReceiverStream {
id,
channels: response_channels,
receiver: response_receiever,
receiver,
channels: connection.channels.clone(),
})
}
@@ -150,6 +146,13 @@ impl PeripheryClient {
"sending request | type: ExecuteContainerExec | container: {container} | shell: {shell} | command: {command}",
);
let connection = periphery_connections()
.get(&self.server_id)
.await
.with_context(|| {
format!("No connection found for server {}", self.server_id)
})?;
let id = self
.request(ExecuteContainerExec {
container,
@@ -159,18 +162,14 @@ impl PeripheryClient {
.await
.context("Failed to create execute terminal connection")?;
let response_channels = all_server_channels()
.get_or_insert_default(&self.server_id)
.await;
let (sender, receiver) = channel(1000);
let (response_sender, response_receiever) = channel(1000);
response_channels.insert(id, response_sender).await;
connection.channels.insert(id, sender).await;
Ok(ReceiverStream {
id,
channels: response_channels,
receiver: response_receiever,
receiver,
channels: connection.channels.clone(),
})
}
}

View File

@@ -18,7 +18,7 @@ use crate::{
helpers::query::get_system_info,
monitor::update_cache_for_server,
state::{
action_states, all_server_channels, db_client,
action_states, db_client, periphery_connections,
server_status_cache,
},
};
@@ -224,7 +224,7 @@ impl super::KomodoResource for Server {
) -> anyhow::Result<()> {
tokio::join!(
server_status_cache().remove(&resource.id),
all_server_channels().remove(&resource.id),
periphery_connections().remove(&resource.id),
);
Ok(())
}

View File

@@ -5,9 +5,7 @@ use std::{
use anyhow::Context;
use arc_swap::ArcSwap;
use bytes::Bytes;
use cache::CloneCache;
use database::Client;
use komodo_client::entities::{
action::ActionState,
build::BuildState,
@@ -20,8 +18,6 @@ use komodo_client::entities::{
use octorust::auth::{
Credentials, InstallationTokenGenerator, JWTCredentials,
};
use tokio::sync::mpsc::Sender;
use uuid::Uuid;
use crate::{
auth::jwt::JwtClient,
@@ -32,12 +28,13 @@ use crate::{
monitor::{
CachedDeploymentStatus, CachedRepoStatus, CachedServerStatus,
CachedStackStatus, History,
}, periphery::PeripheryConnection,
},
periphery::PeripheryConnections,
};
static DB_CLIENT: OnceLock<Client> = OnceLock::new();
static DB_CLIENT: OnceLock<database::Client> = OnceLock::new();
pub fn db_client() -> &'static Client {
pub fn db_client() -> &'static database::Client {
DB_CLIENT
.get()
.expect("db_client accessed before initialized")
@@ -45,7 +42,7 @@ pub fn db_client() -> &'static Client {
/// Must be called in app startup sequence.
pub async fn init_db_client() {
let client = Client::new(&core_config().database)
let client = database::Client::new(&core_config().database)
.await
.context("failed to initialize database client")
.unwrap();
@@ -216,19 +213,7 @@ pub fn all_resources_cache() -> &'static ArcSwap<AllResourcesById> {
ALL_RESOURCES.get_or_init(Default::default)
}
pub type ServerChannels = CloneCache<Uuid, Sender<Bytes>>;
// Server id => ServerChannel
pub type AllServerChannels = CloneCache<String, Arc<ServerChannels>>;
pub fn all_server_channels() -> &'static AllServerChannels {
static CHANNELS: OnceLock<AllServerChannels> = OnceLock::new();
CHANNELS.get_or_init(Default::default)
}
/// server id => connection
pub type PeripheryConnections =
CloneCache<String, Arc<PeripheryConnection>>;
pub fn periphery_connections() -> &'static PeripheryConnections {
static CONNECTIONS: OnceLock<PeripheryConnections> =
OnceLock::new();

View File

@@ -2,7 +2,7 @@ use crate::{
auth::{auth_api_key_check_enabled, auth_jwt_check_enabled},
helpers::query::get_user,
periphery::PeripheryClient,
state::all_server_channels,
state::periphery_connections,
};
use anyhow::anyhow;
use axum::{
@@ -294,9 +294,9 @@ async fn forward_ws_channel(
"Failed to disconnect Periphery terminal forwarding | {e:#}",
)
}
if let Some(response_channels) =
all_server_channels().get(&periphery.server_id).await
if let Some(connection) =
periphery_connections().get(&periphery.server_id).await
{
response_channels.remove(&periphery_connection_id).await;
connection.channels.remove(&periphery_connection_id).await;
}
}

View File

@@ -47,11 +47,9 @@ fn ws_receiver() -> &'static Mutex<BufferedReceiver<Bytes>> {
.expect("response_receiver accessed before initialized")
}
const RESPONSE_BUFFER_MAX_LEN: usize = 1_024;
/// Must call in startup sequence
pub fn init_response_channel() {
let (sender, receiver) = buffered_channel(RESPONSE_BUFFER_MAX_LEN);
let (sender, receiver) = buffered_channel();
WS_SENDER
.set(sender)
.expect("response_sender initialized more than once");

View File

@@ -2,11 +2,12 @@ use std::ops::Deref;
use tokio::sync::mpsc;
const RESPONSE_BUFFER_MAX_LEN: usize = 1_024;
/// Create a buffered channel
pub fn buffered_channel<T: Deref>(
buffer: usize,
) -> (mpsc::Sender<T>, BufferedReceiver<T>) {
let (sender, receiver) = mpsc::channel(buffer);
pub fn buffered_channel<T: Deref>()
-> (mpsc::Sender<T>, BufferedReceiver<T>) {
let (sender, receiver) = mpsc::channel(RESPONSE_BUFFER_MAX_LEN);
(sender, BufferedReceiver::new(receiver))
}