mirror of
https://github.com/moghtech/komodo.git
synced 2026-03-11 17:44:19 -05:00
KL-4 ext 2: Improve rate limiting / attempt state conveyance with response
This commit is contained in:
committed by
Maxwell Becker
parent
85787781ee
commit
8c62f2b5c5
@@ -12,7 +12,9 @@ use komodo_client::{
|
||||
api::auth::JwtResponse,
|
||||
entities::{config::core::CoreConfig, random_string},
|
||||
};
|
||||
use reqwest::StatusCode;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serror::{AddStatusCode as _, AddStatusCodeError as _};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
type ExchangeTokenMap = Mutex<HashMap<String, (JwtResponse, u128)>>;
|
||||
@@ -93,17 +95,21 @@ impl JwtClient {
|
||||
pub async fn redeem_exchange_token(
|
||||
&self,
|
||||
exchange_token: &str,
|
||||
) -> anyhow::Result<JwtResponse> {
|
||||
) -> serror::Result<JwtResponse> {
|
||||
let (jwt, valid_until) = self
|
||||
.exchange_tokens
|
||||
.lock()
|
||||
.await
|
||||
.remove(exchange_token)
|
||||
.context("invalid exchange token: unrecognized")?;
|
||||
.context("Invalid exchange token")
|
||||
.status_code(StatusCode::UNAUTHORIZED)?;
|
||||
if unix_timestamp_ms() < valid_until {
|
||||
Ok(jwt)
|
||||
} else {
|
||||
Err(anyhow!("invalid exchange token: expired"))
|
||||
Err(
|
||||
anyhow!("Invalid exchange token")
|
||||
.status_code(StatusCode::UNAUTHORIZED),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
use anyhow::{Context, anyhow};
|
||||
use async_timing_util::unix_timestamp_ms;
|
||||
use database::{
|
||||
@@ -11,7 +13,7 @@ use komodo_client::{
|
||||
},
|
||||
entities::user::{User, UserConfig},
|
||||
};
|
||||
use rate_limit::WithFailureRateLimit;
|
||||
use rate_limit::{RateLimiter, WithFailureRateLimit};
|
||||
use reqwest::StatusCode;
|
||||
use resolver_api::Resolve;
|
||||
use serror::{AddStatusCode as _, AddStatusCodeError};
|
||||
@@ -117,6 +119,23 @@ async fn sign_up_local_user(
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Local login method has a dedicated rate limiter
|
||||
/// so the UI background calls using existing JWT do
|
||||
/// not influence the number of attempts user has
|
||||
/// to log in.
|
||||
fn login_local_user_rate_limiter() -> &'static RateLimiter {
|
||||
static LOGIN_LOCAL_USER_RATE_LIMITER: OnceLock<Arc<RateLimiter>> =
|
||||
OnceLock::new();
|
||||
LOGIN_LOCAL_USER_RATE_LIMITER.get_or_init(|| {
|
||||
let config = core_config();
|
||||
RateLimiter::new(
|
||||
config.auth_rate_limit_disabled,
|
||||
config.auth_rate_limit_max_attempts as usize,
|
||||
config.auth_rate_limit_window_seconds,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
impl Resolve<AuthArgs> for LoginLocalUser {
|
||||
async fn resolve(
|
||||
self,
|
||||
@@ -124,7 +143,7 @@ impl Resolve<AuthArgs> for LoginLocalUser {
|
||||
) -> serror::Result<LoginLocalUserResponse> {
|
||||
login_local_user(self)
|
||||
.with_failure_rate_limit_using_headers(
|
||||
auth_rate_limiter(),
|
||||
login_local_user_rate_limiter(),
|
||||
headers,
|
||||
)
|
||||
.await
|
||||
|
||||
@@ -12,10 +12,9 @@ use tokio::sync::RwLock;
|
||||
|
||||
/// Trait to extend fallible futures with stateful
|
||||
/// rate limiting.
|
||||
pub trait WithFailureRateLimit<R, E>
|
||||
pub trait WithFailureRateLimit<R>
|
||||
where
|
||||
Self: Future<Output = Result<R, E>> + Sized,
|
||||
E: Into<serror::Error> + Send,
|
||||
Self: Future<Output = serror::Result<R>> + Sized,
|
||||
{
|
||||
/// Ensure the given IP 'ip' is
|
||||
/// not violating the givin 'limiter' rate limit rules
|
||||
@@ -37,7 +36,7 @@ where
|
||||
) -> impl Future<Output = serror::Result<R>> {
|
||||
async {
|
||||
if limiter.disabled {
|
||||
return self.await.map_err(Into::into);
|
||||
return self.await;
|
||||
}
|
||||
|
||||
// Only locks if entry at key does not exist yet.
|
||||
@@ -49,8 +48,14 @@ where
|
||||
let now = Instant::now();
|
||||
let window_start = now - limiter.window;
|
||||
|
||||
let count =
|
||||
read.iter().filter(|&&time| time > window_start).count();
|
||||
let (first, count) =
|
||||
read.iter().filter(|&&time| time > window_start).fold(
|
||||
(Option::<Instant>::None, 0),
|
||||
|(first, count), &time| {
|
||||
(Some(first.unwrap_or(time)), count + 1)
|
||||
},
|
||||
);
|
||||
|
||||
// Drop the read lock immediately
|
||||
drop(read);
|
||||
|
||||
@@ -60,8 +65,9 @@ where
|
||||
attempts.write().await.retain(|&time| time > window_start);
|
||||
return Err(
|
||||
anyhow!(
|
||||
"Too many attempts. Try again in {:?}",
|
||||
"Too many attempts. Try again in {:.0?}",
|
||||
limiter.window
|
||||
- first.map(|first| now - first).unwrap_or_default(),
|
||||
)
|
||||
.status_code(StatusCode::TOO_MANY_REQUESTS),
|
||||
);
|
||||
@@ -71,15 +77,21 @@ where
|
||||
// The succeeding branch has no write locks
|
||||
// after the initial attempt array initializes.
|
||||
Ok(res) => Ok(res),
|
||||
Err(e) => {
|
||||
Err(mut e) => {
|
||||
// Failing branch takes exclusive write lock.
|
||||
let mut write = attempts.write().await;
|
||||
// Use this opportunity to clear the attempts cache
|
||||
write.retain(|&time| time > window_start);
|
||||
// Always push after failed attempts, eg failed api key check.
|
||||
write.push(now);
|
||||
// Return original error converted to serror::Error
|
||||
Err(e.into())
|
||||
// Add 1 to count because it doesn't include this attempt.
|
||||
let remaining_attempts = limiter.max_attempts - (count + 1);
|
||||
// Return original error with remaining attempts shown
|
||||
e.error = anyhow!(
|
||||
"{:#} | You have {remaining_attempts} attempts remaining",
|
||||
e.error,
|
||||
);
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -93,7 +105,7 @@ where
|
||||
async {
|
||||
// Can skip header ip extraction if disabled
|
||||
if limiter.disabled {
|
||||
return self.await.map_err(Into::into);
|
||||
return self.await;
|
||||
}
|
||||
let ip = get_ip_from_headers(headers)?;
|
||||
self.with_failure_rate_limit_using_ip(limiter, &ip).await
|
||||
@@ -101,10 +113,8 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, R, E> WithFailureRateLimit<R, E> for F
|
||||
where
|
||||
F: Future<Output = Result<R, E>> + Sized,
|
||||
E: Into<serror::Error> + Send,
|
||||
impl<F, R> WithFailureRateLimit<R> for F where
|
||||
F: Future<Output = serror::Result<R>> + Sized
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user