KL-4 must fallback to axum extracted IP for cases not using reverse proxy

This commit is contained in:
mbecker20
2025-12-04 19:31:53 -08:00
parent 6fa5acd1e3
commit 1621043a21
13 changed files with 167 additions and 68 deletions

View File

@@ -1,6 +1,15 @@
use std::{sync::OnceLock, time::Instant};
use std::{
net::{IpAddr, SocketAddr},
sync::OnceLock,
time::Instant,
};
use axum::{Router, extract::Path, http::HeaderMap, routing::post};
use axum::{
Router,
extract::{ConnectInfo, Path},
http::HeaderMap,
routing::post,
};
use derive_variants::{EnumVariants, ExtractVariant};
use komodo_client::{api::auth::*, entities::user::User};
use rate_limit::WithFailureRateLimit;
@@ -27,9 +36,11 @@ use crate::{
use super::Variant;
#[derive(Default)]
pub struct AuthArgs {
pub headers: HeaderMap,
/// Prefer extracting IP from headers.
/// This IP will be the IP of reverse proxy itself.
pub ip: IpAddr,
}
#[typeshare]
@@ -79,6 +90,7 @@ pub fn router() -> Router {
async fn variant_handler(
headers: HeaderMap,
info: ConnectInfo<SocketAddr>,
Path(Variant { variant }): Path<Variant>,
Json(params): Json<serde_json::Value>,
) -> serror::Result<axum::response::Response> {
@@ -86,11 +98,12 @@ async fn variant_handler(
"type": variant,
"params": params,
}))?;
handler(headers, Json(req)).await
handler(headers, info, Json(req)).await
}
async fn handler(
headers: HeaderMap,
ConnectInfo(info): ConnectInfo<SocketAddr>,
Json(request): Json<AuthRequest>,
) -> serror::Result<axum::response::Response> {
let timer = Instant::now();
@@ -99,7 +112,12 @@ async fn handler(
"/auth request {req_id} | METHOD: {:?}",
request.extract_variant()
);
let res = request.resolve(&AuthArgs { headers }).await;
let res = request
.resolve(&AuthArgs {
headers,
ip: info.ip(),
})
.await;
if let Err(e) = &res {
debug!("/auth request {req_id} | error: {:#}", e.error);
}
@@ -136,13 +154,14 @@ impl Resolve<AuthArgs> for GetLoginOptions {
impl Resolve<AuthArgs> for ExchangeForJwt {
async fn resolve(
self,
AuthArgs { headers }: &AuthArgs,
AuthArgs { headers, ip }: &AuthArgs,
) -> serror::Result<ExchangeForJwtResponse> {
jwt_client()
.redeem_exchange_token(&self.token)
.with_failure_rate_limit_using_headers(
auth_rate_limiter(),
headers,
Some(*ip),
)
.await
}
@@ -151,7 +170,7 @@ impl Resolve<AuthArgs> for ExchangeForJwt {
impl Resolve<AuthArgs> for GetUser {
async fn resolve(
self,
AuthArgs { headers }: &AuthArgs,
AuthArgs { headers, ip }: &AuthArgs,
) -> serror::Result<User> {
async {
let user_id = get_user_id_from_headers(headers)
@@ -164,6 +183,7 @@ impl Resolve<AuthArgs> for GetUser {
.with_failure_rate_limit_using_headers(
auth_rate_limiter(),
headers,
Some(*ip),
)
.await
}

View File

@@ -1,4 +1,11 @@
use axum::{Router, extract::Path, http::HeaderMap, routing::post};
use std::net::{IpAddr, SocketAddr};
use axum::{
Router,
extract::{ConnectInfo, Path},
http::HeaderMap,
routing::post,
};
use komodo_client::entities::{
action::Action, build::Build, procedure::Procedure, repo::Repo,
resource::Resource, stack::Stack, sync::ResourceSync,
@@ -48,9 +55,9 @@ pub fn router<P: VerifySecret + ExtractBranch>() -> Router {
.route(
"/build/{id}",
post(
|Path(Id { id }), headers: HeaderMap, body: String| async move {
|Path(Id { id }), headers: HeaderMap, ConnectInfo(info): ConnectInfo<SocketAddr>, body: String| async move {
let build =
auth_webhook::<P, Build>(&id, &headers, &body).await?;
auth_webhook::<P, Build>(&id, &headers, info.ip(), &body).await?;
tokio::spawn(async move {
let span = info_span!("BuildWebhook", id);
async {
@@ -74,9 +81,9 @@ pub fn router<P: VerifySecret + ExtractBranch>() -> Router {
.route(
"/repo/{id}/{option}",
post(
|Path(IdAndOption::<RepoWebhookOption> { id, option }), headers: HeaderMap, body: String| async move {
|Path(IdAndOption::<RepoWebhookOption> { id, option }), headers: HeaderMap, ConnectInfo(info): ConnectInfo<SocketAddr>, body: String| async move {
let repo =
auth_webhook::<P, Repo>(&id, &headers, &body).await?;
auth_webhook::<P, Repo>(&id, &headers, info.ip(), &body).await?;
tokio::spawn(async move {
let span = info_span!("RepoWebhook", id);
async {
@@ -100,9 +107,9 @@ pub fn router<P: VerifySecret + ExtractBranch>() -> Router {
.route(
"/stack/{id}/{option}",
post(
|Path(IdAndOption::<StackWebhookOption> { id, option }), headers: HeaderMap, body: String| async move {
|Path(IdAndOption::<StackWebhookOption> { id, option }), headers: HeaderMap, ConnectInfo(info): ConnectInfo<SocketAddr>, body: String| async move {
let stack =
auth_webhook::<P, Stack>(&id, &headers, &body).await?;
auth_webhook::<P, Stack>(&id, &headers, info.ip(), &body).await?;
tokio::spawn(async move {
let span = info_span!("StackWebhook", id);
async {
@@ -126,9 +133,9 @@ pub fn router<P: VerifySecret + ExtractBranch>() -> Router {
.route(
"/sync/{id}/{option}",
post(
|Path(IdAndOption::<SyncWebhookOption> { id, option }), headers: HeaderMap, body: String| async move {
|Path(IdAndOption::<SyncWebhookOption> { id, option }), headers: HeaderMap, ConnectInfo(info): ConnectInfo<SocketAddr>, body: String| async move {
let sync =
auth_webhook::<P, ResourceSync>(&id, &headers, &body).await?;
auth_webhook::<P, ResourceSync>(&id, &headers, info.ip(), &body).await?;
tokio::spawn(async move {
let span = info_span!("ResourceSyncWebhook", id);
async {
@@ -152,9 +159,9 @@ pub fn router<P: VerifySecret + ExtractBranch>() -> Router {
.route(
"/procedure/{id}/{branch}",
post(
|Path(IdAndBranch { id, branch }), headers: HeaderMap, body: String| async move {
|Path(IdAndBranch { id, branch }), headers: HeaderMap, ConnectInfo(info): ConnectInfo<SocketAddr>, body: String| async move {
let procedure =
auth_webhook::<P, Procedure>(&id, &headers, &body).await?;
auth_webhook::<P, Procedure>(&id, &headers, info.ip(), &body).await?;
tokio::spawn(async move {
let span = info_span!("ProcedureWebhook", id);
async {
@@ -178,9 +185,9 @@ pub fn router<P: VerifySecret + ExtractBranch>() -> Router {
.route(
"/action/{id}/{branch}",
post(
|Path(IdAndBranch { id, branch }), headers: HeaderMap, body: String| async move {
|Path(IdAndBranch { id, branch }), headers: HeaderMap, ConnectInfo(info): ConnectInfo<SocketAddr>, body: String| async move {
let action =
auth_webhook::<P, Action>(&id, &headers, &body).await?;
auth_webhook::<P, Action>(&id, &headers, info.ip(), &body).await?;
tokio::spawn(async move {
let span = info_span!("ActionWebhook", id);
async {
@@ -206,6 +213,7 @@ pub fn router<P: VerifySecret + ExtractBranch>() -> Router {
async fn auth_webhook<P, R>(
id: &str,
headers: &HeaderMap,
ip: IpAddr,
body: &str,
) -> serror::Result<Resource<R::Config, R::Info>>
where
@@ -220,6 +228,10 @@ where
.status_code(StatusCode::UNAUTHORIZED)?;
serror::Result::Ok(resource)
}
.with_failure_rate_limit_using_headers(auth_rate_limiter(), headers)
.with_failure_rate_limit_using_headers(
auth_rate_limiter(),
headers,
Some(ip),
)
.await
}

View File

@@ -1,3 +1,5 @@
use std::net::IpAddr;
use crate::{
auth::{auth_api_key_check_enabled, auth_jwt_check_enabled},
helpers::query::get_user,
@@ -30,6 +32,7 @@ pub fn router() -> Router {
async fn user_ws_login(
mut socket: WebSocket,
headers: &HeaderMap,
fallback_ip: IpAddr,
) -> Option<(WebSocket, User)> {
let res = async {
let message = match socket
@@ -66,7 +69,11 @@ async fn user_ws_login(
}
}
}
.with_failure_rate_limit_using_headers(auth_rate_limiter(), headers)
.with_failure_rate_limit_using_headers(
auth_rate_limiter(),
headers,
Some(fallback_ip),
)
.await;
match res {
Ok(user) => {

View File

@@ -1,6 +1,8 @@
use std::net::SocketAddr;
use anyhow::anyhow;
use axum::{
extract::{FromRequestParts, WebSocketUpgrade, ws},
extract::{ConnectInfo, FromRequestParts, WebSocketUpgrade, ws},
http::{HeaderMap, request},
response::IntoResponse,
};
@@ -22,12 +24,14 @@ use crate::{
#[instrument("ConnectTerminal", skip(ws))]
pub async fn handler(
Qs(query): Qs<ConnectTerminalQuery>,
ConnectInfo(info): ConnectInfo<SocketAddr>,
headers: HeaderMap,
ws: WebSocketUpgrade,
) -> impl IntoResponse {
ws.on_upgrade(|socket| async move {
let ip = info.ip();
ws.on_upgrade(move |socket| async move {
let Some((mut client_socket, user)) =
super::user_ws_login(socket, &headers).await
super::user_ws_login(socket, &headers, ip).await
else {
return;
};

View File

@@ -1,6 +1,8 @@
use std::net::SocketAddr;
use anyhow::anyhow;
use axum::{
extract::{WebSocketUpgrade, ws::Message},
extract::{ConnectInfo, WebSocketUpgrade, ws::Message},
http::HeaderMap,
response::IntoResponse,
};
@@ -19,15 +21,17 @@ use crate::helpers::{
pub async fn handler(
headers: HeaderMap,
ConnectInfo(info): ConnectInfo<SocketAddr>,
ws: WebSocketUpgrade,
) -> impl IntoResponse {
// get a reveiver for internal update messages.
let mut receiver = update_channel().receiver.resubscribe();
let ip = info.ip();
// handle http -> ws updgrade
ws.on_upgrade(|socket| async move {
ws.on_upgrade(move |socket| async move {
let Some((client_socket, user)) =
super::user_ws_login(socket, &headers).await
super::user_ws_login(socket, &headers, ip).await
else {
return;
};

View File

@@ -1,6 +1,11 @@
use std::net::SocketAddr;
use anyhow::{Context, anyhow};
use axum::{
Router, extract::Query, http::HeaderMap, response::Redirect,
Router,
extract::{ConnectInfo, Query},
http::HeaderMap,
response::Redirect,
routing::get,
};
use database::mongo_indexed::Document;
@@ -42,15 +47,20 @@ pub fn router() -> Router {
)
.route(
"/callback",
get(|query, headers: HeaderMap| async move {
callback(query)
.map_err(|e| e.status_code(StatusCode::UNAUTHORIZED))
.with_failure_rate_limit_using_headers(
auth_rate_limiter(),
&headers,
)
.await
}),
get(
|query,
headers: HeaderMap,
ConnectInfo(info): ConnectInfo<SocketAddr>| async move {
callback(query)
.map_err(|e| e.status_code(StatusCode::UNAUTHORIZED))
.with_failure_rate_limit_using_headers(
auth_rate_limiter(),
&headers,
Some(info.ip()),
)
.await
},
),
)
}

View File

@@ -1,7 +1,12 @@
use std::net::SocketAddr;
use anyhow::{Context, anyhow};
use async_timing_util::unix_timestamp_ms;
use axum::{
Router, extract::Query, http::HeaderMap, response::Redirect,
Router,
extract::{ConnectInfo, Query},
http::HeaderMap,
response::Redirect,
routing::get,
};
use database::mongo_indexed::Document;
@@ -43,15 +48,20 @@ pub fn router() -> Router {
)
.route(
"/callback",
get(|query, headers: HeaderMap| async move {
callback(query)
.map_err(|e| e.status_code(StatusCode::UNAUTHORIZED))
.with_failure_rate_limit_using_headers(
auth_rate_limiter(),
&headers,
)
.await
}),
get(
|query,
headers: HeaderMap,
ConnectInfo(info): ConnectInfo<SocketAddr>| async move {
callback(query)
.map_err(|e| e.status_code(StatusCode::UNAUTHORIZED))
.with_failure_rate_limit_using_headers(
auth_rate_limiter(),
&headers,
Some(info.ip()),
)
.await
},
),
)
}

View File

@@ -29,12 +29,13 @@ impl Resolve<AuthArgs> for SignUpLocalUser {
#[instrument("SignUpLocalUser", skip(self))]
async fn resolve(
self,
AuthArgs { headers }: &AuthArgs,
AuthArgs { headers, ip }: &AuthArgs,
) -> serror::Result<SignUpLocalUserResponse> {
sign_up_local_user(self)
.with_failure_rate_limit_using_headers(
auth_rate_limiter(),
headers,
Some(*ip),
)
.await
}
@@ -139,12 +140,13 @@ fn login_local_user_rate_limiter() -> &'static RateLimiter {
impl Resolve<AuthArgs> for LoginLocalUser {
async fn resolve(
self,
AuthArgs { headers }: &AuthArgs,
AuthArgs { headers, ip }: &AuthArgs,
) -> serror::Result<LoginLocalUserResponse> {
login_local_user(self)
.with_failure_rate_limit_using_headers(
login_local_user_rate_limiter(),
headers,
Some(*ip),
)
.await
}

View File

@@ -1,7 +1,11 @@
use std::net::SocketAddr;
use anyhow::{Context, anyhow};
use async_timing_util::unix_timestamp_ms;
use axum::{
extract::Request, http::HeaderMap, middleware::Next,
extract::{ConnectInfo, Request},
http::HeaderMap,
middleware::Next,
response::Response,
};
use database::mungos::mongodb::bson::doc;
@@ -45,11 +49,16 @@ pub async fn auth_request(
mut req: Request,
next: Next,
) -> serror::Result<Response> {
let fallback = req
.extensions()
.get::<ConnectInfo<SocketAddr>>()
.map(|addr| addr.ip());
let user = authenticate_check_enabled(&headers)
.map_err(|e| e.status_code(StatusCode::UNAUTHORIZED))
.with_failure_rate_limit_using_headers(
auth_rate_limiter(),
&headers,
fallback,
)
.await?;
req.extensions_mut().insert(user);

View File

@@ -1,8 +1,11 @@
use std::sync::OnceLock;
use std::{net::SocketAddr, sync::OnceLock};
use anyhow::{Context, anyhow};
use axum::{
Router, extract::Query, http::HeaderMap, response::Redirect,
Router,
extract::{ConnectInfo, Query},
http::HeaderMap,
response::Redirect,
routing::get,
};
use client::oidc_client;
@@ -71,15 +74,20 @@ pub fn router() -> Router {
)
.route(
"/callback",
get(|query, headers: HeaderMap| async move {
callback(query)
.map_err(|e| e.status_code(StatusCode::UNAUTHORIZED))
.with_failure_rate_limit_using_headers(
auth_rate_limiter(),
&headers,
)
.await
}),
get(
|query,
headers: HeaderMap,
ConnectInfo(info): ConnectInfo<SocketAddr>| async move {
callback(query)
.map_err(|e| e.status_code(StatusCode::UNAUTHORIZED))
.with_failure_rate_limit_using_headers(
auth_rate_limiter(),
&headers,
Some(info.ip()),
)
.await
},
),
)
}

View File

@@ -82,7 +82,8 @@ async fn app() -> anyhow::Result<()> {
.instrument(startup_span)
.await;
let app = api::app().into_make_service();
let app =
api::app().into_make_service_with_connect_info::<SocketAddr>();
let addr =
format!("{}:{}", core_config().bind_ip, core_config().port);

View File

@@ -1,4 +1,7 @@
use std::str::FromStr;
use std::{
net::{IpAddr, Ipv4Addr},
str::FromStr,
};
use anyhow::Context;
use colored::Colorize;
@@ -304,7 +307,10 @@ async fn ensure_init_user_and_resources() {
username: username.clone(),
password: config.init_admin_password.clone(),
})
.resolve(&AuthArgs::default())
.resolve(&AuthArgs {
headers: Default::default(),
ip: IpAddr::V4(Ipv4Addr::UNSPECIFIED),
})
.await
{
error!("Failed to create init admin user | {:#}", e.error);

View File

@@ -101,13 +101,14 @@ where
self,
limiter: &RateLimiter,
headers: &HeaderMap,
fallback: Option<IpAddr>,
) -> impl Future<Output = serror::Result<R>> {
async {
async move {
// Can skip header ip extraction if disabled
if limiter.disabled {
return self.await;
}
let ip = get_ip_from_headers(headers)?;
let ip = get_ip_from_headers(headers, fallback)?;
self.with_failure_rate_limit_using_ip(limiter, &ip).await
}
}
@@ -186,6 +187,7 @@ fn spawn_cleanup_task(limiter: Arc<RateLimiter>) {
pub fn get_ip_from_headers(
headers: &HeaderMap,
fallback: Option<IpAddr>,
) -> serror::Result<IpAddr> {
// Check X-Forwarded-For header (first IP in chain)
if let Some(forwarded) = headers.get("x-forwarded-for")
@@ -202,8 +204,12 @@ pub fn get_ip_from_headers(
return ip.trim().parse().status_code(StatusCode::UNAUTHORIZED);
}
if let Some(fallback) = fallback {
return Ok(fallback);
}
Err(
anyhow!("'x-forwarded-for' and 'x-real-ip' are both missing")
anyhow!("'x-forwarded-for' and 'x-real-ip' headers are both missing, and no fallback ip could be extracted from the request.")
.status_code(StatusCode::UNAUTHORIZED),
)
}