mirror of
https://github.com/open-webui/open-webui.git
synced 2026-04-30 17:28:51 -05:00
refac: scim
This commit is contained in:
@@ -572,6 +572,14 @@ ENABLE_SCIM = (
|
||||
== "true"
|
||||
)
|
||||
SCIM_TOKEN = os.environ.get("SCIM_TOKEN", "")
|
||||
SCIM_AUTH_PROVIDER = os.environ.get("SCIM_AUTH_PROVIDER", "")
|
||||
|
||||
if ENABLE_SCIM and not SCIM_AUTH_PROVIDER:
|
||||
log.warning(
|
||||
"SCIM is enabled but SCIM_AUTH_PROVIDER is not set. "
|
||||
"Set SCIM_AUTH_PROVIDER to the OAuth provider name (e.g. 'microsoft', 'oidc') "
|
||||
"to enable externalId storage."
|
||||
)
|
||||
|
||||
####################################
|
||||
# LICENSE_KEY
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
"""add scim column to user table
|
||||
|
||||
Revision ID: b2c3d4e5f6a7
|
||||
Revises: a1b2c3d4e5f6
|
||||
Create Date: 2026-02-13 14:19:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "b2c3d4e5f6a7"
|
||||
down_revision: Union[str, None] = "a1b2c3d4e5f6"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("user", sa.Column("scim", sa.JSON(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "scim")
|
||||
@@ -71,6 +71,7 @@ class User(Base):
|
||||
settings = Column(JSON, nullable=True)
|
||||
|
||||
oauth = Column(JSON, nullable=True)
|
||||
scim = Column(JSON, nullable=True)
|
||||
|
||||
last_active_at = Column(BigInteger)
|
||||
updated_at = Column(BigInteger)
|
||||
@@ -103,6 +104,7 @@ class UserModel(BaseModel):
|
||||
settings: Optional[UserSettings] = None
|
||||
|
||||
oauth: Optional[dict] = None
|
||||
scim: Optional[dict] = None
|
||||
|
||||
last_active_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
@@ -351,6 +353,31 @@ class UsersTable:
|
||||
# You may want to log the exception here
|
||||
return None
|
||||
|
||||
def get_user_by_scim_external_id(
|
||||
self, provider: str, external_id: str, db: Optional[Session] = None
|
||||
) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db_context(db) as db: # type: Session
|
||||
dialect_name = db.bind.dialect.name
|
||||
|
||||
query = db.query(User)
|
||||
if dialect_name == "sqlite":
|
||||
query = query.filter(
|
||||
User.scim.contains(
|
||||
{provider: {"external_id": external_id}}
|
||||
)
|
||||
)
|
||||
elif dialect_name == "postgresql":
|
||||
query = query.filter(
|
||||
User.scim[provider].cast(JSONB)["external_id"].astext
|
||||
== external_id
|
||||
)
|
||||
|
||||
user = query.first()
|
||||
return UserModel.model_validate(user) if user else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_users(
|
||||
self,
|
||||
filter: Optional[dict] = None,
|
||||
@@ -646,6 +673,38 @@ class UsersTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_user_scim_by_id(
|
||||
self,
|
||||
id: str,
|
||||
provider: str,
|
||||
external_id: str,
|
||||
db: Optional[Session] = None,
|
||||
) -> Optional[UserModel]:
|
||||
"""
|
||||
Update or insert a SCIM provider/external_id pair into the user's scim JSON field.
|
||||
Example resulting structure:
|
||||
{
|
||||
"microsoft": { "external_id": "abc" },
|
||||
"okta": { "external_id": "def" }
|
||||
}
|
||||
"""
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
if not user:
|
||||
return None
|
||||
|
||||
scim = user.scim or {}
|
||||
scim[provider] = {"external_id": external_id}
|
||||
|
||||
db.query(User).filter_by(id=id).update({"scim": scim})
|
||||
db.commit()
|
||||
|
||||
return UserModel.model_validate(user)
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_user_by_id(
|
||||
self, id: str, updated: dict, db: Optional[Session] = None
|
||||
) -> Optional[UserModel]:
|
||||
|
||||
@@ -25,6 +25,9 @@ from open_webui.utils.auth import (
|
||||
)
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
|
||||
from open_webui.config import OAUTH_PROVIDERS
|
||||
from open_webui.env import SCIM_AUTH_PROVIDER
|
||||
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from open_webui.internal.db import get_session
|
||||
@@ -300,6 +303,45 @@ def get_scim_auth(
|
||||
)
|
||||
|
||||
|
||||
def get_external_id(user: UserModel) -> Optional[str]:
|
||||
"""Extract externalId from a user's scim data.
|
||||
|
||||
Checks all stored provider entries and returns the first external_id found.
|
||||
"""
|
||||
if not user.scim:
|
||||
return None
|
||||
for provider_data in user.scim.values():
|
||||
if isinstance(provider_data, dict) and "external_id" in provider_data:
|
||||
return provider_data["external_id"]
|
||||
return None
|
||||
|
||||
|
||||
def get_scim_provider() -> str:
|
||||
"""Return the configured SCIM auth provider.
|
||||
|
||||
Requires SCIM_AUTH_PROVIDER env var to be set (e.g. 'microsoft', 'oidc').
|
||||
"""
|
||||
if not SCIM_AUTH_PROVIDER:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="SCIM_AUTH_PROVIDER environment variable is required when SCIM is enabled",
|
||||
)
|
||||
return SCIM_AUTH_PROVIDER
|
||||
|
||||
|
||||
def find_user_by_external_id(
|
||||
external_id: str, db=None
|
||||
) -> Optional[UserModel]:
|
||||
"""Find a user by SCIM externalId, falling back to OAuth sub match."""
|
||||
provider = get_scim_provider()
|
||||
user = Users.get_user_by_scim_external_id(provider, external_id, db=db)
|
||||
if user:
|
||||
return user
|
||||
|
||||
# Fallback: check if externalId matches an existing OAuth sub (account linking)
|
||||
return Users.get_user_by_oauth_sub(provider, external_id, db=db)
|
||||
|
||||
|
||||
def user_to_scim(user: UserModel, request: Request, db=None) -> SCIMUser:
|
||||
"""Convert internal User model to SCIM User"""
|
||||
# Parse display name into name components
|
||||
@@ -321,6 +363,7 @@ def user_to_scim(user: UserModel, request: Request, db=None) -> SCIMUser:
|
||||
|
||||
return SCIMUser(
|
||||
id=user.id,
|
||||
externalId=get_external_id(user),
|
||||
userName=user.email,
|
||||
name=SCIMName(
|
||||
formatted=user.name,
|
||||
@@ -494,13 +537,17 @@ async def get_users(
|
||||
|
||||
# Get users from database
|
||||
if filter:
|
||||
# Simple filter parsing - supports userName eq "email"
|
||||
# In production, you'd want a more robust filter parser
|
||||
# Simple filter parsing - supports userName eq, externalId eq
|
||||
if "userName eq" in filter:
|
||||
email = filter.split('"')[1]
|
||||
user = Users.get_user_by_email(email, db=db)
|
||||
users_list = [user] if user else []
|
||||
total = 1 if user else 0
|
||||
elif "externalId eq" in filter:
|
||||
external_id = filter.split('"')[1]
|
||||
user = find_user_by_external_id(external_id, db=db)
|
||||
users_list = [user] if user else []
|
||||
total = 1 if user else 0
|
||||
else:
|
||||
response = Users.get_users(skip=skip, limit=limit, db=db)
|
||||
users_list = response["users"]
|
||||
@@ -546,17 +593,33 @@ async def create_user(
|
||||
db: Session = Depends(get_session),
|
||||
):
|
||||
"""Create SCIM User"""
|
||||
# Check if user already exists
|
||||
existing_user = Users.get_user_by_email(user_data.userName, db=db)
|
||||
# Check for duplicate by externalId
|
||||
if user_data.externalId:
|
||||
existing_user = find_user_by_external_id(user_data.externalId, db=db)
|
||||
if existing_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"User with externalId {user_data.externalId} already exists",
|
||||
)
|
||||
|
||||
# Determine primary email (lowercased per RFC 5321)
|
||||
email = user_data.userName
|
||||
for entry in user_data.emails:
|
||||
if entry.primary:
|
||||
email = entry.value
|
||||
break
|
||||
email = email.lower()
|
||||
|
||||
# Check for duplicate by email
|
||||
existing_user = Users.get_user_by_email(email, db=db)
|
||||
if existing_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"User with email {user_data.userName} already exists",
|
||||
detail=f"User with email {email} already exists",
|
||||
)
|
||||
|
||||
# Create user
|
||||
user_id = str(uuid.uuid4())
|
||||
email = user_data.emails[0].value if user_data.emails else user_data.userName
|
||||
|
||||
# Parse name if provided
|
||||
name = user_data.displayName
|
||||
@@ -571,7 +634,6 @@ async def create_user(
|
||||
if user_data.photos and len(user_data.photos) > 0:
|
||||
profile_image = user_data.photos[0].value
|
||||
|
||||
# Create user
|
||||
new_user = Users.insert_new_user(
|
||||
id=user_id,
|
||||
name=name,
|
||||
@@ -587,6 +649,14 @@ async def create_user(
|
||||
detail="Failed to create user",
|
||||
)
|
||||
|
||||
# Store externalId in the scim field
|
||||
if user_data.externalId:
|
||||
provider = get_scim_provider()
|
||||
Users.update_user_scim_by_id(
|
||||
user_id, provider, user_data.externalId, db=db
|
||||
)
|
||||
new_user = Users.get_user_by_id(user_id, db=db)
|
||||
|
||||
return user_to_scim(new_user, request, db=db)
|
||||
|
||||
|
||||
@@ -631,7 +701,6 @@ async def update_user(
|
||||
if user_data.photos and len(user_data.photos) > 0:
|
||||
update_data["profile_image_url"] = user_data.photos[0].value
|
||||
|
||||
# Update user
|
||||
updated_user = Users.update_user_by_id(user_id, update_data, db=db)
|
||||
if not updated_user:
|
||||
raise HTTPException(
|
||||
@@ -639,6 +708,14 @@ async def update_user(
|
||||
detail="Failed to update user",
|
||||
)
|
||||
|
||||
# Update externalId in the scim field
|
||||
if user_data.externalId:
|
||||
provider = get_scim_provider()
|
||||
Users.update_user_scim_by_id(
|
||||
user_id, provider, user_data.externalId, db=db
|
||||
)
|
||||
updated_user = Users.get_user_by_id(user_id, db=db)
|
||||
|
||||
return user_to_scim(updated_user, request, db=db)
|
||||
|
||||
|
||||
@@ -676,6 +753,11 @@ async def patch_user(
|
||||
update_data["email"] = value
|
||||
elif path == "name.formatted":
|
||||
update_data["name"] = value
|
||||
elif path == "externalId":
|
||||
provider = get_scim_provider()
|
||||
Users.update_user_scim_by_id(
|
||||
user_id, provider, value, db=db
|
||||
)
|
||||
|
||||
# Update user
|
||||
if update_data:
|
||||
|
||||
Reference in New Issue
Block a user