diff --git a/bin/core/src/connection/mod.rs b/bin/core/src/connection/mod.rs index d88c547de..43e05a088 100644 --- a/bin/core/src/connection/mod.rs +++ b/bin/core/src/connection/mod.rs @@ -341,7 +341,7 @@ impl PeripheryConnection { match ws_write.send(message).await { Ok(_) => receiver.clear_buffer(), Err(e) => { - self.set_error(e.into()).await; + self.set_error(e).await; break; } } @@ -363,7 +363,7 @@ impl PeripheryConnection { break; } Err(e) => { - self.set_error(e.into()).await; + self.set_error(e).await; } }; } diff --git a/bin/core/src/periphery/mod.rs b/bin/core/src/periphery/mod.rs index 731df156f..540fe76f3 100644 --- a/bin/core/src/periphery/mod.rs +++ b/bin/core/src/periphery/mod.rs @@ -44,10 +44,7 @@ impl PeripheryClient { return Err(anyhow!("Server {id} is not connected")); } let channels = args - .spawn_client_connection( - id.clone(), - insecure_tls, - ) + .spawn_client_connection(id.clone(), insecure_tls) .await?; return Ok(PeripheryClient { id, channels }); }; @@ -77,10 +74,7 @@ impl PeripheryClient { } else { // Core -> Periphery connection let channels = args - .spawn_client_connection( - id.clone(), - insecure_tls, - ) + .spawn_client_connection(id.clone(), insecure_tls) .await?; Ok(PeripheryClient { id, channels }) } diff --git a/bin/periphery/src/api/keys.rs b/bin/periphery/src/api/keys.rs index 7e39b8d18..07e2a4769 100644 --- a/bin/periphery/src/api/keys.rs +++ b/bin/periphery/src/api/keys.rs @@ -9,8 +9,9 @@ use periphery_client::api::keys::{ }; use resolver_api::Resolve; -use crate::config::{ - core_public_keys, periphery_config, periphery_private_key, +use crate::{ + config::{periphery_config, periphery_private_key}, + connection::core_public_keys, }; // @@ -44,25 +45,25 @@ impl Resolve for RotatePrivateKey { impl Resolve for RotateCorePublicKey { async fn resolve(self, _: &super::Args) -> serror::Result { let config = periphery_config(); - let (Some(core_public_keys_spec), Some(core_public_keys)) = - (config.core_public_keys.as_ref(), core_public_keys()) + + let Some(core_public_keys_spec) = + config.core_public_keys.as_ref() else { return Ok(NoData {}); }; + let Some(core_public_key_path) = core_public_keys_spec .first() .and_then(|key| key.strip_prefix("file:")) else { return Ok(NoData {}); }; - let public_key = SpkiPublicKey::from(self.public_key); - public_key.write_pem(core_public_key_path)?; - let mut new_core_public_keys = vec![public_key]; - // This only replaces the first, extend the rest. - new_core_public_keys - .extend_from_slice(&core_public_keys.load().as_slice()[1..]); - // Store the new keys for the next auth - core_public_keys.store(Arc::new(new_core_public_keys)); + + SpkiPublicKey::from(self.public_key) + .write_pem(core_public_key_path)?; + + core_public_keys().refresh(); + Ok(NoData {}) } } diff --git a/bin/periphery/src/config.rs b/bin/periphery/src/config.rs index 94d016773..9cb7eb377 100644 --- a/bin/periphery/src/config.rs +++ b/bin/periphery/src/config.rs @@ -1,5 +1,4 @@ use std::{ - fs::read_to_string, path::PathBuf, sync::{Arc, OnceLock}, }; @@ -52,39 +51,6 @@ pub fn periphery_public_key() -> &'static ArcSwap { }) } -pub fn core_public_keys() --> Option<&'static ArcSwap>> { - static CORE_PUBLIC_KEYS: OnceLock< - Option>>, - > = OnceLock::new(); - CORE_PUBLIC_KEYS - .get_or_init(|| { - periphery_config().core_public_keys.as_ref().map( - |public_keys| { - let public_keys = public_keys - .iter() - .map(|public_key| { - let maybe_pem = if let Some(path) = - public_key.strip_prefix("file:") - { - read_to_string(path) - .with_context(|| { - format!("Failed to read public key at {path:?}") - }) - .unwrap() - } else { - public_key.clone() - }; - SpkiPublicKey::from_maybe_pem(&maybe_pem).unwrap() - }) - .collect::>(); - ArcSwap::new(Arc::new(public_keys)) - }, - ) - }) - .as_ref() -} - pub fn periphery_args() -> &'static CliArgs { static PERIPHERY_ARGS: OnceLock = OnceLock::new(); PERIPHERY_ARGS.get_or_init(CliArgs::parse) diff --git a/bin/periphery/src/connection/client.rs b/bin/periphery/src/connection/client.rs index c88f5b134..3998efd54 100644 --- a/bin/periphery/src/connection/client.rs +++ b/bin/periphery/src/connection/client.rs @@ -15,7 +15,7 @@ use transport::{ use crate::{ api::Args, config::{periphery_config, periphery_public_key}, - connection::{CorePublicKeyValidator, core_channels}, + connection::{core_channels, core_public_keys}, }; pub async fn handler(address: &str) -> anyhow::Result<()> { @@ -169,7 +169,7 @@ async fn handle_onboarding( ClientLoginFlow::login(LoginFlowArgs { private_key: onboarding_key, identifiers, - public_key_validator: CorePublicKeyValidator, + public_key_validator: core_public_keys(), socket: &mut socket, }) .await?; diff --git a/bin/periphery/src/connection/mod.rs b/bin/periphery/src/connection/mod.rs index 16312e3f0..60e37eb57 100644 --- a/bin/periphery/src/connection/mod.rs +++ b/bin/periphery/src/connection/mod.rs @@ -1,11 +1,14 @@ use std::{ + fs::read_to_string, sync::{Arc, OnceLock}, time::Duration, }; -use anyhow::anyhow; +use anyhow::{Context as _, anyhow}; +use arc_swap::ArcSwap; use bytes::Bytes; use cache::CloneCache; +use noise::key::SpkiPublicKey; use resolver_api::Resolve; use response::JsonBytes; use transport::{ @@ -21,9 +24,7 @@ use uuid::Uuid; use crate::{ api::{Args, PeripheryRequest}, - config::{ - core_public_keys, periphery_config, periphery_private_key, - }, + config::{periphery_config, periphery_private_key}, }; pub mod client; @@ -37,24 +38,70 @@ pub fn core_channels() -> &'static CoreChannels { CORE_CHANNELS.get_or_init(Default::default) } -pub struct CorePublicKeyValidator; +pub fn core_public_keys() -> &'static CorePublicKeys { + static CORE_PUBLIC_KEYS: OnceLock = OnceLock::new(); + CORE_PUBLIC_KEYS.get_or_init(CorePublicKeys::default) +} -impl PublicKeyValidator for CorePublicKeyValidator { +pub struct CorePublicKeys(ArcSwap>); + +impl Default for CorePublicKeys { + fn default() -> Self { + let keys = CorePublicKeys(Default::default()); + keys.refresh(); + keys + } +} + +impl CorePublicKeys { + pub fn is_valid(&self, public_key: &str) -> bool { + let keys = self.0.load(); + keys.is_empty() || keys.iter().any(|pk| pk.as_str() == public_key) + } + + pub fn refresh(&self) { + let Some(core_public_keys) = + periphery_config().core_public_keys.as_ref() + else { + return; + }; + let core_public_keys = core_public_keys + .iter() + .flat_map(|public_key| { + let maybe_pem = + if let Some(path) = public_key.strip_prefix("file:") { + read_to_string(path) + .with_context(|| { + format!("Failed to read public key at {path:?}") + }) + .inspect_err(|e| warn!("{e:#}")) + .ok()? + } else { + public_key.clone() + }; + SpkiPublicKey::from_maybe_pem(&maybe_pem) + .inspect_err(|e| warn!("{e:#}")) + .ok() + }) + .collect::>(); + self.0.store(Arc::new(core_public_keys)); + } +} + +impl PublicKeyValidator for &CorePublicKeys { type ValidationResult = (); async fn validate(&self, public_key: String) -> anyhow::Result<()> { - if let Some(public_keys) = core_public_keys() - && public_keys - .load() - .iter() - .all(|expected| public_key != expected.as_str()) + let keys = self.0.load(); + if keys.is_empty() + || keys.iter().any(|pk| pk.as_str() == public_key) { + Ok(()) + } else { Err( anyhow!("{public_key} is invalid") .context("Ensure public key matches one of the 'core_public_keys' in periphery config (PERIPHERY_CORE_PUBLIC_KEYS)") .context("Periphery failed to validate Core public key"), ) - } else { - Ok(()) } } } @@ -67,7 +114,7 @@ async fn handle_login( socket, identifiers, private_key: periphery_private_key().load().as_str(), - public_key_validator: CorePublicKeyValidator, + public_key_validator: core_public_keys(), }) .await } diff --git a/bin/periphery/src/connection/server.rs b/bin/periphery/src/connection/server.rs index de0bf5566..c7257e4f9 100644 --- a/bin/periphery/src/connection/server.rs +++ b/bin/periphery/src/connection/server.rs @@ -30,9 +30,7 @@ use transport::{ }; use crate::{ - api::Args, - config::{core_public_keys, periphery_config}, - connection::core_channels, + api::Args, config::periphery_config, connection::core_channels, }; pub async fn run() -> anyhow::Result<()> { @@ -160,7 +158,8 @@ async fn handle_login( socket: &mut AxumWebsocket, identifiers: ConnectionIdentifiers<'_>, ) -> anyhow::Result<()> { - match (core_public_keys(), &periphery_config().passkeys) { + let config = periphery_config(); + match (&config.core_public_keys, &config.passkeys) { (Some(_), _) | (_, None) => { // Send login type [0] (Noise auth) socket diff --git a/lib/transport/src/timeout.rs b/lib/transport/src/timeout.rs index 464c65c34..7e39e5c77 100644 --- a/lib/transport/src/timeout.rs +++ b/lib/transport/src/timeout.rs @@ -41,8 +41,7 @@ impl< tokio::time::timeout(timeout, self.inner).map(|res| { res .context("Timed out waiting for message.") - .map(|inner| inner.map_err(Into::into)) - .flatten() + .and_then(|inner| inner.map_err(Into::into)) }) } } diff --git a/lib/transport/src/websocket/mod.rs b/lib/transport/src/websocket/mod.rs index 33bc90262..c0eda5d99 100644 --- a/lib/transport/src/websocket/mod.rs +++ b/lib/transport/src/websocket/mod.rs @@ -84,7 +84,7 @@ pub trait Websocket: Send { MaybeWithTimeout::new( self .recv_message() - .map(|res| res.map(|message| message.into_parts()).flatten()), + .map(|res| res.and_then(|message| message.into_parts())), ) } @@ -96,12 +96,10 @@ pub trait Websocket: Send { impl Future> + Send, > { MaybeWithTimeout::new(self.recv_parts().map(|res| { - res - .map(|(data, _, state)| match state { - MessageState::Successful => Ok(data), - _ => Err(deserialize_error_bytes(&data)), - }) - .flatten() + res.and_then(|(data, _, state)| match state { + MessageState::Successful => Ok(data), + _ => Err(deserialize_error_bytes(&data)), + }) })) } } @@ -154,7 +152,7 @@ pub trait WebsocketReceiver: Send { MaybeWithTimeout::new( self .recv_message() - .map(|res| res.map(|message| message.into_parts()).flatten()), + .map(|res| res.and_then(|message| message.into_parts())), ) } }