implement the periphery request guard

This commit is contained in:
mbecker20
2023-06-08 06:38:36 +00:00
parent 44a16dd214
commit 20d496e617
16 changed files with 269 additions and 72 deletions

18
Cargo.lock generated
View File

@@ -1201,6 +1201,17 @@ version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94"
[[package]]
name = "periphery_client"
version = "1.0.0"
dependencies = [
"anyhow",
"log",
"monitor_types",
"reqwest",
"serde_json",
]
[[package]]
name = "pin-project"
version = "1.1.0"
@@ -1653,6 +1664,13 @@ dependencies = [
[[package]]
name = "tests"
version = "0.1.0"
dependencies = [
"anyhow",
"monitor_client",
"monitor_types",
"periphery_client",
"tokio",
]
[[package]]
name = "thiserror"

View File

@@ -13,6 +13,7 @@ license = "GPL-3.0-or-later"
monitor_helpers = { path = "lib/helpers" }
monitor_types = { path = "lib/types" }
monitor_client = { path = "lib/rs_client" }
periphery_client = { path = "lib/periphery_client" }
# external
tokio = { version = "1.28", features = ["full"] }
axum = { version = "0.6", features = ["ws", "json", "headers"] }

View File

@@ -1,4 +1,4 @@
use std::{os::linux::raw::stat, sync::Arc};
use std::sync::Arc;
pub struct AppState {}

View File

@@ -1,18 +0,0 @@
use serde::{de::DeserializeOwned, Serialize};
pub trait HasResponse: Serialize + std::fmt::Debug {
type Response: DeserializeOwned + std::fmt::Debug;
fn req_type() -> &'static str;
}
#[macro_export]
macro_rules! impl_has_response {
($req:ty, $res:ty) => {
impl $crate::HasResponse for $req {
type Response = $res;
fn req_type() -> &'static str {
stringify!($t)
}
}
};
}

View File

@@ -0,0 +1,14 @@
[package]
name = "periphery_client"
version.workspace = true
edition.workspace = true
license.workspace = true
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
monitor_types.workspace = true
reqwest.workspace = true
anyhow.workspace = true
serde_json.workspace = true
log.workspace = true

View File

@@ -0,0 +1,56 @@
#[macro_use]
extern crate log;
use anyhow::{anyhow, Context};
use monitor_types::{HasResponse, periphery_api::requests};
use reqwest::StatusCode;
use serde_json::json;
pub struct PeripheryClient {
reqwest: reqwest::Client,
address: String,
passkey: String,
}
impl PeripheryClient {
pub fn new(address: impl Into<String>, passkey: impl Into<String>) -> PeripheryClient {
PeripheryClient {
reqwest: Default::default(),
address: address.into(),
passkey: passkey.into(),
}
}
pub async fn request<T: HasResponse>(&self, request: T) -> anyhow::Result<T::Response> {
let req_type = T::req_type();
trace!("sending request | type: {req_type} | body: {request:?}");
let res = self
.reqwest
.post(&self.address)
.json(&json!({
"type": req_type,
"params": request
}))
.header("authorization", &self.passkey)
.send()
.await?;
let status = res.status();
debug!("got response | type: {req_type} | {status} | body: {res:?}",);
if status == StatusCode::OK {
res.json().await.context(format!(
"failed to parse response to json | type: {req_type} | body: {request:?}"
))
} else {
let text = res
.text()
.await
.context("failed to convert response to text")?;
Err(anyhow!("request failed | {status} | {text}"))
}
}
pub async fn health_check(&self) -> anyhow::Result<()> {
self.request(requests::GetHealth {}).await?;
Ok(())
}
}

View File

@@ -141,7 +141,7 @@ macro_rules! impl_has_response {
impl $crate::HasResponse for $req {
type Response = $res;
fn req_type() -> &'static str {
stringify!($t)
stringify!($req)
}
}
};

View File

@@ -1,46 +1,43 @@
use serde::{Serialize, Deserialize};
use serde::{Deserialize, Serialize};
use crate::SystemCommand;
use self::requests::{GetVersion, GetHealth};
use self::requests::{GetHealth, GetVersion};
pub mod requests;
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "type", content = "params")]
pub enum PeripheryRequest {
// GET
GetHealth(GetHealth),
GetVersion(GetVersion),
GetSystemInformation {},
GetSystemStats {},
GetAccounts {},
GetSecrets {},
GetContainerList {},
GetContainerLog {},
GetContainerStats {},
GetContainerStatsList {},
GetNetworkList {},
// GET
GetHealth(GetHealth),
GetVersion(GetVersion),
GetSystemInformation {},
GetSystemStats {},
GetAccounts {},
GetSecrets {},
GetContainerList {},
GetContainerLog {},
GetContainerStats {},
GetContainerStatsList {},
GetNetworkList {},
// ACTIONS
RunCommand(SystemCommand),
CloneRepo {},
PullRepo {},
DeleteRepo {},
Build {},
Deploy {},
StartContainer {},
StopContainer {},
RemoveContainer {},
RenameContainer {},
PruneContainers {},
// ACTIONS
RunCommand(SystemCommand),
CloneRepo {},
PullRepo {},
DeleteRepo {},
Build {},
Deploy {},
StartContainer {},
StopContainer {},
RemoveContainer {},
RenameContainer {},
PruneContainers {},
}
impl Default for PeripheryRequest {
fn default() -> PeripheryRequest {
PeripheryRequest::GetHealth(GetHealth {})
}
}
fn default() -> PeripheryRequest {
PeripheryRequest::GetHealth(GetHealth {})
}
}

View File

@@ -28,3 +28,15 @@ pub struct GetVersionResponse {
}
impl_has_response!(GetVersion, GetVersionResponse);
#[typeshare]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct GetSystemInformation {}
#[typeshare]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct GetSystemInformationResponse {
pub version: String,
}
impl_has_response!(GetSystemInformation, GetSystemInformationResponse);

View File

@@ -1,13 +1,14 @@
use anyhow::{anyhow, Context};
use monitor_types::periphery_api::{requests::GetVersionResponse, PeripheryRequest};
use crate::state::AppState;
use crate::state::State;
impl AppState {
impl State {
pub async fn handle_request(&self, request: PeripheryRequest) -> anyhow::Result<String> {
match request {
PeripheryRequest::GetHealth(_) => Ok(String::from("{}")),
PeripheryRequest::GetVersion(_) => AppState::get_version(),
PeripheryRequest::GetVersion(_) => State::get_version(),
PeripheryRequest::GetSystemInformation {} => todo!(),
_ => Err(anyhow!("not implemented")),
}
}

View File

@@ -37,6 +37,7 @@ pub struct Env {
config_paths: String,
#[serde(default)]
config_keywords: String,
port: Option<u16>,
}
impl Env {
@@ -87,13 +88,16 @@ impl PeripheryConfig {
.unwrap_or(&env_match_keywords)
.iter()
.map(|kw| kw.as_str());
let config = parse_config_paths::<PeripheryConfig>(
let mut config = parse_config_paths::<PeripheryConfig>(
config_paths,
match_keywords,
args.merge_nested_config,
args.extend_config_arrays,
)
.expect("failed at parsing config from paths");
if let Some(port) = env.port {
config.port = port;
}
Ok(config)
}
}

95
periphery/src/guard.rs Normal file
View File

@@ -0,0 +1,95 @@
use std::{net::SocketAddr, sync::Arc};
use axum::{
body::Body,
extract::ConnectInfo,
http::{Request, StatusCode},
middleware::Next,
response::Response,
Json, RequestExt,
};
use serde_json::Value;
use crate::state::State;
pub async fn guard_request_by_passkey(
req: Request<Body>,
next: Next<Body>,
) -> Result<Response, (StatusCode, String)> {
let state = req.extensions().get::<Arc<State>>().ok_or((
StatusCode::INTERNAL_SERVER_ERROR,
"could not get state extension".to_string(),
))?;
if state.config.passkeys.is_empty() {
return Ok(next.run(req).await);
}
let req_passkey = req.headers().get("authorization");
if req_passkey.is_none() {
return Err((
StatusCode::UNAUTHORIZED,
String::from("request was not sent with passkey"),
));
}
let req_passkey = req_passkey
.unwrap()
.to_str()
.map_err(|e| {
(
StatusCode::UNAUTHORIZED,
format!("failed to get passkey from authorization header as str: {e:?}"),
)
})?
.to_string();
if state.config.passkeys.contains(&req_passkey) {
Ok(next.run(req).await)
} else {
let ConnectInfo(socket_addr) =
req.extensions().get::<ConnectInfo<SocketAddr>>().ok_or((
StatusCode::UNAUTHORIZED,
"could not get socket addr of request".to_string(),
))?;
let ip = socket_addr.ip();
let body = req
.extract::<Json<Value>, _>()
.await
.ok()
.map(|Json(body)| body);
warn!("unauthorized request from {ip} (bad passkey) | body: {body:?}");
Err((
StatusCode::UNAUTHORIZED,
String::from("request passkey invalid"),
))
}
}
pub async fn guard_request_by_ip(
req: Request<Body>,
next: Next<Body>,
) -> Result<Response, (StatusCode, String)> {
let state = req.extensions().get::<Arc<State>>().ok_or((
StatusCode::INTERNAL_SERVER_ERROR,
"could not get state extension".to_string(),
))?;
if state.config.allowed_ips.is_empty() {
return Ok(next.run(req).await);
}
let ConnectInfo(socket_addr) = req.extensions().get::<ConnectInfo<SocketAddr>>().ok_or((
StatusCode::UNAUTHORIZED,
"could not get socket addr of request".to_string(),
))?;
let ip = socket_addr.ip();
if state.config.allowed_ips.contains(&ip) {
Ok(next.run(req).await)
} else {
let body = req
.extract::<Json<Value>, _>()
.await
.ok()
.map(|Json(body)| body);
warn!("unauthorized request from {ip} (bad passkey) | body: {body:?}");
Err((
StatusCode::UNAUTHORIZED,
format!("requesting ip {ip} not allowed"),
))
}
}

View File

@@ -1,31 +1,32 @@
#[macro_use]
extern crate log;
use std::sync::Arc;
use std::{net::SocketAddr, sync::Arc};
use axum::{
extract::State, headers::ContentType, http::StatusCode, routing::post, Json, Router,
headers::ContentType, http::StatusCode, middleware, routing::post, Extension, Json, Router,
TypedHeader,
};
use monitor_types::periphery_api::PeripheryRequest;
use state::AppState;
use state::State;
use termination_signal::tokio::immediate_term_handle;
use uuid::Uuid;
mod api;
mod config;
mod guard;
mod state;
async fn app() -> anyhow::Result<()> {
let state = AppState::load().await?;
let state = State::load().await?;
let socket_addr = state.socket_addr()?;
let app = Router::new()
.route(
"/api",
"/",
post(
|state: State<Arc<AppState>>, Json(request): Json<PeripheryRequest>| async move {
|state: Extension<Arc<State>>, Json(request): Json<PeripheryRequest>| async move {
let req_id = Uuid::new_v4();
info!("request {req_id}: {:?}", request);
let res = state
@@ -41,12 +42,14 @@ async fn app() -> anyhow::Result<()> {
},
),
)
.with_state(state);
.layer(middleware::from_fn(guard::guard_request_by_ip))
.layer(middleware::from_fn(guard::guard_request_by_passkey))
.layer(Extension(state));
info!("starting server on {}", socket_addr);
axum::Server::bind(&socket_addr)
.serve(app.into_make_service())
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
.await?;
Ok(())

View File

@@ -6,12 +6,12 @@ use simple_logger::SimpleLogger;
use crate::config::{CliArgs, Env, PeripheryConfig};
pub struct AppState {
pub struct State {
pub config: PeripheryConfig,
}
impl AppState {
pub async fn load() -> anyhow::Result<Arc<AppState>> {
impl State {
pub async fn load() -> anyhow::Result<Arc<State>> {
let env = Env::load()?;
let args = CliArgs::parse();
SimpleLogger::new()
@@ -23,7 +23,7 @@ impl AppState {
.context("failed to configure logger")?;
info!("version: {}", env!("CARGO_PKG_VERSION"));
let config = PeripheryConfig::load(&env, &args)?;
let state = AppState { config };
let state = State { config };
Ok(state.into())
}

View File

@@ -6,3 +6,8 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
monitor_types.workspace = true
monitor_client.workspace = true
periphery_client.workspace = true
tokio.workspace = true
anyhow.workspace = true

View File

@@ -1,3 +1,12 @@
fn main() {
println!("Hello, world!");
use monitor_types::periphery_api::requests;
use periphery_client::PeripheryClient;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let periphery = PeripheryClient::new("http://localhost:9001", "monitor_passkey");
let version = periphery.request(requests::GetVersion {}).await?;
println!("{version:?}");
Ok(())
}