feat: model update

This commit is contained in:
Timothy J. Baek
2024-05-24 18:26:36 -07:00
parent 0a48114bd2
commit 708d755eda
8 changed files with 396 additions and 407 deletions

View File

@@ -40,6 +40,9 @@ app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.MODELS = {}
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER

View File

@@ -33,6 +33,8 @@ class ModelParams(BaseModel):
# ModelMeta is a model for the data stored in the meta field of the Model table
# It isn't currently used in the backend, but it's here as a reference
class ModelMeta(BaseModel):
profile_image_url: Optional[str] = "/favicon.png"
description: Optional[str] = None
"""
User-facing description of the model.
@@ -84,6 +86,7 @@ class Model(pw.Model):
class ModelModel(BaseModel):
id: str
user_id: str
base_model_id: Optional[str] = None
name: str
@@ -123,18 +126,26 @@ class ModelsTable:
self.db = db
self.db.create_tables([Model])
def insert_new_model(self, model: ModelForm, user_id: str) -> Optional[ModelModel]:
def insert_new_model(
self, form_data: ModelForm, user_id: str
) -> Optional[ModelModel]:
model = ModelModel(
**{
**form_data.model_dump(),
"user_id": user_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
try:
model = Model.create(
**{
**model.model_dump(),
"user_id": user_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
return ModelModel(**model_to_dict(model))
except:
result = Model.create(**model.model_dump())
if result:
return model
else:
return None
except Exception as e:
print(e)
return None
def get_all_models(self) -> List[ModelModel]:

View File

@@ -1,4 +1,4 @@
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi import Depends, FastAPI, HTTPException, status, Request
from datetime import datetime, timedelta
from typing import List, Union, Optional
@@ -65,17 +65,28 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/{id}/update", response_model=Optional[ModelModel])
async def update_model_by_id(
id: str, form_data: ModelForm, user=Depends(get_admin_user)
request: Request, id: str, form_data: ModelForm, user=Depends(get_admin_user)
):
model = Models.get_model_by_id(id)
if model:
model = Models.update_model_by_id(id, form_data)
return model
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if form_data.id in request.app.state.MODELS:
model = Models.insert_new_model(form_data, user.id)
print(model)
if model:
return model
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(),
)
############################