core public keys improve refresh

This commit is contained in:
mbecker20
2025-10-07 00:27:47 -07:00
parent 845e8780c7
commit c8c62ea562
9 changed files with 90 additions and 86 deletions

View File

@@ -341,7 +341,7 @@ impl PeripheryConnection {
match ws_write.send(message).await { match ws_write.send(message).await {
Ok(_) => receiver.clear_buffer(), Ok(_) => receiver.clear_buffer(),
Err(e) => { Err(e) => {
self.set_error(e.into()).await; self.set_error(e).await;
break; break;
} }
} }
@@ -363,7 +363,7 @@ impl PeripheryConnection {
break; break;
} }
Err(e) => { Err(e) => {
self.set_error(e.into()).await; self.set_error(e).await;
} }
}; };
} }

View File

@@ -44,10 +44,7 @@ impl PeripheryClient {
return Err(anyhow!("Server {id} is not connected")); return Err(anyhow!("Server {id} is not connected"));
} }
let channels = args let channels = args
.spawn_client_connection( .spawn_client_connection(id.clone(), insecure_tls)
id.clone(),
insecure_tls,
)
.await?; .await?;
return Ok(PeripheryClient { id, channels }); return Ok(PeripheryClient { id, channels });
}; };
@@ -77,10 +74,7 @@ impl PeripheryClient {
} else { } else {
// Core -> Periphery connection // Core -> Periphery connection
let channels = args let channels = args
.spawn_client_connection( .spawn_client_connection(id.clone(), insecure_tls)
id.clone(),
insecure_tls,
)
.await?; .await?;
Ok(PeripheryClient { id, channels }) Ok(PeripheryClient { id, channels })
} }

View File

@@ -9,8 +9,9 @@ use periphery_client::api::keys::{
}; };
use resolver_api::Resolve; use resolver_api::Resolve;
use crate::config::{ use crate::{
core_public_keys, periphery_config, periphery_private_key, config::{periphery_config, periphery_private_key},
connection::core_public_keys,
}; };
// //
@@ -44,25 +45,25 @@ impl Resolve<super::Args> for RotatePrivateKey {
impl Resolve<super::Args> for RotateCorePublicKey { impl Resolve<super::Args> for RotateCorePublicKey {
async fn resolve(self, _: &super::Args) -> serror::Result<NoData> { async fn resolve(self, _: &super::Args) -> serror::Result<NoData> {
let config = periphery_config(); let config = periphery_config();
let (Some(core_public_keys_spec), Some(core_public_keys)) =
(config.core_public_keys.as_ref(), core_public_keys()) let Some(core_public_keys_spec) =
config.core_public_keys.as_ref()
else { else {
return Ok(NoData {}); return Ok(NoData {});
}; };
let Some(core_public_key_path) = core_public_keys_spec let Some(core_public_key_path) = core_public_keys_spec
.first() .first()
.and_then(|key| key.strip_prefix("file:")) .and_then(|key| key.strip_prefix("file:"))
else { else {
return Ok(NoData {}); return Ok(NoData {});
}; };
let public_key = SpkiPublicKey::from(self.public_key);
public_key.write_pem(core_public_key_path)?; SpkiPublicKey::from(self.public_key)
let mut new_core_public_keys = vec![public_key]; .write_pem(core_public_key_path)?;
// This only replaces the first, extend the rest.
new_core_public_keys core_public_keys().refresh();
.extend_from_slice(&core_public_keys.load().as_slice()[1..]);
// Store the new keys for the next auth
core_public_keys.store(Arc::new(new_core_public_keys));
Ok(NoData {}) Ok(NoData {})
} }
} }

View File

@@ -1,5 +1,4 @@
use std::{ use std::{
fs::read_to_string,
path::PathBuf, path::PathBuf,
sync::{Arc, OnceLock}, sync::{Arc, OnceLock},
}; };
@@ -52,39 +51,6 @@ pub fn periphery_public_key() -> &'static ArcSwap<SpkiPublicKey> {
}) })
} }
pub fn core_public_keys()
-> Option<&'static ArcSwap<Vec<SpkiPublicKey>>> {
static CORE_PUBLIC_KEYS: OnceLock<
Option<ArcSwap<Vec<SpkiPublicKey>>>,
> = OnceLock::new();
CORE_PUBLIC_KEYS
.get_or_init(|| {
periphery_config().core_public_keys.as_ref().map(
|public_keys| {
let public_keys = public_keys
.iter()
.map(|public_key| {
let maybe_pem = if let Some(path) =
public_key.strip_prefix("file:")
{
read_to_string(path)
.with_context(|| {
format!("Failed to read public key at {path:?}")
})
.unwrap()
} else {
public_key.clone()
};
SpkiPublicKey::from_maybe_pem(&maybe_pem).unwrap()
})
.collect::<Vec<_>>();
ArcSwap::new(Arc::new(public_keys))
},
)
})
.as_ref()
}
pub fn periphery_args() -> &'static CliArgs { pub fn periphery_args() -> &'static CliArgs {
static PERIPHERY_ARGS: OnceLock<CliArgs> = OnceLock::new(); static PERIPHERY_ARGS: OnceLock<CliArgs> = OnceLock::new();
PERIPHERY_ARGS.get_or_init(CliArgs::parse) PERIPHERY_ARGS.get_or_init(CliArgs::parse)

View File

@@ -15,7 +15,7 @@ use transport::{
use crate::{ use crate::{
api::Args, api::Args,
config::{periphery_config, periphery_public_key}, config::{periphery_config, periphery_public_key},
connection::{CorePublicKeyValidator, core_channels}, connection::{core_channels, core_public_keys},
}; };
pub async fn handler(address: &str) -> anyhow::Result<()> { pub async fn handler(address: &str) -> anyhow::Result<()> {
@@ -169,7 +169,7 @@ async fn handle_onboarding(
ClientLoginFlow::login(LoginFlowArgs { ClientLoginFlow::login(LoginFlowArgs {
private_key: onboarding_key, private_key: onboarding_key,
identifiers, identifiers,
public_key_validator: CorePublicKeyValidator, public_key_validator: core_public_keys(),
socket: &mut socket, socket: &mut socket,
}) })
.await?; .await?;

View File

@@ -1,11 +1,14 @@
use std::{ use std::{
fs::read_to_string,
sync::{Arc, OnceLock}, sync::{Arc, OnceLock},
time::Duration, time::Duration,
}; };
use anyhow::anyhow; use anyhow::{Context as _, anyhow};
use arc_swap::ArcSwap;
use bytes::Bytes; use bytes::Bytes;
use cache::CloneCache; use cache::CloneCache;
use noise::key::SpkiPublicKey;
use resolver_api::Resolve; use resolver_api::Resolve;
use response::JsonBytes; use response::JsonBytes;
use transport::{ use transport::{
@@ -21,9 +24,7 @@ use uuid::Uuid;
use crate::{ use crate::{
api::{Args, PeripheryRequest}, api::{Args, PeripheryRequest},
config::{ config::{periphery_config, periphery_private_key},
core_public_keys, periphery_config, periphery_private_key,
},
}; };
pub mod client; pub mod client;
@@ -37,24 +38,70 @@ pub fn core_channels() -> &'static CoreChannels {
CORE_CHANNELS.get_or_init(Default::default) CORE_CHANNELS.get_or_init(Default::default)
} }
pub struct CorePublicKeyValidator; pub fn core_public_keys() -> &'static CorePublicKeys {
static CORE_PUBLIC_KEYS: OnceLock<CorePublicKeys> = OnceLock::new();
CORE_PUBLIC_KEYS.get_or_init(CorePublicKeys::default)
}
impl PublicKeyValidator for CorePublicKeyValidator { pub struct CorePublicKeys(ArcSwap<Vec<SpkiPublicKey>>);
impl Default for CorePublicKeys {
fn default() -> Self {
let keys = CorePublicKeys(Default::default());
keys.refresh();
keys
}
}
impl CorePublicKeys {
pub fn is_valid(&self, public_key: &str) -> bool {
let keys = self.0.load();
keys.is_empty() || keys.iter().any(|pk| pk.as_str() == public_key)
}
pub fn refresh(&self) {
let Some(core_public_keys) =
periphery_config().core_public_keys.as_ref()
else {
return;
};
let core_public_keys = core_public_keys
.iter()
.flat_map(|public_key| {
let maybe_pem =
if let Some(path) = public_key.strip_prefix("file:") {
read_to_string(path)
.with_context(|| {
format!("Failed to read public key at {path:?}")
})
.inspect_err(|e| warn!("{e:#}"))
.ok()?
} else {
public_key.clone()
};
SpkiPublicKey::from_maybe_pem(&maybe_pem)
.inspect_err(|e| warn!("{e:#}"))
.ok()
})
.collect::<Vec<_>>();
self.0.store(Arc::new(core_public_keys));
}
}
impl PublicKeyValidator for &CorePublicKeys {
type ValidationResult = (); type ValidationResult = ();
async fn validate(&self, public_key: String) -> anyhow::Result<()> { async fn validate(&self, public_key: String) -> anyhow::Result<()> {
if let Some(public_keys) = core_public_keys() let keys = self.0.load();
&& public_keys if keys.is_empty()
.load() || keys.iter().any(|pk| pk.as_str() == public_key)
.iter()
.all(|expected| public_key != expected.as_str())
{ {
Ok(())
} else {
Err( Err(
anyhow!("{public_key} is invalid") anyhow!("{public_key} is invalid")
.context("Ensure public key matches one of the 'core_public_keys' in periphery config (PERIPHERY_CORE_PUBLIC_KEYS)") .context("Ensure public key matches one of the 'core_public_keys' in periphery config (PERIPHERY_CORE_PUBLIC_KEYS)")
.context("Periphery failed to validate Core public key"), .context("Periphery failed to validate Core public key"),
) )
} else {
Ok(())
} }
} }
} }
@@ -67,7 +114,7 @@ async fn handle_login<W: Websocket, L: LoginFlow>(
socket, socket,
identifiers, identifiers,
private_key: periphery_private_key().load().as_str(), private_key: periphery_private_key().load().as_str(),
public_key_validator: CorePublicKeyValidator, public_key_validator: core_public_keys(),
}) })
.await .await
} }

View File

@@ -30,9 +30,7 @@ use transport::{
}; };
use crate::{ use crate::{
api::Args, api::Args, config::periphery_config, connection::core_channels,
config::{core_public_keys, periphery_config},
connection::core_channels,
}; };
pub async fn run() -> anyhow::Result<()> { pub async fn run() -> anyhow::Result<()> {
@@ -160,7 +158,8 @@ async fn handle_login(
socket: &mut AxumWebsocket, socket: &mut AxumWebsocket,
identifiers: ConnectionIdentifiers<'_>, identifiers: ConnectionIdentifiers<'_>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
match (core_public_keys(), &periphery_config().passkeys) { let config = periphery_config();
match (&config.core_public_keys, &config.passkeys) {
(Some(_), _) | (_, None) => { (Some(_), _) | (_, None) => {
// Send login type [0] (Noise auth) // Send login type [0] (Noise auth)
socket socket

View File

@@ -41,8 +41,7 @@ impl<
tokio::time::timeout(timeout, self.inner).map(|res| { tokio::time::timeout(timeout, self.inner).map(|res| {
res res
.context("Timed out waiting for message.") .context("Timed out waiting for message.")
.map(|inner| inner.map_err(Into::into)) .and_then(|inner| inner.map_err(Into::into))
.flatten()
}) })
} }
} }

View File

@@ -84,7 +84,7 @@ pub trait Websocket: Send {
MaybeWithTimeout::new( MaybeWithTimeout::new(
self self
.recv_message() .recv_message()
.map(|res| res.map(|message| message.into_parts()).flatten()), .map(|res| res.and_then(|message| message.into_parts())),
) )
} }
@@ -96,12 +96,10 @@ pub trait Websocket: Send {
impl Future<Output = anyhow::Result<Bytes>> + Send, impl Future<Output = anyhow::Result<Bytes>> + Send,
> { > {
MaybeWithTimeout::new(self.recv_parts().map(|res| { MaybeWithTimeout::new(self.recv_parts().map(|res| {
res res.and_then(|(data, _, state)| match state {
.map(|(data, _, state)| match state { MessageState::Successful => Ok(data),
MessageState::Successful => Ok(data), _ => Err(deserialize_error_bytes(&data)),
_ => Err(deserialize_error_bytes(&data)), })
})
.flatten()
})) }))
} }
} }
@@ -154,7 +152,7 @@ pub trait WebsocketReceiver: Send {
MaybeWithTimeout::new( MaybeWithTimeout::new(
self self
.recv_message() .recv_message()
.map(|res| res.map(|message| message.into_parts()).flatten()), .map(|res| res.and_then(|message| message.into_parts())),
) )
} }
} }