mirror of
https://github.com/open-webui/open-webui.git
synced 2026-05-03 18:59:38 -05:00
refac
This commit is contained in:
@@ -3,7 +3,7 @@ from test.util.mock_user import mock_webui_user
|
||||
|
||||
|
||||
class TestAuths(AbstractPostgresTest):
|
||||
BASE_PATH = "/api/v1/auths"
|
||||
BASE_PATH = '/api/v1/auths'
|
||||
|
||||
def setup_class(cls):
|
||||
super().setup_class()
|
||||
@@ -15,171 +15,167 @@ class TestAuths(AbstractPostgresTest):
|
||||
|
||||
def test_get_session_user(self):
|
||||
with mock_webui_user():
|
||||
response = self.fast_api_client.get(self.create_url(""))
|
||||
response = self.fast_api_client.get(self.create_url(''))
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"id": "1",
|
||||
"name": "John Doe",
|
||||
"email": "john.doe@openwebui.com",
|
||||
"role": "user",
|
||||
"profile_image_url": "/user.png",
|
||||
'id': '1',
|
||||
'name': 'John Doe',
|
||||
'email': 'john.doe@openwebui.com',
|
||||
'role': 'user',
|
||||
'profile_image_url': '/user.png',
|
||||
}
|
||||
|
||||
def test_update_profile(self):
|
||||
from open_webui.utils.auth import get_password_hash
|
||||
|
||||
user = self.auths.insert_new_auth(
|
||||
email="john.doe@openwebui.com",
|
||||
password=get_password_hash("old_password"),
|
||||
name="John Doe",
|
||||
profile_image_url="/user.png",
|
||||
role="user",
|
||||
email='john.doe@openwebui.com',
|
||||
password=get_password_hash('old_password'),
|
||||
name='John Doe',
|
||||
profile_image_url='/user.png',
|
||||
role='user',
|
||||
)
|
||||
|
||||
with mock_webui_user(id=user.id):
|
||||
response = self.fast_api_client.post(
|
||||
self.create_url("/update/profile"),
|
||||
json={"name": "John Doe 2", "profile_image_url": "/user2.png"},
|
||||
self.create_url('/update/profile'),
|
||||
json={'name': 'John Doe 2', 'profile_image_url': '/user2.png'},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
db_user = self.users.get_user_by_id(user.id)
|
||||
assert db_user.name == "John Doe 2"
|
||||
assert db_user.profile_image_url == "/user2.png"
|
||||
assert db_user.name == 'John Doe 2'
|
||||
assert db_user.profile_image_url == '/user2.png'
|
||||
|
||||
def test_update_password(self):
|
||||
from open_webui.utils.auth import get_password_hash
|
||||
|
||||
user = self.auths.insert_new_auth(
|
||||
email="john.doe@openwebui.com",
|
||||
password=get_password_hash("old_password"),
|
||||
name="John Doe",
|
||||
profile_image_url="/user.png",
|
||||
role="user",
|
||||
email='john.doe@openwebui.com',
|
||||
password=get_password_hash('old_password'),
|
||||
name='John Doe',
|
||||
profile_image_url='/user.png',
|
||||
role='user',
|
||||
)
|
||||
|
||||
with mock_webui_user(id=user.id):
|
||||
response = self.fast_api_client.post(
|
||||
self.create_url("/update/password"),
|
||||
json={"password": "old_password", "new_password": "new_password"},
|
||||
self.create_url('/update/password'),
|
||||
json={'password': 'old_password', 'new_password': 'new_password'},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
old_auth = self.auths.authenticate_user(
|
||||
"john.doe@openwebui.com", "old_password"
|
||||
)
|
||||
old_auth = self.auths.authenticate_user('john.doe@openwebui.com', 'old_password')
|
||||
assert old_auth is None
|
||||
new_auth = self.auths.authenticate_user(
|
||||
"john.doe@openwebui.com", "new_password"
|
||||
)
|
||||
new_auth = self.auths.authenticate_user('john.doe@openwebui.com', 'new_password')
|
||||
assert new_auth is not None
|
||||
|
||||
def test_signin(self):
|
||||
from open_webui.utils.auth import get_password_hash
|
||||
|
||||
user = self.auths.insert_new_auth(
|
||||
email="john.doe@openwebui.com",
|
||||
password=get_password_hash("password"),
|
||||
name="John Doe",
|
||||
profile_image_url="/user.png",
|
||||
role="user",
|
||||
email='john.doe@openwebui.com',
|
||||
password=get_password_hash('password'),
|
||||
name='John Doe',
|
||||
profile_image_url='/user.png',
|
||||
role='user',
|
||||
)
|
||||
response = self.fast_api_client.post(
|
||||
self.create_url("/signin"),
|
||||
json={"email": "john.doe@openwebui.com", "password": "password"},
|
||||
self.create_url('/signin'),
|
||||
json={'email': 'john.doe@openwebui.com', 'password': 'password'},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == user.id
|
||||
assert data["name"] == "John Doe"
|
||||
assert data["email"] == "john.doe@openwebui.com"
|
||||
assert data["role"] == "user"
|
||||
assert data["profile_image_url"] == "/user.png"
|
||||
assert data["token"] is not None and len(data["token"]) > 0
|
||||
assert data["token_type"] == "Bearer"
|
||||
assert data['id'] == user.id
|
||||
assert data['name'] == 'John Doe'
|
||||
assert data['email'] == 'john.doe@openwebui.com'
|
||||
assert data['role'] == 'user'
|
||||
assert data['profile_image_url'] == '/user.png'
|
||||
assert data['token'] is not None and len(data['token']) > 0
|
||||
assert data['token_type'] == 'Bearer'
|
||||
|
||||
def test_signup(self):
|
||||
response = self.fast_api_client.post(
|
||||
self.create_url("/signup"),
|
||||
self.create_url('/signup'),
|
||||
json={
|
||||
"name": "John Doe",
|
||||
"email": "john.doe@openwebui.com",
|
||||
"password": "password",
|
||||
'name': 'John Doe',
|
||||
'email': 'john.doe@openwebui.com',
|
||||
'password': 'password',
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] is not None and len(data["id"]) > 0
|
||||
assert data["name"] == "John Doe"
|
||||
assert data["email"] == "john.doe@openwebui.com"
|
||||
assert data["role"] in ["admin", "user", "pending"]
|
||||
assert data["profile_image_url"] == "/user.png"
|
||||
assert data["token"] is not None and len(data["token"]) > 0
|
||||
assert data["token_type"] == "Bearer"
|
||||
assert data['id'] is not None and len(data['id']) > 0
|
||||
assert data['name'] == 'John Doe'
|
||||
assert data['email'] == 'john.doe@openwebui.com'
|
||||
assert data['role'] in ['admin', 'user', 'pending']
|
||||
assert data['profile_image_url'] == '/user.png'
|
||||
assert data['token'] is not None and len(data['token']) > 0
|
||||
assert data['token_type'] == 'Bearer'
|
||||
|
||||
def test_add_user(self):
|
||||
with mock_webui_user():
|
||||
response = self.fast_api_client.post(
|
||||
self.create_url("/add"),
|
||||
self.create_url('/add'),
|
||||
json={
|
||||
"name": "John Doe 2",
|
||||
"email": "john.doe2@openwebui.com",
|
||||
"password": "password2",
|
||||
"role": "admin",
|
||||
'name': 'John Doe 2',
|
||||
'email': 'john.doe2@openwebui.com',
|
||||
'password': 'password2',
|
||||
'role': 'admin',
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] is not None and len(data["id"]) > 0
|
||||
assert data["name"] == "John Doe 2"
|
||||
assert data["email"] == "john.doe2@openwebui.com"
|
||||
assert data["role"] == "admin"
|
||||
assert data["profile_image_url"] == "/user.png"
|
||||
assert data["token"] is not None and len(data["token"]) > 0
|
||||
assert data["token_type"] == "Bearer"
|
||||
assert data['id'] is not None and len(data['id']) > 0
|
||||
assert data['name'] == 'John Doe 2'
|
||||
assert data['email'] == 'john.doe2@openwebui.com'
|
||||
assert data['role'] == 'admin'
|
||||
assert data['profile_image_url'] == '/user.png'
|
||||
assert data['token'] is not None and len(data['token']) > 0
|
||||
assert data['token_type'] == 'Bearer'
|
||||
|
||||
def test_get_admin_details(self):
|
||||
self.auths.insert_new_auth(
|
||||
email="john.doe@openwebui.com",
|
||||
password="password",
|
||||
name="John Doe",
|
||||
profile_image_url="/user.png",
|
||||
role="admin",
|
||||
email='john.doe@openwebui.com',
|
||||
password='password',
|
||||
name='John Doe',
|
||||
profile_image_url='/user.png',
|
||||
role='admin',
|
||||
)
|
||||
with mock_webui_user():
|
||||
response = self.fast_api_client.get(self.create_url("/admin/details"))
|
||||
response = self.fast_api_client.get(self.create_url('/admin/details'))
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"name": "John Doe",
|
||||
"email": "john.doe@openwebui.com",
|
||||
'name': 'John Doe',
|
||||
'email': 'john.doe@openwebui.com',
|
||||
}
|
||||
|
||||
def test_create_api_key_(self):
|
||||
user = self.auths.insert_new_auth(
|
||||
email="john.doe@openwebui.com",
|
||||
password="password",
|
||||
name="John Doe",
|
||||
profile_image_url="/user.png",
|
||||
role="admin",
|
||||
email='john.doe@openwebui.com',
|
||||
password='password',
|
||||
name='John Doe',
|
||||
profile_image_url='/user.png',
|
||||
role='admin',
|
||||
)
|
||||
with mock_webui_user(id=user.id):
|
||||
response = self.fast_api_client.post(self.create_url("/api_key"))
|
||||
response = self.fast_api_client.post(self.create_url('/api_key'))
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["api_key"] is not None
|
||||
assert len(data["api_key"]) > 0
|
||||
assert data['api_key'] is not None
|
||||
assert len(data['api_key']) > 0
|
||||
|
||||
def test_delete_api_key(self):
|
||||
user = self.auths.insert_new_auth(
|
||||
email="john.doe@openwebui.com",
|
||||
password="password",
|
||||
name="John Doe",
|
||||
profile_image_url="/user.png",
|
||||
role="admin",
|
||||
email='john.doe@openwebui.com',
|
||||
password='password',
|
||||
name='John Doe',
|
||||
profile_image_url='/user.png',
|
||||
role='admin',
|
||||
)
|
||||
self.users.update_user_api_key_by_id(user.id, "abc")
|
||||
self.users.update_user_api_key_by_id(user.id, 'abc')
|
||||
with mock_webui_user(id=user.id):
|
||||
response = self.fast_api_client.delete(self.create_url("/api_key"))
|
||||
response = self.fast_api_client.delete(self.create_url('/api_key'))
|
||||
assert response.status_code == 200
|
||||
assert response.json() == True
|
||||
db_user = self.users.get_user_by_id(user.id)
|
||||
@@ -187,14 +183,14 @@ class TestAuths(AbstractPostgresTest):
|
||||
|
||||
def test_get_api_key(self):
|
||||
user = self.auths.insert_new_auth(
|
||||
email="john.doe@openwebui.com",
|
||||
password="password",
|
||||
name="John Doe",
|
||||
profile_image_url="/user.png",
|
||||
role="admin",
|
||||
email='john.doe@openwebui.com',
|
||||
password='password',
|
||||
name='John Doe',
|
||||
profile_image_url='/user.png',
|
||||
role='admin',
|
||||
)
|
||||
self.users.update_user_api_key_by_id(user.id, "abc")
|
||||
self.users.update_user_api_key_by_id(user.id, 'abc')
|
||||
with mock_webui_user(id=user.id):
|
||||
response = self.fast_api_client.get(self.create_url("/api_key"))
|
||||
response = self.fast_api_client.get(self.create_url('/api_key'))
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"api_key": "abc"}
|
||||
assert response.json() == {'api_key': 'abc'}
|
||||
|
||||
@@ -3,7 +3,7 @@ from test.util.mock_user import mock_webui_user
|
||||
|
||||
|
||||
class TestModels(AbstractPostgresTest):
|
||||
BASE_PATH = "/api/v1/models"
|
||||
BASE_PATH = '/api/v1/models'
|
||||
|
||||
def setup_class(cls):
|
||||
super().setup_class()
|
||||
@@ -12,50 +12,46 @@ class TestModels(AbstractPostgresTest):
|
||||
cls.models = Model
|
||||
|
||||
def test_models(self):
|
||||
with mock_webui_user(id="2"):
|
||||
response = self.fast_api_client.get(self.create_url("/"))
|
||||
with mock_webui_user(id='2'):
|
||||
response = self.fast_api_client.get(self.create_url('/'))
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 0
|
||||
|
||||
with mock_webui_user(id="2"):
|
||||
with mock_webui_user(id='2'):
|
||||
response = self.fast_api_client.post(
|
||||
self.create_url("/add"),
|
||||
self.create_url('/add'),
|
||||
json={
|
||||
"id": "my-model",
|
||||
"base_model_id": "base-model-id",
|
||||
"name": "Hello World",
|
||||
"meta": {
|
||||
"profile_image_url": "/static/favicon.png",
|
||||
"description": "description",
|
||||
"capabilities": None,
|
||||
"model_config": {},
|
||||
'id': 'my-model',
|
||||
'base_model_id': 'base-model-id',
|
||||
'name': 'Hello World',
|
||||
'meta': {
|
||||
'profile_image_url': '/static/favicon.png',
|
||||
'description': 'description',
|
||||
'capabilities': None,
|
||||
'model_config': {},
|
||||
},
|
||||
"params": {},
|
||||
'params': {},
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
with mock_webui_user(id="2"):
|
||||
response = self.fast_api_client.get(self.create_url("/"))
|
||||
with mock_webui_user(id='2'):
|
||||
response = self.fast_api_client.get(self.create_url('/'))
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 1
|
||||
|
||||
with mock_webui_user(id="2"):
|
||||
response = self.fast_api_client.get(
|
||||
self.create_url(query_params={"id": "my-model"})
|
||||
)
|
||||
with mock_webui_user(id='2'):
|
||||
response = self.fast_api_client.get(self.create_url(query_params={'id': 'my-model'}))
|
||||
assert response.status_code == 200
|
||||
data = response.json()[0]
|
||||
assert data["id"] == "my-model"
|
||||
assert data["name"] == "Hello World"
|
||||
assert data['id'] == 'my-model'
|
||||
assert data['name'] == 'Hello World'
|
||||
|
||||
with mock_webui_user(id="2"):
|
||||
response = self.fast_api_client.delete(
|
||||
self.create_url("/delete?id=my-model")
|
||||
)
|
||||
with mock_webui_user(id='2'):
|
||||
response = self.fast_api_client.delete(self.create_url('/delete?id=my-model'))
|
||||
assert response.status_code == 200
|
||||
|
||||
with mock_webui_user(id="2"):
|
||||
response = self.fast_api_client.get(self.create_url("/"))
|
||||
with mock_webui_user(id='2'):
|
||||
response = self.fast_api_client.get(self.create_url('/'))
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 0
|
||||
|
||||
@@ -3,17 +3,17 @@ from test.util.mock_user import mock_webui_user
|
||||
|
||||
|
||||
def _get_user_by_id(data, param):
|
||||
return next((item for item in data if item["id"] == param), None)
|
||||
return next((item for item in data if item['id'] == param), None)
|
||||
|
||||
|
||||
def _assert_user(data, id, **kwargs):
|
||||
user = _get_user_by_id(data, id)
|
||||
assert user is not None
|
||||
comparison_data = {
|
||||
"name": f"user {id}",
|
||||
"email": f"user{id}@openwebui.com",
|
||||
"profile_image_url": f"/api/v1/users/{id}/profile/image",
|
||||
"role": "user",
|
||||
'name': f'user {id}',
|
||||
'email': f'user{id}@openwebui.com',
|
||||
'profile_image_url': f'/api/v1/users/{id}/profile/image',
|
||||
'role': 'user',
|
||||
**kwargs,
|
||||
}
|
||||
for key, value in comparison_data.items():
|
||||
@@ -21,7 +21,7 @@ def _assert_user(data, id, **kwargs):
|
||||
|
||||
|
||||
class TestUsers(AbstractPostgresTest):
|
||||
BASE_PATH = "/api/v1/users"
|
||||
BASE_PATH = '/api/v1/users'
|
||||
|
||||
def setup_class(cls):
|
||||
super().setup_class()
|
||||
@@ -32,136 +32,134 @@ class TestUsers(AbstractPostgresTest):
|
||||
def setup_method(self):
|
||||
super().setup_method()
|
||||
self.users.insert_new_user(
|
||||
id="1",
|
||||
name="user 1",
|
||||
email="user1@openwebui.com",
|
||||
profile_image_url="/user1.png",
|
||||
role="user",
|
||||
id='1',
|
||||
name='user 1',
|
||||
email='user1@openwebui.com',
|
||||
profile_image_url='/user1.png',
|
||||
role='user',
|
||||
)
|
||||
self.users.insert_new_user(
|
||||
id="2",
|
||||
name="user 2",
|
||||
email="user2@openwebui.com",
|
||||
profile_image_url="/user2.png",
|
||||
role="user",
|
||||
id='2',
|
||||
name='user 2',
|
||||
email='user2@openwebui.com',
|
||||
profile_image_url='/user2.png',
|
||||
role='user',
|
||||
)
|
||||
|
||||
def test_users(self):
|
||||
# Get all users
|
||||
with mock_webui_user(id="3"):
|
||||
response = self.fast_api_client.get(self.create_url(""))
|
||||
with mock_webui_user(id='3'):
|
||||
response = self.fast_api_client.get(self.create_url(''))
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 2
|
||||
data = response.json()
|
||||
_assert_user(data, "1")
|
||||
_assert_user(data, "2")
|
||||
_assert_user(data, '1')
|
||||
_assert_user(data, '2')
|
||||
|
||||
# update role
|
||||
with mock_webui_user(id="3"):
|
||||
response = self.fast_api_client.post(
|
||||
self.create_url("/update/role"), json={"id": "2", "role": "admin"}
|
||||
)
|
||||
with mock_webui_user(id='3'):
|
||||
response = self.fast_api_client.post(self.create_url('/update/role'), json={'id': '2', 'role': 'admin'})
|
||||
assert response.status_code == 200
|
||||
_assert_user([response.json()], "2", role="admin")
|
||||
_assert_user([response.json()], '2', role='admin')
|
||||
|
||||
# Get all users
|
||||
with mock_webui_user(id="3"):
|
||||
response = self.fast_api_client.get(self.create_url(""))
|
||||
with mock_webui_user(id='3'):
|
||||
response = self.fast_api_client.get(self.create_url(''))
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 2
|
||||
data = response.json()
|
||||
_assert_user(data, "1")
|
||||
_assert_user(data, "2", role="admin")
|
||||
_assert_user(data, '1')
|
||||
_assert_user(data, '2', role='admin')
|
||||
|
||||
# Get (empty) user settings
|
||||
with mock_webui_user(id="2"):
|
||||
response = self.fast_api_client.get(self.create_url("/user/settings"))
|
||||
with mock_webui_user(id='2'):
|
||||
response = self.fast_api_client.get(self.create_url('/user/settings'))
|
||||
assert response.status_code == 200
|
||||
assert response.json() is None
|
||||
|
||||
# Update user settings
|
||||
with mock_webui_user(id="2"):
|
||||
with mock_webui_user(id='2'):
|
||||
response = self.fast_api_client.post(
|
||||
self.create_url("/user/settings/update"),
|
||||
self.create_url('/user/settings/update'),
|
||||
json={
|
||||
"ui": {"attr1": "value1", "attr2": "value2"},
|
||||
"model_config": {"attr3": "value3", "attr4": "value4"},
|
||||
'ui': {'attr1': 'value1', 'attr2': 'value2'},
|
||||
'model_config': {'attr3': 'value3', 'attr4': 'value4'},
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Get user settings
|
||||
with mock_webui_user(id="2"):
|
||||
response = self.fast_api_client.get(self.create_url("/user/settings"))
|
||||
with mock_webui_user(id='2'):
|
||||
response = self.fast_api_client.get(self.create_url('/user/settings'))
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"ui": {"attr1": "value1", "attr2": "value2"},
|
||||
"model_config": {"attr3": "value3", "attr4": "value4"},
|
||||
'ui': {'attr1': 'value1', 'attr2': 'value2'},
|
||||
'model_config': {'attr3': 'value3', 'attr4': 'value4'},
|
||||
}
|
||||
|
||||
# Get (empty) user info
|
||||
with mock_webui_user(id="1"):
|
||||
response = self.fast_api_client.get(self.create_url("/user/info"))
|
||||
with mock_webui_user(id='1'):
|
||||
response = self.fast_api_client.get(self.create_url('/user/info'))
|
||||
assert response.status_code == 200
|
||||
assert response.json() is None
|
||||
|
||||
# Update user info
|
||||
with mock_webui_user(id="1"):
|
||||
with mock_webui_user(id='1'):
|
||||
response = self.fast_api_client.post(
|
||||
self.create_url("/user/info/update"),
|
||||
json={"attr1": "value1", "attr2": "value2"},
|
||||
self.create_url('/user/info/update'),
|
||||
json={'attr1': 'value1', 'attr2': 'value2'},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Get user info
|
||||
with mock_webui_user(id="1"):
|
||||
response = self.fast_api_client.get(self.create_url("/user/info"))
|
||||
with mock_webui_user(id='1'):
|
||||
response = self.fast_api_client.get(self.create_url('/user/info'))
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"attr1": "value1", "attr2": "value2"}
|
||||
assert response.json() == {'attr1': 'value1', 'attr2': 'value2'}
|
||||
|
||||
# Get user by id
|
||||
with mock_webui_user(id="1"):
|
||||
response = self.fast_api_client.get(self.create_url("/2"))
|
||||
with mock_webui_user(id='1'):
|
||||
response = self.fast_api_client.get(self.create_url('/2'))
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"name": "user 2", "profile_image_url": "/user2.png"}
|
||||
assert response.json() == {'name': 'user 2', 'profile_image_url': '/user2.png'}
|
||||
|
||||
# Update user by id
|
||||
with mock_webui_user(id="1"):
|
||||
with mock_webui_user(id='1'):
|
||||
response = self.fast_api_client.post(
|
||||
self.create_url("/2/update"),
|
||||
self.create_url('/2/update'),
|
||||
json={
|
||||
"name": "user 2 updated",
|
||||
"email": "user2-updated@openwebui.com",
|
||||
"profile_image_url": "/user2-updated.png",
|
||||
'name': 'user 2 updated',
|
||||
'email': 'user2-updated@openwebui.com',
|
||||
'profile_image_url': '/user2-updated.png',
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Get all users
|
||||
with mock_webui_user(id="3"):
|
||||
response = self.fast_api_client.get(self.create_url(""))
|
||||
with mock_webui_user(id='3'):
|
||||
response = self.fast_api_client.get(self.create_url(''))
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 2
|
||||
data = response.json()
|
||||
_assert_user(data, "1")
|
||||
_assert_user(data, '1')
|
||||
_assert_user(
|
||||
data,
|
||||
"2",
|
||||
role="admin",
|
||||
name="user 2 updated",
|
||||
email="user2-updated@openwebui.com",
|
||||
profile_image_url=f"/api/v1/users/2/profile/image",
|
||||
'2',
|
||||
role='admin',
|
||||
name='user 2 updated',
|
||||
email='user2-updated@openwebui.com',
|
||||
profile_image_url=f'/api/v1/users/2/profile/image',
|
||||
)
|
||||
|
||||
# Delete user by id
|
||||
with mock_webui_user(id="1"):
|
||||
response = self.fast_api_client.delete(self.create_url("/2"))
|
||||
with mock_webui_user(id='1'):
|
||||
response = self.fast_api_client.delete(self.create_url('/2'))
|
||||
assert response.status_code == 200
|
||||
|
||||
# Get all users
|
||||
with mock_webui_user(id="3"):
|
||||
response = self.fast_api_client.get(self.create_url(""))
|
||||
with mock_webui_user(id='3'):
|
||||
response = self.fast_api_client.get(self.create_url(''))
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 1
|
||||
data = response.json()
|
||||
_assert_user(data, "1")
|
||||
_assert_user(data, '1')
|
||||
|
||||
@@ -13,9 +13,9 @@ from unittest.mock import MagicMock
|
||||
|
||||
def mock_upload_dir(monkeypatch, tmp_path):
|
||||
"""Fixture to monkey-patch the UPLOAD_DIR and create a temporary directory."""
|
||||
directory = tmp_path / "uploads"
|
||||
directory = tmp_path / 'uploads'
|
||||
directory.mkdir()
|
||||
monkeypatch.setattr(provider, "UPLOAD_DIR", str(directory))
|
||||
monkeypatch.setattr(provider, 'UPLOAD_DIR', str(directory))
|
||||
return directory
|
||||
|
||||
|
||||
@@ -29,16 +29,16 @@ def test_imports():
|
||||
|
||||
|
||||
def test_get_storage_provider():
|
||||
Storage = provider.get_storage_provider("local")
|
||||
Storage = provider.get_storage_provider('local')
|
||||
assert isinstance(Storage, provider.LocalStorageProvider)
|
||||
Storage = provider.get_storage_provider("s3")
|
||||
Storage = provider.get_storage_provider('s3')
|
||||
assert isinstance(Storage, provider.S3StorageProvider)
|
||||
Storage = provider.get_storage_provider("gcs")
|
||||
Storage = provider.get_storage_provider('gcs')
|
||||
assert isinstance(Storage, provider.GCSStorageProvider)
|
||||
Storage = provider.get_storage_provider("azure")
|
||||
Storage = provider.get_storage_provider('azure')
|
||||
assert isinstance(Storage, provider.AzureStorageProvider)
|
||||
with pytest.raises(RuntimeError):
|
||||
provider.get_storage_provider("invalid")
|
||||
provider.get_storage_provider('invalid')
|
||||
|
||||
|
||||
def test_class_instantiation():
|
||||
@@ -58,10 +58,10 @@ def test_class_instantiation():
|
||||
|
||||
class TestLocalStorageProvider:
|
||||
Storage = provider.LocalStorageProvider()
|
||||
file_content = b"test content"
|
||||
file_content = b'test content'
|
||||
file_bytesio = io.BytesIO(file_content)
|
||||
filename = "test.txt"
|
||||
filename_extra = "test_exyta.txt"
|
||||
filename = 'test.txt'
|
||||
filename_extra = 'test_exyta.txt'
|
||||
file_bytesio_empty = io.BytesIO()
|
||||
|
||||
def test_upload_file(self, monkeypatch, tmp_path):
|
||||
@@ -99,14 +99,13 @@ class TestLocalStorageProvider:
|
||||
|
||||
@mock_aws
|
||||
class TestS3StorageProvider:
|
||||
|
||||
def __init__(self):
|
||||
self.Storage = provider.S3StorageProvider()
|
||||
self.Storage.bucket_name = "my-bucket"
|
||||
self.s3_client = boto3.resource("s3", region_name="us-east-1")
|
||||
self.file_content = b"test content"
|
||||
self.filename = "test.txt"
|
||||
self.filename_extra = "test_exyta.txt"
|
||||
self.Storage.bucket_name = 'my-bucket'
|
||||
self.s3_client = boto3.resource('s3', region_name='us-east-1')
|
||||
self.file_content = b'test content'
|
||||
self.filename = 'test.txt'
|
||||
self.filename_extra = 'test_exyta.txt'
|
||||
self.file_bytesio_empty = io.BytesIO()
|
||||
super().__init__()
|
||||
|
||||
@@ -116,25 +115,21 @@ class TestS3StorageProvider:
|
||||
with pytest.raises(Exception):
|
||||
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
|
||||
self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
|
||||
contents, s3_file_path = self.Storage.upload_file(
|
||||
io.BytesIO(self.file_content), self.filename
|
||||
)
|
||||
contents, s3_file_path = self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
|
||||
object = self.s3_client.Object(self.Storage.bucket_name, self.filename)
|
||||
assert self.file_content == object.get()["Body"].read()
|
||||
assert self.file_content == object.get()['Body'].read()
|
||||
# local checks
|
||||
assert (upload_dir / self.filename).exists()
|
||||
assert (upload_dir / self.filename).read_bytes() == self.file_content
|
||||
assert contents == self.file_content
|
||||
assert s3_file_path == "s3://" + self.Storage.bucket_name + "/" + self.filename
|
||||
assert s3_file_path == 's3://' + self.Storage.bucket_name + '/' + self.filename
|
||||
with pytest.raises(ValueError):
|
||||
self.Storage.upload_file(self.file_bytesio_empty, self.filename)
|
||||
|
||||
def test_get_file(self, monkeypatch, tmp_path):
|
||||
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
|
||||
self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
|
||||
contents, s3_file_path = self.Storage.upload_file(
|
||||
io.BytesIO(self.file_content), self.filename
|
||||
)
|
||||
contents, s3_file_path = self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
|
||||
file_path = self.Storage.get_file(s3_file_path)
|
||||
assert file_path == str(upload_dir / self.filename)
|
||||
assert (upload_dir / self.filename).exists()
|
||||
@@ -142,17 +137,15 @@ class TestS3StorageProvider:
|
||||
def test_delete_file(self, monkeypatch, tmp_path):
|
||||
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
|
||||
self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
|
||||
contents, s3_file_path = self.Storage.upload_file(
|
||||
io.BytesIO(self.file_content), self.filename
|
||||
)
|
||||
contents, s3_file_path = self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
|
||||
assert (upload_dir / self.filename).exists()
|
||||
self.Storage.delete_file(s3_file_path)
|
||||
assert not (upload_dir / self.filename).exists()
|
||||
with pytest.raises(ClientError) as exc:
|
||||
self.s3_client.Object(self.Storage.bucket_name, self.filename).load()
|
||||
error = exc.value.response["Error"]
|
||||
assert error["Code"] == "404"
|
||||
assert error["Message"] == "Not Found"
|
||||
error = exc.value.response['Error']
|
||||
assert error['Code'] == '404'
|
||||
assert error['Message'] == 'Not Found'
|
||||
|
||||
def test_delete_all_files(self, monkeypatch, tmp_path):
|
||||
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
|
||||
@@ -160,12 +153,12 @@ class TestS3StorageProvider:
|
||||
self.s3_client.create_bucket(Bucket=self.Storage.bucket_name)
|
||||
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
|
||||
object = self.s3_client.Object(self.Storage.bucket_name, self.filename)
|
||||
assert self.file_content == object.get()["Body"].read()
|
||||
assert self.file_content == object.get()['Body'].read()
|
||||
assert (upload_dir / self.filename).exists()
|
||||
assert (upload_dir / self.filename).read_bytes() == self.file_content
|
||||
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename_extra)
|
||||
object = self.s3_client.Object(self.Storage.bucket_name, self.filename_extra)
|
||||
assert self.file_content == object.get()["Body"].read()
|
||||
assert self.file_content == object.get()['Body'].read()
|
||||
assert (upload_dir / self.filename).exists()
|
||||
assert (upload_dir / self.filename).read_bytes() == self.file_content
|
||||
|
||||
@@ -173,15 +166,15 @@ class TestS3StorageProvider:
|
||||
assert not (upload_dir / self.filename).exists()
|
||||
with pytest.raises(ClientError) as exc:
|
||||
self.s3_client.Object(self.Storage.bucket_name, self.filename).load()
|
||||
error = exc.value.response["Error"]
|
||||
assert error["Code"] == "404"
|
||||
assert error["Message"] == "Not Found"
|
||||
error = exc.value.response['Error']
|
||||
assert error['Code'] == '404'
|
||||
assert error['Message'] == 'Not Found'
|
||||
assert not (upload_dir / self.filename_extra).exists()
|
||||
with pytest.raises(ClientError) as exc:
|
||||
self.s3_client.Object(self.Storage.bucket_name, self.filename_extra).load()
|
||||
error = exc.value.response["Error"]
|
||||
assert error["Code"] == "404"
|
||||
assert error["Message"] == "Not Found"
|
||||
error = exc.value.response['Error']
|
||||
assert error['Code'] == '404'
|
||||
assert error['Message'] == 'Not Found'
|
||||
|
||||
self.Storage.delete_all_files()
|
||||
assert not (upload_dir / self.filename).exists()
|
||||
@@ -190,8 +183,8 @@ class TestS3StorageProvider:
|
||||
def test_init_without_credentials(self, monkeypatch):
|
||||
"""Test that S3StorageProvider can initialize without explicit credentials."""
|
||||
# Temporarily unset the environment variables
|
||||
monkeypatch.setattr(provider, "S3_ACCESS_KEY_ID", None)
|
||||
monkeypatch.setattr(provider, "S3_SECRET_ACCESS_KEY", None)
|
||||
monkeypatch.setattr(provider, 'S3_ACCESS_KEY_ID', None)
|
||||
monkeypatch.setattr(provider, 'S3_SECRET_ACCESS_KEY', None)
|
||||
|
||||
# Should not raise an exception
|
||||
storage = provider.S3StorageProvider()
|
||||
@@ -201,19 +194,19 @@ class TestS3StorageProvider:
|
||||
|
||||
class TestGCSStorageProvider:
|
||||
Storage = provider.GCSStorageProvider()
|
||||
Storage.bucket_name = "my-bucket"
|
||||
file_content = b"test content"
|
||||
filename = "test.txt"
|
||||
filename_extra = "test_exyta.txt"
|
||||
Storage.bucket_name = 'my-bucket'
|
||||
file_content = b'test content'
|
||||
filename = 'test.txt'
|
||||
filename_extra = 'test_exyta.txt'
|
||||
file_bytesio_empty = io.BytesIO()
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
@pytest.fixture(scope='class')
|
||||
def setup(self):
|
||||
host, port = "localhost", 9023
|
||||
host, port = 'localhost', 9023
|
||||
|
||||
server = create_server(host, port, in_memory=True)
|
||||
server.start()
|
||||
os.environ["STORAGE_EMULATOR_HOST"] = f"http://{host}:{port}"
|
||||
os.environ['STORAGE_EMULATOR_HOST'] = f'http://{host}:{port}'
|
||||
|
||||
gcs_client = storage.Client()
|
||||
bucket = gcs_client.bucket(self.Storage.bucket_name)
|
||||
@@ -227,36 +220,30 @@ class TestGCSStorageProvider:
|
||||
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
|
||||
# catch error if bucket does not exist
|
||||
with pytest.raises(Exception):
|
||||
self.Storage.bucket = monkeypatch(self.Storage, "bucket", None)
|
||||
self.Storage.bucket = monkeypatch(self.Storage, 'bucket', None)
|
||||
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
|
||||
contents, gcs_file_path = self.Storage.upload_file(
|
||||
io.BytesIO(self.file_content), self.filename
|
||||
)
|
||||
contents, gcs_file_path = self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
|
||||
object = self.Storage.bucket.get_blob(self.filename)
|
||||
assert self.file_content == object.download_as_bytes()
|
||||
# local checks
|
||||
assert (upload_dir / self.filename).exists()
|
||||
assert (upload_dir / self.filename).read_bytes() == self.file_content
|
||||
assert contents == self.file_content
|
||||
assert gcs_file_path == "gs://" + self.Storage.bucket_name + "/" + self.filename
|
||||
assert gcs_file_path == 'gs://' + self.Storage.bucket_name + '/' + self.filename
|
||||
# test error if file is empty
|
||||
with pytest.raises(ValueError):
|
||||
self.Storage.upload_file(self.file_bytesio_empty, self.filename)
|
||||
|
||||
def test_get_file(self, monkeypatch, tmp_path, setup):
|
||||
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
|
||||
contents, gcs_file_path = self.Storage.upload_file(
|
||||
io.BytesIO(self.file_content), self.filename
|
||||
)
|
||||
contents, gcs_file_path = self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
|
||||
file_path = self.Storage.get_file(gcs_file_path)
|
||||
assert file_path == str(upload_dir / self.filename)
|
||||
assert (upload_dir / self.filename).exists()
|
||||
|
||||
def test_delete_file(self, monkeypatch, tmp_path, setup):
|
||||
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
|
||||
contents, gcs_file_path = self.Storage.upload_file(
|
||||
io.BytesIO(self.file_content), self.filename
|
||||
)
|
||||
contents, gcs_file_path = self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
|
||||
# ensure that local directory has the uploaded file as well
|
||||
assert (upload_dir / self.filename).exists()
|
||||
assert self.Storage.bucket.get_blob(self.filename).name == self.filename
|
||||
@@ -278,10 +265,7 @@ class TestGCSStorageProvider:
|
||||
object = self.Storage.bucket.get_blob(self.filename_extra)
|
||||
assert (upload_dir / self.filename_extra).exists()
|
||||
assert (upload_dir / self.filename_extra).read_bytes() == self.file_content
|
||||
assert (
|
||||
self.Storage.bucket.get_blob(self.filename_extra).name
|
||||
== self.filename_extra
|
||||
)
|
||||
assert self.Storage.bucket.get_blob(self.filename_extra).name == self.filename_extra
|
||||
assert self.file_content == object.download_as_bytes()
|
||||
|
||||
self.Storage.delete_all_files()
|
||||
@@ -295,7 +279,7 @@ class TestAzureStorageProvider:
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
@pytest.fixture(scope='class')
|
||||
def setup_storage(self, monkeypatch):
|
||||
# Create mock Blob Service Client and related clients
|
||||
mock_blob_service_client = MagicMock()
|
||||
@@ -303,32 +287,28 @@ class TestAzureStorageProvider:
|
||||
mock_blob_client = MagicMock()
|
||||
|
||||
# Set up return values for the mock
|
||||
mock_blob_service_client.get_container_client.return_value = (
|
||||
mock_container_client
|
||||
)
|
||||
mock_blob_service_client.get_container_client.return_value = mock_container_client
|
||||
mock_container_client.get_blob_client.return_value = mock_blob_client
|
||||
|
||||
# Monkeypatch the Azure classes to return our mocks
|
||||
monkeypatch.setattr(
|
||||
azure.storage.blob,
|
||||
"BlobServiceClient",
|
||||
'BlobServiceClient',
|
||||
lambda *args, **kwargs: mock_blob_service_client,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
azure.storage.blob,
|
||||
"ContainerClient",
|
||||
'ContainerClient',
|
||||
lambda *args, **kwargs: mock_container_client,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
azure.storage.blob, "BlobClient", lambda *args, **kwargs: mock_blob_client
|
||||
)
|
||||
monkeypatch.setattr(azure.storage.blob, 'BlobClient', lambda *args, **kwargs: mock_blob_client)
|
||||
|
||||
self.Storage = provider.AzureStorageProvider()
|
||||
self.Storage.endpoint = "https://myaccount.blob.core.windows.net"
|
||||
self.Storage.container_name = "my-container"
|
||||
self.file_content = b"test content"
|
||||
self.filename = "test.txt"
|
||||
self.filename_extra = "test_extra.txt"
|
||||
self.Storage.endpoint = 'https://myaccount.blob.core.windows.net'
|
||||
self.Storage.container_name = 'my-container'
|
||||
self.file_content = b'test content'
|
||||
self.filename = 'test.txt'
|
||||
self.filename_extra = 'test_extra.txt'
|
||||
self.file_bytesio_empty = io.BytesIO()
|
||||
|
||||
# Apply mocks to the Storage instance
|
||||
@@ -339,18 +319,14 @@ class TestAzureStorageProvider:
|
||||
upload_dir = mock_upload_dir(monkeypatch, tmp_path)
|
||||
|
||||
# Simulate an error when container does not exist
|
||||
self.Storage.container_client.get_blob_client.side_effect = Exception(
|
||||
"Container does not exist"
|
||||
)
|
||||
self.Storage.container_client.get_blob_client.side_effect = Exception('Container does not exist')
|
||||
with pytest.raises(Exception):
|
||||
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
|
||||
|
||||
# Reset side effect and create container
|
||||
self.Storage.container_client.get_blob_client.side_effect = None
|
||||
self.Storage.create_container()
|
||||
contents, azure_file_path = self.Storage.upload_file(
|
||||
io.BytesIO(self.file_content), self.filename
|
||||
)
|
||||
contents, azure_file_path = self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
|
||||
|
||||
# Assertions
|
||||
self.Storage.container_client.get_blob_client.assert_called_with(self.filename)
|
||||
@@ -359,8 +335,7 @@ class TestAzureStorageProvider:
|
||||
)
|
||||
assert contents == self.file_content
|
||||
assert (
|
||||
azure_file_path
|
||||
== f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
|
||||
azure_file_path == f'https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}'
|
||||
)
|
||||
assert (upload_dir / self.filename).exists()
|
||||
assert (upload_dir / self.filename).read_bytes() == self.file_content
|
||||
@@ -375,11 +350,9 @@ class TestAzureStorageProvider:
|
||||
# Mock upload behavior
|
||||
self.Storage.upload_file(io.BytesIO(self.file_content), self.filename)
|
||||
# Mock blob download behavior
|
||||
self.Storage.container_client.get_blob_client().download_blob().readall.return_value = (
|
||||
self.file_content
|
||||
)
|
||||
self.Storage.container_client.get_blob_client().download_blob().readall.return_value = self.file_content
|
||||
|
||||
file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
|
||||
file_url = f'https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}'
|
||||
file_path = self.Storage.get_file(file_url)
|
||||
|
||||
assert file_path == str(upload_dir / self.filename)
|
||||
@@ -395,7 +368,7 @@ class TestAzureStorageProvider:
|
||||
# Mock deletion
|
||||
self.Storage.container_client.get_blob_client().delete_blob.return_value = None
|
||||
|
||||
file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
|
||||
file_url = f'https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}'
|
||||
self.Storage.delete_file(file_url)
|
||||
|
||||
self.Storage.container_client.get_blob_client().delete_blob.assert_called_once()
|
||||
@@ -411,8 +384,8 @@ class TestAzureStorageProvider:
|
||||
|
||||
# Mock listing and deletion behavior
|
||||
self.Storage.container_client.list_blobs.return_value = [
|
||||
{"name": self.filename},
|
||||
{"name": self.filename_extra},
|
||||
{'name': self.filename},
|
||||
{'name': self.filename_extra},
|
||||
]
|
||||
self.Storage.container_client.get_blob_client().delete_blob.return_value = None
|
||||
|
||||
@@ -426,10 +399,8 @@ class TestAzureStorageProvider:
|
||||
def test_get_file_not_found(self, monkeypatch):
|
||||
self.Storage.create_container()
|
||||
|
||||
file_url = f"https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}"
|
||||
file_url = f'https://myaccount.blob.core.windows.net/{self.Storage.container_name}/{self.filename}'
|
||||
# Mock behavior to raise an error for missing blobs
|
||||
self.Storage.container_client.get_blob_client().download_blob.side_effect = (
|
||||
Exception("Blob not found")
|
||||
)
|
||||
with pytest.raises(Exception, match="Blob not found"):
|
||||
self.Storage.container_client.get_blob_client().download_blob.side_effect = Exception('Blob not found')
|
||||
with pytest.raises(Exception, match='Blob not found'):
|
||||
self.Storage.get_file(file_url)
|
||||
|
||||
Reference in New Issue
Block a user