diff --git a/Cargo.lock b/Cargo.lock index f7a8310c9..580a845d8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4153,19 +4153,10 @@ version = "2.0.0-dev-4" version = "2.0.0-dev-5" >>>>>>> 3edfbe8b (deploy 2.0.0-dev-5) dependencies = [ - "anyhow", - "bytes", - "cache", - "futures-util", "komodo_client", "resolver_api", "serde", - "serde_json", "serror", - "tokio", - "tokio-util", - "tracing", - "transport", "uuid", ] diff --git a/bin/core/src/api/write/build.rs b/bin/core/src/api/write/build.rs index 067d6bb3d..9aec2a7a7 100644 --- a/bin/core/src/api/write/build.rs +++ b/bin/core/src/api/write/build.rs @@ -25,16 +25,14 @@ use komodo_client::{ use octorust::types::{ ReposCreateWebhookRequest, ReposCreateWebhookRequestConfig, }; -use periphery_client::{ - PeripheryClient, - api::build::{ - GetDockerfileContentsOnHost, WriteDockerfileContentsToHost, - }, +use periphery_client::api::build::{ + GetDockerfileContentsOnHost, WriteDockerfileContentsToHost, }; use resolver_api::Resolve; use tokio::fs; use crate::connection::client::spawn_client_connection; +use crate::periphery::PeripheryClient; use crate::{ config::core_config, helpers::{ diff --git a/bin/core/src/connection/client.rs b/bin/core/src/connection/client.rs index d13d69458..adffced85 100644 --- a/bin/core/src/connection/client.rs +++ b/bin/core/src/connection/client.rs @@ -1,16 +1,9 @@ use std::{collections::HashMap, sync::Arc, time::Duration}; -use crate::{ - all_server_channels, - config::core_config, - connection::{MessageHandler, PeripheryConnection}, -}; use anyhow::Context; use axum::http::HeaderValue; use komodo_client::entities::{optional_string, server::Server}; -use periphery_client::{ - CONNECTION_RETRY_SECONDS, periphery_connections, -}; +use periphery_client::CONNECTION_RETRY_SECONDS; use rustls::{ClientConfig, client::danger::ServerCertVerifier}; use tokio_tungstenite::Connector; use tracing::{info, warn}; @@ -20,6 +13,13 @@ use transport::{ websocket::tungstenite::TungsteniteWebsocket, }; +use crate::{ + config::core_config, + connection::MessageHandler, + periphery::PeripheryConnection, + state::{all_server_channels, periphery_connections}, +}; + /// Managed connections to exactly those specified by specs (ServerId -> Address) pub async fn manage_client_connections(servers: &[Server]) { let periphery_connections = periphery_connections(); diff --git a/bin/core/src/connection/mod.rs b/bin/core/src/connection/mod.rs index c8e699b54..c09c980ec 100644 --- a/bin/core/src/connection/mod.rs +++ b/bin/core/src/connection/mod.rs @@ -3,7 +3,6 @@ use std::sync::Arc; use anyhow::anyhow; use bytes::Bytes; use cache::CloneCache; -use periphery_client::{PeripheryConnection, all_server_channels}; use tokio::sync::mpsc::Sender; use tokio_util::sync::CancellationToken; use tracing::warn; @@ -18,6 +17,10 @@ use transport::{ }; use uuid::Uuid; +use crate::{ + periphery::PeripheryConnection, state::all_server_channels, +}; + pub mod client; pub mod server; diff --git a/bin/core/src/connection/server.rs b/bin/core/src/connection/server.rs index 086834045..fdc4ff587 100644 --- a/bin/core/src/connection/server.rs +++ b/bin/core/src/connection/server.rs @@ -5,7 +5,6 @@ use axum::{ response::Response, }; use komodo_client::entities::server::Server; -use periphery_client::periphery_connections; use serror::{AddStatusCode, AddStatusCodeError}; use transport::{ PeripheryConnectionQuery, @@ -16,6 +15,7 @@ use transport::{ use crate::{ config::core_config, connection::{MessageHandler, PeripheryConnection}, + state::periphery_connections, }; pub async fn handler( @@ -49,13 +49,12 @@ pub async fn handler( // Ensure connected server can't get bumped off the connection. // Treat this as authorization issue. if let Some(existing_connection) = connections.get(&server.id).await + && existing_connection.connected() { - if existing_connection.connected() { - return Err( - anyhow!("A Server '{_server}' is already connected") - .status_code(StatusCode::UNAUTHORIZED), - ); - } + return Err( + anyhow!("A Server '{_server}' is already connected") + .status_code(StatusCode::UNAUTHORIZED), + ); } let expected_public_key = if server.config.public_key.is_empty() { diff --git a/bin/core/src/helpers/builder.rs b/bin/core/src/helpers/builder.rs index ceeb09861..15d76ffaf 100644 --- a/bin/core/src/helpers/builder.rs +++ b/bin/core/src/helpers/builder.rs @@ -10,22 +10,14 @@ use komodo_client::entities::{ server::Server, update::{Log, Update}, }; -use periphery_client::{ - PeripheryClient, - api::{self, GetVersionResponse}, -}; +use periphery_client::api::{self, GetVersionResponse}; use crate::{ cloud::{ - BuildCleanupData, aws::ec2::{ - Ec2Instance, launch_ec2_instance, - terminate_ec2_instance_with_retry, - }, - }, - connection::client::spawn_client_connection, - helpers::update::update_update, - resource, + launch_ec2_instance, terminate_ec2_instance_with_retry, Ec2Instance + }, BuildCleanupData + }, connection::client::spawn_client_connection, helpers::update::update_update, periphery::PeripheryClient, resource }; use super::periphery_client; diff --git a/bin/core/src/helpers/mod.rs b/bin/core/src/helpers/mod.rs index affe132c2..98d68839c 100644 --- a/bin/core/src/helpers/mod.rs +++ b/bin/core/src/helpers/mod.rs @@ -15,10 +15,11 @@ use komodo_client::entities::{ stack::Stack, user::User, }; -use periphery_client::PeripheryClient; use rand::Rng; -use crate::{config::core_config, state::db_client}; +use crate::{ + config::core_config, periphery::PeripheryClient, state::db_client, +}; pub mod action_state; pub mod all_resources; diff --git a/bin/core/src/main.rs b/bin/core/src/main.rs index 435b15013..1f1a72cb9 100644 --- a/bin/core/src/main.rs +++ b/bin/core/src/main.rs @@ -8,7 +8,6 @@ use std::{net::SocketAddr, str::FromStr}; use anyhow::Context; use axum::Router; use axum_server::{Handle, tls_rustls::RustlsConfig}; -use periphery_client::all_server_channels; use tower_http::{ cors::{Any, CorsLayer}, services::{ServeDir, ServeFile}, @@ -26,6 +25,7 @@ mod helpers; mod listener; mod monitor; mod network; +mod periphery; mod permission; mod resource; mod schedule; diff --git a/bin/core/src/monitor/lists.rs b/bin/core/src/monitor/lists.rs index ffd5fe42f..e622a5621 100644 --- a/bin/core/src/monitor/lists.rs +++ b/bin/core/src/monitor/lists.rs @@ -5,10 +5,9 @@ use komodo_client::entities::{ }, stack::ComposeProject, }; -use periphery_client::{ - PeripheryClient, - api::{GetDockerLists, GetDockerListsResponse}, -}; +use periphery_client::api::{GetDockerLists, GetDockerListsResponse}; + +use crate::periphery::PeripheryClient; pub async fn get_docker_lists( periphery: &PeripheryClient, diff --git a/bin/core/src/periphery/mod.rs b/bin/core/src/periphery/mod.rs new file mode 100644 index 000000000..d13f16498 --- /dev/null +++ b/bin/core/src/periphery/mod.rs @@ -0,0 +1,254 @@ +use std::{ + sync::{ + Arc, + atomic::{self, AtomicBool}, + }, + time::Duration, +}; + +use anyhow::{Context, anyhow}; +use bytes::Bytes; +use periphery_client::api; +use resolver_api::HasResponse; +use serde::{Serialize, de::DeserializeOwned}; +use serde_json::json; +use serror::{deserialize_error_bytes, serror_into_anyhow_error}; +use tokio::sync::{ + RwLock, + mpsc::{self, Sender, error::SendError}, +}; +use tokio_util::sync::CancellationToken; +use tracing::warn; +use transport::{ + MessageState, + bytes::{from_transport_bytes, to_transport_bytes}, + channel::{BufferedReceiver, buffered_channel}, + fix_ws_address, +}; +use uuid::Uuid; + +use crate::state::{ + ServerChannels, all_server_channels, periphery_connections, +}; + +pub mod terminal; + +pub struct PeripheryClient { + pub server_id: String, + channels: Arc, +} + +impl PeripheryClient { + pub async fn new(server_id: String) -> PeripheryClient { + PeripheryClient { + channels: all_server_channels() + .get_or_insert_default(&server_id) + .await, + server_id, + } + } + + pub async fn new_with_spawned_client_connection< + F: Future>, + >( + server_id: String, + address: &str, + // (Server id, address) + spawn: impl FnOnce(String, String) -> F, + ) -> anyhow::Result { + if address.is_empty() { + return Err(anyhow!("Server address cannot be empty")); + } + let periphery = PeripheryClient::new(server_id.clone()).await; + spawn(server_id, fix_ws_address(address)).await?; + Ok(periphery) + } + + #[tracing::instrument(level = "debug", skip(self))] + pub async fn health_check(&self) -> anyhow::Result<()> { + self.request(api::GetHealth {}).await?; + Ok(()) + } + + #[tracing::instrument( + name = "PeripheryRequest", + skip(self), + level = "debug" + )] + pub async fn request( + &self, + request: T, + ) -> anyhow::Result + where + T: std::fmt::Debug + Serialize + HasResponse, + T::Response: DeserializeOwned, + { + let connection = periphery_connections() + .get(&self.server_id) + .await + .with_context(|| { + format!("No connection found for server {}", self.server_id) + })?; + + // Polls connected 3 times before bailing + connection.bail_if_not_connected().await?; + + let id = Uuid::new_v4(); + let (response_sender, mut response_receiever) = + mpsc::channel(1000); + self.channels.insert(id, response_sender).await; + + let req_type = T::req_type(); + let data = serde_json::to_vec(&json!({ + "type": req_type, + "params": request + })) + .context("Failed to serialize request to bytes")?; + + if let Err(e) = connection + .send(to_transport_bytes(data, id, MessageState::Request)) + .await + .context("Failed to send request over channel") + { + // cleanup + self.channels.remove(&id).await; + return Err(e); + } + + // Poll for the associated response + loop { + let next = tokio::select! { + msg = response_receiever.recv() => msg, + // Periphery will send InProgress every 5s to avoid timeout + _ = tokio::time::sleep(Duration::from_secs(10)) => { + return Err(anyhow!("Response timed out")); + } + }; + + let bytes = match next { + Some(bytes) => bytes, + None => { + return Err(anyhow!( + "Sender dropped before response was recieved" + )); + } + }; + + let (state, data) = match from_transport_bytes(bytes) { + Ok((data, _, state)) if !data.is_empty() => (state, data), + // Ignore no data cases + Ok(_) => continue, + Err(e) => { + warn!( + "Server {} | Received invalid message | {e:#}", + self.server_id + ); + continue; + } + }; + match state { + // TODO: improve the allocation in .to_vec + MessageState::Successful => { + // cleanup + self.channels.remove(&id).await; + return serde_json::from_slice(&data) + .context("Failed to parse successful response"); + } + MessageState::Failed => { + // cleanup + self.channels.remove(&id).await; + return Err(deserialize_error_bytes(&data)); + } + MessageState::InProgress => continue, + // Shouldn't be received by this receiver + other => { + // TODO: delete log + warn!( + "Server {} | Got other message over over response channel: {other:?}", + self.server_id + ); + continue; + } + } + } + } +} + +#[derive(Debug)] +pub struct PeripheryConnection { + // Inbound connections have this as None + pub address: Option, + pub write_sender: Sender, + pub connected: AtomicBool, + pub error: RwLock>, + pub cancel: CancellationToken, +} + +impl PeripheryConnection { + pub fn new( + address: Option, + ) -> (Arc, BufferedReceiver) { + let (write_sender, write_receiver) = buffered_channel(1000); + ( + PeripheryConnection { + address, + write_sender, + connected: AtomicBool::new(false), + error: RwLock::new(None), + cancel: CancellationToken::new(), + } + .into(), + write_receiver, + ) + } + + pub async fn send( + &self, + value: Bytes, + ) -> Result<(), SendError> { + self.write_sender.send(value).await + } + + pub fn set_connected(&self, connected: bool) { + self.connected.store(connected, atomic::Ordering::Relaxed); + } + + pub fn connected(&self) -> bool { + self.connected.load(atomic::Ordering::Relaxed) + } + + /// Polls connected 3 times (1s in between) before bailing. + pub async fn bail_if_not_connected(&self) -> anyhow::Result<()> { + for i in 0..3 { + if self.connected() { + return Ok(()); + } + if i < 2 { + tokio::time::sleep(Duration::from_secs(1)).await; + } + } + if let Some(e) = self.error().await { + Err(serror_into_anyhow_error(e)) + } else { + Err(anyhow!("Server is not currently connected")) + } + } + + pub async fn error(&self) -> Option { + self.error.read().await.clone() + } + + pub async fn set_error(&self, e: anyhow::Error) { + let mut error = self.error.write().await; + *error = Some(e.into()); + } + + pub async fn clear_error(&self) { + let mut error = self.error.write().await; + *error = None; + } + + pub fn cancel(&self) { + self.cancel.cancel(); + } +} diff --git a/client/periphery/rs/src/terminal.rs b/bin/core/src/periphery/terminal.rs similarity index 95% rename from client/periphery/rs/src/terminal.rs rename to bin/core/src/periphery/terminal.rs index 593c5ef8c..ad4c10cef 100644 --- a/client/periphery/rs/src/terminal.rs +++ b/bin/core/src/periphery/terminal.rs @@ -7,14 +7,18 @@ use std::{ use anyhow::Context; use bytes::Bytes; use cache::CloneCache; -use futures_util::Stream; +use futures::Stream; +use periphery_client::api::terminal::{ + ConnectContainerExec, ConnectTerminal, END_OF_OUTPUT, + ExecuteContainerExec, ExecuteTerminal, +}; use tokio::sync::mpsc::{Receiver, Sender, channel}; use transport::bytes::data_from_transport_bytes; use uuid::Uuid; use crate::{ - PeripheryClient, all_server_channels, api::terminal::*, - periphery_connections, + periphery::PeripheryClient, + state::{all_server_channels, periphery_connections}, }; impl PeripheryClient { @@ -171,10 +175,6 @@ impl PeripheryClient { } } -/// Execute Sentinels -pub const START_OF_OUTPUT: &str = "__KOMODO_START_OF_OUTPUT__"; -pub const END_OF_OUTPUT: &str = "__KOMODO_END_OF_OUTPUT__"; - pub struct ReceiverStream { id: Uuid, channels: Arc>>, diff --git a/bin/core/src/resource/server.rs b/bin/core/src/resource/server.rs index d0cae9ee1..aefa15520 100644 --- a/bin/core/src/resource/server.rs +++ b/bin/core/src/resource/server.rs @@ -12,13 +12,15 @@ use komodo_client::entities::{ update::Update, user::User, }; -use periphery_client::all_server_channels; use crate::{ config::core_config, helpers::query::get_system_info, monitor::update_cache_for_server, - state::{action_states, db_client, server_status_cache}, + state::{ + action_states, all_server_channels, db_client, + server_status_cache, + }, }; impl super::KomodoResource for Server { diff --git a/bin/core/src/stack/execute.rs b/bin/core/src/stack/execute.rs index a67d41eb4..c82894ca3 100644 --- a/bin/core/src/stack/execute.rs +++ b/bin/core/src/stack/execute.rs @@ -7,11 +7,12 @@ use komodo_client::{ user::User, }, }; -use periphery_client::{PeripheryClient, api::compose::*}; +use periphery_client::api::compose::*; use crate::{ helpers::{periphery_client, update::update_update}, monitor::update_cache_for_server, + periphery::PeripheryClient, state::action_states, }; diff --git a/bin/core/src/state.rs b/bin/core/src/state.rs index 039946c52..44068b369 100644 --- a/bin/core/src/state.rs +++ b/bin/core/src/state.rs @@ -5,6 +5,7 @@ use std::{ use anyhow::Context; use arc_swap::ArcSwap; +use bytes::Bytes; use cache::CloneCache; use database::Client; use komodo_client::entities::{ @@ -19,6 +20,8 @@ use komodo_client::entities::{ use octorust::auth::{ Credentials, InstallationTokenGenerator, JWTCredentials, }; +use tokio::sync::mpsc::Sender; +use uuid::Uuid; use crate::{ auth::jwt::JwtClient, @@ -29,7 +32,7 @@ use crate::{ monitor::{ CachedDeploymentStatus, CachedRepoStatus, CachedServerStatus, CachedStackStatus, History, - }, + }, periphery::PeripheryConnection, }; static DB_CLIENT: OnceLock = OnceLock::new(); @@ -46,7 +49,9 @@ pub async fn init_db_client() { .await .context("failed to initialize database client") .unwrap(); - DB_CLIENT.set(client).expect("db_client initialized more than once"); + DB_CLIENT + .set(client) + .expect("db_client initialized more than once"); } pub fn jwt_client() -> &'static JwtClient { @@ -210,3 +215,22 @@ pub fn all_resources_cache() -> &'static ArcSwap { OnceLock::new(); ALL_RESOURCES.get_or_init(Default::default) } + +pub type ServerChannels = CloneCache>; +// Server id => ServerChannel +pub type AllServerChannels = CloneCache>; + +pub fn all_server_channels() -> &'static AllServerChannels { + static CHANNELS: OnceLock = OnceLock::new(); + CHANNELS.get_or_init(Default::default) +} + +/// server id => connection +pub type PeripheryConnections = + CloneCache>; + +pub fn periphery_connections() -> &'static PeripheryConnections { + static CONNECTIONS: OnceLock = + OnceLock::new(); + CONNECTIONS.get_or_init(Default::default) +} diff --git a/bin/core/src/ws/mod.rs b/bin/core/src/ws/mod.rs index 025384fd1..abf414831 100644 --- a/bin/core/src/ws/mod.rs +++ b/bin/core/src/ws/mod.rs @@ -1,6 +1,8 @@ use crate::{ auth::{auth_api_key_check_enabled, auth_jwt_check_enabled}, helpers::query::get_user, + periphery::PeripheryClient, + state::all_server_channels, }; use anyhow::anyhow; use axum::{ @@ -14,10 +16,7 @@ use komodo_client::{ entities::{server::Server, user::User}, ws::WsLoginMessage, }; -use periphery_client::{ - PeripheryClient, all_server_channels, - api::terminal::DisconnectTerminal, -}; +use periphery_client::api::terminal::DisconnectTerminal; use tokio::sync::mpsc::{Receiver, Sender}; use tokio_util::sync::CancellationToken; use transport::{ diff --git a/bin/periphery/src/api/terminal.rs b/bin/periphery/src/api/terminal.rs index 7e7638aee..07663129b 100644 --- a/bin/periphery/src/api/terminal.rs +++ b/bin/periphery/src/api/terminal.rs @@ -8,10 +8,7 @@ use komodo_client::{ api::write::TerminalRecreateMode, entities::{KOMODO_EXIT_CODE, NoData, server::TerminalInfo}, }; -use periphery_client::{ - api::terminal::*, - terminal::{END_OF_OUTPUT, START_OF_OUTPUT}, -}; +use periphery_client::api::terminal::*; use resolver_api::Resolve; use serror::AddStatusCodeError; use tokio_util::{codec::LinesCodecError, sync::CancellationToken}; diff --git a/client/periphery/rs/Cargo.toml b/client/periphery/rs/Cargo.toml index 35d8a2afc..334e3c040 100644 --- a/client/periphery/rs/Cargo.toml +++ b/client/periphery/rs/Cargo.toml @@ -12,18 +12,9 @@ repository.workspace = true [dependencies] # local komodo_client.workspace = true -transport.workspace = true -cache.workspace = true # mogh resolver_api.workspace = true serror.workspace = true # external -futures-util.workspace = true -tokio-util.workspace = true -serde_json.workspace = true -tracing.workspace = true -anyhow.workspace = true -bytes.workspace = true -tokio.workspace = true serde.workspace = true uuid.workspace = true diff --git a/client/periphery/rs/src/api/terminal.rs b/client/periphery/rs/src/api/terminal.rs index f17b69f72..cfc4caa49 100644 --- a/client/periphery/rs/src/api/terminal.rs +++ b/client/periphery/rs/src/api/terminal.rs @@ -6,6 +6,10 @@ use resolver_api::Resolve; use serde::{Deserialize, Serialize}; use uuid::Uuid; +/// Execute Sentinels +pub const START_OF_OUTPUT: &str = "__KOMODO_START_OF_OUTPUT__"; +pub const END_OF_OUTPUT: &str = "__KOMODO_END_OF_OUTPUT__"; + #[derive(Serialize, Deserialize, Debug, Clone, Resolve)] #[response(Vec)] #[error(serror::Error)] diff --git a/client/periphery/rs/src/lib.rs b/client/periphery/rs/src/lib.rs index f828ec921..8501d0f8c 100644 --- a/client/periphery/rs/src/lib.rs +++ b/client/periphery/rs/src/lib.rs @@ -1,272 +1,3 @@ -use std::{ - sync::{ - Arc, OnceLock, - atomic::{self, AtomicBool}, - }, - time::Duration, -}; - -use anyhow::{Context, anyhow}; -use bytes::Bytes; -use cache::CloneCache; -use resolver_api::HasResponse; -use serde::{Serialize, de::DeserializeOwned}; -use serde_json::json; -use serror::{deserialize_error_bytes, serror_into_anyhow_error}; -use tokio::sync::{ - RwLock, - mpsc::{self, Sender, error::SendError}, -}; -use tokio_util::sync::CancellationToken; -use tracing::warn; -use transport::{ - MessageState, - bytes::{from_transport_bytes, to_transport_bytes}, - channel::{BufferedReceiver, buffered_channel}, - fix_ws_address, -}; -use uuid::Uuid; - pub mod api; -pub mod terminal; - -pub type ServerChannels = CloneCache>; -// Server id => ServerChannel -pub type AllServerChannels = CloneCache>; - -pub fn all_server_channels() -> &'static AllServerChannels { - static CHANNELS: OnceLock = OnceLock::new(); - CHANNELS.get_or_init(Default::default) -} - -pub struct PeripheryClient { - pub server_id: String, - channels: Arc, -} - -impl PeripheryClient { - pub async fn new(server_id: String) -> PeripheryClient { - PeripheryClient { - channels: all_server_channels() - .get_or_insert_default(&server_id) - .await, - server_id, - } - } - - pub async fn new_with_spawned_client_connection< - F: Future>, - >( - server_id: String, - address: &str, - // (Server id, address) - spawn: impl FnOnce(String, String) -> F, - ) -> anyhow::Result { - if address.is_empty() { - return Err(anyhow!("Server address cannot be empty")); - } - let periphery = PeripheryClient::new(server_id.clone()).await; - spawn(server_id, fix_ws_address(address)).await?; - Ok(periphery) - } - - #[tracing::instrument(level = "debug", skip(self))] - pub async fn health_check(&self) -> anyhow::Result<()> { - self.request(api::GetHealth {}).await?; - Ok(()) - } - - #[tracing::instrument( - name = "PeripheryRequest", - skip(self), - level = "debug" - )] - pub async fn request( - &self, - request: T, - ) -> anyhow::Result - where - T: std::fmt::Debug + Serialize + HasResponse, - T::Response: DeserializeOwned, - { - let connection = periphery_connections() - .get(&self.server_id) - .await - .with_context(|| { - format!("No connection found for server {}", self.server_id) - })?; - - // Polls connected 3 times before bailing - connection.bail_if_not_connected().await?; - - let id = Uuid::new_v4(); - let (response_sender, mut response_receiever) = - mpsc::channel(1000); - self.channels.insert(id, response_sender).await; - - let req_type = T::req_type(); - let data = serde_json::to_vec(&json!({ - "type": req_type, - "params": request - })) - .context("Failed to serialize request to bytes")?; - - if let Err(e) = connection - .send(to_transport_bytes(data, id, MessageState::Request)) - .await - .context("Failed to send request over channel") - { - // cleanup - self.channels.remove(&id).await; - return Err(e); - } - - // Poll for the associated response - loop { - let next = tokio::select! { - msg = response_receiever.recv() => msg, - // Periphery will send InProgress every 5s to avoid timeout - _ = tokio::time::sleep(Duration::from_secs(10)) => { - return Err(anyhow!("Response timed out")); - } - }; - - let bytes = match next { - Some(bytes) => bytes, - None => { - return Err(anyhow!( - "Sender dropped before response was recieved" - )); - } - }; - - let (state, data) = match from_transport_bytes(bytes) { - Ok((data, _, state)) if !data.is_empty() => (state, data), - // Ignore no data cases - Ok(_) => continue, - Err(e) => { - warn!( - "Server {} | Received invalid message | {e:#}", - self.server_id - ); - continue; - } - }; - match state { - // TODO: improve the allocation in .to_vec - MessageState::Successful => { - // cleanup - self.channels.remove(&id).await; - return serde_json::from_slice(&data) - .context("Failed to parse successful response"); - } - MessageState::Failed => { - // cleanup - self.channels.remove(&id).await; - return Err(deserialize_error_bytes(&data)); - } - MessageState::InProgress => continue, - // Shouldn't be received by this receiver - other => { - // TODO: delete log - warn!( - "Server {} | Got other message over over response channel: {other:?}", - self.server_id - ); - continue; - } - } - } - } -} - -/// server id => connection -pub type PeripheryConnections = - CloneCache>; - -pub fn periphery_connections() -> &'static PeripheryConnections { - static CONNECTIONS: OnceLock = - OnceLock::new(); - CONNECTIONS.get_or_init(Default::default) -} pub const CONNECTION_RETRY_SECONDS: u64 = 5; - -#[derive(Debug)] -pub struct PeripheryConnection { - // Inbound connections have this as None - pub address: Option, - pub write_sender: Sender, - pub connected: AtomicBool, - pub error: RwLock>, - pub cancel: CancellationToken, -} - -impl PeripheryConnection { - pub fn new( - address: Option, - ) -> (Arc, BufferedReceiver) { - let (write_sender, write_receiver) = buffered_channel(1000); - ( - PeripheryConnection { - address, - write_sender, - connected: AtomicBool::new(false), - error: RwLock::new(None), - cancel: CancellationToken::new(), - } - .into(), - write_receiver, - ) - } - - pub async fn send( - &self, - value: Bytes, - ) -> Result<(), SendError> { - self.write_sender.send(value).await - } - - pub fn set_connected(&self, connected: bool) { - self.connected.store(connected, atomic::Ordering::Relaxed); - } - - pub fn connected(&self) -> bool { - self.connected.load(atomic::Ordering::Relaxed) - } - - /// Polls connected 3 times (1s in between) before bailing. - pub async fn bail_if_not_connected(&self) -> anyhow::Result<()> { - for i in 0..3 { - if self.connected() { - return Ok(()); - } - if i < 2 { - tokio::time::sleep(Duration::from_secs(1)).await; - } - } - if let Some(e) = self.error().await { - Err(serror_into_anyhow_error(e)) - } else { - Err(anyhow!("Server is not currently connected")) - } - } - - pub async fn error(&self) -> Option { - self.error.read().await.clone() - } - - pub async fn set_error(&self, e: anyhow::Error) { - let mut error = self.error.write().await; - *error = Some(e.into()); - } - - pub async fn clear_error(&self) { - let mut error = self.error.write().await; - *error = None; - } - - pub fn cancel(&self) { - self.cancel.cancel(); - } -}