wip: access control backend

This commit is contained in:
Timothy Jaeryang Baek
2024-11-15 01:29:07 -08:00
parent b80ec76435
commit 2ab5b2fd71
8 changed files with 282 additions and 52 deletions

View File

@@ -68,7 +68,6 @@ class GroupResponse(BaseModel):
permissions: Optional[dict] = None
meta: Optional[dict] = None
user_ids: list[str] = []
admin_ids: list[str] = []
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
@@ -119,6 +118,16 @@ class GroupTable:
for group in db.query(Group).order_by(Group.updated_at.desc()).all()
]
def get_groups_by_member_id(self, user_id: str) -> list[GroupModel]:
with get_db() as db:
return [
GroupModel.model_validate(group)
for group in db.query(Group)
.filter(Group.user_ids.contains([user_id]))
.order_by(Group.updated_at.desc())
.all()
]
def get_group_by_id(self, id: str) -> Optional[GroupModel]:
try:
with get_db() as db:

View File

@@ -4,9 +4,20 @@ from typing import Optional
from open_webui.apps.webui.internal.db import Base, JSONField, get_db
from open_webui.env import SRC_LOG_LEVELS
from open_webui.apps.webui.models.groups import Groups
from pydantic import BaseModel, ConfigDict
from sqlalchemy import or_, and_, func
from sqlalchemy.dialects import postgresql, sqlite
from sqlalchemy import BigInteger, Column, Text, JSON
from open_webui.utils.utils import has_access
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -112,8 +123,14 @@ class ModelModel(BaseModel):
class ModelResponse(BaseModel):
id: str
user_id: str
base_model_id: Optional[str] = None
name: str
params: ModelParams
meta: ModelMeta
access_control: Optional[dict] = None
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
@@ -157,6 +174,24 @@ class ModelsTable:
with get_db() as db:
return [ModelModel.model_validate(model) for model in db.query(Model).all()]
def get_models(self) -> list[ModelModel]:
with get_db() as db:
return [
ModelModel.model_validate(model)
for model in db.query(Model).filter(Model.base_model_id != None).all()
]
def get_models_by_user_id(
self, user_id: str, permission: str = "write"
) -> list[ModelModel]:
models = self.get_all_models()
return [
model
for model in models
if model.user_id == user_id
or has_access(user_id, permission, model.access_control)
]
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
try:
with get_db() as db:

View File

@@ -2,6 +2,8 @@ import time
from typing import Optional
from open_webui.apps.webui.internal.db import Base, get_db
from open_webui.apps.webui.models.groups import Groups
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, JSON
@@ -100,6 +102,64 @@ class PromptsTable:
PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
]
def get_prompts_by_user_id(
self, user_id: str, permission: str = "write"
) -> list[PromptModel]:
prompts = self.get_prompts()
groups = Groups.get_groups_by_member_id(user_id)
group_ids = [group.id for group in groups]
if permission == "write":
return [
prompt
for prompt in prompts
if prompt.user_id == user_id
or (
prompt.access_control
and (
any(
group_id
in prompt.access_control.get(permission, {}).get(
"group_ids", []
)
for group_id in group_ids
)
or (
user_id
in prompt.access_control.get(permission, {}).get(
"user_ids", []
)
)
)
)
]
elif permission == "read":
return [
prompt
for prompt in prompts
if prompt.user_id == user_id
or prompt.access_control is None
or (
prompt.access_control
and (
any(
prompt.access_control.get(permission, {}).get(
"group_ids", []
)
in group_id
for group_id in group_ids
)
or (
user_id
in prompt.access_control.get(permission, {}).get(
"user_ids", []
)
)
)
)
]
def update_prompt_by_command(
self, command: str, form_data: PromptForm
) -> Optional[PromptModel]: