diff --git a/Cargo.lock b/Cargo.lock index cf4057053..78e578374 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -104,7 +104,7 @@ dependencies = [ "sha-1", "sync_wrapper", "tokio", - "tokio-tungstenite", + "tokio-tungstenite 0.17.2", "tower", "tower-http", "tower-layer", @@ -401,6 +401,7 @@ dependencies = [ "sha2", "slack_client_rs", "tokio", + "tokio-tungstenite 0.18.0", "tokio-util", "tower", "tower-http", @@ -1297,7 +1298,7 @@ dependencies = [ "libc", "log", "wasi 0.11.0+wasi-snapshot-preview1", - "windows-sys", + "windows-sys 0.36.1", ] [[package]] @@ -1389,6 +1390,7 @@ dependencies = [ "daemonize", "dotenv", "envy", + "futures-util", "helpers", "monitor_types", "run_command", @@ -1397,6 +1399,7 @@ dependencies = [ "serde_json", "sysinfo", "tokio", + "tokio-util", "toml", "tower", ] @@ -1420,9 +1423,9 @@ dependencies = [ [[package]] name = "mungos" -version = "0.2.26" +version = "0.2.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ac86e7d26ca046ef914a069d1185af3a82471d2aac3c8390e0cd5ece11ace41" +checksum = "f51f1a0a42db4291d00a7d0a269fb19f11784fadc6b335637a2ab8f6702aa8c3" dependencies = [ "anyhow", "futures", @@ -1641,7 +1644,7 @@ dependencies = [ "libc", "redox_syscall", "smallvec", - "windows-sys", + "windows-sys 0.36.1", ] [[package]] @@ -1664,11 +1667,14 @@ name = "periphery_client" version = "0.1.0" dependencies = [ "anyhow", + "futures-util", "helpers", "monitor_types", "reqwest", "serde", "serde_json", + "tokio", + "tokio-tungstenite 0.18.0", ] [[package]] @@ -1965,7 +1971,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "88d6731146462ea25d9244b2ed5fd1d716d25c52e4d54aa4fb0f3c4e9854dbe2" dependencies = [ "lazy_static", - "windows-sys", + "windows-sys 0.36.1", ] [[package]] @@ -2117,6 +2123,17 @@ dependencies = [ "digest", ] +[[package]] +name = "sha1" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sha2" version = "0.10.6" @@ -2372,9 +2389,9 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" [[package]] name = "tokio" -version = "1.22.0" +version = "1.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d76ce4a75fb488c605c54bf610f221cea8b0dafb53333c1a67e8ee199dcd2ae3" +checksum = "eab6d665857cc6ca78d6e80303a02cea7a7851e85dfbd77cbdc09bd129f1ef46" dependencies = [ "autocfg", "bytes", @@ -2387,7 +2404,7 @@ dependencies = [ "signal-hook-registry", "socket2", "tokio-macros", - "winapi", + "windows-sys 0.42.0", ] [[package]] @@ -2431,7 +2448,21 @@ dependencies = [ "futures-util", "log", "tokio", - "tungstenite", + "tungstenite 0.17.3", +] + +[[package]] +name = "tokio-tungstenite" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54319c93411147bced34cb5609a80e0a8e44c5999c93903a81cd866630ec0bfd" +dependencies = [ + "futures-util", + "log", + "native-tls", + "tokio", + "tokio-native-tls", + "tungstenite 0.18.0", ] [[package]] @@ -2606,6 +2637,26 @@ dependencies = [ "utf-8", ] +[[package]] +name = "tungstenite" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30ee6ab729cd4cf0fd55218530c4522ed30b7b6081752839b68fcec8d0960788" +dependencies = [ + "base64", + "byteorder", + "bytes", + "http", + "httparse", + "log", + "native-tls", + "rand", + "sha1", + "thiserror", + "url", + "utf-8", +] + [[package]] name = "typed-builder" version = "0.10.0" @@ -2895,43 +2946,100 @@ version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ea04155a16a59f9eab786fe12a4a450e75cdb175f9e0d80da1e17db09f55b8d2" dependencies = [ - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_msvc", + "windows_aarch64_msvc 0.36.1", + "windows_i686_gnu 0.36.1", + "windows_i686_msvc 0.36.1", + "windows_x86_64_gnu 0.36.1", + "windows_x86_64_msvc 0.36.1", ] +[[package]] +name = "windows-sys" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc 0.42.0", + "windows_i686_gnu 0.42.0", + "windows_i686_msvc 0.42.0", + "windows_x86_64_gnu 0.42.0", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc 0.42.0", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d2aa71f6f0cbe00ae5167d90ef3cfe66527d6f613ca78ac8024c3ccab9a19e" + [[package]] name = "windows_aarch64_msvc" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9bb8c3fd39ade2d67e9874ac4f3db21f0d710bee00fe7cab16949ec184eeaa47" +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd0f252f5a35cac83d6311b2e795981f5ee6e67eb1f9a7f64eb4500fbc4dcdb4" + [[package]] name = "windows_i686_gnu" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "180e6ccf01daf4c426b846dfc66db1fc518f074baa793aa7d9b9aaeffad6a3b6" +[[package]] +name = "windows_i686_gnu" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbeae19f6716841636c28d695375df17562ca208b2b7d0dc47635a50ae6c5de7" + [[package]] name = "windows_i686_msvc" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2e7917148b2812d1eeafaeb22a97e4813dfa60a3f8f78ebe204bcc88f12f024" +[[package]] +name = "windows_i686_msvc" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84c12f65daa39dd2babe6e442988fc329d6243fdce47d7d2d155b8d874862246" + [[package]] name = "windows_x86_64_gnu" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4dcd171b8776c41b97521e5da127a2d86ad280114807d0b2ab1e462bc764d9e1" +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf7b1b21b5362cbc318f686150e5bcea75ecedc74dd157d874d754a2ca44b0ed" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09d525d2ba30eeb3297665bd434a54297e4170c7f1a44cad4ef58095b4cd2028" + [[package]] name = "windows_x86_64_msvc" version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c811ca4a8c853ef420abd8592ba53ddbbac90410fab6903b3e79972a631f7680" +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40009d85759725a34da6d89a94e63d7bdc50a862acf0dbc7c8e488f1edcb6f5" + [[package]] name = "winreg" version = "0.7.0" diff --git a/core/Cargo.toml b/core/Cargo.toml index b7c03fbeb..57da8a00c 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -12,13 +12,14 @@ types = { package = "monitor_types", path = "../lib/types" } periphery = { package = "periphery_client", path = "../lib/periphery_client" } axum_oauth2 = { path = "../lib/axum_oauth2" } tokio = { version = "1.21", features = ["full"] } +tokio-tungstenite = "0.18" tokio-util = "0.7" axum = { version = "0.6", features = ["ws", "json"] } axum-extra = { version = "0.4", features = ["spa"] } tower = { version = "0.4", features = ["full"] } tower-http = { version = "0.3", features = ["cors"] } slack = { package = "slack_client_rs", version = "0.0.7" } -mungos = "0.2.26" +mungos = "0.2.27" serde = "1.0" serde_derive = "1.0" serde_json = "1.0" diff --git a/core/src/api/server.rs b/core/src/api/server.rs index b46b22706..13c523a4d 100644 --- a/core/src/api/server.rs +++ b/core/src/api/server.rs @@ -1,12 +1,16 @@ use anyhow::Context; use axum::{ - extract::{Path, Query}, + extract::{ws::Message as AxumMessage, Path, Query, WebSocketUpgrade}, + response::IntoResponse, routing::{delete, get, patch, post}, Extension, Json, Router, }; -use futures_util::future::join_all; +use futures_util::{future::join_all, SinkExt, StreamExt}; use helpers::handle_anyhow_error; use mungos::{Deserialize, Document, Serialize}; +use tokio::select; +use tokio_tungstenite::tungstenite::Message; +use tokio_util::sync::CancellationToken; use types::{ traits::Permissioned, BasicContainerInfo, ImageSummary, Network, PermissionLevel, Server, ServerActionState, ServerStatus, ServerWithStatus, SystemStats, SystemStatsQuery, @@ -132,6 +136,22 @@ pub fn router() -> Router { }, ), ) + .route( + "/:id/stats/ws", + get( + |Extension(state): StateExtension, + Extension(user): RequestUserExtension, + Path(ServerId { id }): Path, + Query(query): Query, + ws: WebSocketUpgrade| async move { + let connection = state + .subscribe_to_stats_ws(&id, &user, &query, ws) + .await + .map_err(handle_anyhow_error)?; + response!(connection) + }, + ), + ) .route( "/:id/networks", get( @@ -337,6 +357,51 @@ impl State { Ok(stats) } + async fn subscribe_to_stats_ws( + &self, + server_id: &str, + user: &RequestUser, + query: &SystemStatsQuery, + ws: WebSocketUpgrade, + ) -> anyhow::Result { + let server = self + .get_server_check_permissions(server_id, user, PermissionLevel::Read) + .await?; + let mut stats_reciever = self.periphery.subscribe_to_stats_ws(&server, query).await?; + let upgrade = ws.on_upgrade(|socket| async move { + let (mut ws_sender, mut ws_recv) = socket.split(); + let cancel = CancellationToken::new(); + let cancel_clone = cancel.clone(); + tokio::spawn(async move { + loop { + let stats = select! { + _ = cancel_clone.cancelled() => break, + stats = stats_reciever.next() => stats + }; + if let Some(Ok(Message::Text(msg))) = stats { + let _ = ws_sender.send(AxumMessage::Text(msg)).await; + } + } + }); + while let Some(msg) = ws_recv.next().await { + match msg { + Ok(msg) => match msg { + AxumMessage::Close(_) => { + cancel.cancel(); + return; + } + _ => {} + }, + Err(_) => { + cancel.cancel(); + return; + } + } + } + }); + Ok(upgrade) + } + async fn get_networks( &self, server_id: &str, diff --git a/core/src/ws/mod.rs b/core/src/ws/mod.rs index 9b6ae9f54..c55770e92 100644 --- a/core/src/ws/mod.rs +++ b/core/src/ws/mod.rs @@ -1,7 +1,67 @@ -use axum::{routing::get, Router}; +use axum::{ + extract::ws::{Message, WebSocket}, + routing::get, + Router, +}; +use crate::{ + auth::{JwtClient, RequestUser}, + state::State, +}; + +mod stats; pub mod update; pub fn router() -> Router { - Router::new().route("/update", get(update::ws_handler)) + Router::new() + .route("/update", get(update::ws_handler)) + .route("/stats/:id", get(stats::ws_handler)) +} + +impl State { + pub async fn ws_login( + &self, + mut socket: WebSocket, + jwt_client: &JwtClient, + ) -> Option<(WebSocket, RequestUser)> { + if let Some(jwt) = socket.recv().await { + match jwt { + Ok(jwt) => match jwt { + Message::Text(jwt) => { + match jwt_client.auth_jwt_check_enabled(&jwt, self).await { + Ok(user) => { + let _ = socket.send(Message::Text("LOGGED_IN".to_string())).await; + Some((socket, user)) + } + Err(e) => { + let _ = socket + .send(Message::Text(format!( + "failed to authenticate user | {e:#?}" + ))) + .await; + let _ = socket.close().await; + None + } + } + } + msg => { + let _ = socket + .send(Message::Text(format!("invalid login msg: {msg:#?}"))) + .await; + let _ = socket.close().await; + None + } + }, + Err(e) => { + let _ = socket + .send(Message::Text(format!("failed to get jwt message: {e:#?}"))) + .await; + let _ = socket.close().await; + None + } + } + } else { + None + } + } } diff --git a/core/src/ws/stats.rs b/core/src/ws/stats.rs new file mode 100644 index 000000000..8ca47c659 --- /dev/null +++ b/core/src/ws/stats.rs @@ -0,0 +1,96 @@ +use std::sync::Arc; + +use axum::{ + extract::{ws::Message as AxumMessage, Path, Query, WebSocketUpgrade}, + http::StatusCode, + response::IntoResponse, +}; +use futures_util::{SinkExt, StreamExt}; +use helpers::handle_anyhow_error; +use mungos::Deserialize; +use tokio::select; +use tokio_tungstenite::tungstenite::Message; +use tokio_util::sync::CancellationToken; +use types::{traits::Permissioned, PermissionLevel, SystemStatsQuery}; + +use crate::{auth::JwtExtension, state::StateExtension}; + +#[derive(Deserialize)] +pub struct ServerId { + id: String, +} + +pub async fn ws_handler( + state: StateExtension, + jwt_client: JwtExtension, + path: Path, + query: Query, + ws: WebSocketUpgrade, +) -> Result { + let server = state + .db + .get_server(&path.id) + .await + .map_err(handle_anyhow_error)?; + let query = Arc::new(query); + let upgrade = ws.on_upgrade(|socket| async move { + let login_res = state.ws_login(socket, &jwt_client).await; + if login_res.is_none() { + return; + } + let (mut socket, user) = login_res.unwrap(); + if !user.is_admin && server.get_user_permissions(&user.id) < PermissionLevel::Read { + let _ = socket + .send(AxumMessage::Text( + "permission denied. user must have at least read permissions on this server" + .to_string(), + )) + .await; + return; + } + let (mut ws_sender, mut ws_reciever) = socket.split(); + let res = state.periphery.subscribe_to_stats_ws(&server, &query).await; + if let Err(e) = &res { + let _ = ws_sender + .send(AxumMessage::Text(format!("ERROR: {e}"))) + .await; + return; + } + let mut stats_recv = res.unwrap(); + let cancel = CancellationToken::new(); + let cancel_clone = cancel.clone(); + tokio::spawn(async move { + loop { + let stats = select! { + _ = cancel_clone.cancelled() => { + let _ = stats_recv.close(None).await; + break + }, + stats = stats_recv.next() => stats, + }; + if let Some(Ok(Message::Text(msg))) = stats { + let _ = ws_sender.send(AxumMessage::Text(msg)).await; + } else { + let _ = stats_recv.close(None).await; + break; + } + } + }); + while let Some(msg) = ws_reciever.next().await { + match msg { + Ok(msg) => match msg { + AxumMessage::Close(_) => { + cancel.cancel(); + return; + } + _ => {} + }, + Err(_) => { + cancel.cancel(); + return; + } + } + } + }); + Ok(upgrade) +} diff --git a/core/src/ws/update.rs b/core/src/ws/update.rs index b2803aab0..75ced9fa4 100644 --- a/core/src/ws/update.rs +++ b/core/src/ws/update.rs @@ -1,13 +1,7 @@ -use std::sync::Arc; - use anyhow::anyhow; use axum::{ - extract::{ - ws::{Message, WebSocket}, - WebSocketUpgrade, - }, + extract::{ws::Message, WebSocketUpgrade}, response::IntoResponse, - Extension, }; use db::DbClient; use futures_util::{SinkExt, StreamExt}; @@ -22,10 +16,7 @@ use tokio::{ use tokio_util::sync::CancellationToken; use types::{PermissionLevel, Update, UpdateTarget, User}; -use crate::{ - auth::{JwtClient, JwtExtension}, - state::{State, StateExtension}, -}; +use crate::{auth::JwtExtension, state::StateExtension}; pub type UpdateWsSender = Mutex>; @@ -46,19 +37,18 @@ impl UpdateWsChannel { } pub async fn ws_handler( - Extension(jwt_client): JwtExtension, - Extension(state): StateExtension, + jwt_client: JwtExtension, + state: StateExtension, ws: WebSocketUpgrade, ) -> impl IntoResponse { let mut reciever = state.update.reciever.resubscribe(); ws.on_upgrade(|socket| async move { - let login_res = login(socket, &jwt_client, &state).await; + let login_res = state.ws_login(socket, &jwt_client).await; if login_res.is_none() { return; } - let (socket, user_id) = login_res.unwrap(); - let (ws_sender, mut ws_reciever) = socket.split(); - let ws_sender = Arc::new(Mutex::new(ws_sender)); + let (socket, user) = login_res.unwrap(); + let (mut ws_sender, mut ws_reciever) = socket.split(); let cancel = CancellationToken::new(); let cancel_clone = cancel.clone(); tokio::spawn(async move { @@ -67,25 +57,21 @@ pub async fn ws_handler( _ = cancel_clone.cancelled() => break, update = reciever.recv() => {update.expect("failed to recv update msg")} }; - let user = state.db.users.find_one_by_id(&user_id).await; + let user = state.db.users.find_one_by_id(&user.id).await; if user.is_err() || user.as_ref().unwrap().is_none() || !user.as_ref().unwrap().as_ref().unwrap().enabled { let _ = ws_sender - .lock() - .await .send(Message::Text(json!({ "type": "INVALID_USER" }).to_string())) .await; - let _ = ws_sender.lock().await.close().await; + let _ = ws_sender.close().await; return; } let user = user.unwrap().unwrap(); // already handle cases where this panics in the above early return - match user_can_see_update(&user, &user_id, &update.target, &state.db).await { + match user_can_see_update(&user, &user.id, &update.target, &state.db).await { Ok(_) => { let _ = ws_sender - .lock() - .await .send(Message::Text(serde_json::to_string(&update).unwrap())) .await; } @@ -114,50 +100,6 @@ pub async fn ws_handler( }) } -async fn login( - mut socket: WebSocket, - jwt_client: &JwtClient, - state: &State, -) -> Option<(WebSocket, String)> { - if let Some(jwt) = socket.recv().await { - match jwt { - Ok(jwt) => match jwt { - Message::Text(jwt) => match jwt_client.auth_jwt_check_enabled(&jwt, state).await { - Ok(user) => { - let _ = socket.send(Message::Text("LOGGED_IN".to_string())).await; - Some((socket, user.id)) - } - Err(e) => { - let _ = socket - .send(Message::Text(format!( - "failed to authenticate user | {e:#?}" - ))) - .await; - let _ = socket.close().await; - None - } - }, - msg => { - let _ = socket - .send(Message::Text(format!("invalid login msg: {msg:#?}"))) - .await; - let _ = socket.close().await; - None - } - }, - Err(e) => { - let _ = socket - .send(Message::Text(format!("failed to get jwt message: {e:#?}"))) - .await; - let _ = socket.close().await; - None - } - } - } else { - None - } -} - async fn user_can_see_update( user: &User, user_id: &str, diff --git a/lib/periphery_client/Cargo.toml b/lib/periphery_client/Cargo.toml index 11c96ac3f..3b93147a9 100644 --- a/lib/periphery_client/Cargo.toml +++ b/lib/periphery_client/Cargo.toml @@ -8,7 +8,10 @@ edition = "2021" [dependencies] types = { package = "monitor_types", path = "../types" } helpers = { path = "../helpers" } +tokio-tungstenite = { version = "0.18", features=["native-tls"] } +tokio = "1.23" reqwest = { version = "0.11", features = ["json"] } serde = "1.0" serde_json = "1.0" -anyhow = "1.0" \ No newline at end of file +anyhow = "1.0" +futures-util = "0.3" \ No newline at end of file diff --git a/lib/periphery_client/src/lib.rs b/lib/periphery_client/src/lib.rs index 833b91a6b..34fdfe5e6 100644 --- a/lib/periphery_client/src/lib.rs +++ b/lib/periphery_client/src/lib.rs @@ -1,6 +1,8 @@ use anyhow::{anyhow, Context}; use reqwest::StatusCode; use serde::{de::DeserializeOwned, Serialize}; +use tokio::net::TcpStream; +use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream}; use types::{Server, SystemStats, SystemStatsQuery}; mod build; @@ -41,7 +43,7 @@ impl PeripheryClient { self.get_json( server, &format!( - "/stats/system?networks={}&components={}&processes={}", + "/stats?networks={}&components={}&processes={}", query.networks, query.components, query.processes ), ) @@ -49,6 +51,24 @@ impl PeripheryClient { .context("failed to get system stats from periphery") } + pub async fn subscribe_to_stats_ws( + &self, + server: &Server, + query: &SystemStatsQuery, + ) -> anyhow::Result>> { + let ws_url = format!( + "{}/stats/ws?networks={}&components={}&processes={}", + server.address.replace("http", "ws"), + query.networks, + query.components, + query.processes + ); + let (socket, _) = connect_async(ws_url) + .await + .context("failed to connect to periphery stats ws")?; + Ok(socket) + } + async fn get_text(&self, server: &Server, endpoint: &str) -> anyhow::Result { let res = self .http_client diff --git a/lib/types/src/lib.rs b/lib/types/src/lib.rs index 8556b670d..67e33e0ec 100644 --- a/lib/types/src/lib.rs +++ b/lib/types/src/lib.rs @@ -239,6 +239,12 @@ pub enum Timelength { ThirtyDays, } +impl Default for Timelength { + fn default() -> Timelength { + Timelength::FiveMinutes + } +} + pub fn monitor_timestamp() -> String { Utc::now().to_rfc3339_opts(SecondsFormat::Millis, false) } diff --git a/lib/types/src/server.rs b/lib/types/src/server.rs index 0e20d5b02..6ac3897c3 100644 --- a/lib/types/src/server.rs +++ b/lib/types/src/server.rs @@ -129,7 +129,7 @@ pub enum ServerStatus { } #[typeshare] -#[derive(Serialize, Deserialize, Debug, Clone, Default)] +#[derive(Serialize, Deserialize, Debug, Clone, Copy, Default)] pub struct SystemStatsQuery { #[serde(default)] pub networks: bool, @@ -150,7 +150,7 @@ impl SystemStatsQuery { } #[typeshare] -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Default, Clone)] pub struct SystemStats { pub cpu_perc: f32, // in % pub mem_used_gb: f64, // in GB @@ -165,7 +165,7 @@ pub struct SystemStats { } #[typeshare] -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Default, Clone)] pub struct DiskUsage { pub used_gb: f64, // in GB pub total_gb: f64, // in GB @@ -175,7 +175,7 @@ pub struct DiskUsage { } #[typeshare] -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub struct SingleDiskUsage { pub mount: PathBuf, pub used_gb: f64, // in GB @@ -183,7 +183,7 @@ pub struct SingleDiskUsage { } #[typeshare] -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub struct SystemNetwork { pub name: String, pub recieved_kb: f64, // in kB @@ -191,7 +191,7 @@ pub struct SystemNetwork { } #[typeshare] -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub struct SystemComponent { pub label: String, pub temp: f32, @@ -201,7 +201,7 @@ pub struct SystemComponent { } #[typeshare] -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub struct SystemProcess { pub pid: u32, pub name: String, diff --git a/periphery/Cargo.toml b/periphery/Cargo.toml index f370244c1..f0496a771 100644 --- a/periphery/Cargo.toml +++ b/periphery/Cargo.toml @@ -15,7 +15,7 @@ types = { package = "monitor_types", path = "../lib/types" } run_command = { version = "0.0.5", features = ["async_tokio"] } async_timing_util = "0.1.12" tokio = { version = "1.21", features = ["full"] } -axum = { version = "0.6" } +axum = { version = "0.6", features = ["ws"] } tower = { version = "0.4", features = ["full"] } dotenv = "0.15" serde = "1.0" @@ -28,3 +28,5 @@ sysinfo = "0.27.2" toml = "0.5" daemonize = "0.4" clap = { version = "4.0", features = ["derive"] } +futures-util = "0.3" +tokio-util = "0.7" diff --git a/periphery/src/api/stats.rs b/periphery/src/api/stats.rs index f3a3aa46e..dbd589167 100644 --- a/periphery/src/api/stats.rs +++ b/periphery/src/api/stats.rs @@ -1,8 +1,19 @@ use std::sync::{Arc, RwLock}; use async_timing_util::wait_until_timelength; -use axum::{extract::Query, routing::get, Extension, Json, Router}; +use axum::{ + extract::{ws::Message, Query, WebSocketUpgrade}, + response::IntoResponse, + routing::get, + Extension, Json, Router, +}; +use futures_util::{SinkExt, StreamExt}; use sysinfo::{ComponentExt, CpuExt, DiskExt, NetworkExt, PidExt, ProcessExt, SystemExt}; +use tokio::{ + select, + sync::broadcast::{self, Receiver}, +}; +use tokio_util::sync::CancellationToken; use types::{ DiskUsage, SingleDiskUsage, SystemComponent, SystemNetwork, SystemProcess, SystemStats, SystemStatsQuery, Timelength, @@ -11,14 +22,24 @@ use types::{ pub fn router(stats_polling_rate: Timelength) -> Router { Router::new() .route( - "/system", + "/", get( |Extension(sys): StatsExtension, Query(query): Query| async move { - let stats = sys.read().unwrap().get_stats(query); + let stats = sys.read().unwrap().get_cached_stats(query); Json(stats) }, ), ) + .route( + "/ws", + get( + |Extension(sys): StatsExtension, + Query(query): Query, + ws: WebSocketUpgrade| async move { + sys.read().unwrap().ws_subscribe(ws, Arc::new(query)) + }, + ), + ) .layer(StatsClient::extension(stats_polling_rate)) } @@ -26,9 +47,11 @@ type StatsExtension = Extension>>; struct StatsClient { sys: sysinfo::System, + cache: SystemStats, polling_rate: Timelength, refresh_ts: u128, refresh_list_ts: u128, + receiver: Receiver, } const BYTES_PER_GB: f64 = 1073741824.0; @@ -36,12 +59,15 @@ const BYTES_PER_MB: f64 = 1048576.0; const BYTES_PER_KB: f64 = 1024.0; impl StatsClient { - pub fn extension(polling_rate: Timelength) -> StatsExtension { + fn extension(polling_rate: Timelength) -> StatsExtension { + let (sender, receiver) = broadcast::channel::(10); let client = StatsClient { sys: sysinfo::System::new_all(), + cache: SystemStats::default(), polling_rate, refresh_ts: 0, refresh_list_ts: 0, + receiver, }; let client = Arc::new(RwLock::new(client)); let clone = client.clone(); @@ -49,9 +75,15 @@ impl StatsClient { let polling_rate = polling_rate.to_string().parse().unwrap(); loop { let ts = wait_until_timelength(polling_rate, 0).await; - let mut client = clone.write().unwrap(); - client.refresh(); - client.refresh_ts = ts; + { + let mut client = clone.write().unwrap(); + client.refresh(); + client.refresh_ts = ts; + client.cache = client.get_stats(); + } + sender + .send(clone.read().unwrap().cache.clone()) + .expect("failed to broadcast new stats to reciever"); } }); let clone = client.clone(); @@ -66,6 +98,61 @@ impl StatsClient { Extension(client) } + fn ws_subscribe( + &self, + ws: WebSocketUpgrade, + query: Arc, + ) -> impl IntoResponse { + // println!("client subscribe"); + let mut reciever = self.get_receiver(); + ws.on_upgrade(|socket| async move { + let (mut ws_sender, mut ws_reciever) = socket.split(); + let cancel = CancellationToken::new(); + let cancel_clone = cancel.clone(); + tokio::spawn(async move { + loop { + let mut stats = select! { + _ = cancel_clone.cancelled() => break, + stats = reciever.recv() => { stats.expect("failed to recv stats msg") } + }; + if !query.components { + stats.components = vec![] + } + if !query.networks { + stats.networks = vec![] + } + if !query.processes { + stats.processes = vec![] + } + let _ = ws_sender + .send(Message::Text(serde_json::to_string(&stats).unwrap())) + .await; + } + }); + while let Some(msg) = ws_reciever.next().await { + match msg { + Ok(msg) => match msg { + Message::Close(_) => { + // println!("client CLOSE"); + cancel.cancel(); + return; + } + _ => {} + }, + Err(_) => { + // println!("client CLOSE"); + cancel.cancel(); + return; + } + } + } + }) + } + + fn get_receiver(&self) -> Receiver { + self.receiver.resubscribe() + } + fn refresh(&mut self) { self.sys.refresh_cpu(); self.sys.refresh_memory(); @@ -81,27 +168,29 @@ impl StatsClient { self.sys.refresh_components_list(); } - pub fn get_stats(&self, query: SystemStatsQuery) -> SystemStats { + fn get_cached_stats(&self, query: SystemStatsQuery) -> SystemStats { + let mut stats = self.cache.clone(); + if !query.networks { + stats.networks = Vec::new(); + } + if !query.components { + stats.components = Vec::new(); + } + if !query.processes { + stats.processes = Vec::new(); + } + stats + } + + fn get_stats(&self) -> SystemStats { SystemStats { cpu_perc: self.sys.global_cpu_info().cpu_usage(), mem_used_gb: self.sys.used_memory() as f64 / BYTES_PER_GB, mem_total_gb: self.sys.total_memory() as f64 / BYTES_PER_GB, disk: self.get_disk_usage(), - networks: if query.networks { - self.get_networks() - } else { - vec![] - }, - components: if query.components { - self.get_components() - } else { - vec![] - }, - processes: if query.processes { - self.get_processes() - } else { - vec![] - }, + networks: self.get_networks(), + components: self.get_components(), + processes: self.get_processes(), polling_rate: self.polling_rate, refresh_ts: self.refresh_ts, refresh_list_ts: self.refresh_list_ts,