abstract websocket handling implementations on both sides

This commit is contained in:
mbecker20
2025-09-20 22:43:15 -07:00
parent 2d83105500
commit 951ff34a9e
15 changed files with 468 additions and 509 deletions

View File

@@ -21,4 +21,4 @@ rand.workspace = true
snow.workspace = true
sha1.workspace = true
sha2.workspace = true
uuid.workspace = true
uuid.workspace = true

View File

@@ -6,15 +6,16 @@
//! This is trivial for Periphery -> Core connection, but presents a challenge
//! for Core -> Periphery, where untrusted TLS certs are being used.
use anyhow::Context;
use anyhow::{Context, anyhow};
use base64::{Engine, prelude::BASE64_STANDARD};
use bytes::Bytes;
use rand::RngCore;
use sha2::{Digest, Sha256};
use tracing::warn;
use crate::websocket::{
Websocket, axum::AxumWebsocket, tungstenite::TungsteniteWebsocket,
};
use crate::{MessageState, websocket::Websocket};
const NOISE_XX_PARAMS: &str = "Noise_XX_25519_ChaChaPoly_BLAKE2s";
pub struct ConnectionIdentifiers<'a> {
/// Server hostname
@@ -25,130 +26,178 @@ pub struct ConnectionIdentifiers<'a> {
pub accept: &'a [u8],
}
#[derive(Debug, Clone, Copy)]
pub enum AuthType {
Passkey = 0,
Noise = 1,
pub trait LoginFlow {
fn login(
socket: &mut impl Websocket,
connection_identifiers: ConnectionIdentifiers<'_>,
private_key: &[u8],
) -> impl Future<Output = anyhow::Result<()>>;
}
const NOISE_XX_PARAMS: &str = "Noise_XX_25519_ChaChaPoly_BLAKE2s";
pub struct ServerLoginFlow;
pub async fn handle_server_side_login(
socket: &mut AxumWebsocket,
id: ConnectionIdentifiers<'_>,
private_key: &[u8],
) -> anyhow::Result<()> {
// Server generates random nonce and sends to client
let nonce = nonce();
socket
.send(Bytes::from_owner(nonce))
.await
.context("Failed to send connection nonce")?;
impl LoginFlow for ServerLoginFlow {
async fn login(
socket: &mut impl Websocket,
connection_identifiers: ConnectionIdentifiers<'_>,
private_key: &[u8],
) -> anyhow::Result<()> {
let res = async {
// Server generates random nonce and sends to client
let nonce = nonce();
socket
.send(Bytes::from_owner(nonce))
.await
.context("Failed to send connection nonce")?;
// Build the handshake using the unique prologue hash.
// The prologue must be the same on both sides of connection.
let mut handshake = snow::Builder::new(NOISE_XX_PARAMS.parse()?)
.local_private_key(private_key)?
.prologue(&id.hash(&nonce))?
.build_responder()?;
// Build the handshake using the connection-unique prologue hash.
// The prologue must be the same on both sides of connection.
let mut handshake =
snow::Builder::new(NOISE_XX_PARAMS.parse()?)
.local_private_key(private_key)?
.prologue(&connection_identifiers.hash(&nonce))?
.build_responder()?;
// Receive and read handshake_m1
let handshake_m1 = socket
.recv_bytes()
.await
.context("Failed to get handshake_m1")?;
handshake
.read_message(&handshake_m1, &mut [])
.context("Failed to read handshake_m1")?;
// Receive and read handshake_m1
let handshake_m1 = socket
.recv_bytes()
.await
.context("Failed to get handshake_m1")?;
handshake
.read_message(&handshake_m1, &mut [])
.context("Failed to read handshake_m1")?;
// Send handshake_m2
let mut handshake_m2 = [0u8; 1024];
let written = handshake
.write_message(&[], &mut handshake_m2)
.context("Failed to write handshake_m2")?;
socket
.send(Bytes::copy_from_slice(&handshake_m2[..written]))
.await
.context("Failed to send handshake_m2")?;
// Send handshake_m2
let mut handshake_m2 = [0u8; 1024];
let written = handshake
.write_message(&[], &mut handshake_m2)
.context("Failed to write handshake_m2")?;
socket
.send(Bytes::copy_from_slice(&handshake_m2[..written]))
.await
.context("Failed to send handshake_m2")?;
// Receive and read handshake_m3
let handshake_m3 = socket
.recv_bytes()
.await
.context("Failed to get handshake_m3")?;
handshake
.read_message(&handshake_m3, &mut [])
.context("Failed to read handshake_m3")?;
// Receive and read handshake_m3
let handshake_m3 = socket
.recv_bytes()
.await
.context("Failed to get handshake_m3")?;
handshake
.read_message(&handshake_m3, &mut [])
.context("Failed to read handshake_m3")?;
// Server now has client public key
let client_public_key = handshake
.get_remote_static()
.context("Failed to get remote public key")?;
println!(
"Server got client public key: {}",
BASE64_STANDARD.encode(client_public_key)
);
// Server now has client public key
let client_public_key = handshake
.get_remote_static()
.context("Failed to get remote public key")?;
println!(
"Server got client public key: {}",
BASE64_STANDARD.encode(client_public_key)
);
Ok(())
anyhow::Ok(())
}
.await;
match res {
Ok(_) => {
socket
.send(MessageState::Successful.into())
.await
.context("Failed to send login successful to client")?;
Ok(())
}
Err(e) => {
if let Err(e) = socket
.send(MessageState::Successful.into())
.await
.context("Failed to send login successful to client")
{
// Log additional error
warn!("{e:#}");
// Close socket
let _ = socket.close(None).await;
}
// Return the original error
Err(e)
}
}
}
}
pub async fn handle_client_side_login(
socket: &mut TungsteniteWebsocket,
id: ConnectionIdentifiers<'_>,
private_key: &[u8],
) -> anyhow::Result<()> {
// Receive nonce from server
let nonce = socket
.recv_bytes()
.await
.context("Failed to receive connection nonce")?;
pub struct ClientLoginFlow;
// Build the handshake using the unique prologue hash.
// The prologue must be the same on both sides of connection.
let mut handshake = snow::Builder::new(NOISE_XX_PARAMS.parse()?)
.local_private_key(private_key)?
.prologue(&id.hash(&nonce))?
.build_initiator()?;
impl LoginFlow for ClientLoginFlow {
async fn login(
socket: &mut impl Websocket,
connection_identifiers: ConnectionIdentifiers<'_>,
private_key: &[u8],
) -> anyhow::Result<()> {
// Receive nonce from server
let nonce = socket
.recv_bytes()
.await
.context("Failed to receive connection nonce")?;
// Send handshake_m1
let mut handshake_m1 = [0u8; 1024];
let written = handshake
.write_message(&[], &mut handshake_m1)
.context("Failed to write handshake_m1")?;
socket
.send(Bytes::copy_from_slice(&handshake_m1[..written]))
.await
.context("Failed to send handshake_m1")?;
// Build the handshake using the connection-unique prologue hash.
// The prologue must be the same on both sides of connection.
let mut handshake = snow::Builder::new(NOISE_XX_PARAMS.parse()?)
.local_private_key(private_key)?
.prologue(&connection_identifiers.hash(&nonce))?
.build_initiator()?;
// Receive and read handshake_m2
let handshake_m2 = socket
.recv_bytes()
.await
.context("Failed to get handshake_m2")?;
handshake
.read_message(&handshake_m2, &mut [])
.context("Failed to read handshake_m2")?;
// Send handshake_m1
let mut handshake_m1 = [0u8; 1024];
let written = handshake
.write_message(&[], &mut handshake_m1)
.context("Failed to write handshake_m1")?;
socket
.send(Bytes::copy_from_slice(&handshake_m1[..written]))
.await
.context("Failed to send handshake_m1")?;
// Client now has server public key
let server_public_key = handshake
.get_remote_static()
.context("Failed to get remote public key")?;
println!(
"Client got server public key: {}",
BASE64_STANDARD.encode(server_public_key)
);
// Receive and read handshake_m2
let handshake_m2 = socket
.recv_bytes()
.await
.context("Failed to get handshake_m2")?;
handshake
.read_message(&handshake_m2, &mut [])
.context("Failed to read handshake_m2")?;
// Send handshake_m3
let mut handshake_m3 = [0u8; 1024];
let written = handshake
.write_message(&[], &mut handshake_m3)
.context("Failed to write handshake_m3")?;
socket
.send(Bytes::copy_from_slice(&handshake_m3[..written]))
.await
.context("Failed to send handshake_m3")?;
// Client now has server public key
let server_public_key = handshake
.get_remote_static()
.context("Failed to get remote public key")?;
println!(
"Client got server public key: {}",
BASE64_STANDARD.encode(server_public_key)
);
Ok(())
// Send handshake_m3
let mut handshake_m3 = [0u8; 1024];
let written = handshake
.write_message(&[], &mut handshake_m3)
.context("Failed to write handshake_m3")?;
socket
.send(Bytes::copy_from_slice(&handshake_m3[..written]))
.await
.context("Failed to send handshake_m3")?;
// Receive login state message and return based on value
let state = socket
.recv_bytes()
.await
.context("Failed to receive authentication state message")?;
let state = state.first().context(
"Authentication state message did not contain state byte",
)?;
match MessageState::from_byte(*state) {
MessageState::Successful => Ok(()),
// Todo: More descriptive error?
_ => Err(anyhow!("Authentication failed")),
}
}
}
fn nonce() -> [u8; 32] {
@@ -182,3 +231,39 @@ pub fn compute_accept(sec_websocket_key: &[u8]) -> String {
let digest = sha1.finalize();
BASE64_STANDARD.encode(digest)
}
/// This completes an example handshake
/// to produce the correct public key
/// for a given private key.
pub fn generage_public_key(
private_key: &[u8],
) -> anyhow::Result<Bytes> {
let mut ch = snow::Builder::new(NOISE_XX_PARAMS.parse()?)
.local_private_key(b"ANY")?
.build_initiator()?;
// Use the target private key with server handshake,
// since its public key is the first available in the flow.
let mut sh = snow::Builder::new(NOISE_XX_PARAMS.parse()?)
.local_private_key(private_key)?
.build_responder()?;
// write m1
let mut m1 = [0u8; 1024];
let written = ch
.write_message(&[], &mut m1)
.context("CLIENT: failed to write m1")?;
// read m1
sh.read_message(&m1[..written], &mut [])
.context("SERVER: failed to read m1")?;
// write m2
let mut m2 = [0u8; 1024];
let written = sh
.write_message(&[], &mut m2)
.context("SERVER: failed to write m2")?;
// read m2
ch.read_message(&m2[..written], &mut [])
.context("CLIENT: failed to read m2")?;
// client now has server public key
Ok(Bytes::copy_from_slice(
ch.get_remote_static().context("Failed to get public key")?,
))
}

View File

@@ -2,7 +2,7 @@ use anyhow::{Context, anyhow};
use bytes::Bytes;
use uuid::Uuid;
use crate::{MessageState, auth::AuthType};
use crate::MessageState;
/// Serializes data + channel id + state to byte vec.
/// The last byte is the State, and the 16 before that is the Uuid.
@@ -94,19 +94,3 @@ impl MessageState {
}
}
}
impl AuthType {
pub fn from_byte(byte: u8) -> AuthType {
match byte {
0 => AuthType::Passkey,
_ => AuthType::Noise,
}
}
pub fn as_byte(&self) -> u8 {
match self {
AuthType::Passkey => 0,
AuthType::Noise => 1,
}
}
}

View File

@@ -14,6 +14,11 @@ impl Websocket for AxumWebsocket {
type CloseFrame = CloseFrame;
type Error = axum::Error;
fn split(self) -> (impl WebsocketSender, impl WebsocketReceiver) {
let (tx, rx) = self.0.split();
(AxumWebsocketSender(tx), AxumWebsocketReceiver(rx))
}
async fn recv(
&mut self,
) -> Result<WebsocketMessage<Self::CloseFrame>, Self::Error> {
@@ -35,13 +40,6 @@ impl Websocket for AxumWebsocket {
}
}
impl AxumWebsocket {
pub fn split(self) -> (AxumWebsocketSender, AxumWebsocketReceiver) {
let (tx, rx) = self.0.split();
(AxumWebsocketSender(tx), AxumWebsocketReceiver(rx))
}
}
pub type InnerWebsocketReceiver =
SplitStream<axum::extract::ws::WebSocket>;

View File

@@ -20,9 +20,12 @@ pub enum WebsocketMessage<CloseFrame> {
/// Standard traits for websocket
pub trait Websocket {
type CloseFrame: std::fmt::Debug;
type CloseFrame: std::fmt::Debug + Send + Sync + 'static;
type Error: std::error::Error + Send + Sync + 'static;
/// Abstraction over websocket splitting
fn split(self) -> (impl WebsocketSender, impl WebsocketReceiver);
/// Looping receiver for websocket messages which only returns
/// on significant messages.
fn recv(
@@ -64,7 +67,7 @@ pub trait Websocket {
/// Traits for split websocket receiver
pub trait WebsocketReceiver {
type CloseFrame: std::fmt::Debug;
type CloseFrame: std::fmt::Debug + Send + Sync + 'static;
type Error: std::error::Error + Send + Sync + 'static;
/// Looping receiver for websocket messages which only returns
@@ -73,23 +76,24 @@ pub trait WebsocketReceiver {
&mut self,
) -> impl Future<
Output = Result<WebsocketMessage<Self::CloseFrame>, Self::Error>,
>;
> + Send
+ Sync;
}
/// Traits for split websocket receiver
pub trait WebsocketSender {
type CloseFrame: std::fmt::Debug;
type CloseFrame: std::fmt::Debug + Send + Sync + 'static;
type Error: std::error::Error + Send + Sync + 'static;
/// Streamlined sending on bytes
fn send(
&mut self,
bytes: Bytes,
) -> impl Future<Output = Result<(), Self::Error>>;
) -> impl Future<Output = Result<(), Self::Error>> + Send + Sync;
/// Send close message
fn close(
&mut self,
frame: Option<Self::CloseFrame>,
) -> impl Future<Output = Result<(), Self::Error>>;
) -> impl Future<Output = Result<(), Self::Error>> + Send + Sync;
}

View File

@@ -20,6 +20,14 @@ impl Websocket for TungsteniteWebsocket {
type CloseFrame = CloseFrame;
type Error = tungstenite::Error;
fn split(self) -> (impl WebsocketSender, impl WebsocketReceiver) {
let (tx, rx) = self.0.split();
(
TungsteniteWebsocketSender(tx),
TungsteniteWebsocketReceiver(rx),
)
}
async fn recv(
&mut self,
) -> Result<WebsocketMessage<Self::CloseFrame>, Self::Error> {
@@ -41,18 +49,6 @@ impl Websocket for TungsteniteWebsocket {
}
}
impl TungsteniteWebsocket {
pub fn split(
self,
) -> (TungsteniteWebsocketSender, TungsteniteWebsocketReceiver) {
let (tx, rx) = self.0.split();
(
TungsteniteWebsocketSender(tx),
TungsteniteWebsocketReceiver(rx),
)
}
}
pub type InnerWebsocketReceiver =
SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;