mirror of
https://github.com/moghtech/komodo.git
synced 2026-04-29 21:27:26 -05:00
249 lines
6.3 KiB
Rust
249 lines
6.3 KiB
Rust
use std::{
|
|
fs::read_to_string,
|
|
sync::{Arc, OnceLock},
|
|
time::Duration,
|
|
};
|
|
|
|
use anyhow::{Context as _, anyhow};
|
|
use arc_swap::ArcSwap;
|
|
use bytes::Bytes;
|
|
use cache::CloneCache;
|
|
use noise::key::SpkiPublicKey;
|
|
use resolver_api::Resolve;
|
|
use response::JsonBytes;
|
|
use transport::{
|
|
auth::{
|
|
ConnectionIdentifiers, LoginFlow, LoginFlowArgs,
|
|
PublicKeyValidator,
|
|
},
|
|
channel::{BufferedChannel, BufferedReceiver, Sender},
|
|
message::{Message, MessageState},
|
|
websocket::{Websocket, WebsocketReceiver, WebsocketSender as _},
|
|
};
|
|
use uuid::Uuid;
|
|
|
|
use crate::{
|
|
api::{Args, PeripheryRequest},
|
|
config::{periphery_config, periphery_private_key},
|
|
};
|
|
|
|
pub mod client;
|
|
pub mod server;
|
|
|
|
// Core Address / Host -> Channel
|
|
pub type CoreChannels = CloneCache<String, Arc<BufferedChannel>>;
|
|
|
|
pub fn core_channels() -> &'static CoreChannels {
|
|
static CORE_CHANNELS: OnceLock<CoreChannels> = OnceLock::new();
|
|
CORE_CHANNELS.get_or_init(Default::default)
|
|
}
|
|
|
|
pub fn core_public_keys() -> &'static CorePublicKeys {
|
|
static CORE_PUBLIC_KEYS: OnceLock<CorePublicKeys> = OnceLock::new();
|
|
CORE_PUBLIC_KEYS.get_or_init(CorePublicKeys::default)
|
|
}
|
|
|
|
pub struct CorePublicKeys(ArcSwap<Vec<SpkiPublicKey>>);
|
|
|
|
impl Default for CorePublicKeys {
|
|
fn default() -> Self {
|
|
let keys = CorePublicKeys(Default::default());
|
|
keys.refresh();
|
|
keys
|
|
}
|
|
}
|
|
|
|
impl CorePublicKeys {
|
|
pub fn is_valid(&self, public_key: &str) -> bool {
|
|
let keys = self.0.load();
|
|
keys.is_empty() || keys.iter().any(|pk| pk.as_str() == public_key)
|
|
}
|
|
|
|
pub fn refresh(&self) {
|
|
let config = periphery_config();
|
|
let Some(core_public_keys) = config.core_public_keys.as_ref()
|
|
else {
|
|
return;
|
|
};
|
|
let core_public_keys = core_public_keys
|
|
.iter()
|
|
.flat_map(|public_key| {
|
|
let res = || {
|
|
if let Some(path) = public_key.strip_prefix("file:") {
|
|
let contents =
|
|
read_to_string(path).with_context(|| {
|
|
format!("Failed to read public key at {path:?}")
|
|
})?;
|
|
SpkiPublicKey::from_maybe_pem(&contents)
|
|
} else {
|
|
SpkiPublicKey::from_maybe_pem(&public_key)
|
|
}
|
|
};
|
|
match (res(), config.server_enabled) {
|
|
(Ok(public_key), _) => Some(public_key),
|
|
(Err(e), false) => {
|
|
// If only outbound connections, only warn.
|
|
// It will be written the next time `RotateCoreKeys` is executed.
|
|
warn!("{e:#}");
|
|
None
|
|
}
|
|
(Err(e), true) => {
|
|
// This is too dangerous to allow if server_enabled.
|
|
panic!("{e:#}");
|
|
}
|
|
}
|
|
})
|
|
.collect::<Vec<_>>();
|
|
self.0.store(Arc::new(core_public_keys));
|
|
}
|
|
}
|
|
|
|
impl PublicKeyValidator for &CorePublicKeys {
|
|
type ValidationResult = ();
|
|
async fn validate(&self, public_key: String) -> anyhow::Result<()> {
|
|
let keys = self.0.load();
|
|
if keys.is_empty()
|
|
|| keys.iter().any(|pk| pk.as_str() == public_key)
|
|
{
|
|
Ok(())
|
|
} else {
|
|
Err(
|
|
anyhow!("{public_key} is invalid")
|
|
.context("Ensure public key matches one of the 'core_public_keys' in periphery config (PERIPHERY_CORE_PUBLIC_KEYS)")
|
|
.context("Periphery failed to validate Core public key"),
|
|
)
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn handle_login<W: Websocket, L: LoginFlow>(
|
|
socket: &mut W,
|
|
identifiers: ConnectionIdentifiers<'_>,
|
|
) -> anyhow::Result<()> {
|
|
L::login(LoginFlowArgs {
|
|
socket,
|
|
identifiers,
|
|
private_key: periphery_private_key().load().as_str(),
|
|
public_key_validator: core_public_keys(),
|
|
})
|
|
.await
|
|
}
|
|
|
|
async fn handle_socket<W: Websocket>(
|
|
socket: W,
|
|
args: &Arc<Args>,
|
|
sender: &Sender,
|
|
receiver: &mut BufferedReceiver,
|
|
) {
|
|
let config = periphery_config();
|
|
info!(
|
|
"Logged in to Komodo Core {} websocket{}",
|
|
args.core,
|
|
if !config.core_addresses.is_empty()
|
|
&& !config.connect_as.is_empty()
|
|
{
|
|
format!(" as Server {}", config.connect_as)
|
|
} else {
|
|
String::new()
|
|
}
|
|
);
|
|
|
|
let (mut ws_write, mut ws_read) = socket.split();
|
|
|
|
let forward_writes = async {
|
|
loop {
|
|
let message = match receiver.recv().await {
|
|
Ok(message) => message,
|
|
Err(e) => {
|
|
warn!("{e:#}");
|
|
break;
|
|
}
|
|
};
|
|
match ws_write.send(message).await {
|
|
// Clears the stored message from receiver buffer.
|
|
Ok(_) => receiver.clear_buffer(),
|
|
Err(e) => {
|
|
warn!("Failed to send response | {e:?}");
|
|
let _ = ws_write.close(None).await;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
let handle_reads = async {
|
|
loop {
|
|
let (data, channel, state) = match ws_read.recv_parts().await {
|
|
Ok(res) => res,
|
|
Err(e) => {
|
|
warn!("{e:#}");
|
|
break;
|
|
}
|
|
};
|
|
match state {
|
|
MessageState::Request => {
|
|
handle_request(args.clone(), sender.clone(), channel, data)
|
|
}
|
|
MessageState::Terminal => {
|
|
crate::terminal::handle_message(channel, data).await
|
|
}
|
|
// Rest shouldn't be received by Periphery
|
|
_ => {}
|
|
}
|
|
}
|
|
};
|
|
|
|
tokio::select! {
|
|
_ = forward_writes => {},
|
|
_ = handle_reads => {},
|
|
}
|
|
}
|
|
|
|
fn handle_request(
|
|
args: Arc<Args>,
|
|
sender: Sender,
|
|
req_id: Uuid,
|
|
request: Bytes,
|
|
) {
|
|
tokio::spawn(async move {
|
|
let request =
|
|
match serde_json::from_slice::<PeripheryRequest>(&request) {
|
|
Ok(req) => req,
|
|
Err(e) => {
|
|
// TODO: handle:
|
|
warn!("Failed to parse transport bytes | {e:#}");
|
|
return;
|
|
}
|
|
};
|
|
|
|
let resolve_response = async {
|
|
let message: Message = match request.resolve(&args).await {
|
|
Ok(JsonBytes::Ok(res)) => {
|
|
(res, req_id, MessageState::Successful).into()
|
|
}
|
|
Ok(JsonBytes::Err(e)) => (&e.into(), req_id).into(),
|
|
Err(e) => (&e.error, req_id).into(),
|
|
};
|
|
if let Err(e) = sender.send(message).await {
|
|
error!("Failed to send response over channel | {e:?}");
|
|
}
|
|
};
|
|
|
|
let ping_in_progress = async {
|
|
loop {
|
|
tokio::time::sleep(Duration::from_secs(5)).await;
|
|
if let Err(e) =
|
|
sender.send((req_id, MessageState::InProgress)).await
|
|
{
|
|
error!("Failed to ping in progress over channel | {e:?}");
|
|
}
|
|
}
|
|
};
|
|
|
|
tokio::select! {
|
|
_ = resolve_response => {},
|
|
_ = ping_in_progress => {},
|
|
}
|
|
});
|
|
}
|