core public keys improve refresh

This commit is contained in:
mbecker20
2025-10-07 00:27:47 -07:00
parent 845e8780c7
commit c8c62ea562
9 changed files with 90 additions and 86 deletions

View File

@@ -341,7 +341,7 @@ impl PeripheryConnection {
match ws_write.send(message).await {
Ok(_) => receiver.clear_buffer(),
Err(e) => {
self.set_error(e.into()).await;
self.set_error(e).await;
break;
}
}
@@ -363,7 +363,7 @@ impl PeripheryConnection {
break;
}
Err(e) => {
self.set_error(e.into()).await;
self.set_error(e).await;
}
};
}

View File

@@ -44,10 +44,7 @@ impl PeripheryClient {
return Err(anyhow!("Server {id} is not connected"));
}
let channels = args
.spawn_client_connection(
id.clone(),
insecure_tls,
)
.spawn_client_connection(id.clone(), insecure_tls)
.await?;
return Ok(PeripheryClient { id, channels });
};
@@ -77,10 +74,7 @@ impl PeripheryClient {
} else {
// Core -> Periphery connection
let channels = args
.spawn_client_connection(
id.clone(),
insecure_tls,
)
.spawn_client_connection(id.clone(), insecure_tls)
.await?;
Ok(PeripheryClient { id, channels })
}

View File

@@ -9,8 +9,9 @@ use periphery_client::api::keys::{
};
use resolver_api::Resolve;
use crate::config::{
core_public_keys, periphery_config, periphery_private_key,
use crate::{
config::{periphery_config, periphery_private_key},
connection::core_public_keys,
};
//
@@ -44,25 +45,25 @@ impl Resolve<super::Args> for RotatePrivateKey {
impl Resolve<super::Args> for RotateCorePublicKey {
async fn resolve(self, _: &super::Args) -> serror::Result<NoData> {
let config = periphery_config();
let (Some(core_public_keys_spec), Some(core_public_keys)) =
(config.core_public_keys.as_ref(), core_public_keys())
let Some(core_public_keys_spec) =
config.core_public_keys.as_ref()
else {
return Ok(NoData {});
};
let Some(core_public_key_path) = core_public_keys_spec
.first()
.and_then(|key| key.strip_prefix("file:"))
else {
return Ok(NoData {});
};
let public_key = SpkiPublicKey::from(self.public_key);
public_key.write_pem(core_public_key_path)?;
let mut new_core_public_keys = vec![public_key];
// This only replaces the first, extend the rest.
new_core_public_keys
.extend_from_slice(&core_public_keys.load().as_slice()[1..]);
// Store the new keys for the next auth
core_public_keys.store(Arc::new(new_core_public_keys));
SpkiPublicKey::from(self.public_key)
.write_pem(core_public_key_path)?;
core_public_keys().refresh();
Ok(NoData {})
}
}

View File

@@ -1,5 +1,4 @@
use std::{
fs::read_to_string,
path::PathBuf,
sync::{Arc, OnceLock},
};
@@ -52,39 +51,6 @@ pub fn periphery_public_key() -> &'static ArcSwap<SpkiPublicKey> {
})
}
pub fn core_public_keys()
-> Option<&'static ArcSwap<Vec<SpkiPublicKey>>> {
static CORE_PUBLIC_KEYS: OnceLock<
Option<ArcSwap<Vec<SpkiPublicKey>>>,
> = OnceLock::new();
CORE_PUBLIC_KEYS
.get_or_init(|| {
periphery_config().core_public_keys.as_ref().map(
|public_keys| {
let public_keys = public_keys
.iter()
.map(|public_key| {
let maybe_pem = if let Some(path) =
public_key.strip_prefix("file:")
{
read_to_string(path)
.with_context(|| {
format!("Failed to read public key at {path:?}")
})
.unwrap()
} else {
public_key.clone()
};
SpkiPublicKey::from_maybe_pem(&maybe_pem).unwrap()
})
.collect::<Vec<_>>();
ArcSwap::new(Arc::new(public_keys))
},
)
})
.as_ref()
}
pub fn periphery_args() -> &'static CliArgs {
static PERIPHERY_ARGS: OnceLock<CliArgs> = OnceLock::new();
PERIPHERY_ARGS.get_or_init(CliArgs::parse)

View File

@@ -15,7 +15,7 @@ use transport::{
use crate::{
api::Args,
config::{periphery_config, periphery_public_key},
connection::{CorePublicKeyValidator, core_channels},
connection::{core_channels, core_public_keys},
};
pub async fn handler(address: &str) -> anyhow::Result<()> {
@@ -169,7 +169,7 @@ async fn handle_onboarding(
ClientLoginFlow::login(LoginFlowArgs {
private_key: onboarding_key,
identifiers,
public_key_validator: CorePublicKeyValidator,
public_key_validator: core_public_keys(),
socket: &mut socket,
})
.await?;

View File

@@ -1,11 +1,14 @@
use std::{
fs::read_to_string,
sync::{Arc, OnceLock},
time::Duration,
};
use anyhow::anyhow;
use anyhow::{Context as _, anyhow};
use arc_swap::ArcSwap;
use bytes::Bytes;
use cache::CloneCache;
use noise::key::SpkiPublicKey;
use resolver_api::Resolve;
use response::JsonBytes;
use transport::{
@@ -21,9 +24,7 @@ use uuid::Uuid;
use crate::{
api::{Args, PeripheryRequest},
config::{
core_public_keys, periphery_config, periphery_private_key,
},
config::{periphery_config, periphery_private_key},
};
pub mod client;
@@ -37,24 +38,70 @@ pub fn core_channels() -> &'static CoreChannels {
CORE_CHANNELS.get_or_init(Default::default)
}
pub struct CorePublicKeyValidator;
pub fn core_public_keys() -> &'static CorePublicKeys {
static CORE_PUBLIC_KEYS: OnceLock<CorePublicKeys> = OnceLock::new();
CORE_PUBLIC_KEYS.get_or_init(CorePublicKeys::default)
}
impl PublicKeyValidator for CorePublicKeyValidator {
pub struct CorePublicKeys(ArcSwap<Vec<SpkiPublicKey>>);
impl Default for CorePublicKeys {
fn default() -> Self {
let keys = CorePublicKeys(Default::default());
keys.refresh();
keys
}
}
impl CorePublicKeys {
pub fn is_valid(&self, public_key: &str) -> bool {
let keys = self.0.load();
keys.is_empty() || keys.iter().any(|pk| pk.as_str() == public_key)
}
pub fn refresh(&self) {
let Some(core_public_keys) =
periphery_config().core_public_keys.as_ref()
else {
return;
};
let core_public_keys = core_public_keys
.iter()
.flat_map(|public_key| {
let maybe_pem =
if let Some(path) = public_key.strip_prefix("file:") {
read_to_string(path)
.with_context(|| {
format!("Failed to read public key at {path:?}")
})
.inspect_err(|e| warn!("{e:#}"))
.ok()?
} else {
public_key.clone()
};
SpkiPublicKey::from_maybe_pem(&maybe_pem)
.inspect_err(|e| warn!("{e:#}"))
.ok()
})
.collect::<Vec<_>>();
self.0.store(Arc::new(core_public_keys));
}
}
impl PublicKeyValidator for &CorePublicKeys {
type ValidationResult = ();
async fn validate(&self, public_key: String) -> anyhow::Result<()> {
if let Some(public_keys) = core_public_keys()
&& public_keys
.load()
.iter()
.all(|expected| public_key != expected.as_str())
let keys = self.0.load();
if keys.is_empty()
|| keys.iter().any(|pk| pk.as_str() == public_key)
{
Ok(())
} else {
Err(
anyhow!("{public_key} is invalid")
.context("Ensure public key matches one of the 'core_public_keys' in periphery config (PERIPHERY_CORE_PUBLIC_KEYS)")
.context("Periphery failed to validate Core public key"),
)
} else {
Ok(())
}
}
}
@@ -67,7 +114,7 @@ async fn handle_login<W: Websocket, L: LoginFlow>(
socket,
identifiers,
private_key: periphery_private_key().load().as_str(),
public_key_validator: CorePublicKeyValidator,
public_key_validator: core_public_keys(),
})
.await
}

View File

@@ -30,9 +30,7 @@ use transport::{
};
use crate::{
api::Args,
config::{core_public_keys, periphery_config},
connection::core_channels,
api::Args, config::periphery_config, connection::core_channels,
};
pub async fn run() -> anyhow::Result<()> {
@@ -160,7 +158,8 @@ async fn handle_login(
socket: &mut AxumWebsocket,
identifiers: ConnectionIdentifiers<'_>,
) -> anyhow::Result<()> {
match (core_public_keys(), &periphery_config().passkeys) {
let config = periphery_config();
match (&config.core_public_keys, &config.passkeys) {
(Some(_), _) | (_, None) => {
// Send login type [0] (Noise auth)
socket

View File

@@ -41,8 +41,7 @@ impl<
tokio::time::timeout(timeout, self.inner).map(|res| {
res
.context("Timed out waiting for message.")
.map(|inner| inner.map_err(Into::into))
.flatten()
.and_then(|inner| inner.map_err(Into::into))
})
}
}

View File

@@ -84,7 +84,7 @@ pub trait Websocket: Send {
MaybeWithTimeout::new(
self
.recv_message()
.map(|res| res.map(|message| message.into_parts()).flatten()),
.map(|res| res.and_then(|message| message.into_parts())),
)
}
@@ -96,12 +96,10 @@ pub trait Websocket: Send {
impl Future<Output = anyhow::Result<Bytes>> + Send,
> {
MaybeWithTimeout::new(self.recv_parts().map(|res| {
res
.map(|(data, _, state)| match state {
MessageState::Successful => Ok(data),
_ => Err(deserialize_error_bytes(&data)),
})
.flatten()
res.and_then(|(data, _, state)| match state {
MessageState::Successful => Ok(data),
_ => Err(deserialize_error_bytes(&data)),
})
}))
}
}
@@ -154,7 +152,7 @@ pub trait WebsocketReceiver: Send {
MaybeWithTimeout::new(
self
.recv_message()
.map(|res| res.map(|message| message.into_parts()).flatten()),
.map(|res| res.and_then(|message| message.into_parts())),
)
}
}