diff --git a/bin/core/src/api/auth.rs b/bin/core/src/api/auth.rs index cbf535ad3..fb053d5ab 100644 --- a/bin/core/src/api/auth.rs +++ b/bin/core/src/api/auth.rs @@ -1,6 +1,15 @@ -use std::{sync::OnceLock, time::Instant}; +use std::{ + net::{IpAddr, SocketAddr}, + sync::OnceLock, + time::Instant, +}; -use axum::{Router, extract::Path, http::HeaderMap, routing::post}; +use axum::{ + Router, + extract::{ConnectInfo, Path}, + http::HeaderMap, + routing::post, +}; use derive_variants::{EnumVariants, ExtractVariant}; use komodo_client::{api::auth::*, entities::user::User}; use rate_limit::WithFailureRateLimit; @@ -27,9 +36,11 @@ use crate::{ use super::Variant; -#[derive(Default)] pub struct AuthArgs { pub headers: HeaderMap, + /// Prefer extracting IP from headers. + /// This IP will be the IP of reverse proxy itself. + pub ip: IpAddr, } #[typeshare] @@ -79,6 +90,7 @@ pub fn router() -> Router { async fn variant_handler( headers: HeaderMap, + info: ConnectInfo, Path(Variant { variant }): Path, Json(params): Json, ) -> serror::Result { @@ -86,11 +98,12 @@ async fn variant_handler( "type": variant, "params": params, }))?; - handler(headers, Json(req)).await + handler(headers, info, Json(req)).await } async fn handler( headers: HeaderMap, + ConnectInfo(info): ConnectInfo, Json(request): Json, ) -> serror::Result { let timer = Instant::now(); @@ -99,7 +112,12 @@ async fn handler( "/auth request {req_id} | METHOD: {:?}", request.extract_variant() ); - let res = request.resolve(&AuthArgs { headers }).await; + let res = request + .resolve(&AuthArgs { + headers, + ip: info.ip(), + }) + .await; if let Err(e) = &res { debug!("/auth request {req_id} | error: {:#}", e.error); } @@ -136,13 +154,14 @@ impl Resolve for GetLoginOptions { impl Resolve for ExchangeForJwt { async fn resolve( self, - AuthArgs { headers }: &AuthArgs, + AuthArgs { headers, ip }: &AuthArgs, ) -> serror::Result { jwt_client() .redeem_exchange_token(&self.token) .with_failure_rate_limit_using_headers( auth_rate_limiter(), headers, + Some(*ip), ) .await } @@ -151,7 +170,7 @@ impl Resolve for ExchangeForJwt { impl Resolve for GetUser { async fn resolve( self, - AuthArgs { headers }: &AuthArgs, + AuthArgs { headers, ip }: &AuthArgs, ) -> serror::Result { async { let user_id = get_user_id_from_headers(headers) @@ -164,6 +183,7 @@ impl Resolve for GetUser { .with_failure_rate_limit_using_headers( auth_rate_limiter(), headers, + Some(*ip), ) .await } diff --git a/bin/core/src/api/listener/router.rs b/bin/core/src/api/listener/router.rs index 33c3a4b7e..c64d74214 100644 --- a/bin/core/src/api/listener/router.rs +++ b/bin/core/src/api/listener/router.rs @@ -1,4 +1,11 @@ -use axum::{Router, extract::Path, http::HeaderMap, routing::post}; +use std::net::{IpAddr, SocketAddr}; + +use axum::{ + Router, + extract::{ConnectInfo, Path}, + http::HeaderMap, + routing::post, +}; use komodo_client::entities::{ action::Action, build::Build, procedure::Procedure, repo::Repo, resource::Resource, stack::Stack, sync::ResourceSync, @@ -48,9 +55,9 @@ pub fn router() -> Router { .route( "/build/{id}", post( - |Path(Id { id }), headers: HeaderMap, body: String| async move { + |Path(Id { id }), headers: HeaderMap, ConnectInfo(info): ConnectInfo, body: String| async move { let build = - auth_webhook::(&id, &headers, &body).await?; + auth_webhook::(&id, &headers, info.ip(), &body).await?; tokio::spawn(async move { let span = info_span!("BuildWebhook", id); async { @@ -74,9 +81,9 @@ pub fn router() -> Router { .route( "/repo/{id}/{option}", post( - |Path(IdAndOption:: { id, option }), headers: HeaderMap, body: String| async move { + |Path(IdAndOption:: { id, option }), headers: HeaderMap, ConnectInfo(info): ConnectInfo, body: String| async move { let repo = - auth_webhook::(&id, &headers, &body).await?; + auth_webhook::(&id, &headers, info.ip(), &body).await?; tokio::spawn(async move { let span = info_span!("RepoWebhook", id); async { @@ -100,9 +107,9 @@ pub fn router() -> Router { .route( "/stack/{id}/{option}", post( - |Path(IdAndOption:: { id, option }), headers: HeaderMap, body: String| async move { + |Path(IdAndOption:: { id, option }), headers: HeaderMap, ConnectInfo(info): ConnectInfo, body: String| async move { let stack = - auth_webhook::(&id, &headers, &body).await?; + auth_webhook::(&id, &headers, info.ip(), &body).await?; tokio::spawn(async move { let span = info_span!("StackWebhook", id); async { @@ -126,9 +133,9 @@ pub fn router() -> Router { .route( "/sync/{id}/{option}", post( - |Path(IdAndOption:: { id, option }), headers: HeaderMap, body: String| async move { + |Path(IdAndOption:: { id, option }), headers: HeaderMap, ConnectInfo(info): ConnectInfo, body: String| async move { let sync = - auth_webhook::(&id, &headers, &body).await?; + auth_webhook::(&id, &headers, info.ip(), &body).await?; tokio::spawn(async move { let span = info_span!("ResourceSyncWebhook", id); async { @@ -152,9 +159,9 @@ pub fn router() -> Router { .route( "/procedure/{id}/{branch}", post( - |Path(IdAndBranch { id, branch }), headers: HeaderMap, body: String| async move { + |Path(IdAndBranch { id, branch }), headers: HeaderMap, ConnectInfo(info): ConnectInfo, body: String| async move { let procedure = - auth_webhook::(&id, &headers, &body).await?; + auth_webhook::(&id, &headers, info.ip(), &body).await?; tokio::spawn(async move { let span = info_span!("ProcedureWebhook", id); async { @@ -178,9 +185,9 @@ pub fn router() -> Router { .route( "/action/{id}/{branch}", post( - |Path(IdAndBranch { id, branch }), headers: HeaderMap, body: String| async move { + |Path(IdAndBranch { id, branch }), headers: HeaderMap, ConnectInfo(info): ConnectInfo, body: String| async move { let action = - auth_webhook::(&id, &headers, &body).await?; + auth_webhook::(&id, &headers, info.ip(), &body).await?; tokio::spawn(async move { let span = info_span!("ActionWebhook", id); async { @@ -206,6 +213,7 @@ pub fn router() -> Router { async fn auth_webhook( id: &str, headers: &HeaderMap, + ip: IpAddr, body: &str, ) -> serror::Result> where @@ -220,6 +228,10 @@ where .status_code(StatusCode::UNAUTHORIZED)?; serror::Result::Ok(resource) } - .with_failure_rate_limit_using_headers(auth_rate_limiter(), headers) + .with_failure_rate_limit_using_headers( + auth_rate_limiter(), + headers, + Some(ip), + ) .await } diff --git a/bin/core/src/api/ws/mod.rs b/bin/core/src/api/ws/mod.rs index b30b9e2f0..0a27e04fc 100644 --- a/bin/core/src/api/ws/mod.rs +++ b/bin/core/src/api/ws/mod.rs @@ -1,3 +1,5 @@ +use std::net::IpAddr; + use crate::{ auth::{auth_api_key_check_enabled, auth_jwt_check_enabled}, helpers::query::get_user, @@ -30,6 +32,7 @@ pub fn router() -> Router { async fn user_ws_login( mut socket: WebSocket, headers: &HeaderMap, + fallback_ip: IpAddr, ) -> Option<(WebSocket, User)> { let res = async { let message = match socket @@ -66,7 +69,11 @@ async fn user_ws_login( } } } - .with_failure_rate_limit_using_headers(auth_rate_limiter(), headers) + .with_failure_rate_limit_using_headers( + auth_rate_limiter(), + headers, + Some(fallback_ip), + ) .await; match res { Ok(user) => { diff --git a/bin/core/src/api/ws/terminal.rs b/bin/core/src/api/ws/terminal.rs index 8b7984d02..f1cb7138a 100644 --- a/bin/core/src/api/ws/terminal.rs +++ b/bin/core/src/api/ws/terminal.rs @@ -1,6 +1,8 @@ +use std::net::SocketAddr; + use anyhow::anyhow; use axum::{ - extract::{FromRequestParts, WebSocketUpgrade, ws}, + extract::{ConnectInfo, FromRequestParts, WebSocketUpgrade, ws}, http::{HeaderMap, request}, response::IntoResponse, }; @@ -22,12 +24,14 @@ use crate::{ #[instrument("ConnectTerminal", skip(ws))] pub async fn handler( Qs(query): Qs, + ConnectInfo(info): ConnectInfo, headers: HeaderMap, ws: WebSocketUpgrade, ) -> impl IntoResponse { - ws.on_upgrade(|socket| async move { + let ip = info.ip(); + ws.on_upgrade(move |socket| async move { let Some((mut client_socket, user)) = - super::user_ws_login(socket, &headers).await + super::user_ws_login(socket, &headers, ip).await else { return; }; diff --git a/bin/core/src/api/ws/update.rs b/bin/core/src/api/ws/update.rs index 053ae98d5..c482ebd04 100644 --- a/bin/core/src/api/ws/update.rs +++ b/bin/core/src/api/ws/update.rs @@ -1,6 +1,8 @@ +use std::net::SocketAddr; + use anyhow::anyhow; use axum::{ - extract::{WebSocketUpgrade, ws::Message}, + extract::{ConnectInfo, WebSocketUpgrade, ws::Message}, http::HeaderMap, response::IntoResponse, }; @@ -19,15 +21,17 @@ use crate::helpers::{ pub async fn handler( headers: HeaderMap, + ConnectInfo(info): ConnectInfo, ws: WebSocketUpgrade, ) -> impl IntoResponse { // get a reveiver for internal update messages. let mut receiver = update_channel().receiver.resubscribe(); + let ip = info.ip(); // handle http -> ws updgrade - ws.on_upgrade(|socket| async move { + ws.on_upgrade(move |socket| async move { let Some((client_socket, user)) = - super::user_ws_login(socket, &headers).await + super::user_ws_login(socket, &headers, ip).await else { return; }; diff --git a/bin/core/src/auth/github/mod.rs b/bin/core/src/auth/github/mod.rs index f59e060b4..d252749d5 100644 --- a/bin/core/src/auth/github/mod.rs +++ b/bin/core/src/auth/github/mod.rs @@ -1,6 +1,11 @@ +use std::net::SocketAddr; + use anyhow::{Context, anyhow}; use axum::{ - Router, extract::Query, http::HeaderMap, response::Redirect, + Router, + extract::{ConnectInfo, Query}, + http::HeaderMap, + response::Redirect, routing::get, }; use database::mongo_indexed::Document; @@ -42,15 +47,20 @@ pub fn router() -> Router { ) .route( "/callback", - get(|query, headers: HeaderMap| async move { - callback(query) - .map_err(|e| e.status_code(StatusCode::UNAUTHORIZED)) - .with_failure_rate_limit_using_headers( - auth_rate_limiter(), - &headers, - ) - .await - }), + get( + |query, + headers: HeaderMap, + ConnectInfo(info): ConnectInfo| async move { + callback(query) + .map_err(|e| e.status_code(StatusCode::UNAUTHORIZED)) + .with_failure_rate_limit_using_headers( + auth_rate_limiter(), + &headers, + Some(info.ip()), + ) + .await + }, + ), ) } diff --git a/bin/core/src/auth/google/mod.rs b/bin/core/src/auth/google/mod.rs index a50f1b80a..3b0bc054e 100644 --- a/bin/core/src/auth/google/mod.rs +++ b/bin/core/src/auth/google/mod.rs @@ -1,7 +1,12 @@ +use std::net::SocketAddr; + use anyhow::{Context, anyhow}; use async_timing_util::unix_timestamp_ms; use axum::{ - Router, extract::Query, http::HeaderMap, response::Redirect, + Router, + extract::{ConnectInfo, Query}, + http::HeaderMap, + response::Redirect, routing::get, }; use database::mongo_indexed::Document; @@ -43,15 +48,20 @@ pub fn router() -> Router { ) .route( "/callback", - get(|query, headers: HeaderMap| async move { - callback(query) - .map_err(|e| e.status_code(StatusCode::UNAUTHORIZED)) - .with_failure_rate_limit_using_headers( - auth_rate_limiter(), - &headers, - ) - .await - }), + get( + |query, + headers: HeaderMap, + ConnectInfo(info): ConnectInfo| async move { + callback(query) + .map_err(|e| e.status_code(StatusCode::UNAUTHORIZED)) + .with_failure_rate_limit_using_headers( + auth_rate_limiter(), + &headers, + Some(info.ip()), + ) + .await + }, + ), ) } diff --git a/bin/core/src/auth/local.rs b/bin/core/src/auth/local.rs index 32b3e612c..d9faf5344 100644 --- a/bin/core/src/auth/local.rs +++ b/bin/core/src/auth/local.rs @@ -29,12 +29,13 @@ impl Resolve for SignUpLocalUser { #[instrument("SignUpLocalUser", skip(self))] async fn resolve( self, - AuthArgs { headers }: &AuthArgs, + AuthArgs { headers, ip }: &AuthArgs, ) -> serror::Result { sign_up_local_user(self) .with_failure_rate_limit_using_headers( auth_rate_limiter(), headers, + Some(*ip), ) .await } @@ -139,12 +140,13 @@ fn login_local_user_rate_limiter() -> &'static RateLimiter { impl Resolve for LoginLocalUser { async fn resolve( self, - AuthArgs { headers }: &AuthArgs, + AuthArgs { headers, ip }: &AuthArgs, ) -> serror::Result { login_local_user(self) .with_failure_rate_limit_using_headers( login_local_user_rate_limiter(), headers, + Some(*ip), ) .await } diff --git a/bin/core/src/auth/mod.rs b/bin/core/src/auth/mod.rs index 79919e5f1..a150316eb 100644 --- a/bin/core/src/auth/mod.rs +++ b/bin/core/src/auth/mod.rs @@ -1,7 +1,11 @@ +use std::net::SocketAddr; + use anyhow::{Context, anyhow}; use async_timing_util::unix_timestamp_ms; use axum::{ - extract::Request, http::HeaderMap, middleware::Next, + extract::{ConnectInfo, Request}, + http::HeaderMap, + middleware::Next, response::Response, }; use database::mungos::mongodb::bson::doc; @@ -45,11 +49,16 @@ pub async fn auth_request( mut req: Request, next: Next, ) -> serror::Result { + let fallback = req + .extensions() + .get::>() + .map(|addr| addr.ip()); let user = authenticate_check_enabled(&headers) .map_err(|e| e.status_code(StatusCode::UNAUTHORIZED)) .with_failure_rate_limit_using_headers( auth_rate_limiter(), &headers, + fallback, ) .await?; req.extensions_mut().insert(user); diff --git a/bin/core/src/auth/oidc/mod.rs b/bin/core/src/auth/oidc/mod.rs index b518b92b2..6f95f8030 100644 --- a/bin/core/src/auth/oidc/mod.rs +++ b/bin/core/src/auth/oidc/mod.rs @@ -1,8 +1,11 @@ -use std::sync::OnceLock; +use std::{net::SocketAddr, sync::OnceLock}; use anyhow::{Context, anyhow}; use axum::{ - Router, extract::Query, http::HeaderMap, response::Redirect, + Router, + extract::{ConnectInfo, Query}, + http::HeaderMap, + response::Redirect, routing::get, }; use client::oidc_client; @@ -71,15 +74,20 @@ pub fn router() -> Router { ) .route( "/callback", - get(|query, headers: HeaderMap| async move { - callback(query) - .map_err(|e| e.status_code(StatusCode::UNAUTHORIZED)) - .with_failure_rate_limit_using_headers( - auth_rate_limiter(), - &headers, - ) - .await - }), + get( + |query, + headers: HeaderMap, + ConnectInfo(info): ConnectInfo| async move { + callback(query) + .map_err(|e| e.status_code(StatusCode::UNAUTHORIZED)) + .with_failure_rate_limit_using_headers( + auth_rate_limiter(), + &headers, + Some(info.ip()), + ) + .await + }, + ), ) } diff --git a/bin/core/src/main.rs b/bin/core/src/main.rs index c56508c7b..e7adf9ea0 100644 --- a/bin/core/src/main.rs +++ b/bin/core/src/main.rs @@ -82,7 +82,8 @@ async fn app() -> anyhow::Result<()> { .instrument(startup_span) .await; - let app = api::app().into_make_service(); + let app = + api::app().into_make_service_with_connect_info::(); let addr = format!("{}:{}", core_config().bind_ip, core_config().port); diff --git a/bin/core/src/startup.rs b/bin/core/src/startup.rs index d562d07ce..e21837e8d 100644 --- a/bin/core/src/startup.rs +++ b/bin/core/src/startup.rs @@ -1,4 +1,7 @@ -use std::str::FromStr; +use std::{ + net::{IpAddr, Ipv4Addr}, + str::FromStr, +}; use anyhow::Context; use colored::Colorize; @@ -304,7 +307,10 @@ async fn ensure_init_user_and_resources() { username: username.clone(), password: config.init_admin_password.clone(), }) - .resolve(&AuthArgs::default()) + .resolve(&AuthArgs { + headers: Default::default(), + ip: IpAddr::V4(Ipv4Addr::UNSPECIFIED), + }) .await { error!("Failed to create init admin user | {:#}", e.error); diff --git a/lib/rate_limit/src/lib.rs b/lib/rate_limit/src/lib.rs index 603321ab2..e96babcba 100644 --- a/lib/rate_limit/src/lib.rs +++ b/lib/rate_limit/src/lib.rs @@ -101,13 +101,14 @@ where self, limiter: &RateLimiter, headers: &HeaderMap, + fallback: Option, ) -> impl Future> { - async { + async move { // Can skip header ip extraction if disabled if limiter.disabled { return self.await; } - let ip = get_ip_from_headers(headers)?; + let ip = get_ip_from_headers(headers, fallback)?; self.with_failure_rate_limit_using_ip(limiter, &ip).await } } @@ -186,6 +187,7 @@ fn spawn_cleanup_task(limiter: Arc) { pub fn get_ip_from_headers( headers: &HeaderMap, + fallback: Option, ) -> serror::Result { // Check X-Forwarded-For header (first IP in chain) if let Some(forwarded) = headers.get("x-forwarded-for") @@ -202,8 +204,12 @@ pub fn get_ip_from_headers( return ip.trim().parse().status_code(StatusCode::UNAUTHORIZED); } + if let Some(fallback) = fallback { + return Ok(fallback); + } + Err( - anyhow!("'x-forwarded-for' and 'x-real-ip' are both missing") + anyhow!("'x-forwarded-for' and 'x-real-ip' headers are both missing, and no fallback ip could be extracted from the request.") .status_code(StatusCode::UNAUTHORIZED), ) }