KL-4 ext 2: Improve rate limiting / attempt state conveyance with response

This commit is contained in:
mbecker20
2025-11-30 01:30:24 -08:00
committed by Maxwell Becker
parent 85787781ee
commit 8c62f2b5c5
3 changed files with 55 additions and 20 deletions

View File

@@ -12,7 +12,9 @@ use komodo_client::{
api::auth::JwtResponse,
entities::{config::core::CoreConfig, random_string},
};
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
use serror::{AddStatusCode as _, AddStatusCodeError as _};
use tokio::sync::Mutex;
type ExchangeTokenMap = Mutex<HashMap<String, (JwtResponse, u128)>>;
@@ -93,17 +95,21 @@ impl JwtClient {
pub async fn redeem_exchange_token(
&self,
exchange_token: &str,
) -> anyhow::Result<JwtResponse> {
) -> serror::Result<JwtResponse> {
let (jwt, valid_until) = self
.exchange_tokens
.lock()
.await
.remove(exchange_token)
.context("invalid exchange token: unrecognized")?;
.context("Invalid exchange token")
.status_code(StatusCode::UNAUTHORIZED)?;
if unix_timestamp_ms() < valid_until {
Ok(jwt)
} else {
Err(anyhow!("invalid exchange token: expired"))
Err(
anyhow!("Invalid exchange token")
.status_code(StatusCode::UNAUTHORIZED),
)
}
}
}

View File

@@ -1,3 +1,5 @@
use std::sync::{Arc, OnceLock};
use anyhow::{Context, anyhow};
use async_timing_util::unix_timestamp_ms;
use database::{
@@ -11,7 +13,7 @@ use komodo_client::{
},
entities::user::{User, UserConfig},
};
use rate_limit::WithFailureRateLimit;
use rate_limit::{RateLimiter, WithFailureRateLimit};
use reqwest::StatusCode;
use resolver_api::Resolve;
use serror::{AddStatusCode as _, AddStatusCodeError};
@@ -117,6 +119,23 @@ async fn sign_up_local_user(
.map_err(Into::into)
}
/// Local login method has a dedicated rate limiter
/// so the UI background calls using existing JWT do
/// not influence the number of attempts user has
/// to log in.
fn login_local_user_rate_limiter() -> &'static RateLimiter {
static LOGIN_LOCAL_USER_RATE_LIMITER: OnceLock<Arc<RateLimiter>> =
OnceLock::new();
LOGIN_LOCAL_USER_RATE_LIMITER.get_or_init(|| {
let config = core_config();
RateLimiter::new(
config.auth_rate_limit_disabled,
config.auth_rate_limit_max_attempts as usize,
config.auth_rate_limit_window_seconds,
)
})
}
impl Resolve<AuthArgs> for LoginLocalUser {
async fn resolve(
self,
@@ -124,7 +143,7 @@ impl Resolve<AuthArgs> for LoginLocalUser {
) -> serror::Result<LoginLocalUserResponse> {
login_local_user(self)
.with_failure_rate_limit_using_headers(
auth_rate_limiter(),
login_local_user_rate_limiter(),
headers,
)
.await

View File

@@ -12,10 +12,9 @@ use tokio::sync::RwLock;
/// Trait to extend fallible futures with stateful
/// rate limiting.
pub trait WithFailureRateLimit<R, E>
pub trait WithFailureRateLimit<R>
where
Self: Future<Output = Result<R, E>> + Sized,
E: Into<serror::Error> + Send,
Self: Future<Output = serror::Result<R>> + Sized,
{
/// Ensure the given IP 'ip' is
/// not violating the givin 'limiter' rate limit rules
@@ -37,7 +36,7 @@ where
) -> impl Future<Output = serror::Result<R>> {
async {
if limiter.disabled {
return self.await.map_err(Into::into);
return self.await;
}
// Only locks if entry at key does not exist yet.
@@ -49,8 +48,14 @@ where
let now = Instant::now();
let window_start = now - limiter.window;
let count =
read.iter().filter(|&&time| time > window_start).count();
let (first, count) =
read.iter().filter(|&&time| time > window_start).fold(
(Option::<Instant>::None, 0),
|(first, count), &time| {
(Some(first.unwrap_or(time)), count + 1)
},
);
// Drop the read lock immediately
drop(read);
@@ -60,8 +65,9 @@ where
attempts.write().await.retain(|&time| time > window_start);
return Err(
anyhow!(
"Too many attempts. Try again in {:?}",
"Too many attempts. Try again in {:.0?}",
limiter.window
- first.map(|first| now - first).unwrap_or_default(),
)
.status_code(StatusCode::TOO_MANY_REQUESTS),
);
@@ -71,15 +77,21 @@ where
// The succeeding branch has no write locks
// after the initial attempt array initializes.
Ok(res) => Ok(res),
Err(e) => {
Err(mut e) => {
// Failing branch takes exclusive write lock.
let mut write = attempts.write().await;
// Use this opportunity to clear the attempts cache
write.retain(|&time| time > window_start);
// Always push after failed attempts, eg failed api key check.
write.push(now);
// Return original error converted to serror::Error
Err(e.into())
// Add 1 to count because it doesn't include this attempt.
let remaining_attempts = limiter.max_attempts - (count + 1);
// Return original error with remaining attempts shown
e.error = anyhow!(
"{:#} | You have {remaining_attempts} attempts remaining",
e.error,
);
Err(e)
}
}
}
@@ -93,7 +105,7 @@ where
async {
// Can skip header ip extraction if disabled
if limiter.disabled {
return self.await.map_err(Into::into);
return self.await;
}
let ip = get_ip_from_headers(headers)?;
self.with_failure_rate_limit_using_ip(limiter, &ip).await
@@ -101,10 +113,8 @@ where
}
}
impl<F, R, E> WithFailureRateLimit<R, E> for F
where
F: Future<Output = Result<R, E>> + Sized,
E: Into<serror::Error> + Send,
impl<F, R> WithFailureRateLimit<R> for F where
F: Future<Output = serror::Result<R>> + Sized
{
}