forked from github-starred/komodo
abstract websocket handling implementations on both sides
This commit is contained in:
@@ -21,4 +21,4 @@ rand.workspace = true
|
||||
snow.workspace = true
|
||||
sha1.workspace = true
|
||||
sha2.workspace = true
|
||||
uuid.workspace = true
|
||||
uuid.workspace = true
|
||||
|
||||
@@ -6,15 +6,16 @@
|
||||
//! This is trivial for Periphery -> Core connection, but presents a challenge
|
||||
//! for Core -> Periphery, where untrusted TLS certs are being used.
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::{Context, anyhow};
|
||||
use base64::{Engine, prelude::BASE64_STANDARD};
|
||||
use bytes::Bytes;
|
||||
use rand::RngCore;
|
||||
use sha2::{Digest, Sha256};
|
||||
use tracing::warn;
|
||||
|
||||
use crate::websocket::{
|
||||
Websocket, axum::AxumWebsocket, tungstenite::TungsteniteWebsocket,
|
||||
};
|
||||
use crate::{MessageState, websocket::Websocket};
|
||||
|
||||
const NOISE_XX_PARAMS: &str = "Noise_XX_25519_ChaChaPoly_BLAKE2s";
|
||||
|
||||
pub struct ConnectionIdentifiers<'a> {
|
||||
/// Server hostname
|
||||
@@ -25,130 +26,178 @@ pub struct ConnectionIdentifiers<'a> {
|
||||
pub accept: &'a [u8],
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum AuthType {
|
||||
Passkey = 0,
|
||||
Noise = 1,
|
||||
pub trait LoginFlow {
|
||||
fn login(
|
||||
socket: &mut impl Websocket,
|
||||
connection_identifiers: ConnectionIdentifiers<'_>,
|
||||
private_key: &[u8],
|
||||
) -> impl Future<Output = anyhow::Result<()>>;
|
||||
}
|
||||
|
||||
const NOISE_XX_PARAMS: &str = "Noise_XX_25519_ChaChaPoly_BLAKE2s";
|
||||
pub struct ServerLoginFlow;
|
||||
|
||||
pub async fn handle_server_side_login(
|
||||
socket: &mut AxumWebsocket,
|
||||
id: ConnectionIdentifiers<'_>,
|
||||
private_key: &[u8],
|
||||
) -> anyhow::Result<()> {
|
||||
// Server generates random nonce and sends to client
|
||||
let nonce = nonce();
|
||||
socket
|
||||
.send(Bytes::from_owner(nonce))
|
||||
.await
|
||||
.context("Failed to send connection nonce")?;
|
||||
impl LoginFlow for ServerLoginFlow {
|
||||
async fn login(
|
||||
socket: &mut impl Websocket,
|
||||
connection_identifiers: ConnectionIdentifiers<'_>,
|
||||
private_key: &[u8],
|
||||
) -> anyhow::Result<()> {
|
||||
let res = async {
|
||||
// Server generates random nonce and sends to client
|
||||
let nonce = nonce();
|
||||
socket
|
||||
.send(Bytes::from_owner(nonce))
|
||||
.await
|
||||
.context("Failed to send connection nonce")?;
|
||||
|
||||
// Build the handshake using the unique prologue hash.
|
||||
// The prologue must be the same on both sides of connection.
|
||||
let mut handshake = snow::Builder::new(NOISE_XX_PARAMS.parse()?)
|
||||
.local_private_key(private_key)?
|
||||
.prologue(&id.hash(&nonce))?
|
||||
.build_responder()?;
|
||||
// Build the handshake using the connection-unique prologue hash.
|
||||
// The prologue must be the same on both sides of connection.
|
||||
let mut handshake =
|
||||
snow::Builder::new(NOISE_XX_PARAMS.parse()?)
|
||||
.local_private_key(private_key)?
|
||||
.prologue(&connection_identifiers.hash(&nonce))?
|
||||
.build_responder()?;
|
||||
|
||||
// Receive and read handshake_m1
|
||||
let handshake_m1 = socket
|
||||
.recv_bytes()
|
||||
.await
|
||||
.context("Failed to get handshake_m1")?;
|
||||
handshake
|
||||
.read_message(&handshake_m1, &mut [])
|
||||
.context("Failed to read handshake_m1")?;
|
||||
// Receive and read handshake_m1
|
||||
let handshake_m1 = socket
|
||||
.recv_bytes()
|
||||
.await
|
||||
.context("Failed to get handshake_m1")?;
|
||||
handshake
|
||||
.read_message(&handshake_m1, &mut [])
|
||||
.context("Failed to read handshake_m1")?;
|
||||
|
||||
// Send handshake_m2
|
||||
let mut handshake_m2 = [0u8; 1024];
|
||||
let written = handshake
|
||||
.write_message(&[], &mut handshake_m2)
|
||||
.context("Failed to write handshake_m2")?;
|
||||
socket
|
||||
.send(Bytes::copy_from_slice(&handshake_m2[..written]))
|
||||
.await
|
||||
.context("Failed to send handshake_m2")?;
|
||||
// Send handshake_m2
|
||||
let mut handshake_m2 = [0u8; 1024];
|
||||
let written = handshake
|
||||
.write_message(&[], &mut handshake_m2)
|
||||
.context("Failed to write handshake_m2")?;
|
||||
socket
|
||||
.send(Bytes::copy_from_slice(&handshake_m2[..written]))
|
||||
.await
|
||||
.context("Failed to send handshake_m2")?;
|
||||
|
||||
// Receive and read handshake_m3
|
||||
let handshake_m3 = socket
|
||||
.recv_bytes()
|
||||
.await
|
||||
.context("Failed to get handshake_m3")?;
|
||||
handshake
|
||||
.read_message(&handshake_m3, &mut [])
|
||||
.context("Failed to read handshake_m3")?;
|
||||
// Receive and read handshake_m3
|
||||
let handshake_m3 = socket
|
||||
.recv_bytes()
|
||||
.await
|
||||
.context("Failed to get handshake_m3")?;
|
||||
handshake
|
||||
.read_message(&handshake_m3, &mut [])
|
||||
.context("Failed to read handshake_m3")?;
|
||||
|
||||
// Server now has client public key
|
||||
let client_public_key = handshake
|
||||
.get_remote_static()
|
||||
.context("Failed to get remote public key")?;
|
||||
println!(
|
||||
"Server got client public key: {}",
|
||||
BASE64_STANDARD.encode(client_public_key)
|
||||
);
|
||||
// Server now has client public key
|
||||
let client_public_key = handshake
|
||||
.get_remote_static()
|
||||
.context("Failed to get remote public key")?;
|
||||
println!(
|
||||
"Server got client public key: {}",
|
||||
BASE64_STANDARD.encode(client_public_key)
|
||||
);
|
||||
|
||||
Ok(())
|
||||
anyhow::Ok(())
|
||||
}
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Ok(_) => {
|
||||
socket
|
||||
.send(MessageState::Successful.into())
|
||||
.await
|
||||
.context("Failed to send login successful to client")?;
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
if let Err(e) = socket
|
||||
.send(MessageState::Successful.into())
|
||||
.await
|
||||
.context("Failed to send login successful to client")
|
||||
{
|
||||
// Log additional error
|
||||
warn!("{e:#}");
|
||||
// Close socket
|
||||
let _ = socket.close(None).await;
|
||||
}
|
||||
// Return the original error
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_client_side_login(
|
||||
socket: &mut TungsteniteWebsocket,
|
||||
id: ConnectionIdentifiers<'_>,
|
||||
private_key: &[u8],
|
||||
) -> anyhow::Result<()> {
|
||||
// Receive nonce from server
|
||||
let nonce = socket
|
||||
.recv_bytes()
|
||||
.await
|
||||
.context("Failed to receive connection nonce")?;
|
||||
pub struct ClientLoginFlow;
|
||||
|
||||
// Build the handshake using the unique prologue hash.
|
||||
// The prologue must be the same on both sides of connection.
|
||||
let mut handshake = snow::Builder::new(NOISE_XX_PARAMS.parse()?)
|
||||
.local_private_key(private_key)?
|
||||
.prologue(&id.hash(&nonce))?
|
||||
.build_initiator()?;
|
||||
impl LoginFlow for ClientLoginFlow {
|
||||
async fn login(
|
||||
socket: &mut impl Websocket,
|
||||
connection_identifiers: ConnectionIdentifiers<'_>,
|
||||
private_key: &[u8],
|
||||
) -> anyhow::Result<()> {
|
||||
// Receive nonce from server
|
||||
let nonce = socket
|
||||
.recv_bytes()
|
||||
.await
|
||||
.context("Failed to receive connection nonce")?;
|
||||
|
||||
// Send handshake_m1
|
||||
let mut handshake_m1 = [0u8; 1024];
|
||||
let written = handshake
|
||||
.write_message(&[], &mut handshake_m1)
|
||||
.context("Failed to write handshake_m1")?;
|
||||
socket
|
||||
.send(Bytes::copy_from_slice(&handshake_m1[..written]))
|
||||
.await
|
||||
.context("Failed to send handshake_m1")?;
|
||||
// Build the handshake using the connection-unique prologue hash.
|
||||
// The prologue must be the same on both sides of connection.
|
||||
let mut handshake = snow::Builder::new(NOISE_XX_PARAMS.parse()?)
|
||||
.local_private_key(private_key)?
|
||||
.prologue(&connection_identifiers.hash(&nonce))?
|
||||
.build_initiator()?;
|
||||
|
||||
// Receive and read handshake_m2
|
||||
let handshake_m2 = socket
|
||||
.recv_bytes()
|
||||
.await
|
||||
.context("Failed to get handshake_m2")?;
|
||||
handshake
|
||||
.read_message(&handshake_m2, &mut [])
|
||||
.context("Failed to read handshake_m2")?;
|
||||
// Send handshake_m1
|
||||
let mut handshake_m1 = [0u8; 1024];
|
||||
let written = handshake
|
||||
.write_message(&[], &mut handshake_m1)
|
||||
.context("Failed to write handshake_m1")?;
|
||||
socket
|
||||
.send(Bytes::copy_from_slice(&handshake_m1[..written]))
|
||||
.await
|
||||
.context("Failed to send handshake_m1")?;
|
||||
|
||||
// Client now has server public key
|
||||
let server_public_key = handshake
|
||||
.get_remote_static()
|
||||
.context("Failed to get remote public key")?;
|
||||
println!(
|
||||
"Client got server public key: {}",
|
||||
BASE64_STANDARD.encode(server_public_key)
|
||||
);
|
||||
// Receive and read handshake_m2
|
||||
let handshake_m2 = socket
|
||||
.recv_bytes()
|
||||
.await
|
||||
.context("Failed to get handshake_m2")?;
|
||||
handshake
|
||||
.read_message(&handshake_m2, &mut [])
|
||||
.context("Failed to read handshake_m2")?;
|
||||
|
||||
// Send handshake_m3
|
||||
let mut handshake_m3 = [0u8; 1024];
|
||||
let written = handshake
|
||||
.write_message(&[], &mut handshake_m3)
|
||||
.context("Failed to write handshake_m3")?;
|
||||
socket
|
||||
.send(Bytes::copy_from_slice(&handshake_m3[..written]))
|
||||
.await
|
||||
.context("Failed to send handshake_m3")?;
|
||||
// Client now has server public key
|
||||
let server_public_key = handshake
|
||||
.get_remote_static()
|
||||
.context("Failed to get remote public key")?;
|
||||
println!(
|
||||
"Client got server public key: {}",
|
||||
BASE64_STANDARD.encode(server_public_key)
|
||||
);
|
||||
|
||||
Ok(())
|
||||
// Send handshake_m3
|
||||
let mut handshake_m3 = [0u8; 1024];
|
||||
let written = handshake
|
||||
.write_message(&[], &mut handshake_m3)
|
||||
.context("Failed to write handshake_m3")?;
|
||||
socket
|
||||
.send(Bytes::copy_from_slice(&handshake_m3[..written]))
|
||||
.await
|
||||
.context("Failed to send handshake_m3")?;
|
||||
|
||||
// Receive login state message and return based on value
|
||||
let state = socket
|
||||
.recv_bytes()
|
||||
.await
|
||||
.context("Failed to receive authentication state message")?;
|
||||
let state = state.first().context(
|
||||
"Authentication state message did not contain state byte",
|
||||
)?;
|
||||
match MessageState::from_byte(*state) {
|
||||
MessageState::Successful => Ok(()),
|
||||
// Todo: More descriptive error?
|
||||
_ => Err(anyhow!("Authentication failed")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn nonce() -> [u8; 32] {
|
||||
@@ -182,3 +231,39 @@ pub fn compute_accept(sec_websocket_key: &[u8]) -> String {
|
||||
let digest = sha1.finalize();
|
||||
BASE64_STANDARD.encode(digest)
|
||||
}
|
||||
|
||||
/// This completes an example handshake
|
||||
/// to produce the correct public key
|
||||
/// for a given private key.
|
||||
pub fn generage_public_key(
|
||||
private_key: &[u8],
|
||||
) -> anyhow::Result<Bytes> {
|
||||
let mut ch = snow::Builder::new(NOISE_XX_PARAMS.parse()?)
|
||||
.local_private_key(b"ANY")?
|
||||
.build_initiator()?;
|
||||
// Use the target private key with server handshake,
|
||||
// since its public key is the first available in the flow.
|
||||
let mut sh = snow::Builder::new(NOISE_XX_PARAMS.parse()?)
|
||||
.local_private_key(private_key)?
|
||||
.build_responder()?;
|
||||
// write m1
|
||||
let mut m1 = [0u8; 1024];
|
||||
let written = ch
|
||||
.write_message(&[], &mut m1)
|
||||
.context("CLIENT: failed to write m1")?;
|
||||
// read m1
|
||||
sh.read_message(&m1[..written], &mut [])
|
||||
.context("SERVER: failed to read m1")?;
|
||||
// write m2
|
||||
let mut m2 = [0u8; 1024];
|
||||
let written = sh
|
||||
.write_message(&[], &mut m2)
|
||||
.context("SERVER: failed to write m2")?;
|
||||
// read m2
|
||||
ch.read_message(&m2[..written], &mut [])
|
||||
.context("CLIENT: failed to read m2")?;
|
||||
// client now has server public key
|
||||
Ok(Bytes::copy_from_slice(
|
||||
ch.get_remote_static().context("Failed to get public key")?,
|
||||
))
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ use anyhow::{Context, anyhow};
|
||||
use bytes::Bytes;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{MessageState, auth::AuthType};
|
||||
use crate::MessageState;
|
||||
|
||||
/// Serializes data + channel id + state to byte vec.
|
||||
/// The last byte is the State, and the 16 before that is the Uuid.
|
||||
@@ -94,19 +94,3 @@ impl MessageState {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AuthType {
|
||||
pub fn from_byte(byte: u8) -> AuthType {
|
||||
match byte {
|
||||
0 => AuthType::Passkey,
|
||||
_ => AuthType::Noise,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_byte(&self) -> u8 {
|
||||
match self {
|
||||
AuthType::Passkey => 0,
|
||||
AuthType::Noise => 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,6 +14,11 @@ impl Websocket for AxumWebsocket {
|
||||
type CloseFrame = CloseFrame;
|
||||
type Error = axum::Error;
|
||||
|
||||
fn split(self) -> (impl WebsocketSender, impl WebsocketReceiver) {
|
||||
let (tx, rx) = self.0.split();
|
||||
(AxumWebsocketSender(tx), AxumWebsocketReceiver(rx))
|
||||
}
|
||||
|
||||
async fn recv(
|
||||
&mut self,
|
||||
) -> Result<WebsocketMessage<Self::CloseFrame>, Self::Error> {
|
||||
@@ -35,13 +40,6 @@ impl Websocket for AxumWebsocket {
|
||||
}
|
||||
}
|
||||
|
||||
impl AxumWebsocket {
|
||||
pub fn split(self) -> (AxumWebsocketSender, AxumWebsocketReceiver) {
|
||||
let (tx, rx) = self.0.split();
|
||||
(AxumWebsocketSender(tx), AxumWebsocketReceiver(rx))
|
||||
}
|
||||
}
|
||||
|
||||
pub type InnerWebsocketReceiver =
|
||||
SplitStream<axum::extract::ws::WebSocket>;
|
||||
|
||||
|
||||
@@ -20,9 +20,12 @@ pub enum WebsocketMessage<CloseFrame> {
|
||||
|
||||
/// Standard traits for websocket
|
||||
pub trait Websocket {
|
||||
type CloseFrame: std::fmt::Debug;
|
||||
type CloseFrame: std::fmt::Debug + Send + Sync + 'static;
|
||||
type Error: std::error::Error + Send + Sync + 'static;
|
||||
|
||||
/// Abstraction over websocket splitting
|
||||
fn split(self) -> (impl WebsocketSender, impl WebsocketReceiver);
|
||||
|
||||
/// Looping receiver for websocket messages which only returns
|
||||
/// on significant messages.
|
||||
fn recv(
|
||||
@@ -64,7 +67,7 @@ pub trait Websocket {
|
||||
|
||||
/// Traits for split websocket receiver
|
||||
pub trait WebsocketReceiver {
|
||||
type CloseFrame: std::fmt::Debug;
|
||||
type CloseFrame: std::fmt::Debug + Send + Sync + 'static;
|
||||
type Error: std::error::Error + Send + Sync + 'static;
|
||||
|
||||
/// Looping receiver for websocket messages which only returns
|
||||
@@ -73,23 +76,24 @@ pub trait WebsocketReceiver {
|
||||
&mut self,
|
||||
) -> impl Future<
|
||||
Output = Result<WebsocketMessage<Self::CloseFrame>, Self::Error>,
|
||||
>;
|
||||
> + Send
|
||||
+ Sync;
|
||||
}
|
||||
|
||||
/// Traits for split websocket receiver
|
||||
pub trait WebsocketSender {
|
||||
type CloseFrame: std::fmt::Debug;
|
||||
type CloseFrame: std::fmt::Debug + Send + Sync + 'static;
|
||||
type Error: std::error::Error + Send + Sync + 'static;
|
||||
|
||||
/// Streamlined sending on bytes
|
||||
fn send(
|
||||
&mut self,
|
||||
bytes: Bytes,
|
||||
) -> impl Future<Output = Result<(), Self::Error>>;
|
||||
) -> impl Future<Output = Result<(), Self::Error>> + Send + Sync;
|
||||
|
||||
/// Send close message
|
||||
fn close(
|
||||
&mut self,
|
||||
frame: Option<Self::CloseFrame>,
|
||||
) -> impl Future<Output = Result<(), Self::Error>>;
|
||||
) -> impl Future<Output = Result<(), Self::Error>> + Send + Sync;
|
||||
}
|
||||
|
||||
@@ -20,6 +20,14 @@ impl Websocket for TungsteniteWebsocket {
|
||||
type CloseFrame = CloseFrame;
|
||||
type Error = tungstenite::Error;
|
||||
|
||||
fn split(self) -> (impl WebsocketSender, impl WebsocketReceiver) {
|
||||
let (tx, rx) = self.0.split();
|
||||
(
|
||||
TungsteniteWebsocketSender(tx),
|
||||
TungsteniteWebsocketReceiver(rx),
|
||||
)
|
||||
}
|
||||
|
||||
async fn recv(
|
||||
&mut self,
|
||||
) -> Result<WebsocketMessage<Self::CloseFrame>, Self::Error> {
|
||||
@@ -41,18 +49,6 @@ impl Websocket for TungsteniteWebsocket {
|
||||
}
|
||||
}
|
||||
|
||||
impl TungsteniteWebsocket {
|
||||
pub fn split(
|
||||
self,
|
||||
) -> (TungsteniteWebsocketSender, TungsteniteWebsocketReceiver) {
|
||||
let (tx, rx) = self.0.split();
|
||||
(
|
||||
TungsteniteWebsocketSender(tx),
|
||||
TungsteniteWebsocketReceiver(rx),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub type InnerWebsocketReceiver =
|
||||
SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user