add basic org policy check in middleware

This commit is contained in:
miloschwartz
2025-10-23 18:15:00 -07:00
parent 5a7b5d65a4
commit ddcf77a62d
6 changed files with 151 additions and 11 deletions

View File

@@ -25,7 +25,8 @@ export const orgs = pgTable("orgs", {
orgId: varchar("orgId").primaryKey(),
name: varchar("name").notNull(),
subnet: varchar("subnet"),
createdAt: text("createdAt")
createdAt: text("createdAt"),
requireTwoFactor: boolean("requireTwoFactor").default(false)
});
export const orgDomains = pgTable("orgDomains", {

View File

@@ -18,7 +18,8 @@ export const orgs = sqliteTable("orgs", {
orgId: text("orgId").primaryKey(),
name: text("name").notNull(),
subnet: text("subnet"),
createdAt: text("createdAt")
createdAt: text("createdAt"),
requireTwoFactor: integer("requireTwoFactor", { mode: "boolean" })
});
export const userDomains = sqliteTable("userDomains", {
@@ -141,11 +142,15 @@ export const targets = sqliteTable("targets", {
});
export const targetHealthCheck = sqliteTable("targetHealthCheck", {
targetHealthCheckId: integer("targetHealthCheckId").primaryKey({ autoIncrement: true }),
targetHealthCheckId: integer("targetHealthCheckId").primaryKey({
autoIncrement: true
}),
targetId: integer("targetId")
.notNull()
.references(() => targets.targetId, { onDelete: "cascade" }),
hcEnabled: integer("hcEnabled", { mode: "boolean" }).notNull().default(false),
hcEnabled: integer("hcEnabled", { mode: "boolean" })
.notNull()
.default(false),
hcPath: text("hcPath"),
hcScheme: text("hcScheme"),
hcMode: text("hcMode").default("http"),
@@ -155,7 +160,9 @@ export const targetHealthCheck = sqliteTable("targetHealthCheck", {
hcUnhealthyInterval: integer("hcUnhealthyInterval").default(30), // in seconds
hcTimeout: integer("hcTimeout").default(5), // in seconds
hcHeaders: text("hcHeaders"),
hcFollowRedirects: integer("hcFollowRedirects", { mode: "boolean" }).default(true),
hcFollowRedirects: integer("hcFollowRedirects", {
mode: "boolean"
}).default(true),
hcMethod: text("hcMethod").default("GET"),
hcStatus: integer("hcStatus"), // http code
hcHealth: text("hcHealth").default("unknown") // "unknown", "healthy", "unhealthy"

View File

@@ -0,0 +1,17 @@
import { Org, User } from "@server/db";
type CheckOrgAccessPolicyProps = {
orgId?: string;
org?: Org;
userId?: string;
user?: User;
};
export async function checkOrgAccessPolicy(
props: CheckOrgAccessPolicyProps
): Promise<{
success: boolean;
error?: string;
}> {
return { success: true };
}

View File

@@ -40,6 +40,10 @@ export class License {
public setServerSecret(secret: string) {
this.serverSecret = secret;
}
public async isUnlocked() {
return false;
}
}
await setHostMeta();

View File

@@ -1,9 +1,10 @@
import { Request, Response, NextFunction } from "express";
import { db } from "@server/db";
import { db, orgs } from "@server/db";
import { userOrgs } from "@server/db";
import { and, eq } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
export async function verifyOrgAccess(
req: Request,
@@ -43,12 +44,27 @@ export async function verifyOrgAccess(
"User does not have access to this organization"
)
);
} else {
// User has access, attach the user's role to the request for potential future use
req.userOrgRoleId = req.userOrg.roleId;
req.userOrgId = orgId;
return next();
}
const policyCheck = await checkOrgAccessPolicy({
orgId,
userId
});
if (!policyCheck.success || policyCheck.error) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Failed organization access policy check: " +
(policyCheck.error || "Unknown error")
)
);
}
// User has access, attach the user's role to the request for potential future use
req.userOrgRoleId = req.userOrg.roleId;
req.userOrgId = orgId;
return next();
} catch (e) {
return next(
createHttpError(

View File

@@ -0,0 +1,95 @@
/*
* This file is part of a proprietary work.
*
* Copyright (c) 2025 Fossorial, Inc.
* All rights reserved.
*
* This file is licensed under the Fossorial Commercial License.
* You may not use this file except in compliance with the License.
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
*
* This file is not licensed under the AGPLv3.
*/
import { build } from "@server/build";
import { db, Org, orgs, User, users } from "@server/db";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import license from "#private/license/license";
import { eq } from "drizzle-orm";
type CheckOrgAccessPolicyProps = {
orgId?: string;
org?: Org;
userId?: string;
user?: User;
};
export async function checkOrgAccessPolicy(
props: CheckOrgAccessPolicyProps
): Promise<{
success: boolean;
error?: string;
}> {
const userId = props.userId || props.user?.userId;
const orgId = props.orgId || props.org?.orgId;
if (!orgId) {
return { success: false, error: "Organization ID is required" };
}
if (!userId) {
return { success: false, error: "User ID is required" };
}
if (build === "saas") {
const { tier } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
// if not subscribed, don't check the policies
if (!subscribed) {
return { success: true };
}
}
if (build === "enterprise") {
const isUnlocked = await license.isUnlocked();
// if not licensed, don't check the policies
if (!isUnlocked) {
return { success: true };
}
}
// get the needed data
if (!props.org) {
const [orgQuery] = await db
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId));
props.org = orgQuery;
if (!props.org) {
return { success: false, error: "Organization not found" };
}
}
if (!props.user) {
const [userQuery] = await db
.select()
.from(users)
.where(eq(users.userId, userId));
props.user = userQuery;
if (!props.user) {
return { success: false, error: "User not found" };
}
}
// now check the policies
if (!props.org.requireTwoFactor && !props.user.twoFactorEnabled) {
return {
success: false,
error: "Two-factor authentication is required"
};
}
return { success: true };
}