implement google oauth

This commit is contained in:
mbecker20
2022-12-29 07:59:50 +00:00
parent 15662c951d
commit 3b6d3af7bb
23 changed files with 716 additions and 113 deletions

75
Cargo.lock generated
View File

@@ -147,6 +147,20 @@ dependencies = [
"tower-service",
]
[[package]]
name = "axum_oauth2"
version = "0.1.0"
dependencies = [
"anyhow",
"axum",
"jwt",
"rand",
"reqwest",
"serde",
"serde_derive",
"urlencoding",
]
[[package]]
name = "base64"
version = "0.13.0"
@@ -296,7 +310,6 @@ dependencies = [
"js-sys",
"num-integer",
"num-traits",
"serde",
"time 0.1.44",
"wasm-bindgen",
"winapi",
@@ -368,6 +381,7 @@ dependencies = [
"async_timing_util",
"axum",
"axum-extra",
"axum_oauth2",
"bcrypt",
"db_client",
"diff-struct",
@@ -379,7 +393,6 @@ dependencies = [
"jwt",
"monitor_types",
"mungos",
"oauth2",
"periphery_client",
"serde",
"serde_derive",
@@ -847,10 +860,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4eb1a864a501629691edf6c15a593b7a51eebaa1e8468e9ddc623de7c9b58ec6"
dependencies = [
"cfg-if",
"js-sys",
"libc",
"wasi 0.11.0+wasi-snapshot-preview1",
"wasm-bindgen",
]
[[package]]
@@ -1012,19 +1023,6 @@ dependencies = [
"want",
]
[[package]]
name = "hyper-rustls"
version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d87c48c02e0dc5e3b849a2041db3029fd066650f8f717c07bf8ed78ccb895cac"
dependencies = [
"http",
"hyper",
"rustls",
"tokio",
"tokio-rustls",
]
[[package]]
name = "hyper-tls"
version = "0.5.0"
@@ -1327,7 +1325,7 @@ dependencies = [
"rand",
"rustc_version_runtime",
"rustls",
"rustls-pemfile 0.3.0",
"rustls-pemfile",
"serde",
"serde_bytes",
"serde_with",
@@ -1555,26 +1553,6 @@ dependencies = [
"libc",
]
[[package]]
name = "oauth2"
version = "4.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d62c436394991641b970a92e23e8eeb4eb9bca74af4f5badc53bcd568daadbd"
dependencies = [
"base64",
"chrono",
"getrandom",
"http",
"rand",
"reqwest",
"serde",
"serde_json",
"serde_path_to_error",
"sha2",
"thiserror",
"url",
]
[[package]]
name = "once_cell"
version = "1.15.0"
@@ -1871,7 +1849,6 @@ dependencies = [
"http",
"http-body",
"hyper",
"hyper-rustls",
"hyper-tls",
"ipnet",
"js-sys",
@@ -1881,20 +1858,16 @@ dependencies = [
"once_cell",
"percent-encoding",
"pin-project-lite",
"rustls",
"rustls-pemfile 1.0.1",
"serde",
"serde_json",
"serde_urlencoded",
"tokio",
"tokio-native-tls",
"tokio-rustls",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
"webpki-roots",
"winreg 0.10.1",
]
@@ -1972,15 +1945,6 @@ dependencies = [
"base64",
]
[[package]]
name = "rustls-pemfile"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0864aeff53f8c05aa08d86e5ef839d3dfcf07aeba2db32f12db0ef716e87bd55"
dependencies = [
"base64",
]
[[package]]
name = "rustversion"
version = "1.0.9"
@@ -2725,9 +2689,14 @@ dependencies = [
"form_urlencoded",
"idna 0.3.0",
"percent-encoding",
"serde",
]
[[package]]
name = "urlencoding"
version = "2.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8db7427f936968176eaa7cdf81b7f98b980b18495ec28f1b5791ac3bfe3eea9"
[[package]]
name = "utf-8"
version = "0.7.6"

View File

@@ -5,6 +5,7 @@ members = [
"core",
"periphery",
"tests",
"lib/axum_oauth2",
"lib/db_client",
"lib/helpers",
"lib/periphery_client",

View File

@@ -19,6 +19,12 @@ const CORE_IMAGE_NAME: &str = "mbecker20/monitor-core";
const PERIPHERY_IMAGE_NAME: &str = "mbecker20/monitor-periphery";
pub fn gen_core_config(sub_matches: &ArgMatches) {
let host = sub_matches
.get_one::<String>("host")
.map(|p| p.as_str())
.unwrap_or("http://localhost:9000")
.to_string();
let path = sub_matches
.get_one::<String>("path")
.map(|p| p.as_str())
@@ -56,10 +62,12 @@ pub fn gen_core_config(sub_matches: &ArgMatches) {
.map(|p| p.to_owned());
let config = CoreConfig {
host,
port,
jwt_valid_for,
slack_url,
github_oauth: Default::default(),
google_oauth: Default::default(),
mongo: MongoConfig {
uri: mongo_uri,
db_name: mongo_db_name,

View File

@@ -110,6 +110,10 @@ fn cli() -> Command {
.subcommand(
Command::new("gen_config")
.about("generate a periphery config file")
.arg(
arg!(--host <HOST> "the host to use with oauth redirect url, whatever host the user hits to access monitor. eg 'https://monitor.mogh.tech'")
.required(true)
)
.arg(
arg!(--path <PATH> "sets path of generated config file. default is '~/.monitor/periphery.config.toml'")
.required(false)

View File

@@ -6,6 +6,9 @@ use strum_macros::{Display, EnumString};
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct CoreConfig {
// the host to use with oauth redirect url, whatever host the user hits to access monitor. eg 'https://monitor.mogh.tech'
pub host: String,
// port the core web server runs on
#[serde(default = "default_core_port")]
pub port: u16,
@@ -22,6 +25,9 @@ pub struct CoreConfig {
pub github_webhook_secret: String,
pub github_oauth: OauthCredentials,
// google integration
pub google_oauth: OauthCredentials,
// mongo config
pub mongo: MongoConfig,
}

View File

@@ -10,6 +10,7 @@ helpers = { path = "../lib/helpers" }
db = { package = "db_client", path = "../lib/db_client" }
types = { package = "monitor_types", path = "../lib/types" }
periphery = { package = "periphery_client", path = "../lib/periphery_client" }
axum_oauth2 = { path = "../lib/axum_oauth2" }
tokio = { version = "1.21", features = ["full"] }
tokio-util = "0.7"
axum = { version = "0.6", features = ["ws", "json"] }
@@ -23,7 +24,6 @@ serde_derive = "1.0"
serde_json = "1.0"
dotenv = "0.15"
envy = "0.4"
oauth2 = "4.2.3"
anyhow = "1.0"
bcrypt = "0.13"
jwt = "0.16"

View File

@@ -1,33 +1,95 @@
use std::sync::Arc;
use axum::{Extension, Router};
use oauth2::{basic::BasicClient, AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl};
use types::CoreConfig;
use anyhow::{anyhow, Context};
use axum::{extract::Query, response::Redirect, routing::get, Extension, Router};
use axum_oauth2::github::{GithubOauthClient, GithubOauthExtension};
use helpers::handle_anyhow_error;
use mungos::{doc, Deserialize};
use types::{monitor_timestamp, CoreConfig, User};
pub type GithubOauthExtension = Extension<Arc<BasicClient>>;
use crate::{response, state::StateExtension};
use super::JwtExtension;
pub fn router(config: &CoreConfig) -> Router {
Router::new().layer(github_oauth_extension(
config,
format!("http://localhost:9000/auth/github/callback"),
))
let client = GithubOauthClient::new(
config.github_oauth.id.clone(),
config.github_oauth.secret.clone(),
format!("{}/auth/github/callback", config.host),
&[],
"monitor".to_string(),
);
Router::new()
.route(
"/login",
get(|Extension(client): GithubOauthExtension| async move {
Redirect::to(&client.get_login_redirect_url())
}),
)
.route(
"/callback",
get(|client, jwt, state, query| async {
let redirect = callback(client, jwt, state, query)
.await
.map_err(handle_anyhow_error)?;
response!(redirect)
}),
)
.layer(Extension(Arc::new(client)))
}
fn github_oauth_extension(config: &CoreConfig, redirect_url: String) -> GithubOauthExtension {
let github_client_id = ClientId::new(config.github_oauth.id.clone());
let github_client_secret = ClientSecret::new(config.github_oauth.secret.clone());
let auth_url = AuthUrl::new("https://github.com/login/oauth/authorize".to_string())
.expect("invalid auth url");
let token_url = TokenUrl::new("https://github.com/login/oauth/access_token".to_string())
.expect("Invalid token endpoint URL");
// Set up the config for the Github OAuth2 process.
let client = BasicClient::new(
github_client_id,
Some(github_client_secret),
auth_url,
Some(token_url),
)
.set_redirect_uri(RedirectUrl::new(redirect_url).expect("Invalid redirect URL"));
Extension(Arc::new(client))
#[derive(Deserialize)]
struct CallbackQuery {
state: String,
code: String,
}
async fn callback(
Extension(client): GithubOauthExtension,
Extension(jwt_client): JwtExtension,
Extension(state): StateExtension,
Query(query): Query<CallbackQuery>,
) -> anyhow::Result<Redirect> {
if !client.check_state(&query.state) {
return Err(anyhow!("state mismatch"));
}
let token = client.get_access_token(&query.code).await?;
let github_user = client.get_github_user(&token.access_token).await?;
let github_id = github_user.id.to_string();
let user = state
.db
.users
.find_one(doc! { "github_id": &github_id }, None)
.await
.context("failed at find user query from mongo")?;
let jwt = match user {
Some(user) => jwt_client
.generate(user.id)
.context("failed to generate jwt")?,
None => {
let ts = monitor_timestamp();
let user = User {
username: github_user.login,
avatar: github_user.avatar_url.into(),
github_id: github_id.into(),
created_at: ts.clone(),
updated_at: ts,
..Default::default()
};
let user_id = state
.db
.users
.create_one(user)
.await
.context("failed to create user on mongo")?;
jwt_client
.generate(user_id)
.context("failed to generate jwt")?
}
};
let exchange_token = jwt_client.create_exchange_token(jwt);
Ok(Redirect::to(&format!(
"{}?token={exchange_token}",
state.config.host
)))
}

95
core/src/auth/google.rs Normal file
View File

@@ -0,0 +1,95 @@
use std::sync::Arc;
use anyhow::{anyhow, Context};
use axum::{Router, Extension, routing::get, response::Redirect, extract::Query};
use axum_oauth2::google::{GoogleOauthClient, GoogleOauthExtension};
use helpers::handle_anyhow_error;
use mungos::{Deserialize, doc};
use types::{CoreConfig, monitor_timestamp, User};
use crate::{state::StateExtension, response};
use super::JwtExtension;
pub fn router(config: &CoreConfig) -> Router {
let client = GoogleOauthClient::new(
config.google_oauth.id.clone(),
config.google_oauth.secret.clone(),
format!("{}/auth/google/callback", config.host),
&[],
"monitor".to_string(),
);
Router::new()
.route(
"/login",
get(|Extension(client): GoogleOauthExtension| async move {
Redirect::to(&client.get_login_redirect_url())
}),
)
.route(
"/callback",
get(|client, jwt, state, query| async {
let redirect = callback(client, jwt, state, query)
.await
.map_err(handle_anyhow_error)?;
response!(redirect)
}),
)
.layer(Extension(Arc::new(client)))
}
#[derive(Deserialize)]
struct CallbackQuery {
state: String,
code: String,
}
async fn callback(
Extension(client): GoogleOauthExtension,
Extension(jwt_client): JwtExtension,
Extension(state): StateExtension,
Query(query): Query<CallbackQuery>,
) -> anyhow::Result<Redirect> {
if !client.check_state(&query.state) {
return Err(anyhow!("state mismatch"));
}
let token = client.get_access_token(&query.code).await?;
let google_user = client.get_google_user(&token.access_token)?;
let google_id = google_user.id.to_string();
let user = state
.db
.users
.find_one(doc! { "google_id": &google_id }, None)
.await
.context("failed at find user query from mongo")?;
let jwt = match user {
Some(user) => jwt_client
.generate(user.id)
.context("failed to generate jwt")?,
None => {
let ts = monitor_timestamp();
let user = User {
username: google_user.email.split("@").collect::<Vec<&str>>().get(0).unwrap().to_string(),
avatar: google_user.picture.into(),
google_id: google_id.into(),
created_at: ts.clone(),
updated_at: ts,
..Default::default()
};
let user_id = state
.db
.users
.create_one(user)
.await
.context("failed to create user on mongo")?;
jwt_client
.generate(user_id)
.context("failed to generate jwt")?
}
};
let exchange_token = jwt_client.create_exchange_token(jwt);
Ok(Redirect::to(&format!(
"{}?token={exchange_token}",
state.config.host
)))
}

View File

@@ -1,8 +1,12 @@
use std::sync::Arc;
use std::{
collections::HashMap,
sync::{Arc, Mutex},
};
use anyhow::{anyhow, Context};
use async_timing_util::{get_timelength_in_ms, unix_timestamp_ms};
use async_timing_util::{get_timelength_in_ms, unix_timestamp_ms, Timelength};
use axum::{body::Body, http::Request, Extension};
use axum_oauth2::random_string;
use hmac::{Hmac, Mac};
use jwt::{SignWithKey, VerifyWithKey};
use mungos::{Deserialize, Serialize};
@@ -14,6 +18,8 @@ use crate::state::State;
pub type JwtExtension = Extension<Arc<JwtClient>>;
pub type RequestUserExtension = Extension<Arc<RequestUser>>;
type ExchangeTokenMap = Mutex<HashMap<String, (String, u128)>>;
pub struct RequestUser {
pub id: String,
pub is_admin: bool,
@@ -27,10 +33,10 @@ pub struct JwtClaims {
pub exp: u128,
}
#[derive(Clone)]
pub struct JwtClient {
key: Hmac<Sha256>,
valid_for_ms: u128,
exchange_tokens: ExchangeTokenMap,
}
impl JwtClient {
@@ -40,6 +46,7 @@ impl JwtClient {
let client = JwtClient {
key,
valid_for_ms: get_timelength_in_ms(config.jwt_valid_for.to_string().parse().unwrap()),
exchange_tokens: Default::default(),
};
Extension(Arc::new(client))
}
@@ -149,4 +156,30 @@ impl JwtClient {
Err(anyhow!("token has expired"))
}
}
pub fn create_exchange_token(&self, jwt: String) -> String {
let exchange_token = random_string(40);
self.exchange_tokens.lock().unwrap().insert(
exchange_token.clone(),
(
jwt,
unix_timestamp_ms() + get_timelength_in_ms(Timelength::OneMinute),
),
);
exchange_token
}
pub fn redeem_exchange_token(&self, exchange_token: &str) -> anyhow::Result<String> {
let (jwt, valid_until) = self
.exchange_tokens
.lock()
.unwrap()
.remove(exchange_token)
.ok_or(anyhow!("invalid exchange token: unrecognized"))?;
if unix_timestamp_ms() < valid_until {
Ok(jwt)
} else {
Err(anyhow!("invalid exchange token: expired"))
}
}
}

View File

@@ -5,11 +5,16 @@ use axum::{
http::{Request, StatusCode},
middleware::Next,
response::Response,
Router,
routing::post,
Extension, Json, Router,
};
use helpers::handle_anyhow_error;
use mungos::Deserialize;
use types::CoreConfig;
use typeshare::typeshare;
mod github;
mod google;
mod jwt;
mod local;
mod secret;
@@ -18,11 +23,34 @@ pub use self::jwt::{JwtClaims, JwtClient, JwtExtension, RequestUser, RequestUser
pub fn router(config: &CoreConfig) -> Router {
Router::new()
.route(
"/exchange",
post(|jwt, body| async {
exchange_for_jwt(jwt, body)
.await
.map_err(handle_anyhow_error)
}),
)
.nest("/local", local::router())
.nest("/github", github::router(config))
.nest("/google", google::router(config))
.nest("/secret", secret::router())
}
#[typeshare]
#[derive(Deserialize)]
struct TokenExchangeBody {
token: String,
}
async fn exchange_for_jwt(
Extension(jwt): JwtExtension,
Json(body): Json<TokenExchangeBody>,
) -> anyhow::Result<String> {
let jwt = jwt.redeem_exchange_token(&body.token)?;
Ok(jwt)
}
pub async fn auth_request(
mut req: Request<Body>,
next: Next<Body>,

View File

@@ -1,3 +1,4 @@
use axum_extra::routing::SpaRouter;
use dotenv::dotenv;
use helpers::parse_config_file;
use mungos::Deserialize;
@@ -7,15 +8,22 @@ use types::CoreConfig;
struct Env {
#[serde(default = "default_config_path")]
pub config_path: String,
#[serde(default = "default_frontend_path")]
pub frontend_path: String,
}
pub fn load() -> CoreConfig {
pub fn load() -> (CoreConfig, SpaRouter) {
dotenv().ok();
let env: Env = envy::from_env().expect("failed to parse environment variables");
let config = parse_config_file(&env.config_path).expect("failed to parse config");
config
let spa_router = SpaRouter::new("/", env.frontend_path);
(config, spa_router)
}
pub fn default_config_path() -> String {
"/config/config.toml".to_string()
}
fn default_frontend_path() -> String {
"/frontend".to_string()
}

View File

@@ -16,11 +16,12 @@ mod ws;
#[tokio::main]
async fn main() {
let config = config::load();
let (config, spa_router) = config::load();
println!("starting monitor core on port {}...", config.port);
let app = Router::new()
.merge(spa_router)
.nest("/api", api::router())
.nest("/auth", auth::router(&config))
.nest("/ws", ws::router())

View File

@@ -4,7 +4,7 @@ import Input from "../shared/Input";
import Grid from "../shared/layout/Grid";
import { createStore } from "solid-js/store";
import Flex from "../shared/layout/Flex";
import { client, pushNotification } from "../..";
import { client, pushNotification, URL } from "../..";
import { combineClasses } from "../../util/helpers";
import Icon from "../shared/Icon";
import { useUser } from "../../state/UserProvider";
@@ -77,12 +77,12 @@ const Login: Component<{}> = (p) => {
sign up
</button>
</Flex>
{/* <button
<button
class={combineClasses(s.LoginItem, "blue")}
onClick={() => client.loginGithub()}
onClick={() => location.replace(`${URL}/auth/github/login`)}
>
log in with github <Icon type="github" />
</button> */}
</button>
</Grid>
</div>
);

View File

@@ -30,20 +30,22 @@ export const client = new Client(URL, token);
export const { Notifications, pushNotification } = makeNotifications();
render(
() => [
<DimensionProvider>
<UserProvider>
<LoginGuard>
<Router>
<AppStateProvider>
<App />
</AppStateProvider>
</Router>
</LoginGuard>
</UserProvider>
</DimensionProvider>,
<Notifications />,
],
document.getElementById("root") as HTMLElement
);
client.initialize().then(() => {
render(
() => [
<DimensionProvider>
<UserProvider>
<LoginGuard>
<Router>
<AppStateProvider>
<App />
</AppStateProvider>
</Router>
</LoginGuard>
</UserProvider>
</DimensionProvider>,
<Notifications />,
],
document.getElementById("root") as HTMLElement
);
});

View File

@@ -23,7 +23,7 @@ export type UserState = {
const UserContext = createContext<UserState>();
export const UserProvider: ParentComponent = (p) => {
const [user, { mutate, refetch }] = createResource(() => client.getUser());
const [user, { mutate, refetch }] = createResource(() => client.get_user());
const logout = async () => {
client.logout();
mutate(false);

View File

@@ -1,4 +1,5 @@
import axios from "axios";
import { URL } from "..";
import {
BasicContainerInfo,
Build,
@@ -35,18 +36,33 @@ import { generateQuery, QueryObject } from "./helpers";
export class Client {
constructor(private baseURL: string, public token: string | null) {}
async initialize() {
const params = new URLSearchParams(location.search);
const exchange_token = params.get("token");
if (exchange_token) {
history.replaceState({}, "", URL);
try {
const jwt = await this.exchange_for_jwt(exchange_token);
this.token = jwt;
localStorage.setItem("access_token", jwt);
} catch (error) {
console.warn(error);
}
}
}
async login(credentials: UserCredentials) {
const jwt: string = await this.post("/auth/local/login", credentials);
this.token = jwt;
localStorage.setItem("access_token", this.token);
return await this.getUser();
return await this.get_user();
}
async signup(credentials: UserCredentials) {
const jwt: string = await this.post("/auth/local/create_user", credentials);
this.token = jwt;
localStorage.setItem("access_token", this.token);
return await this.getUser();
return await this.get_user();
}
logout() {
@@ -54,7 +70,7 @@ export class Client {
this.token = null;
}
async getUser(): Promise<User | false> {
async get_user(): Promise<User | false> {
if (this.token) {
try {
return await this.get("/api/user");
@@ -67,6 +83,10 @@ export class Client {
}
}
async exchange_for_jwt(exchange_token: string): Promise<string> {
return this.post("/auth/exchange", { token: exchange_token });
}
// deployment
list_deployments(query?: QueryObject): Promise<DeploymentWithContainerState[]> {

View File

@@ -54,3 +54,7 @@ export interface CreateServerBody {
address: string;
}
export interface TokenExchangeBody {
token: string;
}

View File

@@ -0,0 +1,16 @@
[package]
name = "axum_oauth2"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
axum = { version = "0.6", features = ["json"] }
reqwest = { version = "0.11", features = ["json"] }
anyhow = "1.0"
serde = "1.0"
serde_derive = "1.0"
urlencoding = "2.1"
rand = "0.8"
jwt = "0.16"

View File

@@ -0,0 +1,179 @@
use std::sync::{Arc, Mutex};
use anyhow::{anyhow, Context};
use axum::Extension;
use reqwest::StatusCode;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use crate::random_string;
pub type GithubOauthExtension = Extension<Arc<GithubOauthClient>>;
pub struct GithubOauthClient {
http: reqwest::Client,
client_id: String,
client_secret: String,
redirect_uri: String,
scopes: String,
states: Mutex<Vec<String>>,
user_agent: String,
}
impl GithubOauthClient {
pub fn new(
client_id: String,
client_secret: String,
redirect_uri: String,
scopes: &[&str],
user_agent: String,
) -> GithubOauthClient {
GithubOauthClient {
http: reqwest::Client::new(),
client_id,
client_secret,
redirect_uri,
user_agent,
scopes: urlencoding::encode(&scopes.join(" ")).to_string(),
states: Default::default(),
}
}
pub fn get_login_redirect_url(&self) -> String {
let state = random_string(40);
let redirect_url = format!(
"https://github.com/login/oauth/authorize?state={state}&client_id={}&redirect_uri={}&scope={}",
self.client_id, self.redirect_uri, self.scopes
);
{
let mut states = self.states.lock().unwrap();
states.push(state);
// println!("{states:#?}");
}
redirect_url
}
pub fn check_state(&self, state: &str) -> bool {
let mut contained = false;
self.states.lock().unwrap().retain(|s| {
if s.as_str() == state {
contained = true;
false
} else {
true
}
});
contained
}
pub async fn get_access_token(&self, code: &str) -> anyhow::Result<AccessTokenResponse> {
self.post::<(), _>(
"https://github.com/login/oauth/access_token",
&[
("client_id", self.client_id.as_str()),
("client_secret", self.client_secret.as_str()),
("redirect_uri", self.redirect_uri.as_str()),
("code", code),
],
None,
None,
)
.await
.context("failed to get github access token using code")
}
pub async fn get_github_user(&self, token: &str) -> anyhow::Result<GithubUserResponse> {
self.get("https://api.github.com/user", &[], Some(token))
.await
.context("failed to get github user using access token")
}
async fn get<R: DeserializeOwned>(
&self,
endpoint: &str,
query: &[(&str, &str)],
bearer_token: Option<&str>,
) -> anyhow::Result<R> {
let mut req = self
.http
.get(endpoint)
.query(query)
.header("User-Agent", &self.user_agent);
if let Some(bearer_token) = bearer_token {
req = req.header("Authorization", format!("Bearer {bearer_token}"));
}
let res = req.send().await.context("failed to reach github")?;
let status = res.status();
if status == StatusCode::OK {
let body = res
.json()
.await
.context("failed to parse body into expected type")?;
Ok(body)
} else {
let text = res
.text()
.await
.context(format!("status: {status} | failed to get response text"))?;
Err(anyhow!("status: {status} | text: {text}"))
}
}
async fn post<B: Serialize, R: DeserializeOwned>(
&self,
endpoint: &str,
query: &[(&str, &str)],
body: Option<&B>,
bearer_token: Option<&str>,
) -> anyhow::Result<R> {
let mut req = self
.http
.post(endpoint)
.query(query)
.header("Accept", "application/json")
.header("User-Agent", &self.user_agent);
if let Some(body) = body {
req = req.json(body);
}
if let Some(bearer_token) = bearer_token {
req = req.header("Authorization", format!("Bearer {bearer_token}"));
}
let res = req.send().await.context("failed to reach github")?;
let status = res.status();
if status == StatusCode::OK {
let body = res
.json()
.await
.context("failed to parse POST body into expected type")?;
Ok(body)
} else {
let text = res.text().await.context(format!(
"method: POST | status: {status} | failed to get response text"
))?;
Err(anyhow!("method: POST | status: {status} | text: {text}"))
}
}
}
#[derive(Deserialize)]
pub struct AccessTokenResponse {
pub access_token: String,
pub scope: String,
pub token_type: String,
}
#[derive(Deserialize)]
pub struct GithubUserResponse {
pub login: String,
pub id: u128,
pub avatar_url: String,
pub email: Option<String>,
}

View File

@@ -0,0 +1,145 @@
use std::sync::{Mutex, Arc};
use anyhow::{anyhow, Context};
use axum::Extension;
use jwt::{Header, Token};
use reqwest::StatusCode;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use crate::random_string;
pub type GoogleOauthExtension = Extension<Arc<GoogleOauthClient>>;
pub struct GoogleOauthClient {
http: reqwest::Client,
client_id: String,
client_secret: String,
redirect_uri: String,
scopes: String,
states: Mutex<Vec<String>>,
user_agent: String,
}
impl GoogleOauthClient {
pub fn new(
client_id: String,
client_secret: String,
redirect_uri: String,
scopes: &[&str],
user_agent: String,
) -> GoogleOauthClient {
GoogleOauthClient {
http: reqwest::Client::new(),
client_id,
client_secret,
redirect_uri,
user_agent,
scopes: urlencoding::encode(&scopes.join(" ")).to_string(),
states: Default::default(),
}
}
pub fn get_login_redirect_url(&self) -> String {
let state = random_string(40);
let redirect_url = format!(
"https://accounts.google.com/o/oauth2/v2/auth?response_type=code&state={state}&client_id={}&redirect_uri={}&scope={}",
self.client_id, self.redirect_uri, self.scopes
);
{
let mut states = self.states.lock().unwrap();
states.push(state);
// println!("{states:#?}");
}
redirect_url
}
pub fn check_state(&self, state: &str) -> bool {
let mut contained = false;
self.states.lock().unwrap().retain(|s| {
if s.as_str() == state {
contained = true;
false
} else {
true
}
});
contained
}
pub async fn get_access_token(&self, code: &str) -> anyhow::Result<AccessTokenResponse> {
self.post::<(), _>(
"https://oauth2.googleapis.com/token",
&[
("client_id", self.client_id.as_str()),
("client_secret", self.client_secret.as_str()),
("redirect_uri", self.redirect_uri.as_str()),
("code", code),
("grant_type", "authorization_code"),
],
None,
None,
)
.await
.context("failed to get google access token using code")
}
pub fn get_google_user(&self, token: &str) -> anyhow::Result<GoogleUser> {
let token: Token<Header, GoogleUser, jwt::Unverified> = Token::parse_unverified(token)?;
Ok(token.claims().to_owned())
}
async fn post<B: Serialize, R: DeserializeOwned>(
&self,
endpoint: &str,
query: &[(&str, &str)],
body: Option<&B>,
bearer_token: Option<&str>,
) -> anyhow::Result<R> {
let mut req = self
.http
.post(endpoint)
.query(query)
.header("Accept", "application/json")
.header("User-Agent", &self.user_agent);
if let Some(body) = body {
req = req.json(body);
}
if let Some(bearer_token) = bearer_token {
req = req.header("Authorization", format!("Bearer {bearer_token}"));
}
let res = req.send().await.context("failed to reach google")?;
let status = res.status();
if status == StatusCode::OK {
let body = res
.json()
.await
.context("failed to parse POST body into expected type")?;
Ok(body)
} else {
let text = res.text().await.context(format!(
"method: POST | status: {status} | failed to get response text"
))?;
Err(anyhow!("method: POST | status: {status} | text: {text}"))
}
}
}
#[derive(Deserialize)]
pub struct AccessTokenResponse {
pub access_token: String,
pub scope: String,
pub token_type: String,
}
#[derive(Deserialize, Clone)]
pub struct GoogleUser {
#[serde(rename = "sub")]
pub id: String,
pub email: String,
pub picture: String,
}

View File

@@ -0,0 +1,12 @@
use rand::{distributions::Alphanumeric, thread_rng, Rng};
pub mod github;
pub mod google;
pub fn random_string(length: usize) -> String {
thread_rng()
.sample_iter(&Alphanumeric)
.take(length)
.map(char::from)
.collect()
}

View File

@@ -16,6 +16,9 @@ pub type SecretsMap = HashMap<String, String>; // these are used for injection i
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct CoreConfig {
// the host to use with oauth redirect url, whatever host the user hits to access monitor. eg 'https://monitor.mogh.tech'
pub host: String,
// port the core web server runs on
#[serde(default = "default_core_port")]
pub port: u16,
@@ -32,6 +35,9 @@ pub struct CoreConfig {
pub github_webhook_secret: String,
pub github_oauth: OauthCredentials,
// google integration
pub google_oauth: OauthCredentials,
// mongo config
pub mongo: MongoConfig,
}

View File

@@ -7,7 +7,7 @@ use axum::{
middleware::{self, Next},
response::Response,
routing::get,
Router, RequestExt, Json,
Json, RequestExt, Router,
};
use helpers::docker::DockerClient;
use serde_json::Value;
@@ -62,7 +62,11 @@ async fn guard_request(
} else {
let method = req.method().to_owned();
let uri = req.uri().to_owned();
let body = req.extract::<Json<Value>, _>().await.ok().map(|Json(body)| body);
let body = req
.extract::<Json<Value>, _>()
.await
.ok()
.map(|Json(body)| body);
eprintln!(
"{} | unauthorized request from {ip} | method: {method} | uri: {uri} | body: {:?}",
monitor_timestamp(),