mirror of
https://github.com/moghtech/komodo.git
synced 2026-05-03 12:03:21 -05:00
379 lines
11 KiB
Rust
379 lines
11 KiB
Rust
//! Implementes both sides of Noise handshake
|
|
//! using asymmetric private-public key authentication.
|
|
//!
|
|
//! TODO: Revisit
|
|
//! Note. Relies on Server being behind trusted TLS connection.
|
|
//! This is trivial for Periphery -> Core connection, but presents a challenge
|
|
//! for Core -> Periphery, where untrusted TLS certs are being used.
|
|
|
|
use std::time::Duration;
|
|
|
|
use anyhow::Context;
|
|
use axum::http::{HeaderMap, HeaderValue};
|
|
use base64::{Engine, prelude::BASE64_STANDARD};
|
|
use bytes::Bytes;
|
|
use noise::{NoiseHandshake, key::SpkiPublicKey};
|
|
use rand::RngCore;
|
|
use serror::{deserialize_error_bytes, serialize_error_bytes};
|
|
use sha2::{Digest, Sha256};
|
|
use tracing::warn;
|
|
|
|
use crate::{MessageState, websocket::Websocket};
|
|
|
|
pub trait PublicKeyValidator {
|
|
fn validate(&self, public_key: String) -> anyhow::Result<()>;
|
|
}
|
|
|
|
pub struct LoginFlowArgs<'a, 's, V, W> {
|
|
pub identifiers: ConnectionIdentifiers<'a>,
|
|
pub private_key: &'a str,
|
|
pub public_key_validator: V,
|
|
pub socket: &'s mut W,
|
|
}
|
|
|
|
pub trait LoginFlow {
|
|
fn login<'a, 's, V: PublicKeyValidator, W: Websocket>(
|
|
args: LoginFlowArgs<'a, 's, V, W>,
|
|
) -> impl Future<Output = anyhow::Result<()>>;
|
|
}
|
|
|
|
const AUTH_TIMEOUT: Duration = Duration::from_secs(2);
|
|
|
|
pub struct ServerLoginFlow;
|
|
|
|
impl LoginFlow for ServerLoginFlow {
|
|
async fn login<'a, 's, V: PublicKeyValidator, W: Websocket>(
|
|
LoginFlowArgs {
|
|
identifiers,
|
|
private_key,
|
|
public_key_validator,
|
|
socket,
|
|
}: LoginFlowArgs<'a, 's, V, W>,
|
|
) -> anyhow::Result<()> {
|
|
let res = async {
|
|
// Server generates random nonce and sends to client
|
|
let nonce = nonce();
|
|
socket
|
|
.send(Bytes::from_owner(nonce))
|
|
.await
|
|
.context("Failed to send connection nonce")?;
|
|
|
|
let mut handshake = NoiseHandshake::new_responder(
|
|
private_key,
|
|
// Builds the handshake using the connection-unique prologue hash.
|
|
// The prologue must be the same on both sides of connection.
|
|
&identifiers.hash(&nonce),
|
|
)
|
|
.context("Failed to inialize handshake")?;
|
|
|
|
// Receive and read handshake_m1
|
|
let handshake_m1 = socket
|
|
.recv_bytes_with_timeout(AUTH_TIMEOUT)
|
|
.await
|
|
.context("Failed to get handshake_m1")?;
|
|
match MessageState::from_byte(
|
|
*handshake_m1.last().context("handshake_m1 is empty")?,
|
|
) {
|
|
MessageState::Successful => handshake
|
|
.read_message(&handshake_m1[..(handshake_m1.len() - 1)])
|
|
.context("Failed to read handshake_m1")?,
|
|
_ => {
|
|
return Err(deserialize_error_bytes(
|
|
&handshake_m1[..(handshake_m1.len() - 1)],
|
|
));
|
|
}
|
|
}
|
|
|
|
// Send handshake_m2
|
|
let mut handshake_m2 = handshake
|
|
.next_message()
|
|
.context("Failed to write handshake_m2")?;
|
|
handshake_m2.push(MessageState::Successful.as_byte());
|
|
socket
|
|
.send(handshake_m2.into())
|
|
.await
|
|
.context("Failed to send handshake_m2")?;
|
|
|
|
// Receive and read handshake_m3
|
|
let handshake_m3 = socket
|
|
.recv_bytes_with_timeout(AUTH_TIMEOUT)
|
|
.await
|
|
.context("Failed to get handshake_m3")?;
|
|
match MessageState::from_byte(
|
|
*handshake_m3.last().context("handshake_m3 is empty")?,
|
|
) {
|
|
MessageState::Successful => handshake
|
|
.read_message(&handshake_m3[..(handshake_m3.len() - 1)])
|
|
.context("Failed to read handshake_m3")?,
|
|
_ => {
|
|
return Err(deserialize_error_bytes(
|
|
&handshake_m3[..(handshake_m3.len() - 1)],
|
|
));
|
|
}
|
|
}
|
|
|
|
// Server now has client public key
|
|
let public_key =
|
|
SpkiPublicKey::from_raw_bytes(handshake.remote_public_key()?)
|
|
.context("Invalid public key")?
|
|
.into_inner();
|
|
|
|
public_key_validator.validate(public_key)
|
|
}
|
|
.await;
|
|
|
|
match res {
|
|
Ok(_) => {
|
|
socket
|
|
.send(MessageState::Successful.into())
|
|
.await
|
|
.context("Failed to send login successful to client")?;
|
|
Ok(())
|
|
}
|
|
Err(e) => {
|
|
let mut bytes = serialize_error_bytes(&e);
|
|
bytes.push(MessageState::Failed.as_byte());
|
|
if let Err(e) = socket
|
|
.send(bytes.into())
|
|
.await
|
|
.context("Failed to send login failed to client")
|
|
{
|
|
// Log additional error
|
|
warn!("{e:#}");
|
|
}
|
|
// Close socket
|
|
let _ = socket.close(None).await;
|
|
// Return the original error
|
|
Err(e)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct ClientLoginFlow;
|
|
|
|
impl LoginFlow for ClientLoginFlow {
|
|
async fn login<'a, 's, V: PublicKeyValidator, W: Websocket>(
|
|
LoginFlowArgs {
|
|
identifiers,
|
|
private_key,
|
|
public_key_validator,
|
|
socket,
|
|
}: LoginFlowArgs<'a, 's, V, W>,
|
|
) -> anyhow::Result<()> {
|
|
let res = async {
|
|
// Receive nonce from server
|
|
let nonce = socket
|
|
.recv_bytes_with_timeout(AUTH_TIMEOUT)
|
|
.await
|
|
.context("Failed to receive connection nonce")?;
|
|
|
|
let mut handshake = NoiseHandshake::new_initiator(
|
|
private_key,
|
|
// Builds the handshake using the connection-unique prologue hash.
|
|
// The prologue must be the same on both sides of connection.
|
|
&identifiers.hash(&nonce),
|
|
)
|
|
.context("Failed to inialize handshake")?;
|
|
|
|
// Send handshake_m1
|
|
let mut handshake_m1 = handshake
|
|
.next_message()
|
|
.context("Failed to write handshake m1")?;
|
|
handshake_m1.push(MessageState::Successful.as_byte());
|
|
socket
|
|
.send(handshake_m1.into())
|
|
.await
|
|
.context("Failed to send handshake_m1")?;
|
|
|
|
// Receive and read handshake_m2
|
|
let handshake_m2 = socket
|
|
.recv_bytes_with_timeout(AUTH_TIMEOUT)
|
|
.await
|
|
.context("Failed to get handshake_m2")?;
|
|
match MessageState::from_byte(
|
|
*handshake_m2.last().context("handshake_m2 is empty")?,
|
|
) {
|
|
MessageState::Successful => handshake
|
|
.read_message(&handshake_m2[..(handshake_m2.len() - 1)])
|
|
.context("Failed to read handshake_m2")?,
|
|
_ => {
|
|
return Err(deserialize_error_bytes(
|
|
&handshake_m2[..(handshake_m2.len() - 1)],
|
|
));
|
|
}
|
|
}
|
|
|
|
// Client now has server public key.
|
|
// Perform validation before proceeding.
|
|
let public_key =
|
|
SpkiPublicKey::from_raw_bytes(handshake.remote_public_key()?)
|
|
.context("Invalid public key")?
|
|
.into_inner();
|
|
public_key_validator.validate(public_key)?;
|
|
|
|
// Send handshake_m3
|
|
let mut handshake_m3 = handshake
|
|
.next_message()
|
|
.context("Failed to write handshake_m3")?;
|
|
handshake_m3.push(MessageState::Successful.as_byte());
|
|
socket
|
|
.send(handshake_m3.into())
|
|
.await
|
|
.context("Failed to send handshake_m3")?;
|
|
|
|
// Receive login state message and return based on value
|
|
let state_msg = socket
|
|
.recv_bytes_with_timeout(AUTH_TIMEOUT)
|
|
.await
|
|
.context("Failed to receive authentication state message")?;
|
|
let state = state_msg.last().context(
|
|
"Authentication state message did not contain state byte",
|
|
)?;
|
|
match MessageState::from_byte(*state) {
|
|
MessageState::Successful => anyhow::Ok(()),
|
|
_ => Err(deserialize_error_bytes(
|
|
&state_msg[..(state_msg.len() - 1)],
|
|
)),
|
|
}
|
|
}
|
|
.await;
|
|
|
|
if let Err(e) = res {
|
|
let mut bytes = serialize_error_bytes(&e);
|
|
bytes.push(MessageState::Failed.as_byte());
|
|
if let Err(e) = socket
|
|
.send(bytes.into())
|
|
.await
|
|
.context("Failed to send login failed to client")
|
|
{
|
|
// Log additional error
|
|
warn!("{e:#}");
|
|
}
|
|
// Close socket
|
|
let _ = socket.close(None).await;
|
|
// Return the original error
|
|
Err(e)
|
|
} else {
|
|
Ok(())
|
|
}
|
|
}
|
|
}
|
|
|
|
fn nonce() -> [u8; 32] {
|
|
let mut out = [0u8; 32];
|
|
rand::rng().fill_bytes(&mut out);
|
|
out
|
|
}
|
|
|
|
#[derive(Clone, Copy)]
|
|
pub struct ConnectionIdentifiers<'a> {
|
|
/// Server hostname
|
|
pub host: &'a [u8],
|
|
/// Query: 'server=<SERVER>'
|
|
pub query: &'a [u8],
|
|
/// Sec-Websocket-Accept, unique for each connection
|
|
pub accept: &'a [u8],
|
|
}
|
|
|
|
impl ConnectionIdentifiers<'_> {
|
|
/// nonce: Server computed random connection nonce, sent to client before auth handshake
|
|
pub fn hash(&self, nonce: &[u8]) -> [u8; 32] {
|
|
let mut hash = Sha256::new();
|
|
hash.update(b"noise-wss-v1|");
|
|
hash.update(self.host);
|
|
hash.update(b"|");
|
|
hash.update(self.query);
|
|
hash.update(b"|");
|
|
hash.update(self.accept);
|
|
hash.update(b"|");
|
|
hash.update(nonce);
|
|
hash.finalize().into()
|
|
}
|
|
}
|
|
|
|
pub struct AddressConnectionIdentifiers {
|
|
host: String,
|
|
}
|
|
|
|
impl AddressConnectionIdentifiers {
|
|
pub fn extract(address: &str) -> anyhow::Result<Self> {
|
|
let url = ::url::Url::parse(address)
|
|
.context("Failed to parse server address")?;
|
|
let mut host = url.host().context("url has no host")?.to_string();
|
|
if let Some(port) = url.port() {
|
|
host.push(':');
|
|
host.push_str(&port.to_string());
|
|
};
|
|
Ok(Self { host })
|
|
}
|
|
|
|
pub fn host(&self) -> &String {
|
|
&self.host
|
|
}
|
|
|
|
pub fn build<'a>(
|
|
&'a self,
|
|
accept: &'a [u8],
|
|
query: &'a [u8],
|
|
) -> ConnectionIdentifiers<'a> {
|
|
ConnectionIdentifiers {
|
|
host: self.host.as_bytes(),
|
|
query,
|
|
accept,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Used to extract owned connection identifier
|
|
/// in server side connection handler.
|
|
pub struct HeaderConnectionIdentifiers {
|
|
host: HeaderValue,
|
|
accept: String,
|
|
}
|
|
|
|
impl HeaderConnectionIdentifiers {
|
|
pub fn extract(
|
|
headers: &mut HeaderMap,
|
|
) -> anyhow::Result<HeaderConnectionIdentifiers> {
|
|
let host = headers
|
|
.remove("x-forwarded-host")
|
|
.or(headers.remove("host"))
|
|
.context("Failed to get connection host")?;
|
|
let key = headers
|
|
.remove("sec-websocket-key")
|
|
.context("Headers do not contain Sec-Websocket-Key")?;
|
|
let accept = compute_accept(key.as_bytes());
|
|
Ok(HeaderConnectionIdentifiers { host, accept })
|
|
}
|
|
|
|
pub fn host(&self) -> anyhow::Result<String> {
|
|
self
|
|
.host
|
|
.to_str()
|
|
.map(str::to_string)
|
|
.context("Failed to parse header host to string")
|
|
}
|
|
|
|
pub fn build<'a>(
|
|
&'a self,
|
|
query: &'a [u8],
|
|
) -> ConnectionIdentifiers<'a> {
|
|
ConnectionIdentifiers {
|
|
host: self.host.as_bytes(),
|
|
accept: self.accept.as_bytes(),
|
|
query,
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn compute_accept(sec_websocket_key: &[u8]) -> String {
|
|
// This is standard GUID to compute Sec-Websocket-Accept
|
|
const GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
|
|
let mut sha1 = sha1::Sha1::new();
|
|
sha1.update(sec_websocket_key);
|
|
sha1.update(GUID);
|
|
let digest = sha1.finalize();
|
|
BASE64_STANDARD.encode(digest)
|
|
}
|