implement noise auth basic

This commit is contained in:
mbecker20
2025-09-20 15:44:50 -07:00
parent 01de8c4a9b
commit 3d455f5142
16 changed files with 563 additions and 67 deletions

154
Cargo.lock generated
View File

@@ -17,6 +17,41 @@ version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa"
[[package]]
name = "aead"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0"
dependencies = [
"crypto-common",
"generic-array",
]
[[package]]
name = "aes"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0"
dependencies = [
"cfg-if",
"cipher",
"cpufeatures",
]
[[package]]
name = "aes-gcm"
version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1"
dependencies = [
"aead",
"aes",
"cipher",
"ctr",
"ghash",
"subtle",
]
[[package]]
name = "ahash"
version = "0.8.12"
@@ -776,6 +811,15 @@ dependencies = [
"wyz",
]
[[package]]
name = "blake2"
version = "0.10.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe"
dependencies = [
"digest",
]
[[package]]
name = "block-buffer"
version = "0.10.4"
@@ -941,6 +985,30 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]]
name = "chacha20"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3613f74bd2eac03dad61bd53dbe620703d4371614fe0bc3b9f04dd36fe4e818"
dependencies = [
"cfg-if",
"cipher",
"cpufeatures",
]
[[package]]
name = "chacha20poly1305"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10cd79432192d1c0f4e1a0fef9527696cc039165d729fb41b3f4f4f354c2dc35"
dependencies = [
"aead",
"chacha20",
"cipher",
"poly1305",
"zeroize",
]
[[package]]
name = "chrono"
version = "0.4.42"
@@ -973,6 +1041,7 @@ checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad"
dependencies = [
"crypto-common",
"inout",
"zeroize",
]
[[package]]
@@ -1246,6 +1315,15 @@ dependencies = [
"typenum",
]
[[package]]
name = "ctr"
version = "0.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835"
dependencies = [
"cipher",
]
[[package]]
name = "curve25519-dalek"
version = "4.1.3"
@@ -1933,6 +2011,16 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "ghash"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0d8a4362ccb29cb0b265253fb0a2728f592895ee6854fd9bc13f2ffda266ff1"
dependencies = [
"opaque-debug",
"polyval",
]
[[package]]
name = "gimli"
version = "0.31.1"
@@ -2785,6 +2873,7 @@ dependencies = [
"serde_json",
"serde_yaml_ng",
"serror",
"sha1",
"sha2",
"slack_client_rs",
"svi",
@@ -2846,6 +2935,7 @@ dependencies = [
"tokio-util",
"tracing",
"transport",
"url",
"urlencoding",
"uuid",
]
@@ -3399,6 +3489,12 @@ version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4895175b425cb1f87721b59f0f286c2092bd4af812243672510e1ac53e2e0ad"
[[package]]
name = "opaque-debug"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381"
[[package]]
name = "openidconnect"
version = "4.0.1"
@@ -3685,6 +3781,7 @@ version = "1.19.5"
dependencies = [
"anyhow",
"axum",
"base64 0.22.1",
"bytes",
"cache",
"futures-util",
@@ -3694,11 +3791,13 @@ dependencies = [
"serde",
"serde_json",
"serror",
"sha1",
"tokio",
"tokio-tungstenite 0.27.0",
"tokio-util",
"tracing",
"transport",
"url",
"uuid",
]
@@ -3773,6 +3872,29 @@ dependencies = [
"spki",
]
[[package]]
name = "poly1305"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8159bd90725d2df49889a078b54f4f79e87f1f8a8444194cdca81d38f5393abf"
dependencies = [
"cpufeatures",
"opaque-debug",
"universal-hash",
]
[[package]]
name = "polyval"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25"
dependencies = [
"cfg-if",
"cpufeatures",
"opaque-debug",
"universal-hash",
]
[[package]]
name = "portable-pty"
version = "0.9.0"
@@ -4908,6 +5030,23 @@ version = "1.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03"
[[package]]
name = "snow"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "599b506ccc4aff8cf7844bc42cf783009a434c1e26c964432560fb6d6ad02d82"
dependencies = [
"aes-gcm",
"blake2",
"chacha20poly1305",
"curve25519-dalek",
"getrandom 0.3.3",
"ring",
"rustc_version",
"sha2",
"subtle",
]
[[package]]
name = "socket2"
version = "0.5.10"
@@ -5552,9 +5691,14 @@ version = "1.19.4"
dependencies = [
"anyhow",
"axum",
"base64 0.22.1",
"bytes",
"futures-util",
"rand 0.9.2",
"serde",
"sha1",
"sha2",
"snow",
"tokio",
"tokio-tungstenite 0.27.0",
"tracing",
@@ -5696,6 +5840,16 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a1a07cc7db3810833284e8d372ccdc6da29741639ecc70c9ec107df0fa6154c"
[[package]]
name = "universal-hash"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea"
dependencies = [
"crypto-common",
"subtle",
]
[[package]]
name = "unsafe-libyaml"
version = "0.2.11"

View File

@@ -69,7 +69,7 @@ axum = { version = "0.8.4", features = ["ws", "json", "macros"] }
# SER/DE
ipnetwork = { version = "0.21.1", features = ["serde"] }
indexmap = { version = "2.11.3", features = ["serde"] }
indexmap = { version = "2.11.4", features = ["serde"] }
serde = { version = "1.0.219", features = ["derive"] }
strum = { version = "0.27.2", features = ["derive"] }
bson = { version = "2.15.0" } # must keep in sync with mongodb version
@@ -77,6 +77,7 @@ serde_yaml_ng = "0.10.0"
serde_json = "1.0.145"
serde_qs = "0.15.0"
toml = "0.9.6"
url = "2.5.7"
# ERROR
anyhow = "1.0.99"
@@ -105,7 +106,9 @@ urlencoding = "2.1.3"
nom_pem = "4.0.0"
bcrypt = "0.17.1"
base64 = "0.22.1"
snow = "0.10.0"
hmac = "0.12.1"
sha1 = "0.10.6"
sha2 = "0.10.9"
rand = "0.9.2"
hex = "0.4.3"
@@ -129,6 +132,7 @@ croner = "3.0.0"
# MISC
async-compression = { version = "0.4.30", features = ["tokio", "gzip"] }
derive_builder = "0.20.2"
shell-escape = "0.1.5"
comfy-table = "7.2.1"
typeshare = "1.0.4"
octorust = "0.10.0"
@@ -136,5 +140,4 @@ dashmap = "6.1.0"
wildcard = "0.3.0"
colored = "3.0.0"
regex = "1.11.2"
bytes = "1.10.1"
shell-escape = "0.1.5"
bytes = "1.10.1"

View File

@@ -80,5 +80,6 @@ uuid.workspace = true
envy.workspace = true
rand.workspace = true
hmac.workspace = true
sha1.workspace = true
sha2.workspace = true
hex.workspace = true

View File

@@ -1,5 +1,6 @@
use axum::{
extract::{Query, WebSocketUpgrade},
http::HeaderMap,
response::Response,
};
use komodo_client::entities::server::Server;
@@ -9,8 +10,13 @@ pub async fn handler(
Query(PeripheryConnectionQuery { server }): Query<
PeripheryConnectionQuery,
>,
headers: HeaderMap,
ws: WebSocketUpgrade,
) -> serror::Result<Response> {
let server = crate::resource::get::<Server>(&server).await?;
periphery_client::connection::server::handler(server.id, ws).await
let server_id = crate::resource::get::<Server>(&server).await?.id;
let query = format!("server={}", urlencoding::encode(&server));
periphery_client::connection::server::handler(
server_id, headers, query, ws,
)
.await
}

View File

@@ -59,4 +59,5 @@ bytes.workspace = true
axum.workspace = true
clap.workspace = true
envy.workspace = true
uuid.workspace = true
uuid.workspace = true
url.workspace = true

View File

@@ -1,9 +1,10 @@
use std::time::Duration;
use anyhow::Context;
use axum::http::HeaderValue;
use bytes::Bytes;
use transport::{
auth::handle_client_side_login,
auth2::{ConnectionIdentifiers, handle_client_side_login},
fix_ws_address,
websocket::{
WebsocketMessage, WebsocketReceiver, WebsocketSender,
@@ -33,16 +34,23 @@ pub async fn handler(
"{core_host}/ws/periphery?server={}",
urlencoding::encode(connect_as)
);
let parsed_url =
::url::Url::parse(&url).context("Failed to parse ws url")?;
let host: Vec<u8> = parsed_url
.host()
.context("url has no host")?
.to_string()
.into();
let query =
parsed_url.query().context("url has no query")?.as_bytes();
info!("Initiating outbound connection to {url}");
loop {
let mut socket = match tokio_tungstenite::connect_async(&url)
.await
{
Ok((socket, _)) => TungsteniteWebsocket(socket),
let (mut socket, accept) = match connect_websocket(&url).await {
Ok(res) => res,
Err(e) => {
warn!("failed to connect to websocket | url: {url} | {e:?}");
warn!("{e:#}");
tokio::time::sleep(Duration::from_secs(5)).await;
continue;
}
@@ -50,8 +58,16 @@ pub async fn handler(
info!("Connected to core connection websocket");
let id = ConnectionIdentifiers {
host: &host,
accept: accept.as_bytes(),
query,
};
// TODO: source the pk
if let Err(e) =
handle_client_side_login(&mut socket, Bytes::new()).await
handle_client_side_login(&mut socket, id, b"RANDOM_PRIVATE_KEY")
.await
{
warn!("Failed to login | {e:#}");
tokio::time::sleep(Duration::from_secs(5)).await;
@@ -111,3 +127,16 @@ pub async fn handler(
};
}
}
async fn connect_websocket(
url: &str,
) -> anyhow::Result<(TungsteniteWebsocket, HeaderValue)> {
let (ws, mut response) = tokio_tungstenite::connect_async(url)
.await
.with_context(|| format!("Failed to connect to {url}"))?;
let accept = response
.headers_mut()
.remove("sec-websocket-accept")
.context("sec-websocket-accept")?;
Ok((TungsteniteWebsocket(ws), accept))
}

View File

@@ -8,7 +8,7 @@ use axum::{
Router,
body::Body,
extract::{ConnectInfo, WebSocketUpgrade},
http::{Request, StatusCode},
http::{HeaderMap, Request, StatusCode},
middleware::{self, Next},
response::Response,
routing::get,
@@ -17,7 +17,9 @@ use axum_server::tls_rustls::RustlsConfig;
use bytes::Bytes;
use serror::{AddStatusCode, AddStatusCodeError};
use transport::{
auth::handle_server_side_login,
auth2::{
ConnectionIdentifiers, compute_accept, handle_server_side_login,
},
websocket::{
WebsocketMessage, WebsocketReceiver, WebsocketSender,
axum::AxumWebsocket,
@@ -67,17 +69,38 @@ pub async fn run() -> anyhow::Result<()> {
Ok(())
}
async fn handler(ws: WebSocketUpgrade) -> serror::Result<Response> {
async fn handler(
mut headers: HeaderMap,
ws: WebSocketUpgrade,
) -> serror::Result<Response> {
// Limits to only one active websocket connection.
let mut write_receiver = ws_receiver()
.try_lock()
.status_code(StatusCode::FORBIDDEN)?;
let host = headers
.remove("x-forwarded-host")
.or(headers.remove("host"))
.context("Failed to get connection host")
.status_code(StatusCode::UNAUTHORIZED)?;
let ws_key = headers
.remove("sec-websocket-key")
.context("Headers do not contain Sec-Websocket-Key")
.status_code(StatusCode::UNAUTHORIZED)?;
let ws_accept = compute_accept(ws_key.as_bytes());
Ok(ws.on_upgrade(|socket| async move {
let mut socket = AxumWebsocket(socket);
let id = ConnectionIdentifiers {
host: host.as_bytes(),
query: &[],
accept: ws_accept.as_bytes(),
};
if let Err(e) =
handle_server_side_login(&mut socket, |b| true).await
handle_server_side_login(&mut socket, id, b"TEMP_SERVER_PK")
.await
{
warn!("Client failed to login | {e:#}");
return;

View File

@@ -25,8 +25,11 @@ serde_json.workspace = true
tracing.workspace = true
anyhow.workspace = true
rustls.workspace = true
base64.workspace = true
bytes.workspace = true
tokio.workspace = true
serde.workspace = true
axum.workspace = true
uuid.workspace = true
uuid.workspace = true
sha1.workspace = true
url.workspace = true

View File

@@ -1,14 +1,14 @@
use std::{collections::HashMap, sync::Arc, time::Duration};
use anyhow::Context;
use axum::http::HeaderValue;
use bytes::Bytes;
use komodo_client::entities::server::Server;
use rustls::{ClientConfig, client::danger::ServerCertVerifier};
use tokio::net::TcpStream;
use tokio_tungstenite::{Connector, MaybeTlsStream, WebSocketStream};
use tokio_tungstenite::Connector;
use tracing::{info, warn};
use transport::{
auth::handle_client_side_login,
auth2::{ConnectionIdentifiers, handle_client_side_login},
fix_ws_address,
websocket::{
WebsocketMessage, WebsocketReceiver as _, WebsocketSender as _,
@@ -95,6 +95,11 @@ async fn spawn_client_connection(
server_id: String,
address: String,
) -> anyhow::Result<()> {
let url = ::url::Url::parse(&address)
.context("Failed to parse server address")?;
let host: Vec<u8> =
url.host().context("url has no host")?.to_string().into();
info!("Spawning connection for {server_id}");
let handler = MessageHandler::new(&server_id).await;
@@ -111,19 +116,30 @@ async fn spawn_client_connection(
tokio::spawn(async move {
loop {
let mut socket = match connect_websocket(&address).await {
Ok(socket) => TungsteniteWebsocket(socket),
Err(e) => {
connection.set_error(e).await;
tokio::time::sleep(Duration::from_secs(5)).await;
continue;
}
};
let (mut socket, accept) =
match connect_websocket(&address).await {
Ok(res) => res,
Err(e) => {
connection.set_error(e).await;
tokio::time::sleep(Duration::from_secs(5)).await;
continue;
}
};
info!("PERIPHERY: Connected to {address}");
if let Err(e) =
handle_client_side_login(&mut socket, Bytes::new()).await
let id = ConnectionIdentifiers {
host: &host,
accept: accept.as_bytes(),
query: &[],
};
if let Err(e) = handle_client_side_login(
&mut socket,
id,
b"RANDOM_PRIVATE_KEY",
)
.await
{
connection.set_error(e).await;
tokio::time::sleep(Duration::from_secs(5)).await;
@@ -207,8 +223,8 @@ async fn spawn_client_connection(
pub async fn connect_websocket(
url: &str,
) -> anyhow::Result<WebSocketStream<MaybeTlsStream<TcpStream>>> {
let (stream, _) = if url.starts_with("wss") {
) -> anyhow::Result<(TungsteniteWebsocket, HeaderValue)> {
let (ws, mut response) = if url.starts_with("wss") {
tokio_tungstenite::connect_async_tls_with_config(
url,
None,
@@ -232,7 +248,12 @@ pub async fn connect_websocket(
)?
};
Ok(stream)
let accept = response
.headers_mut()
.remove("sec-websocket-accept")
.context("sec-websocket-accept")?;
Ok((TungsteniteWebsocket(ws), accept))
}
#[derive(Debug)]

View File

@@ -1,8 +1,16 @@
use axum::{extract::WebSocketUpgrade, response::Response};
use anyhow::Context;
use axum::{
extract::WebSocketUpgrade,
http::{HeaderMap, StatusCode},
response::Response,
};
use bytes::Bytes;
use serror::AddStatusCode;
use tracing::{error, info, warn};
use transport::{
auth::handle_server_side_login,
auth2::{
ConnectionIdentifiers, compute_accept, handle_server_side_login,
},
websocket::{
WebsocketMessage, WebsocketReceiver, WebsocketSender,
axum::AxumWebsocket,
@@ -15,8 +23,21 @@ use crate::connection::{
pub async fn handler(
server_id: String,
mut headers: HeaderMap,
query: String,
ws: WebSocketUpgrade,
) -> serror::Result<Response> {
let host = headers
.remove("x-forwarded-host")
.or(headers.remove("host"))
.context("Failed to get connection host")
.status_code(StatusCode::UNAUTHORIZED)?;
let ws_key = headers
.get("sec-websocket-key")
.context("Headers do not contain Sec-Websocket-Key")
.status_code(StatusCode::UNAUTHORIZED)?;
let ws_accept = compute_accept(ws_key.as_bytes());
let handler = MessageHandler::new(&server_id).await;
let (connection, mut write_receiver) =
@@ -31,9 +52,17 @@ pub async fn handler(
Ok(ws.on_upgrade(|socket| async move {
let mut socket = AxumWebsocket(socket);
let id = ConnectionIdentifiers {
host: host.as_bytes(),
query: query.as_bytes(),
accept: ws_accept.as_bytes(),
};
// TODO: use proper private key
if let Err(e) =
handle_server_side_login(&mut socket, |b| true).await
handle_server_side_login(&mut socket, id, b"RANDOM_SERVER_PK")
.await
{
warn!("PERIPHERY: Client failed to login | {e:#}");
connection.set_error(e).await;

View File

@@ -12,8 +12,13 @@ tokio-tungstenite.workspace = true
futures-util.workspace = true
tracing.workspace = true
anyhow.workspace = true
base64.workspace = true
bytes.workspace = true
tokio.workspace = true
serde.workspace = true
axum.workspace = true
rand.workspace = true
snow.workspace = true
sha1.workspace = true
sha2.workspace = true
uuid.workspace = true

View File

@@ -1,5 +1,6 @@
use anyhow::{Context, anyhow};
use bytes::Bytes;
use rand::RngCore;
use tracing::{info, warn};
use crate::{
@@ -13,10 +14,50 @@ pub enum AuthType {
Noise = 1,
}
pub async fn handle_server_side_login(
socket: &mut impl Websocket,
validate_credentials: impl Fn(&[u8]) -> bool,
) -> anyhow::Result<()> {
// Server generates random nonce and sends to client
let nonce = nonce();
socket.send(Bytes::copy_from_slice(&nonce)).await?;
let credentials = match socket.recv().await? {
WebsocketMessage::Binary(bytes) => bytes,
WebsocketMessage::Close(frame) => {
return Err(anyhow!(
"Websocket close frame received during login | frame: {frame:?}"
));
}
WebsocketMessage::Closed => {
return Err(anyhow!("Websocket closed during login"));
}
};
if validate_credentials(&credentials) {
// Send login confirmation
// TODO: remove / edit logs
info!("Client logged in");
socket.send(MessageState::Successful.into()).await?;
return Ok(());
} else {
// Send login failure
warn!("Client failed to log in");
socket.send(MessageState::Failed.into()).await?;
let _ = socket.close(None).await;
return Err(anyhow!("Received invalid credentials"));
}
}
pub async fn handle_client_side_login(
socket: &mut impl Websocket,
credentials: Bytes,
) -> anyhow::Result<()> {
let nonce = socket
.recv_bytes()
.await
.context("Failed to receive connection nonce")?;
socket
.send(credentials)
.await
@@ -44,32 +85,8 @@ pub async fn handle_client_side_login(
}
}
pub async fn handle_server_side_login(
socket: &mut impl Websocket,
validate_credentials: impl Fn(&[u8]) -> bool,
) -> anyhow::Result<()> {
let credentials = match socket.recv().await? {
WebsocketMessage::Binary(bytes) => bytes,
WebsocketMessage::Close(frame) => {
return Err(anyhow!(
"Websocket close frame received during login | frame: {frame:?}"
));
}
WebsocketMessage::Closed => {
return Err(anyhow!("Websocket closed during login"));
}
};
if validate_credentials(&credentials) {
// Send login confirmation
// TODO: remove / edit logs
info!("Client logged in");
socket.send(MessageState::Successful.into()).await?;
return Ok(());
} else {
// Send login failure
warn!("Client failed to log in");
socket.send(MessageState::Failed.into()).await?;
let _ = socket.close(None).await;
return Err(anyhow!("Received invalid credentials"));
}
fn nonce() -> [u8; 32] {
let mut out = [0u8; 32];
rand::rng().fill_bytes(&mut out);
out
}

184
lib/transport/src/auth2.rs Normal file
View File

@@ -0,0 +1,184 @@
//! Implementes both sides of Noise handshake
//! using asymmetric private-public key authentication.
//!
//! TODO: Revisit
//! Note. Relies on Server being behind trusted TLS connection.
//! This is trivial for Periphery -> Core connection, but presents a challenge
//! for Core -> Periphery, where untrusted TLS certs are being used.
use anyhow::Context;
use base64::{Engine, prelude::BASE64_STANDARD};
use bytes::Bytes;
use rand::RngCore;
use sha2::{Digest, Sha256};
use crate::websocket::{
Websocket, axum::AxumWebsocket, tungstenite::TungsteniteWebsocket,
};
pub struct ConnectionIdentifiers<'a> {
/// Server hostname
pub host: &'a [u8],
/// Query: 'server=<SERVER>'
pub query: &'a [u8],
/// Sec-Websocket-Accept, unique for each connection
pub accept: &'a [u8],
}
#[derive(Debug, Clone, Copy)]
pub enum AuthType {
Passkey = 0,
Noise = 1,
}
const NOISE_XX_PARAMS: &str = "Noise_XX_25519_ChaChaPoly_BLAKE2s";
pub async fn handle_server_side_login(
socket: &mut AxumWebsocket,
id: ConnectionIdentifiers<'_>,
private_key: &[u8],
) -> anyhow::Result<()> {
// Server generates random nonce and sends to client
let nonce = nonce();
socket
.send(Bytes::from_owner(nonce))
.await
.context("Failed to send connection nonce")?;
// Build the handshake using the unique prologue hash.
// The prologue must be the same on both sides of connection.
let mut handshake = snow::Builder::new(NOISE_XX_PARAMS.parse()?)
.local_private_key(private_key)?
.prologue(&id.hash(&nonce))?
.build_responder()?;
// Receive and read handshake_m1
let handshake_m1 = socket
.recv_bytes()
.await
.context("Failed to get handshake_m1")?;
handshake
.read_message(&handshake_m1, &mut [])
.context("Failed to read handshake_m1")?;
// Send handshake_m2
let mut handshake_m2 = [0u8; 1024];
let written = handshake
.write_message(&[], &mut handshake_m2)
.context("Failed to write handshake_m2")?;
socket
.send(Bytes::copy_from_slice(&handshake_m2[..written]))
.await
.context("Failed to send handshake_m2")?;
// Receive and read handshake_m3
let handshake_m3 = socket
.recv_bytes()
.await
.context("Failed to get handshake_m3")?;
handshake
.read_message(&handshake_m3, &mut [])
.context("Failed to read handshake_m3")?;
// Server now has client public key
let client_public_key = handshake
.get_remote_static()
.context("Failed to get remote public key")?;
println!(
"Server got client public key: {}",
BASE64_STANDARD.encode(client_public_key)
);
Ok(())
}
pub async fn handle_client_side_login(
socket: &mut TungsteniteWebsocket,
id: ConnectionIdentifiers<'_>,
private_key: &[u8],
) -> anyhow::Result<()> {
// Receive nonce from server
let nonce = socket
.recv_bytes()
.await
.context("Failed to receive connection nonce")?;
// Build the handshake using the unique prologue hash.
// The prologue must be the same on both sides of connection.
let mut handshake = snow::Builder::new(NOISE_XX_PARAMS.parse()?)
.local_private_key(private_key)?
.prologue(&id.hash(&nonce))?
.build_initiator()?;
// Send handshake_m1
let mut handshake_m1 = [0u8; 1024];
let written = handshake
.write_message(&[], &mut handshake_m1)
.context("Failed to write handshake_m1")?;
socket
.send(Bytes::copy_from_slice(&handshake_m1[..written]))
.await
.context("Failed to send handshake_m1")?;
// Receive and read handshake_m2
let handshake_m2 = socket
.recv_bytes()
.await
.context("Failed to get handshake_m2")?;
handshake
.read_message(&handshake_m2, &mut [])
.context("Failed to read handshake_m2")?;
// Client now has server public key
let server_public_key = handshake
.get_remote_static()
.context("Failed to get remote public key")?;
println!(
"Client got server public key: {}",
BASE64_STANDARD.encode(server_public_key)
);
// Send handshake_m3
let mut handshake_m3 = [0u8; 1024];
let written = handshake
.write_message(&[], &mut handshake_m3)
.context("Failed to write handshake_m3")?;
socket
.send(Bytes::copy_from_slice(&handshake_m3[..written]))
.await
.context("Failed to send handshake_m3")?;
Ok(())
}
fn nonce() -> [u8; 32] {
let mut out = [0u8; 32];
rand::rng().fill_bytes(&mut out);
out
}
impl ConnectionIdentifiers<'_> {
/// nonce: Server computed random connection nonce, sent to client before auth handshake
pub fn hash(&self, nonce: &[u8]) -> [u8; 32] {
let mut hash = Sha256::new();
hash.update(b"noise-wss-v1|");
hash.update(self.host);
hash.update(b"|");
hash.update(self.query);
hash.update(b"|");
hash.update(self.accept);
hash.update(b"|");
hash.update(nonce);
hash.finalize().into()
}
}
pub fn compute_accept(sec_websocket_key: &[u8]) -> String {
// This is standard GUID to compute Sec-Websocket-Accept
const GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
let mut sha1 = sha1::Sha1::new();
sha1.update(sec_websocket_key);
sha1.update(GUID);
let digest = sha1.finalize();
BASE64_STANDARD.encode(digest)
}

View File

@@ -2,7 +2,7 @@ use anyhow::{Context, anyhow};
use bytes::Bytes;
use uuid::Uuid;
use crate::{MessageState, auth::AuthType};
use crate::{MessageState, auth2::AuthType};
/// Serializes data + channel id + state to byte vec.
/// The last byte is the State, and the 16 before that is the Uuid.

View File

@@ -2,6 +2,7 @@ use ::bytes::Bytes;
use serde::{Deserialize, Serialize};
pub mod auth;
pub mod auth2;
pub mod bytes;
pub mod channel;
pub mod websocket;

View File

@@ -1,6 +1,7 @@
//! Wrappers to normalize behavior of websockets between Tungstenite and Axum,
//! as well as streamline process of handling socket messages.
use anyhow::anyhow;
use bytes::Bytes;
pub mod axum;
@@ -30,6 +31,24 @@ pub trait Websocket {
Output = Result<WebsocketMessage<Self::CloseFrame>, Self::Error>,
>;
/// Looping receiver for websocket messages which only returns
/// on significant messages.
fn recv_bytes(
&mut self,
) -> impl Future<Output = Result<Bytes, anyhow::Error>> {
async {
match self.recv().await? {
WebsocketMessage::Binary(bytes) => Ok(bytes),
WebsocketMessage::Close(frame) => {
Err(anyhow!("Connection closed with framed: {frame:?}"))
}
WebsocketMessage::Closed => {
Err(anyhow!("Connection already closed"))
}
}
}
}
/// Streamlined sending on bytes
fn send(
&mut self,