mirror of
https://github.com/open-webui/open-webui.git
synced 2026-03-22 14:13:08 -05:00
Compare commits
581 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e42cbf07f5 | ||
|
|
016862a6b8 | ||
|
|
cd9dae0f3a | ||
|
|
b7baf4ddef | ||
|
|
1584015d3b | ||
|
|
e4b4a78e52 | ||
|
|
50534a0dcf | ||
|
|
c53ace3c98 | ||
|
|
7fa60dd63e | ||
|
|
6ff6d57507 | ||
|
|
4f93ecf519 | ||
|
|
1e974439d9 | ||
|
|
611955bc91 | ||
|
|
e3937ada38 | ||
|
|
c5b8466c0e | ||
|
|
e25b082162 | ||
|
|
247d7896fc | ||
|
|
a0ad681c77 | ||
|
|
1b7ee5dd2b | ||
|
|
8b6d03e430 | ||
|
|
efc4ad9f65 | ||
|
|
3cfb6507c5 | ||
|
|
24a314876b | ||
|
|
190911206b | ||
|
|
ea11521961 | ||
|
|
909c94e983 | ||
|
|
1c7ff7b1dc | ||
|
|
c7cec2131c | ||
|
|
6f3e7ed0e9 | ||
|
|
d4ccd25760 | ||
|
|
dd933e406f | ||
|
|
3c74ce3653 | ||
|
|
de2bda3e2e | ||
|
|
4b3e1bb747 | ||
|
|
bb30a83520 | ||
|
|
2bdf99b398 | ||
|
|
96a06effe2 | ||
|
|
df4a8e167d | ||
|
|
c16896e96d | ||
|
|
da7fa09053 | ||
|
|
23bf71022e | ||
|
|
798886105e | ||
|
|
645313e321 | ||
|
|
0915352cb0 | ||
|
|
70a108e54f | ||
|
|
a8289745da | ||
|
|
d3ba77837a | ||
|
|
3161a5d1f7 | ||
|
|
3a35f9ae7c | ||
|
|
12cf2b0104 | ||
|
|
22132e155a | ||
|
|
d7d08b40ed | ||
|
|
4f886c3944 | ||
|
|
afdb162f4f | ||
|
|
a4b5a9ac09 | ||
|
|
c4937cc144 | ||
|
|
b5bb853c66 | ||
|
|
cd367534b7 | ||
|
|
2e5c2bc4c2 | ||
|
|
34cc472c48 | ||
|
|
d701b69e05 | ||
|
|
0d7d6899b9 | ||
|
|
a2366a20ba | ||
|
|
326514be4e | ||
|
|
591aac5e16 | ||
|
|
e10897236d | ||
|
|
91429640ff | ||
|
|
d0828711bd | ||
|
|
11b36fe03e | ||
|
|
47419a77af | ||
|
|
d93107d1d6 | ||
|
|
e39617b1c0 | ||
|
|
31a97d8fec | ||
|
|
f91e56d6df | ||
|
|
688f11e1c5 | ||
|
|
4442411f40 | ||
|
|
cd7eff3bdb | ||
|
|
0d70ae6307 | ||
|
|
e61943f55f | ||
|
|
f8269de947 | ||
|
|
4b7f0c5be1 | ||
|
|
cd86161f33 | ||
|
|
e51722348a | ||
|
|
346856b578 | ||
|
|
b70a31f81e | ||
|
|
2d44cd4cda | ||
|
|
46e319dedc | ||
|
|
55da6224b8 | ||
|
|
95da0734b6 | ||
|
|
6b25139d4f | ||
|
|
a074991d3a | ||
|
|
00f3a9cb52 | ||
|
|
a2e0fbc943 | ||
|
|
01649fad64 | ||
|
|
e1a198f0a3 | ||
|
|
364d6eb9c4 | ||
|
|
8511847320 | ||
|
|
d52fc40038 | ||
|
|
16da847342 | ||
|
|
ecd3b4ebd4 | ||
|
|
b4d7268bed | ||
|
|
689b910c77 | ||
|
|
582253fc68 | ||
|
|
6d02485999 | ||
|
|
190aeb3fef | ||
|
|
51e0ed454c | ||
|
|
f05dbb895e | ||
|
|
83d2bf1c0d | ||
|
|
5cca378cc8 | ||
|
|
ed44b21a78 | ||
|
|
fb3c297df2 | ||
|
|
2c8fb66383 | ||
|
|
b44b7e8162 | ||
|
|
15fa7b44ea | ||
|
|
83099a093d | ||
|
|
cdc75237b2 | ||
|
|
76c8602324 | ||
|
|
4c756b5501 | ||
|
|
7ad8918cd9 | ||
|
|
e9194d9524 | ||
|
|
e93b37bab1 | ||
|
|
74cacf8bf5 | ||
|
|
198bd49cc2 | ||
|
|
a4333295ce | ||
|
|
5748f6ef77 | ||
|
|
c6dcac99ac | ||
|
|
698f2f4e51 | ||
|
|
a165e76486 | ||
|
|
f4e5e5171f | ||
|
|
cb3e01de8a | ||
|
|
0d29f31846 | ||
|
|
5e8f3048f9 | ||
|
|
f1d21fc59a | ||
|
|
eaecd15e69 | ||
|
|
2914c29ab3 | ||
|
|
e444f769f6 | ||
|
|
2e85c8e24d | ||
|
|
7c8de9e221 | ||
|
|
79aae7a76e | ||
|
|
7083a61ddb | ||
|
|
27d2fbbe33 | ||
|
|
24c3f7a664 | ||
|
|
06a692282b | ||
|
|
556c75e876 | ||
|
|
271acb2e67 | ||
|
|
c611734088 | ||
|
|
41cabf5a2c | ||
|
|
6981e1e467 | ||
|
|
7a787efc4b | ||
|
|
9d3ab2e40e | ||
|
|
2b7213f04a | ||
|
|
de2825bb89 | ||
|
|
423fee347a | ||
|
|
2fd7bbc259 | ||
|
|
50db2514dc | ||
|
|
4820ecc371 | ||
|
|
03cfac185f | ||
|
|
da0438f08c | ||
|
|
cf0aca1487 | ||
|
|
b1b2a6d78c | ||
|
|
e9cb6e4714 | ||
|
|
8cc834901f | ||
|
|
6524cae407 | ||
|
|
47318daef0 | ||
|
|
ad5cc9d79e | ||
|
|
37ce88e744 | ||
|
|
d6f0c77c34 | ||
|
|
544f516104 | ||
|
|
2721c5939e | ||
|
|
2cd75ceb53 | ||
|
|
bf4f71879f | ||
|
|
50f36a5262 | ||
|
|
ef5a5be60d | ||
|
|
eabd9192f3 | ||
|
|
38208866b9 | ||
|
|
db9aef0eaf | ||
|
|
70de5cf7b8 | ||
|
|
455b38dcc1 | ||
|
|
49b36937ad | ||
|
|
8535974ea3 | ||
|
|
d9573befff | ||
|
|
f133353734 | ||
|
|
6b713bc487 | ||
|
|
7268ea7abd | ||
|
|
03e48de1a9 | ||
|
|
8455396249 | ||
|
|
16504b88f5 | ||
|
|
71d4ed1f6a | ||
|
|
0db0b8ce2c | ||
|
|
1ea00a58f9 | ||
|
|
4c989808d6 | ||
|
|
0b9c6466ce | ||
|
|
c3e8cd03b2 | ||
|
|
0d5ce23885 | ||
|
|
64fe2de962 | ||
|
|
2be9e55545 | ||
|
|
ea0d507e23 | ||
|
|
0523ebcc5e | ||
|
|
e4573d0b6c | ||
|
|
ddac34f769 | ||
|
|
2875326015 | ||
|
|
0f6d302760 | ||
|
|
5871df02ac | ||
|
|
f3454a8bba | ||
|
|
1120f4d09a | ||
|
|
0dc75363aa | ||
|
|
5c149c3aa2 | ||
|
|
d43ca803ca | ||
|
|
366158ff04 | ||
|
|
89e86f5e2e | ||
|
|
76ca3cf452 | ||
|
|
e28427803f | ||
|
|
8f51681801 | ||
|
|
a38934bd23 | ||
|
|
e500461dc0 | ||
|
|
c7e3692678 | ||
|
|
9abae36264 | ||
|
|
5bdb1c99bb | ||
|
|
1902d4238b | ||
|
|
08398e511e | ||
|
|
a14c06fa0b | ||
|
|
9531edf6d6 | ||
|
|
1b599057e9 | ||
|
|
0f33856182 | ||
|
|
9737869d11 | ||
|
|
16d900247a | ||
|
|
e6add2869b | ||
|
|
d87584e7ad | ||
|
|
987b0f0dfd | ||
|
|
54f6ae8fb5 | ||
|
|
d681357c0c | ||
|
|
18fd3db3d5 | ||
|
|
9000fddffc | ||
|
|
7c9f04bd59 | ||
|
|
6962f8f3b3 | ||
|
|
a26d5a2d95 | ||
|
|
61aed5fe28 | ||
|
|
a27df11c02 | ||
|
|
bf9ff1af56 | ||
|
|
ad82f790c1 | ||
|
|
f078bbf006 | ||
|
|
513d88b26d | ||
|
|
edebbd2f6d | ||
|
|
49c1267089 | ||
|
|
4528770a0e | ||
|
|
e4b2d29cba | ||
|
|
4215c8ac3f | ||
|
|
b46f4ff3e0 | ||
|
|
f4b9f77cc9 | ||
|
|
9fc4ef323c | ||
|
|
15f14d0318 | ||
|
|
ab8f8bcf92 | ||
|
|
81bb816881 | ||
|
|
8bf09ae29c | ||
|
|
81e601002f | ||
|
|
31f499af0e | ||
|
|
51061afb10 | ||
|
|
886fb3f426 | ||
|
|
f919f63d33 | ||
|
|
eae86f04c5 | ||
|
|
0563239782 | ||
|
|
7699db0666 | ||
|
|
1cc4eb241a | ||
|
|
ec26296dc3 | ||
|
|
cc9b7a1f1a | ||
|
|
ae589bf604 | ||
|
|
77490e3392 | ||
|
|
d5152e43d5 | ||
|
|
f341971eae | ||
|
|
35cdd43a31 | ||
|
|
4a7c1d8d55 | ||
|
|
ddfed87b15 | ||
|
|
46f2f0fbdb | ||
|
|
4905c180a5 | ||
|
|
1feaee8eca | ||
|
|
1dce50df12 | ||
|
|
d1ca6922d9 | ||
|
|
0e60ba4723 | ||
|
|
ff8e94e41f | ||
|
|
1d1c4bc85d | ||
|
|
bd88aa2af9 | ||
|
|
dd8f13e097 | ||
|
|
495ecc531b | ||
|
|
72ab29013d | ||
|
|
2d368e89a1 | ||
|
|
3eea90b025 | ||
|
|
e802004dc3 | ||
|
|
4718239353 | ||
|
|
a865420cb1 | ||
|
|
42af98ae28 | ||
|
|
f4f8334153 | ||
|
|
1542cb486d | ||
|
|
434241149b | ||
|
|
9b939e99f2 | ||
|
|
9faa5856f5 | ||
|
|
b57f7251a5 | ||
|
|
b9499b4392 | ||
|
|
29efee8ede | ||
|
|
7d55f9bc2e | ||
|
|
7dace30587 | ||
|
|
4adcd2b64a | ||
|
|
b458383b82 | ||
|
|
3ce4e36ab2 | ||
|
|
713bedf3b3 | ||
|
|
aaacd28131 | ||
|
|
3378f35296 | ||
|
|
64c8bbc16a | ||
|
|
eef18d4440 | ||
|
|
7d31b111cc | ||
|
|
ba19ff8ace | ||
|
|
c3d631ca98 | ||
|
|
a0ba5974f6 | ||
|
|
f29dc2f865 | ||
|
|
85508106a8 | ||
|
|
eed7cfd2a2 | ||
|
|
b2dc6fef9f | ||
|
|
1cd43b122b | ||
|
|
7bc1876e37 | ||
|
|
90e70608b9 | ||
|
|
f566c5940a | ||
|
|
61b1a8fdab | ||
|
|
8c30c6282b | ||
|
|
62c6ebe55a | ||
|
|
c0f106836f | ||
|
|
440894f8d3 | ||
|
|
6efca03a8f | ||
|
|
6f3ab5917d | ||
|
|
f2e2b59c18 | ||
|
|
f9a05dd1e1 | ||
|
|
9a081c8593 | ||
|
|
1197c640c4 | ||
|
|
8c38708827 | ||
|
|
d8a01cb911 | ||
|
|
4311bb7b99 | ||
|
|
6aaede510b | ||
|
|
403262d764 | ||
|
|
866c3dff11 | ||
|
|
d9ffcea764 | ||
|
|
eb9733e99f | ||
|
|
a07ff56c50 | ||
|
|
fe5519e0a2 | ||
|
|
772f5ccd60 | ||
|
|
ccdf51588e | ||
|
|
3bda1a8b88 | ||
|
|
9e85ed861d | ||
|
|
b3987ad41e | ||
|
|
867c4bc0d0 | ||
|
|
3ec0a58cd7 | ||
|
|
bfdbb2df69 | ||
|
|
87d695caad | ||
|
|
df0cdd9f3c | ||
|
|
2f2bd88dd1 | ||
|
|
df48eac22b | ||
|
|
4819199650 | ||
|
|
d3d161f723 | ||
|
|
edf6c6a18c | ||
|
|
a495f68b58 | ||
|
|
f6bec8d9f3 | ||
|
|
1349c6049e | ||
|
|
2db837cab4 | ||
|
|
8de91df1ff | ||
|
|
6b46b8bf62 | ||
|
|
10747a6b04 | ||
|
|
faa054d4b4 | ||
|
|
9ddb16345f | ||
|
|
f264d82d13 | ||
|
|
8718067894 | ||
|
|
a3ca632921 | ||
|
|
d4d6d1e7db | ||
|
|
a4ed027498 | ||
|
|
56286a2157 | ||
|
|
961c1efe38 | ||
|
|
43b791927e | ||
|
|
7cbad465e5 | ||
|
|
33099bf9e4 | ||
|
|
1dadfa9f97 | ||
|
|
9c554db37c | ||
|
|
b825947745 | ||
|
|
899424b371 | ||
|
|
8dcee6b6ed | ||
|
|
1439f6862d | ||
|
|
48d604a525 | ||
|
|
9918ec6246 | ||
|
|
d5ce85f34a | ||
|
|
29a2719595 | ||
|
|
976676a482 | ||
|
|
19ad8b4f26 | ||
|
|
f04767f1fe | ||
|
|
36085f3036 | ||
|
|
4a46417275 | ||
|
|
a68e717dcf | ||
|
|
346b7ddb14 | ||
|
|
31545439a0 | ||
|
|
5401b24717 | ||
|
|
152b1224a8 | ||
|
|
ac5ea4ff25 | ||
|
|
6e73b3d075 | ||
|
|
054a718da0 | ||
|
|
a63326b550 | ||
|
|
547d17c362 | ||
|
|
10382998f1 | ||
|
|
c5f7ad2009 | ||
|
|
6173636a11 | ||
|
|
9b6076f726 | ||
|
|
8fcb08c541 | ||
|
|
d42de65298 | ||
|
|
abbf37f3d9 | ||
|
|
5433340bb1 | ||
|
|
1dab0cfada | ||
|
|
2714a1bb64 | ||
|
|
4053de5825 | ||
|
|
29573afb91 | ||
|
|
39b7e54482 | ||
|
|
7f77828e3f | ||
|
|
59c3a18118 | ||
|
|
460992613f | ||
|
|
0055f3dcb6 | ||
|
|
1c078bdb55 | ||
|
|
85d2b898c6 | ||
|
|
b903e3a896 | ||
|
|
aecb7ae398 | ||
|
|
7f4a83ea0e | ||
|
|
c4ea31357f | ||
|
|
19bcda2362 | ||
|
|
39c2f70778 | ||
|
|
368e11e2b2 | ||
|
|
430a79e177 | ||
|
|
9b25efc3bb | ||
|
|
746fe9ea16 | ||
|
|
2fac9b45cd | ||
|
|
370f97b44e | ||
|
|
a26c5d9549 | ||
|
|
96cc8bf127 | ||
|
|
4bf9d071f0 | ||
|
|
1a74199656 | ||
|
|
4e19e66e7d | ||
|
|
27b1a35778 | ||
|
|
ba0b3a984d | ||
|
|
2f84e6b877 | ||
|
|
f311c03a21 | ||
|
|
ef3b6083ff | ||
|
|
d23fac2a5b | ||
|
|
8cb981c3f4 | ||
|
|
81a8ad2762 | ||
|
|
dbf6ec71fe | ||
|
|
0b17ff6eef | ||
|
|
19663e539a | ||
|
|
c192475528 | ||
|
|
9e436fe6b0 | ||
|
|
96b9f81ca7 | ||
|
|
62622893a5 | ||
|
|
aed2caefe1 | ||
|
|
acb61d3c42 | ||
|
|
d3778b0bda | ||
|
|
fda26b4ad0 | ||
|
|
5b879a2121 | ||
|
|
51e344f5b2 | ||
|
|
d0eb59ffdb | ||
|
|
541ff6b41a | ||
|
|
3792051604 | ||
|
|
684a7f0455 | ||
|
|
1f53e0922e | ||
|
|
ba6dc71810 | ||
|
|
33e54a9d3b | ||
|
|
9f981db0b9 | ||
|
|
3e1c42bdc9 | ||
|
|
aba610ac29 | ||
|
|
b4fb0d1da2 | ||
|
|
f547f1424c | ||
|
|
a07213b5be | ||
|
|
0e8e9820d0 | ||
|
|
28ce102a79 | ||
|
|
c1fd1d3490 | ||
|
|
fa5e1f7452 | ||
|
|
95000c7b15 | ||
|
|
8300fa85b0 | ||
|
|
94235397cf | ||
|
|
99e20bfe82 | ||
|
|
6440d8629e | ||
|
|
e5563ca435 | ||
|
|
53fbb19fa0 | ||
|
|
297b9470e8 | ||
|
|
1346559227 | ||
|
|
ac1ee888be | ||
|
|
ff0a7f1cbb | ||
|
|
ea883b2ed4 | ||
|
|
b85cfdc90f | ||
|
|
17ed7351a4 | ||
|
|
7be93d2a84 | ||
|
|
5d724d0b84 | ||
|
|
dbd6ac8080 | ||
|
|
53296c1005 | ||
|
|
26615ba995 | ||
|
|
665e9cd129 | ||
|
|
d219b3bfd6 | ||
|
|
c5ef53a09f | ||
|
|
0a26c41c7b | ||
|
|
da6535eeb4 | ||
|
|
2db35d5969 | ||
|
|
582ce23bb5 | ||
|
|
f42cc90a00 | ||
|
|
cbc7801b0e | ||
|
|
22b5feb747 | ||
|
|
4b536b5283 | ||
|
|
ede29e98b7 | ||
|
|
789e1db260 | ||
|
|
4383306770 | ||
|
|
13796fe3b3 | ||
|
|
63402c48a8 | ||
|
|
44efd4d372 | ||
|
|
a18292da84 | ||
|
|
adba11ebeb | ||
|
|
442f99e075 | ||
|
|
dc7221816f | ||
|
|
bddc293d82 | ||
|
|
3a2247b7a0 | ||
|
|
47483c4402 | ||
|
|
4831c9e57e | ||
|
|
e3236622f3 | ||
|
|
83f1fa7bb9 | ||
|
|
ba427bee3f | ||
|
|
b9458817b1 | ||
|
|
38eb6abbfc | ||
|
|
b173f86690 | ||
|
|
9b8f9c689b | ||
|
|
e40212a662 | ||
|
|
a37cad2ecf | ||
|
|
3b9e21ecf9 | ||
|
|
5fac25a002 | ||
|
|
2b4e8f6cea | ||
|
|
29fac5ecca | ||
|
|
f9e24968e3 | ||
|
|
c4f82309dc | ||
|
|
da4676de2e | ||
|
|
d870386d7d | ||
|
|
8dc73e8744 | ||
|
|
840437e58f | ||
|
|
bd28e1ed7d | ||
|
|
50c3be2136 | ||
|
|
907cf61da7 | ||
|
|
a3a205abd1 | ||
|
|
fac8c3259c | ||
|
|
e67fdce727 | ||
|
|
c465a37597 | ||
|
|
f5bda4bc27 | ||
|
|
9ff580b8ca | ||
|
|
d3acb5cbaa | ||
|
|
2a5506a9cd | ||
|
|
c567185cb1 | ||
|
|
5ed5e532a9 | ||
|
|
a83f89d430 | ||
|
|
b2d3bfa3a8 | ||
|
|
8744a12abb | ||
|
|
374d6cad18 | ||
|
|
d24c21b40f | ||
|
|
db929b5d5e | ||
|
|
47ae5221f7 | ||
|
|
79e988b281 | ||
|
|
9412f51c19 | ||
|
|
320cf06333 | ||
|
|
cede8a966f | ||
|
|
c561a4c42b | ||
|
|
b4e7957a00 | ||
|
|
c4eacbfc0f | ||
|
|
429fa2befa | ||
|
|
3cfd4f8993 | ||
|
|
1c3bc99b86 | ||
|
|
335337fc75 | ||
|
|
1ea03cc156 | ||
|
|
3923c4df65 | ||
|
|
70ea85484c | ||
|
|
35d75e733d | ||
|
|
259881f0b6 | ||
|
|
b4cd685795 | ||
|
|
2d5e1a8c6f | ||
|
|
f07ba60f2a | ||
|
|
e12cf77553 | ||
|
|
a06f57c0a5 | ||
|
|
7589e4f5e5 | ||
|
|
165ee3649b | ||
|
|
5fd511b90b |
136
CHANGELOG.md
136
CHANGELOG.md
@@ -5,6 +5,142 @@ All notable changes to this project will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [0.5.2] - 2024-12-26
|
||||
|
||||
### Added
|
||||
|
||||
- **🖊️ Typing Indicators in Channels**: Know exactly who’s typing in real-time within your channels, enhancing collaboration and keeping everyone engaged.
|
||||
- **👤 User Status Indicators**: Quickly view a user’s status by clicking their profile image in channels for better coordination and availability insights.
|
||||
- **🔒 Configurable API Key Authentication Restrictions**: Flexibly configure endpoint restrictions for API key authentication, now off by default for a smoother setup in trusted environments.
|
||||
|
||||
### Fixed
|
||||
|
||||
- **🔧 Playground Functionality Restored**: Resolved a critical issue where the playground wasn’t working, ensuring seamless experimentation and troubleshooting workflows.
|
||||
- **📊 Corrected Ollama Usage Statistics**: Fixed a calculation error in Ollama’s usage statistics, providing more accurate tracking and insights for better resource management.
|
||||
- **🔗 Pipelines Outlet Hook Registration**: Addressed an issue where outlet hooks for pipelines weren’t registered, restoring functionality and consistency in pipeline workflows.
|
||||
- **🎨 Image Generation Error**: Resolved a persistent issue causing errors with 'get_automatic1111_api_auth()' to ensure smooth image generation workflows.
|
||||
- **🎙️ Text-to-Speech Error**: Fixed the missing argument in Eleven Labs’ 'get_available_voices()', restoring full text-to-speech capabilities for uninterrupted voice interactions.
|
||||
- **🖋️ Title Generation Issue**: Fixed a bug where title generation was not working in certain cases, ensuring consistent and reliable chat organization.
|
||||
|
||||
## [0.5.1] - 2024-12-25
|
||||
|
||||
### Added
|
||||
|
||||
- **🔕 Notification Sound Toggle**: Added a new setting under Settings > Interface to disable notification sounds, giving you greater control over your workspace environment and focus.
|
||||
|
||||
### Fixed
|
||||
|
||||
- **🔄 Non-Streaming Response Visibility**: Resolved an issue where non-streaming responses were not displayed, ensuring all responses are now reliably shown in your conversations.
|
||||
- **🖋️ Title Generation with OpenAI APIs**: Fixed a bug preventing title generation when using OpenAI APIs, restoring the ability to automatically generate chat titles for smoother organization.
|
||||
- **👥 Admin Panel User List**: Addressed the issue where only 50 users were visible in the admin panel. You can now manage and view all users without restrictions.
|
||||
- **🖼️ Image Generation Error**: Fixed the issue causing 'get_automatic1111_api_auth()' errors in image generation, ensuring seamless creative workflows.
|
||||
- **⚙️ Pipeline Settings Loading Issue**: Resolved a problem where pipeline settings were stuck at the loading screen, restoring full configurability in the admin panel.
|
||||
|
||||
## [0.5.0] - 2024-12-25
|
||||
|
||||
### Added
|
||||
|
||||
- **💬 True Asynchronous Chat Support**: Create chats, navigate away, and return anytime with responses ready. Ideal for reasoning models and multi-agent workflows, enhancing multitasking like never before.
|
||||
- **🔔 Chat Completion Notifications**: Never miss a completed response. Receive instant in-UI notifications when a chat finishes in a non-active tab, keeping you updated while you work elsewhere.
|
||||
- **🌐 Notification Webhook Integration**: Get alerts via webhooks even when your tab is closed! Configure your webhook URL in Settings > Account and receive timely updates for long-running chats or external integration needs.
|
||||
- **📚 Channels (Beta)**: Explore Discord/Slack-style chat rooms designed for real-time collaboration between users and AIs. Build bots for channels and unlock asynchronous communication for proactive multi-agent workflows. Opt-in via Admin Settings > General. A Comprehensive Bot SDK tutorial (https://github.com/open-webui/bot) is incoming, so stay tuned!
|
||||
- **🖼️ Client-Side Image Compression**: Now compress images before upload (Settings > Interface), saving bandwidth and improving performance seamlessly.
|
||||
- **🛠️ OAuth Management for User Groups**: Enable group-level management via OAuth integration for enhanced control and scalability in collaborative environments.
|
||||
- **✅ Structured Output for Ollama**: Pass structured data output directly to Ollama, unlocking new possibilities for streamlined automation and precise data handling.
|
||||
- **📜 Offline Swagger Documentation**: Developer-friendly Swagger API docs are now available offline, ensuring full accessibility wherever you are.
|
||||
- **📸 Quick Screen Capture Button**: Effortlessly capture your screen with a single click from the message input menu.
|
||||
- **🌍 i18n Updates**: Improved and refined translations across several languages, including Ukrainian, German, Brazilian Portuguese, Catalan, and more, ensuring a seamless global user experience.
|
||||
|
||||
### Fixed
|
||||
|
||||
- **📋 Table Export to CSV**: Resolved issues with CSV export where headers were missing or errors occurred due to values with commas, ensuring smooth and reliable data handling.
|
||||
- **🔓 BYPASS_MODEL_ACCESS_CONTROL**: Fixed an issue where users could see models but couldn’t use them with 'BYPASS_MODEL_ACCESS_CONTROL=True', restoring proper functionality for environments leveraging this setting.
|
||||
|
||||
### Changed
|
||||
|
||||
- **💡 API Key Authentication Restriction**: Narrowed API key auth permissions to '/api/models' and '/api/chat/completions' for enhanced security and better API governance.
|
||||
- **⚙️ Backend Overhaul for Performance**: Major backend restructuring; a heads-up that some "Functions" using internal variables may face compatibility issues. Moving forward, websocket support is mandatory to ensure Open WebUI operates seamlessly.
|
||||
|
||||
### Removed
|
||||
|
||||
- **⚠️ Legacy Functionality Clean-Up**: Deprecated outdated backend systems that were non-essential or overlapped with newer implementations, allowing for a leaner, more efficient platform.
|
||||
|
||||
## [0.4.8] - 2024-12-07
|
||||
|
||||
### Added
|
||||
|
||||
- **🔓 Bypass Model Access Control**: Introduced the 'BYPASS_MODEL_ACCESS_CONTROL' environment variable. Easily bypass model access controls for user roles when access control isn't required, simplifying workflows for trusted environments.
|
||||
- **📝 Markdown in Banners**: Now supports markdown for banners, enabling richer, more visually engaging announcements.
|
||||
- **🌐 Internationalization Updates**: Enhanced translations across multiple languages, further improving accessibility and global user experience.
|
||||
- **🎨 Styling Enhancements**: General UI style refinements for a cleaner and more polished interface.
|
||||
- **📋 Rich Text Reliability**: Improved the reliability and stability of rich text input across chats for smoother interactions.
|
||||
|
||||
### Fixed
|
||||
|
||||
- **💡 Tailwind Build Issue**: Resolved a breaking bug caused by Tailwind, ensuring smoother builds and overall system reliability.
|
||||
- **📚 Knowledge Collection Query Fix**: Addressed API endpoint issues with querying knowledge collections, ensuring accurate and reliable information retrieval.
|
||||
|
||||
## [0.4.7] - 2024-12-01
|
||||
|
||||
### Added
|
||||
|
||||
- **✨ Prompt Input Auto-Completion**: Type a prompt and let AI intelligently suggest and complete your inputs. Simply press 'Tab' or swipe right on mobile to confirm. Available only with Rich Text Input (default setting). Disable via Admin Settings for full control.
|
||||
- **🌍 Improved Translations**: Enhanced localization for multiple languages, ensuring a more polished and accessible experience for international users.
|
||||
|
||||
### Fixed
|
||||
|
||||
- **🛠️ Tools Export Issue**: Resolved a critical issue where exporting tools wasn’t functioning, restoring seamless export capabilities.
|
||||
- **🔗 Model ID Registration**: Fixed an issue where model IDs weren’t registering correctly in the model editor, ensuring reliable model setup and tracking.
|
||||
- **🖋️ Textarea Auto-Expansion**: Corrected a bug where textareas didn’t expand automatically on certain browsers, improving usability for multi-line inputs.
|
||||
- **🔧 Ollama Embed Endpoint**: Addressed the /ollama/embed endpoint malfunction, ensuring consistent performance and functionality.
|
||||
|
||||
### Changed
|
||||
|
||||
- **🎨 Knowledge Base Styling**: Refined knowledge base visuals for a cleaner, more modern look, laying the groundwork for further enhancements in upcoming releases.
|
||||
|
||||
## [0.4.6] - 2024-11-26
|
||||
|
||||
### Added
|
||||
|
||||
- **🌍 Enhanced Translations**: Various language translations improved to make the WebUI more accessible and user-friendly worldwide.
|
||||
|
||||
### Fixed
|
||||
|
||||
- **✏️ Textarea Shifting Bug**: Resolved the issue where the textarea shifted unexpectedly, ensuring a smoother typing experience.
|
||||
- **⚙️ Model Configuration Modal**: Fixed the issue where the models configuration modal introduced in 0.4.5 wasn’t working for some users.
|
||||
- **🔍 Legacy Query Support**: Restored functionality for custom query generation in RAG when using legacy prompts, ensuring both default and custom templates now work seamlessly.
|
||||
- **⚡ Improved General Reliability**: Various minor fixes improve platform stability and ensure a smoother overall experience across workflows.
|
||||
|
||||
## [0.4.5] - 2024-11-26
|
||||
|
||||
### Added
|
||||
|
||||
- **🎨 Model Order/Defaults Reintroduced**: Brought back the ability to set model order and default models, now configurable via Admin Settings > Models > Configure (Gear Icon).
|
||||
|
||||
### Fixed
|
||||
|
||||
- **🔍 Query Generation Issue**: Resolved an error in web search query generation, enhancing search accuracy and ensuring smoother search workflows.
|
||||
- **📏 Textarea Auto Height Bug**: Fixed a layout issue where textarea input height was shifting unpredictably, particularly when editing system prompts.
|
||||
- **🔑 Ollama Authentication**: Corrected an issue with Ollama’s authorization headers, guaranteeing reliable authentication across all endpoints.
|
||||
- **⚙️ Missing Min_P Save**: Resolved an issue where the 'min_p' parameter was not being saved in configurations.
|
||||
- **🛠️ Tools Description**: Fixed a key issue that omitted tool descriptions in tools payload.
|
||||
|
||||
## [0.4.4] - 2024-11-22
|
||||
|
||||
### Added
|
||||
|
||||
- **🌐 Translation Updates**: Refreshed Catalan, Brazilian Portuguese, German, and Ukrainian translations, further enhancing the platform's accessibility and improving the experience for international users.
|
||||
|
||||
### Fixed
|
||||
|
||||
- **📱 Mobile Controls Visibility**: Resolved an issue where the controls button was not displaying on the new chats page for mobile users, ensuring smoother navigation and functionality on smaller screens.
|
||||
- **📷 LDAP Profile Image Issue**: Fixed an LDAP integration bug related to profile images, ensuring seamless authentication and a reliable login experience for users.
|
||||
- **⏳ RAG Query Generation Issue**: Addressed a significant problem where RAG query generation occurred unnecessarily without attached files, drastically improving speed and reducing delays during chat completions.
|
||||
|
||||
### Changed
|
||||
|
||||
- **⚙️ Legacy Event Emitter Support**: Reintroduced compatibility with legacy "citation" types for event emitters in tools and functions, providing smoother workflows and broader tool support for users.
|
||||
|
||||
## [0.4.3] - 2024-11-21
|
||||
|
||||
### Added
|
||||
|
||||
@@ -2,76 +2,98 @@
|
||||
|
||||
## Our Pledge
|
||||
|
||||
We as members, contributors, and leaders pledge to make participation in our
|
||||
community a harassment-free experience for everyone, regardless of age, body
|
||||
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
||||
identity and expression, level of experience, education, socio-economic status,
|
||||
nationality, personal appearance, race, religion, or sexual identity
|
||||
and orientation.
|
||||
As members, contributors, and leaders of this community, we pledge to make participation in our open-source project a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation.
|
||||
|
||||
We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community.
|
||||
We are committed to creating and maintaining an open, respectful, and professional environment where positive contributions and meaningful discussions can flourish. By participating in this project, you agree to uphold these values and align your behavior to the standards outlined in this Code of Conduct.
|
||||
|
||||
## Why These Standards Are Important
|
||||
|
||||
Open-source projects rely on a community of volunteers dedicating their time, expertise, and effort toward a shared goal. These projects are inherently collaborative but also fragile, as the success of the project depends on the goodwill, energy, and productivity of those involved.
|
||||
|
||||
Maintaining a positive and respectful environment is essential to safeguarding the integrity of this project and protecting contributors' efforts. Behavior that disrupts this atmosphere—whether through hostility, entitlement, or unprofessional conduct—can severely harm the morale and productivity of the community. **Strict enforcement of these standards ensures a safe and supportive space for meaningful collaboration.**
|
||||
|
||||
This is a community where **respect and professionalism are mandatory.** Violations of these standards will result in **zero tolerance** and immediate enforcement to prevent disruption and ensure the well-being of all participants.
|
||||
|
||||
## Our Standards
|
||||
|
||||
Examples of behavior that contribute to a positive environment for our community include:
|
||||
Examples of behavior that contribute to a positive and professional community include:
|
||||
|
||||
- Demonstrating empathy and kindness toward other people
|
||||
- Being respectful of differing opinions, viewpoints, and experiences
|
||||
- Giving and gracefully accepting constructive feedback
|
||||
- Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience
|
||||
- Focusing on what is best not just for us as individuals, but for the overall community
|
||||
- **Respecting others.** Be considerate, listen actively, and engage with empathy toward others' viewpoints and experiences.
|
||||
- **Constructive feedback.** Provide actionable, thoughtful, and respectful feedback that helps improve the project and encourages collaboration. Avoid unproductive negativity or hypercriticism.
|
||||
- **Recognizing volunteer contributions.** Appreciate that contributors dedicate their free time and resources selflessly. Approach them with gratitude and patience.
|
||||
- **Focusing on shared goals.** Collaborate in ways that prioritize the health, success, and sustainability of the community over individual agendas.
|
||||
|
||||
Examples of unacceptable behavior include:
|
||||
|
||||
- The use of sexualized language or imagery, and sexual attention or advances of any kind
|
||||
- Trolling, insulting or derogatory comments, and personal or political attacks
|
||||
- Public or private harassment
|
||||
- Publishing others' private information, such as a physical or email address, without their explicit permission
|
||||
- **Spamming of any kind**
|
||||
- Aggressive sales tactics targeting our community members are strictly prohibited. You can mention your product if it's relevant to the discussion, but under no circumstances should you push it forcefully
|
||||
- Other conduct which could reasonably be considered inappropriate in a professional setting
|
||||
- The use of discriminatory, demeaning, or sexualized language or behavior.
|
||||
- Personal attacks, derogatory comments, trolling, or inflammatory political or ideological arguments.
|
||||
- Harassment, intimidation, or any behavior intended to create a hostile, uncomfortable, or unsafe environment.
|
||||
- Publishing others' private information (e.g., physical or email addresses) without explicit permission.
|
||||
- **Entitlement, demand, or aggression toward contributors.** Volunteers are under no obligation to provide immediate or personalized support. Rude or dismissive behavior will not be tolerated.
|
||||
- **Unproductive or destructive behavior.** This includes venting frustration as hostility ("tantrums"), hypercriticism, attention-seeking negativity, or anything that distracts from the project's goals.
|
||||
- **Spamming and promotional exploitation.** Sharing irrelevant product promotions or self-promotion in the community is not allowed unless it directly contributes value to the discussion.
|
||||
|
||||
### Feedback and Community Engagement
|
||||
|
||||
- **Constructive feedback is encouraged, but hostile or entitled behavior will result in immediate action.** If you disagree with elements of the project, we encourage you to offer meaningful improvements or fork the project if necessary. Healthy discussions and technical disagreements are welcome only when handled with professionalism.
|
||||
- **Respect contributors' time and efforts.** No one is entitled to personalized or on-demand assistance. This is a community built on collaboration and shared effort; demanding or demeaning behavior undermines that trust and will not be allowed.
|
||||
|
||||
### Zero Tolerance: No Warnings, Immediate Action
|
||||
|
||||
This community operates under a **zero-tolerance policy.** Any behavior deemed unacceptable under this Code of Conduct will result in **immediate enforcement, without prior warning.**
|
||||
|
||||
We employ this approach to ensure that unproductive or disruptive behavior does not escalate further or cause unnecessary harm to other contributors. The standards are clear, and violations of any kind—whether mild or severe—will be addressed decisively to protect the community.
|
||||
|
||||
## Enforcement Responsibilities
|
||||
|
||||
Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful.
|
||||
Community leaders are responsible for upholding and enforcing these standards. They are empowered to take **immediate and appropriate action** to address any behaviors they deem unacceptable under this Code of Conduct. These actions are taken with the goal of protecting the community and preserving its safe, positive, and productive environment.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies within all community spaces and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event.
|
||||
This Code of Conduct applies to all community spaces, including forums, repositories, social media accounts, and in-person events. It also applies when an individual represents the community in public settings, such as conferences or official communications.
|
||||
|
||||
## Enforcement
|
||||
Additionally, any behavior outside of these defined spaces that negatively impacts the community or its members may fall within the scope of this Code of Conduct.
|
||||
|
||||
Instances of abusive, harassing, spamming, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at hello@openwebui.com. All complaints will be reviewed and investigated promptly and fairly.
|
||||
## Reporting Violations
|
||||
|
||||
All community leaders are obligated to respect the privacy and security of the reporter of any incident.
|
||||
Instances of unacceptable behavior can be reported to the leadership team at **hello@openwebui.com**. Reports will be handled promptly, confidentially, and with consideration for the safety and well-being of the reporter.
|
||||
|
||||
All community leaders are required to uphold confidentiality and impartiality when addressing reports of violations.
|
||||
|
||||
## Enforcement Guidelines
|
||||
|
||||
Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct:
|
||||
### Ban
|
||||
|
||||
### 1. Temporary Ban
|
||||
**Community Impact**: Community leaders will issue a ban to any participant whose behavior is deemed unacceptable according to this Code of Conduct. Bans are enforced immediately and without prior notice.
|
||||
|
||||
**Community Impact**: Any violation of community standards, including but not limited to inappropriate language, unprofessional behavior, harassment, or spamming.
|
||||
A ban may be temporary or permanent, depending on the severity of the violation. This includes—but is not limited to—behavior such as:
|
||||
|
||||
**Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban.
|
||||
- Harassment or abusive behavior toward contributors.
|
||||
- Persistent negativity or hostility that disrupts the collaborative environment.
|
||||
- Disrespectful, demanding, or aggressive interactions with others.
|
||||
- Attempts to cause harm or sabotage the community.
|
||||
|
||||
### 2. Permanent Ban
|
||||
**Consequence**: A banned individual is immediately removed from access to all community spaces, communication channels, and events. Community leaders reserve the right to enforce either a time-limited suspension or a permanent ban based on the specific circumstances of the violation.
|
||||
|
||||
**Community Impact**: Repeated or severe violations of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals.
|
||||
This approach ensures that disruptive behaviors are addressed swiftly and decisively in order to maintain the integrity and productivity of the community.
|
||||
|
||||
**Consequence**: A permanent ban from any sort of public interaction within the community.
|
||||
## Why Zero Tolerance Is Necessary
|
||||
|
||||
Open-source projects thrive on collaboration, goodwill, and mutual respect. Toxic behaviors—such as entitlement, hostility, or persistent negativity—threaten not just individual contributors but the health of the project as a whole. Allowing such behaviors to persist robs contributors of their time, energy, and enthusiasm for the work they do.
|
||||
|
||||
By enforcing a zero-tolerance policy, we ensure that the community remains a safe, welcoming space for all participants. These measures are not about harshness—they are about protecting contributors and fostering a productive environment where innovation can thrive.
|
||||
|
||||
Our expectations are clear, and our enforcement reflects our commitment to this project's long-term success.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
||||
version 2.0, available at
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at
|
||||
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
|
||||
|
||||
Community Impact Guidelines were inspired by [Mozilla's code of conduct
|
||||
enforcement ladder](https://github.com/mozilla/diversity).
|
||||
Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity).
|
||||
|
||||
[homepage]: https://www.contributor-covenant.org
|
||||
|
||||
For answers to common questions about this code of conduct, see the FAQ at
|
||||
https://www.contributor-covenant.org/faq. Translations are available at
|
||||
For answers to common questions about this code of conduct, see the FAQ at
|
||||
https://www.contributor-covenant.org/faq. Translations are available at
|
||||
https://www.contributor-covenant.org/translations.
|
||||
|
||||
@@ -1,713 +0,0 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from pydub import AudioSegment
|
||||
from pydub.silence import split_on_silence
|
||||
|
||||
import requests
|
||||
from open_webui.config import (
|
||||
AUDIO_STT_ENGINE,
|
||||
AUDIO_STT_MODEL,
|
||||
AUDIO_STT_OPENAI_API_BASE_URL,
|
||||
AUDIO_STT_OPENAI_API_KEY,
|
||||
AUDIO_TTS_API_KEY,
|
||||
AUDIO_TTS_ENGINE,
|
||||
AUDIO_TTS_MODEL,
|
||||
AUDIO_TTS_OPENAI_API_BASE_URL,
|
||||
AUDIO_TTS_OPENAI_API_KEY,
|
||||
AUDIO_TTS_SPLIT_ON,
|
||||
AUDIO_TTS_VOICE,
|
||||
AUDIO_TTS_AZURE_SPEECH_REGION,
|
||||
AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT,
|
||||
CACHE_DIR,
|
||||
CORS_ALLOW_ORIGIN,
|
||||
WHISPER_MODEL,
|
||||
WHISPER_MODEL_AUTO_UPDATE,
|
||||
WHISPER_MODEL_DIR,
|
||||
AppConfig,
|
||||
)
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import (
|
||||
ENV,
|
||||
SRC_LOG_LEVELS,
|
||||
DEVICE_TYPE,
|
||||
ENABLE_FORWARD_USER_INFO_HEADERS,
|
||||
)
|
||||
|
||||
from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel
|
||||
from open_webui.utils.utils import get_admin_user, get_verified_user
|
||||
|
||||
# Constants
|
||||
MAX_FILE_SIZE_MB = 25
|
||||
MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["AUDIO"])
|
||||
|
||||
app = FastAPI(
|
||||
docs_url="/docs" if ENV == "dev" else None,
|
||||
openapi_url="/openapi.json" if ENV == "dev" else None,
|
||||
redoc_url=None,
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=CORS_ALLOW_ORIGIN,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.state.config = AppConfig()
|
||||
|
||||
app.state.config.STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL
|
||||
app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY
|
||||
app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
|
||||
app.state.config.STT_MODEL = AUDIO_STT_MODEL
|
||||
|
||||
app.state.config.WHISPER_MODEL = WHISPER_MODEL
|
||||
app.state.faster_whisper_model = None
|
||||
|
||||
app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
|
||||
app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
|
||||
app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
|
||||
app.state.config.TTS_MODEL = AUDIO_TTS_MODEL
|
||||
app.state.config.TTS_VOICE = AUDIO_TTS_VOICE
|
||||
app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY
|
||||
app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON
|
||||
|
||||
|
||||
app.state.speech_synthesiser = None
|
||||
app.state.speech_speaker_embeddings_dataset = None
|
||||
|
||||
app.state.config.TTS_AZURE_SPEECH_REGION = AUDIO_TTS_AZURE_SPEECH_REGION
|
||||
app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT
|
||||
|
||||
# setting device type for whisper model
|
||||
whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
|
||||
log.info(f"whisper_device_type: {whisper_device_type}")
|
||||
|
||||
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
|
||||
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def set_faster_whisper_model(model: str, auto_update: bool = False):
|
||||
if model and app.state.config.STT_ENGINE == "":
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
faster_whisper_kwargs = {
|
||||
"model_size_or_path": model,
|
||||
"device": whisper_device_type,
|
||||
"compute_type": "int8",
|
||||
"download_root": WHISPER_MODEL_DIR,
|
||||
"local_files_only": not auto_update,
|
||||
}
|
||||
|
||||
try:
|
||||
app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
|
||||
except Exception:
|
||||
log.warning(
|
||||
"WhisperModel initialization failed, attempting download with local_files_only=False"
|
||||
)
|
||||
faster_whisper_kwargs["local_files_only"] = False
|
||||
app.state.faster_whisper_model = WhisperModel(**faster_whisper_kwargs)
|
||||
|
||||
else:
|
||||
app.state.faster_whisper_model = None
|
||||
|
||||
|
||||
class TTSConfigForm(BaseModel):
|
||||
OPENAI_API_BASE_URL: str
|
||||
OPENAI_API_KEY: str
|
||||
API_KEY: str
|
||||
ENGINE: str
|
||||
MODEL: str
|
||||
VOICE: str
|
||||
SPLIT_ON: str
|
||||
AZURE_SPEECH_REGION: str
|
||||
AZURE_SPEECH_OUTPUT_FORMAT: str
|
||||
|
||||
|
||||
class STTConfigForm(BaseModel):
|
||||
OPENAI_API_BASE_URL: str
|
||||
OPENAI_API_KEY: str
|
||||
ENGINE: str
|
||||
MODEL: str
|
||||
WHISPER_MODEL: str
|
||||
|
||||
|
||||
class AudioConfigUpdateForm(BaseModel):
|
||||
tts: TTSConfigForm
|
||||
stt: STTConfigForm
|
||||
|
||||
|
||||
from pydub import AudioSegment
|
||||
from pydub.utils import mediainfo
|
||||
|
||||
|
||||
def is_mp4_audio(file_path):
|
||||
"""Check if the given file is an MP4 audio file."""
|
||||
if not os.path.isfile(file_path):
|
||||
print(f"File not found: {file_path}")
|
||||
return False
|
||||
|
||||
info = mediainfo(file_path)
|
||||
if (
|
||||
info.get("codec_name") == "aac"
|
||||
and info.get("codec_type") == "audio"
|
||||
and info.get("codec_tag_string") == "mp4a"
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def convert_mp4_to_wav(file_path, output_path):
|
||||
"""Convert MP4 audio file to WAV format."""
|
||||
audio = AudioSegment.from_file(file_path, format="mp4")
|
||||
audio.export(output_path, format="wav")
|
||||
print(f"Converted {file_path} to {output_path}")
|
||||
|
||||
|
||||
@app.get("/config")
|
||||
async def get_audio_config(user=Depends(get_admin_user)):
|
||||
return {
|
||||
"tts": {
|
||||
"OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
|
||||
"OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
|
||||
"API_KEY": app.state.config.TTS_API_KEY,
|
||||
"ENGINE": app.state.config.TTS_ENGINE,
|
||||
"MODEL": app.state.config.TTS_MODEL,
|
||||
"VOICE": app.state.config.TTS_VOICE,
|
||||
"SPLIT_ON": app.state.config.TTS_SPLIT_ON,
|
||||
"AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION,
|
||||
"AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
|
||||
},
|
||||
"stt": {
|
||||
"OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
|
||||
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
|
||||
"ENGINE": app.state.config.STT_ENGINE,
|
||||
"MODEL": app.state.config.STT_MODEL,
|
||||
"WHISPER_MODEL": app.state.config.WHISPER_MODEL,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@app.post("/config/update")
|
||||
async def update_audio_config(
|
||||
form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
|
||||
):
|
||||
app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
|
||||
app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
|
||||
app.state.config.TTS_API_KEY = form_data.tts.API_KEY
|
||||
app.state.config.TTS_ENGINE = form_data.tts.ENGINE
|
||||
app.state.config.TTS_MODEL = form_data.tts.MODEL
|
||||
app.state.config.TTS_VOICE = form_data.tts.VOICE
|
||||
app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
|
||||
app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
|
||||
app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
|
||||
form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
|
||||
)
|
||||
|
||||
app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
|
||||
app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
|
||||
app.state.config.STT_ENGINE = form_data.stt.ENGINE
|
||||
app.state.config.STT_MODEL = form_data.stt.MODEL
|
||||
app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
|
||||
set_faster_whisper_model(form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE)
|
||||
|
||||
return {
|
||||
"tts": {
|
||||
"OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
|
||||
"OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
|
||||
"API_KEY": app.state.config.TTS_API_KEY,
|
||||
"ENGINE": app.state.config.TTS_ENGINE,
|
||||
"MODEL": app.state.config.TTS_MODEL,
|
||||
"VOICE": app.state.config.TTS_VOICE,
|
||||
"SPLIT_ON": app.state.config.TTS_SPLIT_ON,
|
||||
"AZURE_SPEECH_REGION": app.state.config.TTS_AZURE_SPEECH_REGION,
|
||||
"AZURE_SPEECH_OUTPUT_FORMAT": app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
|
||||
},
|
||||
"stt": {
|
||||
"OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
|
||||
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
|
||||
"ENGINE": app.state.config.STT_ENGINE,
|
||||
"MODEL": app.state.config.STT_MODEL,
|
||||
"WHISPER_MODEL": app.state.config.WHISPER_MODEL,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def load_speech_pipeline():
|
||||
from transformers import pipeline
|
||||
from datasets import load_dataset
|
||||
|
||||
if app.state.speech_synthesiser is None:
|
||||
app.state.speech_synthesiser = pipeline(
|
||||
"text-to-speech", "microsoft/speecht5_tts"
|
||||
)
|
||||
|
||||
if app.state.speech_speaker_embeddings_dataset is None:
|
||||
app.state.speech_speaker_embeddings_dataset = load_dataset(
|
||||
"Matthijs/cmu-arctic-xvectors", split="validation"
|
||||
)
|
||||
|
||||
|
||||
@app.post("/speech")
|
||||
async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
body = await request.body()
|
||||
name = hashlib.sha256(body).hexdigest()
|
||||
|
||||
file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
|
||||
file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
|
||||
|
||||
# Check if the file already exists in the cache
|
||||
if file_path.is_file():
|
||||
return FileResponse(file_path)
|
||||
|
||||
if app.state.config.TTS_ENGINE == "openai":
|
||||
headers = {}
|
||||
headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}"
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS:
|
||||
headers["X-OpenWebUI-User-Name"] = user.name
|
||||
headers["X-OpenWebUI-User-Id"] = user.id
|
||||
headers["X-OpenWebUI-User-Email"] = user.email
|
||||
headers["X-OpenWebUI-User-Role"] = user.role
|
||||
|
||||
try:
|
||||
body = body.decode("utf-8")
|
||||
body = json.loads(body)
|
||||
body["model"] = app.state.config.TTS_MODEL
|
||||
body = json.dumps(body).encode("utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
r = None
|
||||
try:
|
||||
r = requests.post(
|
||||
url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
|
||||
data=body,
|
||||
headers=headers,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
|
||||
# Save the streaming content to a file
|
||||
with open(file_path, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
with open(file_body_path, "w") as f:
|
||||
json.dump(json.loads(body.decode("utf-8")), f)
|
||||
|
||||
# Return the saved file
|
||||
return FileResponse(file_path)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"External: {res['error']['message']}"
|
||||
except Exception:
|
||||
error_detail = f"External: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=r.status_code if r != None else 500,
|
||||
detail=error_detail,
|
||||
)
|
||||
|
||||
elif app.state.config.TTS_ENGINE == "elevenlabs":
|
||||
payload = None
|
||||
try:
|
||||
payload = json.loads(body.decode("utf-8"))
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON payload")
|
||||
|
||||
voice_id = payload.get("voice", "")
|
||||
|
||||
if voice_id not in get_available_voices():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid voice id",
|
||||
)
|
||||
|
||||
url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}"
|
||||
|
||||
headers = {
|
||||
"Accept": "audio/mpeg",
|
||||
"Content-Type": "application/json",
|
||||
"xi-api-key": app.state.config.TTS_API_KEY,
|
||||
}
|
||||
|
||||
data = {
|
||||
"text": payload["input"],
|
||||
"model_id": app.state.config.TTS_MODEL,
|
||||
"voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
|
||||
}
|
||||
|
||||
try:
|
||||
r = requests.post(url, json=data, headers=headers)
|
||||
|
||||
r.raise_for_status()
|
||||
|
||||
# Save the streaming content to a file
|
||||
with open(file_path, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
with open(file_body_path, "w") as f:
|
||||
json.dump(json.loads(body.decode("utf-8")), f)
|
||||
|
||||
# Return the saved file
|
||||
return FileResponse(file_path)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"External: {res['error']['message']}"
|
||||
except Exception:
|
||||
error_detail = f"External: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=r.status_code if r != None else 500,
|
||||
detail=error_detail,
|
||||
)
|
||||
|
||||
elif app.state.config.TTS_ENGINE == "azure":
|
||||
payload = None
|
||||
try:
|
||||
payload = json.loads(body.decode("utf-8"))
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON payload")
|
||||
|
||||
region = app.state.config.TTS_AZURE_SPEECH_REGION
|
||||
language = app.state.config.TTS_VOICE
|
||||
locale = "-".join(app.state.config.TTS_VOICE.split("-")[:1])
|
||||
output_format = app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
|
||||
url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1"
|
||||
|
||||
headers = {
|
||||
"Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY,
|
||||
"Content-Type": "application/ssml+xml",
|
||||
"X-Microsoft-OutputFormat": output_format,
|
||||
}
|
||||
|
||||
data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
|
||||
<voice name="{language}">{payload["input"]}</voice>
|
||||
</speak>"""
|
||||
|
||||
response = requests.post(url, headers=headers, data=data)
|
||||
|
||||
if response.status_code == 200:
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
return FileResponse(file_path)
|
||||
else:
|
||||
log.error(f"Error synthesizing speech - {response.reason}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error synthesizing speech - {response.reason}"
|
||||
)
|
||||
elif app.state.config.TTS_ENGINE == "transformers":
|
||||
payload = None
|
||||
try:
|
||||
payload = json.loads(body.decode("utf-8"))
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON payload")
|
||||
|
||||
import torch
|
||||
import soundfile as sf
|
||||
|
||||
load_speech_pipeline()
|
||||
|
||||
embeddings_dataset = app.state.speech_speaker_embeddings_dataset
|
||||
|
||||
speaker_index = 6799
|
||||
try:
|
||||
speaker_index = embeddings_dataset["filename"].index(
|
||||
app.state.config.TTS_MODEL
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
speaker_embedding = torch.tensor(
|
||||
embeddings_dataset[speaker_index]["xvector"]
|
||||
).unsqueeze(0)
|
||||
|
||||
speech = app.state.speech_synthesiser(
|
||||
payload["input"],
|
||||
forward_params={"speaker_embeddings": speaker_embedding},
|
||||
)
|
||||
|
||||
sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"])
|
||||
with open(file_body_path, "w") as f:
|
||||
json.dump(json.loads(body.decode("utf-8")), f)
|
||||
|
||||
return FileResponse(file_path)
|
||||
|
||||
|
||||
def transcribe(file_path):
|
||||
print("transcribe", file_path)
|
||||
filename = os.path.basename(file_path)
|
||||
file_dir = os.path.dirname(file_path)
|
||||
id = filename.split(".")[0]
|
||||
|
||||
if app.state.config.STT_ENGINE == "":
|
||||
if app.state.faster_whisper_model is None:
|
||||
set_faster_whisper_model(app.state.config.WHISPER_MODEL)
|
||||
|
||||
model = app.state.faster_whisper_model
|
||||
segments, info = model.transcribe(file_path, beam_size=5)
|
||||
log.info(
|
||||
"Detected language '%s' with probability %f"
|
||||
% (info.language, info.language_probability)
|
||||
)
|
||||
|
||||
transcript = "".join([segment.text for segment in list(segments)])
|
||||
data = {"text": transcript.strip()}
|
||||
|
||||
# save the transcript to a json file
|
||||
transcript_file = f"{file_dir}/{id}.json"
|
||||
with open(transcript_file, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
log.debug(data)
|
||||
return data
|
||||
elif app.state.config.STT_ENGINE == "openai":
|
||||
if is_mp4_audio(file_path):
|
||||
print("is_mp4_audio")
|
||||
os.rename(file_path, file_path.replace(".wav", ".mp4"))
|
||||
# Convert MP4 audio file to WAV format
|
||||
convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
|
||||
|
||||
headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"}
|
||||
|
||||
files = {"file": (filename, open(file_path, "rb"))}
|
||||
data = {"model": app.state.config.STT_MODEL}
|
||||
|
||||
log.debug(files, data)
|
||||
|
||||
r = None
|
||||
try:
|
||||
r = requests.post(
|
||||
url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
|
||||
headers=headers,
|
||||
files=files,
|
||||
data=data,
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
|
||||
data = r.json()
|
||||
|
||||
# save the transcript to a json file
|
||||
transcript_file = f"{file_dir}/{id}.json"
|
||||
with open(transcript_file, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
print(data)
|
||||
return data
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"External: {res['error']['message']}"
|
||||
except Exception:
|
||||
error_detail = f"External: {e}"
|
||||
|
||||
raise Exception(error_detail)
|
||||
|
||||
|
||||
@app.post("/transcriptions")
|
||||
def transcription(
|
||||
file: UploadFile = File(...),
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
log.info(f"file.content_type: {file.content_type}")
|
||||
|
||||
if file.content_type not in ["audio/mpeg", "audio/wav", "audio/ogg", "audio/x-m4a"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
|
||||
)
|
||||
|
||||
try:
|
||||
ext = file.filename.split(".")[-1]
|
||||
id = uuid.uuid4()
|
||||
|
||||
filename = f"{id}.{ext}"
|
||||
contents = file.file.read()
|
||||
|
||||
file_dir = f"{CACHE_DIR}/audio/transcriptions"
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
file_path = f"{file_dir}/{filename}"
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(contents)
|
||||
|
||||
try:
|
||||
if os.path.getsize(file_path) > MAX_FILE_SIZE: # file is bigger than 25MB
|
||||
log.debug(f"File size is larger than {MAX_FILE_SIZE_MB}MB")
|
||||
audio = AudioSegment.from_file(file_path)
|
||||
audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio
|
||||
compressed_path = f"{file_dir}/{id}_compressed.opus"
|
||||
audio.export(compressed_path, format="opus", bitrate="32k")
|
||||
log.debug(f"Compressed audio to {compressed_path}")
|
||||
file_path = compressed_path
|
||||
|
||||
if (
|
||||
os.path.getsize(file_path) > MAX_FILE_SIZE
|
||||
): # Still larger than 25MB after compression
|
||||
log.debug(
|
||||
f"Compressed file size is still larger than {MAX_FILE_SIZE_MB}MB: {os.path.getsize(file_path)}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.FILE_TOO_LARGE(
|
||||
size=f"{MAX_FILE_SIZE_MB}MB"
|
||||
),
|
||||
)
|
||||
|
||||
data = transcribe(file_path)
|
||||
else:
|
||||
data = transcribe(file_path)
|
||||
|
||||
file_path = file_path.split("/")[-1]
|
||||
return {**data, "filename": file_path}
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
|
||||
|
||||
def get_available_models() -> list[dict]:
|
||||
if app.state.config.TTS_ENGINE == "openai":
|
||||
return [{"id": "tts-1"}, {"id": "tts-1-hd"}]
|
||||
elif app.state.config.TTS_ENGINE == "elevenlabs":
|
||||
headers = {
|
||||
"xi-api-key": app.state.config.TTS_API_KEY,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
"https://api.elevenlabs.io/v1/models", headers=headers, timeout=5
|
||||
)
|
||||
response.raise_for_status()
|
||||
models = response.json()
|
||||
return [
|
||||
{"name": model["name"], "id": model["model_id"]} for model in models
|
||||
]
|
||||
except requests.RequestException as e:
|
||||
log.error(f"Error fetching voices: {str(e)}")
|
||||
return []
|
||||
|
||||
|
||||
@app.get("/models")
|
||||
async def get_models(user=Depends(get_verified_user)):
|
||||
return {"models": get_available_models()}
|
||||
|
||||
|
||||
def get_available_voices() -> dict:
|
||||
"""Returns {voice_id: voice_name} dict"""
|
||||
ret = {}
|
||||
if app.state.config.TTS_ENGINE == "openai":
|
||||
ret = {
|
||||
"alloy": "alloy",
|
||||
"echo": "echo",
|
||||
"fable": "fable",
|
||||
"onyx": "onyx",
|
||||
"nova": "nova",
|
||||
"shimmer": "shimmer",
|
||||
}
|
||||
elif app.state.config.TTS_ENGINE == "elevenlabs":
|
||||
try:
|
||||
ret = get_elevenlabs_voices()
|
||||
except Exception:
|
||||
# Avoided @lru_cache with exception
|
||||
pass
|
||||
elif app.state.config.TTS_ENGINE == "azure":
|
||||
try:
|
||||
region = app.state.config.TTS_AZURE_SPEECH_REGION
|
||||
url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list"
|
||||
headers = {"Ocp-Apim-Subscription-Key": app.state.config.TTS_API_KEY}
|
||||
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
voices = response.json()
|
||||
for voice in voices:
|
||||
ret[voice["ShortName"]] = (
|
||||
f"{voice['DisplayName']} ({voice['ShortName']})"
|
||||
)
|
||||
except requests.RequestException as e:
|
||||
log.error(f"Error fetching voices: {str(e)}")
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_elevenlabs_voices() -> dict:
|
||||
"""
|
||||
Note, set the following in your .env file to use Elevenlabs:
|
||||
AUDIO_TTS_ENGINE=elevenlabs
|
||||
AUDIO_TTS_API_KEY=sk_... # Your Elevenlabs API key
|
||||
AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices
|
||||
AUDIO_TTS_MODEL=eleven_multilingual_v2
|
||||
"""
|
||||
headers = {
|
||||
"xi-api-key": app.state.config.TTS_API_KEY,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
try:
|
||||
# TODO: Add retries
|
||||
response = requests.get("https://api.elevenlabs.io/v1/voices", headers=headers)
|
||||
response.raise_for_status()
|
||||
voices_data = response.json()
|
||||
|
||||
voices = {}
|
||||
for voice in voices_data.get("voices", []):
|
||||
voices[voice["voice_id"]] = voice["name"]
|
||||
except requests.RequestException as e:
|
||||
# Avoid @lru_cache with exception
|
||||
log.error(f"Error fetching voices: {str(e)}")
|
||||
raise RuntimeError(f"Error fetching voices: {str(e)}")
|
||||
|
||||
return voices
|
||||
|
||||
|
||||
@app.get("/voices")
|
||||
async def get_voices(user=Depends(get_verified_user)):
|
||||
return {"voices": [{"id": k, "name": v} for k, v in get_available_voices().items()]}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,22 +0,0 @@
|
||||
from open_webui.config import VECTOR_DB
|
||||
|
||||
if VECTOR_DB == "milvus":
|
||||
from open_webui.apps.retrieval.vector.dbs.milvus import MilvusClient
|
||||
|
||||
VECTOR_DB_CLIENT = MilvusClient()
|
||||
elif VECTOR_DB == "qdrant":
|
||||
from open_webui.apps.retrieval.vector.dbs.qdrant import QdrantClient
|
||||
|
||||
VECTOR_DB_CLIENT = QdrantClient()
|
||||
elif VECTOR_DB == "opensearch":
|
||||
from open_webui.apps.retrieval.vector.dbs.opensearch import OpenSearchClient
|
||||
|
||||
VECTOR_DB_CLIENT = OpenSearchClient()
|
||||
elif VECTOR_DB == "pgvector":
|
||||
from open_webui.apps.retrieval.vector.dbs.pgvector import PgvectorClient
|
||||
|
||||
VECTOR_DB_CLIENT = PgvectorClient()
|
||||
else:
|
||||
from open_webui.apps.retrieval.vector.dbs.chroma import ChromaClient
|
||||
|
||||
VECTOR_DB_CLIENT = ChromaClient()
|
||||
@@ -1,221 +0,0 @@
|
||||
# TODO: move socket to webui app
|
||||
|
||||
import asyncio
|
||||
import socketio
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
|
||||
from open_webui.apps.webui.models.users import Users
|
||||
from open_webui.env import (
|
||||
ENABLE_WEBSOCKET_SUPPORT,
|
||||
WEBSOCKET_MANAGER,
|
||||
WEBSOCKET_REDIS_URL,
|
||||
)
|
||||
from open_webui.utils.utils import decode_token
|
||||
from open_webui.apps.socket.utils import RedisDict
|
||||
|
||||
from open_webui.env import (
|
||||
GLOBAL_LOG_LEVEL,
|
||||
SRC_LOG_LEVELS,
|
||||
)
|
||||
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["SOCKET"])
|
||||
|
||||
|
||||
if WEBSOCKET_MANAGER == "redis":
|
||||
mgr = socketio.AsyncRedisManager(WEBSOCKET_REDIS_URL)
|
||||
sio = socketio.AsyncServer(
|
||||
cors_allowed_origins=[],
|
||||
async_mode="asgi",
|
||||
transports=(
|
||||
["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]
|
||||
),
|
||||
allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
|
||||
always_connect=True,
|
||||
client_manager=mgr,
|
||||
)
|
||||
else:
|
||||
sio = socketio.AsyncServer(
|
||||
cors_allowed_origins=[],
|
||||
async_mode="asgi",
|
||||
transports=(
|
||||
["polling", "websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]
|
||||
),
|
||||
allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
|
||||
always_connect=True,
|
||||
)
|
||||
|
||||
|
||||
# Dictionary to maintain the user pool
|
||||
|
||||
if WEBSOCKET_MANAGER == "redis":
|
||||
SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL)
|
||||
USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL)
|
||||
USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL)
|
||||
else:
|
||||
SESSION_POOL = {}
|
||||
USER_POOL = {}
|
||||
USAGE_POOL = {}
|
||||
|
||||
|
||||
# Timeout duration in seconds
|
||||
TIMEOUT_DURATION = 3
|
||||
|
||||
|
||||
async def periodic_usage_pool_cleanup():
|
||||
while True:
|
||||
now = int(time.time())
|
||||
for model_id, connections in list(USAGE_POOL.items()):
|
||||
# Creating a list of sids to remove if they have timed out
|
||||
expired_sids = [
|
||||
sid
|
||||
for sid, details in connections.items()
|
||||
if now - details["updated_at"] > TIMEOUT_DURATION
|
||||
]
|
||||
|
||||
for sid in expired_sids:
|
||||
del connections[sid]
|
||||
|
||||
if not connections:
|
||||
log.debug(f"Cleaning up model {model_id} from usage pool")
|
||||
del USAGE_POOL[model_id]
|
||||
else:
|
||||
USAGE_POOL[model_id] = connections
|
||||
|
||||
# Emit updated usage information after cleaning
|
||||
await sio.emit("usage", {"models": get_models_in_use()})
|
||||
|
||||
await asyncio.sleep(TIMEOUT_DURATION)
|
||||
|
||||
|
||||
app = socketio.ASGIApp(
|
||||
sio,
|
||||
socketio_path="/ws/socket.io",
|
||||
)
|
||||
|
||||
|
||||
def get_models_in_use():
|
||||
# List models that are currently in use
|
||||
models_in_use = list(USAGE_POOL.keys())
|
||||
return models_in_use
|
||||
|
||||
|
||||
@sio.on("usage")
|
||||
async def usage(sid, data):
|
||||
model_id = data["model"]
|
||||
# Record the timestamp for the last update
|
||||
current_time = int(time.time())
|
||||
|
||||
# Store the new usage data and task
|
||||
USAGE_POOL[model_id] = {
|
||||
**(USAGE_POOL[model_id] if model_id in USAGE_POOL else {}),
|
||||
sid: {"updated_at": current_time},
|
||||
}
|
||||
|
||||
# Broadcast the usage data to all clients
|
||||
await sio.emit("usage", {"models": get_models_in_use()})
|
||||
|
||||
|
||||
@sio.event
|
||||
async def connect(sid, environ, auth):
|
||||
user = None
|
||||
if auth and "token" in auth:
|
||||
data = decode_token(auth["token"])
|
||||
|
||||
if data is not None and "id" in data:
|
||||
user = Users.get_user_by_id(data["id"])
|
||||
|
||||
if user:
|
||||
SESSION_POOL[sid] = user.id
|
||||
if user.id in USER_POOL:
|
||||
USER_POOL[user.id].append(sid)
|
||||
else:
|
||||
USER_POOL[user.id] = [sid]
|
||||
|
||||
# print(f"user {user.name}({user.id}) connected with session ID {sid}")
|
||||
await sio.emit("user-count", {"count": len(USER_POOL.items())})
|
||||
await sio.emit("usage", {"models": get_models_in_use()})
|
||||
|
||||
|
||||
@sio.on("user-join")
|
||||
async def user_join(sid, data):
|
||||
# print("user-join", sid, data)
|
||||
|
||||
auth = data["auth"] if "auth" in data else None
|
||||
if not auth or "token" not in auth:
|
||||
return
|
||||
|
||||
data = decode_token(auth["token"])
|
||||
if data is None or "id" not in data:
|
||||
return
|
||||
|
||||
user = Users.get_user_by_id(data["id"])
|
||||
if not user:
|
||||
return
|
||||
|
||||
SESSION_POOL[sid] = user.id
|
||||
if user.id in USER_POOL:
|
||||
USER_POOL[user.id].append(sid)
|
||||
else:
|
||||
USER_POOL[user.id] = [sid]
|
||||
|
||||
# print(f"user {user.name}({user.id}) connected with session ID {sid}")
|
||||
|
||||
await sio.emit("user-count", {"count": len(USER_POOL.items())})
|
||||
|
||||
|
||||
@sio.on("user-count")
|
||||
async def user_count(sid):
|
||||
await sio.emit("user-count", {"count": len(USER_POOL.items())})
|
||||
|
||||
|
||||
@sio.event
|
||||
async def disconnect(sid):
|
||||
if sid in SESSION_POOL:
|
||||
user_id = SESSION_POOL[sid]
|
||||
del SESSION_POOL[sid]
|
||||
|
||||
USER_POOL[user_id] = [_sid for _sid in USER_POOL[user_id] if _sid != sid]
|
||||
|
||||
if len(USER_POOL[user_id]) == 0:
|
||||
del USER_POOL[user_id]
|
||||
|
||||
await sio.emit("user-count", {"count": len(USER_POOL)})
|
||||
else:
|
||||
pass
|
||||
# print(f"Unknown session ID {sid} disconnected")
|
||||
|
||||
|
||||
def get_event_emitter(request_info):
|
||||
async def __event_emitter__(event_data):
|
||||
await sio.emit(
|
||||
"chat-events",
|
||||
{
|
||||
"chat_id": request_info["chat_id"],
|
||||
"message_id": request_info["message_id"],
|
||||
"data": event_data,
|
||||
},
|
||||
to=request_info["session_id"],
|
||||
)
|
||||
|
||||
return __event_emitter__
|
||||
|
||||
|
||||
def get_event_call(request_info):
|
||||
async def __event_call__(event_data):
|
||||
response = await sio.call(
|
||||
"chat-events",
|
||||
{
|
||||
"chat_id": request_info["chat_id"],
|
||||
"message_id": request_info["message_id"],
|
||||
"data": event_data,
|
||||
},
|
||||
to=request_info["session_id"],
|
||||
)
|
||||
return response
|
||||
|
||||
return __event_call__
|
||||
@@ -1,495 +0,0 @@
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import AsyncGenerator, Generator, Iterator
|
||||
|
||||
from open_webui.apps.socket.main import get_event_call, get_event_emitter
|
||||
from open_webui.apps.webui.models.functions import Functions
|
||||
from open_webui.apps.webui.models.models import Models
|
||||
from open_webui.apps.webui.routers import (
|
||||
auths,
|
||||
chats,
|
||||
folders,
|
||||
configs,
|
||||
groups,
|
||||
files,
|
||||
functions,
|
||||
memories,
|
||||
models,
|
||||
knowledge,
|
||||
prompts,
|
||||
evaluations,
|
||||
tools,
|
||||
users,
|
||||
utils,
|
||||
)
|
||||
from open_webui.apps.webui.utils import load_function_module_by_id
|
||||
from open_webui.config import (
|
||||
ADMIN_EMAIL,
|
||||
CORS_ALLOW_ORIGIN,
|
||||
DEFAULT_MODELS,
|
||||
DEFAULT_PROMPT_SUGGESTIONS,
|
||||
DEFAULT_USER_ROLE,
|
||||
ENABLE_COMMUNITY_SHARING,
|
||||
ENABLE_LOGIN_FORM,
|
||||
ENABLE_MESSAGE_RATING,
|
||||
ENABLE_SIGNUP,
|
||||
ENABLE_API_KEY,
|
||||
ENABLE_EVALUATION_ARENA_MODELS,
|
||||
EVALUATION_ARENA_MODELS,
|
||||
DEFAULT_ARENA_MODEL,
|
||||
JWT_EXPIRES_IN,
|
||||
ENABLE_OAUTH_ROLE_MANAGEMENT,
|
||||
OAUTH_ROLES_CLAIM,
|
||||
OAUTH_EMAIL_CLAIM,
|
||||
OAUTH_PICTURE_CLAIM,
|
||||
OAUTH_USERNAME_CLAIM,
|
||||
OAUTH_ALLOWED_ROLES,
|
||||
OAUTH_ADMIN_ROLES,
|
||||
SHOW_ADMIN_DETAILS,
|
||||
USER_PERMISSIONS,
|
||||
WEBHOOK_URL,
|
||||
WEBUI_AUTH,
|
||||
WEBUI_BANNERS,
|
||||
ENABLE_LDAP,
|
||||
LDAP_SERVER_LABEL,
|
||||
LDAP_SERVER_HOST,
|
||||
LDAP_SERVER_PORT,
|
||||
LDAP_ATTRIBUTE_FOR_USERNAME,
|
||||
LDAP_SEARCH_FILTERS,
|
||||
LDAP_SEARCH_BASE,
|
||||
LDAP_APP_DN,
|
||||
LDAP_APP_PASSWORD,
|
||||
LDAP_USE_TLS,
|
||||
LDAP_CA_CERT_FILE,
|
||||
LDAP_CIPHERS,
|
||||
AppConfig,
|
||||
)
|
||||
from open_webui.env import (
|
||||
ENV,
|
||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
||||
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
||||
)
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from open_webui.utils.misc import (
|
||||
openai_chat_chunk_message_template,
|
||||
openai_chat_completion_message_template,
|
||||
)
|
||||
from open_webui.utils.payload import (
|
||||
apply_model_params_to_body_openai,
|
||||
apply_model_system_prompt_to_body,
|
||||
)
|
||||
|
||||
|
||||
from open_webui.utils.tools import get_tools
|
||||
|
||||
app = FastAPI(
|
||||
docs_url="/docs" if ENV == "dev" else None,
|
||||
openapi_url="/openapi.json" if ENV == "dev" else None,
|
||||
redoc_url=None,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
app.state.config = AppConfig()
|
||||
|
||||
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
|
||||
app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM
|
||||
app.state.config.ENABLE_API_KEY = ENABLE_API_KEY
|
||||
|
||||
app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
|
||||
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
|
||||
app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
|
||||
|
||||
|
||||
app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
|
||||
app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
|
||||
|
||||
|
||||
app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
|
||||
app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
|
||||
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
|
||||
|
||||
|
||||
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
|
||||
app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
||||
app.state.config.BANNERS = WEBUI_BANNERS
|
||||
|
||||
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
|
||||
app.state.config.ENABLE_MESSAGE_RATING = ENABLE_MESSAGE_RATING
|
||||
|
||||
app.state.config.ENABLE_EVALUATION_ARENA_MODELS = ENABLE_EVALUATION_ARENA_MODELS
|
||||
app.state.config.EVALUATION_ARENA_MODELS = EVALUATION_ARENA_MODELS
|
||||
|
||||
app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
|
||||
app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
|
||||
app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
|
||||
|
||||
app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
|
||||
app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
|
||||
app.state.config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
|
||||
app.state.config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
|
||||
|
||||
app.state.config.ENABLE_LDAP = ENABLE_LDAP
|
||||
app.state.config.LDAP_SERVER_LABEL = LDAP_SERVER_LABEL
|
||||
app.state.config.LDAP_SERVER_HOST = LDAP_SERVER_HOST
|
||||
app.state.config.LDAP_SERVER_PORT = LDAP_SERVER_PORT
|
||||
app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = LDAP_ATTRIBUTE_FOR_USERNAME
|
||||
app.state.config.LDAP_APP_DN = LDAP_APP_DN
|
||||
app.state.config.LDAP_APP_PASSWORD = LDAP_APP_PASSWORD
|
||||
app.state.config.LDAP_SEARCH_BASE = LDAP_SEARCH_BASE
|
||||
app.state.config.LDAP_SEARCH_FILTERS = LDAP_SEARCH_FILTERS
|
||||
app.state.config.LDAP_USE_TLS = LDAP_USE_TLS
|
||||
app.state.config.LDAP_CA_CERT_FILE = LDAP_CA_CERT_FILE
|
||||
app.state.config.LDAP_CIPHERS = LDAP_CIPHERS
|
||||
|
||||
app.state.TOOLS = {}
|
||||
app.state.FUNCTIONS = {}
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=CORS_ALLOW_ORIGIN,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
app.include_router(configs.router, prefix="/configs", tags=["configs"])
|
||||
|
||||
app.include_router(auths.router, prefix="/auths", tags=["auths"])
|
||||
app.include_router(users.router, prefix="/users", tags=["users"])
|
||||
|
||||
app.include_router(chats.router, prefix="/chats", tags=["chats"])
|
||||
|
||||
app.include_router(models.router, prefix="/models", tags=["models"])
|
||||
app.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"])
|
||||
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
|
||||
app.include_router(tools.router, prefix="/tools", tags=["tools"])
|
||||
|
||||
app.include_router(memories.router, prefix="/memories", tags=["memories"])
|
||||
app.include_router(folders.router, prefix="/folders", tags=["folders"])
|
||||
|
||||
app.include_router(groups.router, prefix="/groups", tags=["groups"])
|
||||
app.include_router(files.router, prefix="/files", tags=["files"])
|
||||
app.include_router(functions.router, prefix="/functions", tags=["functions"])
|
||||
app.include_router(evaluations.router, prefix="/evaluations", tags=["evaluations"])
|
||||
|
||||
|
||||
app.include_router(utils.router, prefix="/utils", tags=["utils"])
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def get_status():
|
||||
return {
|
||||
"status": True,
|
||||
"auth": WEBUI_AUTH,
|
||||
"default_models": app.state.config.DEFAULT_MODELS,
|
||||
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
|
||||
}
|
||||
|
||||
|
||||
async def get_all_models():
|
||||
models = []
|
||||
pipe_models = await get_pipe_models()
|
||||
models = models + pipe_models
|
||||
|
||||
if app.state.config.ENABLE_EVALUATION_ARENA_MODELS:
|
||||
arena_models = []
|
||||
if len(app.state.config.EVALUATION_ARENA_MODELS) > 0:
|
||||
arena_models = [
|
||||
{
|
||||
"id": model["id"],
|
||||
"name": model["name"],
|
||||
"info": {
|
||||
"meta": model["meta"],
|
||||
},
|
||||
"object": "model",
|
||||
"created": int(time.time()),
|
||||
"owned_by": "arena",
|
||||
"arena": True,
|
||||
}
|
||||
for model in app.state.config.EVALUATION_ARENA_MODELS
|
||||
]
|
||||
else:
|
||||
# Add default arena model
|
||||
arena_models = [
|
||||
{
|
||||
"id": DEFAULT_ARENA_MODEL["id"],
|
||||
"name": DEFAULT_ARENA_MODEL["name"],
|
||||
"info": {
|
||||
"meta": DEFAULT_ARENA_MODEL["meta"],
|
||||
},
|
||||
"object": "model",
|
||||
"created": int(time.time()),
|
||||
"owned_by": "arena",
|
||||
"arena": True,
|
||||
}
|
||||
]
|
||||
models = models + arena_models
|
||||
return models
|
||||
|
||||
|
||||
def get_function_module(pipe_id: str):
|
||||
# Check if function is already loaded
|
||||
if pipe_id not in app.state.FUNCTIONS:
|
||||
function_module, _, _ = load_function_module_by_id(pipe_id)
|
||||
app.state.FUNCTIONS[pipe_id] = function_module
|
||||
else:
|
||||
function_module = app.state.FUNCTIONS[pipe_id]
|
||||
|
||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||
valves = Functions.get_function_valves_by_id(pipe_id)
|
||||
function_module.valves = function_module.Valves(**(valves if valves else {}))
|
||||
return function_module
|
||||
|
||||
|
||||
async def get_pipe_models():
|
||||
pipes = Functions.get_functions_by_type("pipe", active_only=True)
|
||||
pipe_models = []
|
||||
|
||||
for pipe in pipes:
|
||||
function_module = get_function_module(pipe.id)
|
||||
|
||||
# Check if function is a manifold
|
||||
if hasattr(function_module, "pipes"):
|
||||
sub_pipes = []
|
||||
|
||||
# Check if pipes is a function or a list
|
||||
|
||||
try:
|
||||
if callable(function_module.pipes):
|
||||
sub_pipes = function_module.pipes()
|
||||
else:
|
||||
sub_pipes = function_module.pipes
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
sub_pipes = []
|
||||
|
||||
print(sub_pipes)
|
||||
|
||||
for p in sub_pipes:
|
||||
sub_pipe_id = f'{pipe.id}.{p["id"]}'
|
||||
sub_pipe_name = p["name"]
|
||||
|
||||
if hasattr(function_module, "name"):
|
||||
sub_pipe_name = f"{function_module.name}{sub_pipe_name}"
|
||||
|
||||
pipe_flag = {"type": pipe.type}
|
||||
pipe_models.append(
|
||||
{
|
||||
"id": sub_pipe_id,
|
||||
"name": sub_pipe_name,
|
||||
"object": "model",
|
||||
"created": pipe.created_at,
|
||||
"owned_by": "openai",
|
||||
"pipe": pipe_flag,
|
||||
}
|
||||
)
|
||||
else:
|
||||
pipe_flag = {"type": "pipe"}
|
||||
|
||||
pipe_models.append(
|
||||
{
|
||||
"id": pipe.id,
|
||||
"name": pipe.name,
|
||||
"object": "model",
|
||||
"created": pipe.created_at,
|
||||
"owned_by": "openai",
|
||||
"pipe": pipe_flag,
|
||||
}
|
||||
)
|
||||
|
||||
return pipe_models
|
||||
|
||||
|
||||
async def execute_pipe(pipe, params):
|
||||
if inspect.iscoroutinefunction(pipe):
|
||||
return await pipe(**params)
|
||||
else:
|
||||
return pipe(**params)
|
||||
|
||||
|
||||
async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
|
||||
if isinstance(res, str):
|
||||
return res
|
||||
if isinstance(res, Generator):
|
||||
return "".join(map(str, res))
|
||||
if isinstance(res, AsyncGenerator):
|
||||
return "".join([str(stream) async for stream in res])
|
||||
|
||||
|
||||
def process_line(form_data: dict, line):
|
||||
if isinstance(line, BaseModel):
|
||||
line = line.model_dump_json()
|
||||
line = f"data: {line}"
|
||||
if isinstance(line, dict):
|
||||
line = f"data: {json.dumps(line)}"
|
||||
|
||||
try:
|
||||
line = line.decode("utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if line.startswith("data:"):
|
||||
return f"{line}\n\n"
|
||||
else:
|
||||
line = openai_chat_chunk_message_template(form_data["model"], line)
|
||||
return f"data: {json.dumps(line)}\n\n"
|
||||
|
||||
|
||||
def get_pipe_id(form_data: dict) -> str:
|
||||
pipe_id = form_data["model"]
|
||||
if "." in pipe_id:
|
||||
pipe_id, _ = pipe_id.split(".", 1)
|
||||
print(pipe_id)
|
||||
return pipe_id
|
||||
|
||||
|
||||
def get_function_params(function_module, form_data, user, extra_params=None):
|
||||
if extra_params is None:
|
||||
extra_params = {}
|
||||
|
||||
pipe_id = get_pipe_id(form_data)
|
||||
|
||||
# Get the signature of the function
|
||||
sig = inspect.signature(function_module.pipe)
|
||||
params = {"body": form_data} | {
|
||||
k: v for k, v in extra_params.items() if k in sig.parameters
|
||||
}
|
||||
|
||||
if "__user__" in params and hasattr(function_module, "UserValves"):
|
||||
user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
|
||||
try:
|
||||
params["__user__"]["valves"] = function_module.UserValves(**user_valves)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
params["__user__"]["valves"] = function_module.UserValves()
|
||||
|
||||
return params
|
||||
|
||||
|
||||
async def generate_function_chat_completion(form_data, user, models: dict = {}):
|
||||
model_id = form_data.get("model")
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
metadata = form_data.pop("metadata", {})
|
||||
|
||||
files = metadata.get("files", [])
|
||||
tool_ids = metadata.get("tool_ids", [])
|
||||
# Check if tool_ids is None
|
||||
if tool_ids is None:
|
||||
tool_ids = []
|
||||
|
||||
__event_emitter__ = None
|
||||
__event_call__ = None
|
||||
__task__ = None
|
||||
__task_body__ = None
|
||||
|
||||
if metadata:
|
||||
if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
|
||||
__event_emitter__ = get_event_emitter(metadata)
|
||||
__event_call__ = get_event_call(metadata)
|
||||
__task__ = metadata.get("task", None)
|
||||
__task_body__ = metadata.get("task_body", None)
|
||||
|
||||
extra_params = {
|
||||
"__event_emitter__": __event_emitter__,
|
||||
"__event_call__": __event_call__,
|
||||
"__task__": __task__,
|
||||
"__task_body__": __task_body__,
|
||||
"__files__": files,
|
||||
"__user__": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
"__metadata__": metadata,
|
||||
}
|
||||
extra_params["__tools__"] = get_tools(
|
||||
app,
|
||||
tool_ids,
|
||||
user,
|
||||
{
|
||||
**extra_params,
|
||||
"__model__": models.get(form_data["model"], None),
|
||||
"__messages__": form_data["messages"],
|
||||
"__files__": files,
|
||||
},
|
||||
)
|
||||
|
||||
if model_info:
|
||||
if model_info.base_model_id:
|
||||
form_data["model"] = model_info.base_model_id
|
||||
|
||||
params = model_info.params.model_dump()
|
||||
form_data = apply_model_params_to_body_openai(params, form_data)
|
||||
form_data = apply_model_system_prompt_to_body(params, form_data, user)
|
||||
|
||||
pipe_id = get_pipe_id(form_data)
|
||||
function_module = get_function_module(pipe_id)
|
||||
|
||||
pipe = function_module.pipe
|
||||
params = get_function_params(function_module, form_data, user, extra_params)
|
||||
|
||||
if form_data.get("stream", False):
|
||||
|
||||
async def stream_content():
|
||||
try:
|
||||
res = await execute_pipe(pipe, params)
|
||||
|
||||
# Directly return if the response is a StreamingResponse
|
||||
if isinstance(res, StreamingResponse):
|
||||
async for data in res.body_iterator:
|
||||
yield data
|
||||
return
|
||||
if isinstance(res, dict):
|
||||
yield f"data: {json.dumps(res)}\n\n"
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
|
||||
return
|
||||
|
||||
if isinstance(res, str):
|
||||
message = openai_chat_chunk_message_template(form_data["model"], res)
|
||||
yield f"data: {json.dumps(message)}\n\n"
|
||||
|
||||
if isinstance(res, Iterator):
|
||||
for line in res:
|
||||
yield process_line(form_data, line)
|
||||
|
||||
if isinstance(res, AsyncGenerator):
|
||||
async for line in res:
|
||||
yield process_line(form_data, line)
|
||||
|
||||
if isinstance(res, str) or isinstance(res, Generator):
|
||||
finish_message = openai_chat_chunk_message_template(
|
||||
form_data["model"], ""
|
||||
)
|
||||
finish_message["choices"][0]["finish_reason"] = "stop"
|
||||
yield f"data: {json.dumps(finish_message)}\n\n"
|
||||
yield "data: [DONE]"
|
||||
|
||||
return StreamingResponse(stream_content(), media_type="text/event-stream")
|
||||
else:
|
||||
try:
|
||||
res = await execute_pipe(pipe, params)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return {"error": {"detail": str(e)}}
|
||||
|
||||
if isinstance(res, StreamingResponse) or isinstance(res, dict):
|
||||
return res
|
||||
if isinstance(res, BaseModel):
|
||||
return res.model_dump()
|
||||
|
||||
message = await get_message_content(res)
|
||||
return openai_chat_completion_message_template(form_data["model"], message)
|
||||
@@ -10,7 +10,7 @@ from urllib.parse import urlparse
|
||||
import chromadb
|
||||
import requests
|
||||
import yaml
|
||||
from open_webui.apps.webui.internal.db import Base, get_db
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.env import (
|
||||
OPEN_WEBUI_DIR,
|
||||
DATA_DIR,
|
||||
@@ -21,6 +21,7 @@ from open_webui.env import (
|
||||
WEBUI_NAME,
|
||||
log,
|
||||
DATABASE_URL,
|
||||
OFFLINE_MODE,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import JSON, Column, DateTime, Integer, func
|
||||
@@ -271,6 +272,18 @@ ENABLE_API_KEY = PersistentConfig(
|
||||
os.environ.get("ENABLE_API_KEY", "True").lower() == "true",
|
||||
)
|
||||
|
||||
ENABLE_API_KEY_ENDPOINT_RESTRICTIONS = PersistentConfig(
|
||||
"ENABLE_API_KEY_ENDPOINT_RESTRICTIONS",
|
||||
"auth.api_key.endpoint_restrictions",
|
||||
os.environ.get("ENABLE_API_KEY_ENDPOINT_RESTRICTIONS", "False").lower() == "true",
|
||||
)
|
||||
|
||||
API_KEY_ALLOWED_ENDPOINTS = PersistentConfig(
|
||||
"API_KEY_ALLOWED_ENDPOINTS",
|
||||
"auth.api_key.allowed_endpoints",
|
||||
os.environ.get("API_KEY_ALLOWED_ENDPOINTS", ""),
|
||||
)
|
||||
|
||||
|
||||
JWT_EXPIRES_IN = PersistentConfig(
|
||||
"JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
|
||||
@@ -306,6 +319,7 @@ GOOGLE_CLIENT_SECRET = PersistentConfig(
|
||||
os.environ.get("GOOGLE_CLIENT_SECRET", ""),
|
||||
)
|
||||
|
||||
|
||||
GOOGLE_OAUTH_SCOPE = PersistentConfig(
|
||||
"GOOGLE_OAUTH_SCOPE",
|
||||
"oauth.google.scope",
|
||||
@@ -402,12 +416,24 @@ OAUTH_EMAIL_CLAIM = PersistentConfig(
|
||||
os.environ.get("OAUTH_EMAIL_CLAIM", "email"),
|
||||
)
|
||||
|
||||
OAUTH_GROUPS_CLAIM = PersistentConfig(
|
||||
"OAUTH_GROUPS_CLAIM",
|
||||
"oauth.oidc.group_claim",
|
||||
os.environ.get("OAUTH_GROUP_CLAIM", "groups"),
|
||||
)
|
||||
|
||||
ENABLE_OAUTH_ROLE_MANAGEMENT = PersistentConfig(
|
||||
"ENABLE_OAUTH_ROLE_MANAGEMENT",
|
||||
"oauth.enable_role_mapping",
|
||||
os.environ.get("ENABLE_OAUTH_ROLE_MANAGEMENT", "False").lower() == "true",
|
||||
)
|
||||
|
||||
ENABLE_OAUTH_GROUP_MANAGEMENT = PersistentConfig(
|
||||
"ENABLE_OAUTH_GROUP_MANAGEMENT",
|
||||
"oauth.enable_group_mapping",
|
||||
os.environ.get("ENABLE_OAUTH_GROUP_MANAGEMENT", "False").lower() == "true",
|
||||
)
|
||||
|
||||
OAUTH_ROLES_CLAIM = PersistentConfig(
|
||||
"OAUTH_ROLES_CLAIM",
|
||||
"oauth.roles_claim",
|
||||
@@ -429,6 +455,15 @@ OAUTH_ADMIN_ROLES = PersistentConfig(
|
||||
[role.strip() for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split(",")],
|
||||
)
|
||||
|
||||
OAUTH_ALLOWED_DOMAINS = PersistentConfig(
|
||||
"OAUTH_ALLOWED_DOMAINS",
|
||||
"oauth.allowed_domains",
|
||||
[
|
||||
domain.strip()
|
||||
for domain in os.environ.get("OAUTH_ALLOWED_DOMAINS", "*").split(",")
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def load_oauth_providers():
|
||||
OAUTH_PROVIDERS.clear()
|
||||
@@ -583,6 +618,12 @@ OLLAMA_API_BASE_URL = os.environ.get(
|
||||
)
|
||||
|
||||
OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "")
|
||||
if OLLAMA_BASE_URL:
|
||||
# Remove trailing slash
|
||||
OLLAMA_BASE_URL = (
|
||||
OLLAMA_BASE_URL[:-1] if OLLAMA_BASE_URL.endswith("/") else OLLAMA_BASE_URL
|
||||
)
|
||||
|
||||
|
||||
K8S_FLAG = os.environ.get("K8S_FLAG", "")
|
||||
USE_OLLAMA_DOCKER = os.environ.get("USE_OLLAMA_DOCKER", "false")
|
||||
@@ -680,6 +721,12 @@ OPENAI_API_BASE_URL = "https://api.openai.com/v1"
|
||||
# WEBUI
|
||||
####################################
|
||||
|
||||
|
||||
WEBUI_URL = PersistentConfig(
|
||||
"WEBUI_URL", "webui.url", os.environ.get("WEBUI_URL", "http://localhost:3000")
|
||||
)
|
||||
|
||||
|
||||
ENABLE_SIGNUP = PersistentConfig(
|
||||
"ENABLE_SIGNUP",
|
||||
"ui.enable_signup",
|
||||
@@ -696,6 +743,7 @@ ENABLE_LOGIN_FORM = PersistentConfig(
|
||||
os.environ.get("ENABLE_LOGIN_FORM", "True").lower() == "true",
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_LOCALE = PersistentConfig(
|
||||
"DEFAULT_LOCALE",
|
||||
"ui.default_locale",
|
||||
@@ -740,13 +788,18 @@ DEFAULT_PROMPT_SUGGESTIONS = PersistentConfig(
|
||||
],
|
||||
)
|
||||
|
||||
MODEL_ORDER_LIST = PersistentConfig(
|
||||
"MODEL_ORDER_LIST",
|
||||
"ui.model_order_list",
|
||||
[],
|
||||
)
|
||||
|
||||
DEFAULT_USER_ROLE = PersistentConfig(
|
||||
"DEFAULT_USER_ROLE",
|
||||
"ui.default_user_role",
|
||||
os.getenv("DEFAULT_USER_ROLE", "pending"),
|
||||
)
|
||||
|
||||
|
||||
USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS = (
|
||||
os.environ.get("USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS", "False").lower()
|
||||
== "true"
|
||||
@@ -801,6 +854,12 @@ USER_PERMISSIONS = PersistentConfig(
|
||||
},
|
||||
)
|
||||
|
||||
ENABLE_CHANNELS = PersistentConfig(
|
||||
"ENABLE_CHANNELS",
|
||||
"channels.enable",
|
||||
os.environ.get("ENABLE_CHANNELS", "False").lower() == "true",
|
||||
)
|
||||
|
||||
|
||||
ENABLE_EVALUATION_ARENA_MODELS = PersistentConfig(
|
||||
"ENABLE_EVALUATION_ARENA_MODELS",
|
||||
@@ -936,12 +995,45 @@ TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
|
||||
os.environ.get("TITLE_GENERATION_PROMPT_TEMPLATE", ""),
|
||||
)
|
||||
|
||||
DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE = """Create a concise, 3-5 word title with an emoji as a title for the chat history, in the given language. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT.
|
||||
|
||||
Examples of titles:
|
||||
📉 Stock Market Trends
|
||||
🍪 Perfect Chocolate Chip Recipe
|
||||
Evolution of Music Streaming
|
||||
Remote Work Productivity Tips
|
||||
Artificial Intelligence in Healthcare
|
||||
🎮 Video Game Development Insights
|
||||
|
||||
<chat_history>
|
||||
{{MESSAGES:END:2}}
|
||||
</chat_history>"""
|
||||
|
||||
|
||||
TAGS_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
|
||||
"TAGS_GENERATION_PROMPT_TEMPLATE",
|
||||
"task.tags.prompt_template",
|
||||
os.environ.get("TAGS_GENERATION_PROMPT_TEMPLATE", ""),
|
||||
)
|
||||
|
||||
DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE = """### Task:
|
||||
Generate 1-3 broad tags categorizing the main themes of the chat history, along with 1-3 more specific subtopic tags.
|
||||
|
||||
### Guidelines:
|
||||
- Start with high-level domains (e.g. Science, Technology, Philosophy, Arts, Politics, Business, Health, Sports, Entertainment, Education)
|
||||
- Consider including relevant subfields/subdomains if they are strongly represented throughout the conversation
|
||||
- If content is too short (less than 3 messages) or too diverse, use only ["General"]
|
||||
- Use the chat's primary language; default to English if multilingual
|
||||
- Prioritize accuracy over specificity
|
||||
|
||||
### Output:
|
||||
JSON format: { "tags": ["tag1", "tag2", "tag3"] }
|
||||
|
||||
### Chat History:
|
||||
<chat_history>
|
||||
{{MESSAGES:END:6}}
|
||||
</chat_history>"""
|
||||
|
||||
ENABLE_TAGS_GENERATION = PersistentConfig(
|
||||
"ENABLE_TAGS_GENERATION",
|
||||
"task.tags.enable",
|
||||
@@ -969,19 +1061,20 @@ QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
|
||||
)
|
||||
|
||||
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE = """### Task:
|
||||
Based on the chat history, determine whether a search is necessary, and if so, generate a 1-3 broad search queries to retrieve comprehensive and updated information. If no search is required, return an empty list.
|
||||
Analyze the chat history to determine the necessity of generating search queries, in the given language. By default, **prioritize generating 1-3 broad and relevant search queries** unless it is absolutely certain that no additional information is required. The aim is to retrieve comprehensive, updated, and valuable information even with minimal uncertainty. If no search is unequivocally needed, return an empty list.
|
||||
|
||||
### Guidelines:
|
||||
- Respond exclusively with a JSON object.
|
||||
- If a search query is needed, return an object like: { "queries": ["query1", "query2"] } where each query is distinct and concise.
|
||||
- If no search query is necessary, output should be: { "queries": [] }
|
||||
- Default to suggesting a search query to ensure accurate and updated information, unless it is definitively clear no search is required.
|
||||
- Be concise, focusing strictly on composing search queries with no additional commentary or text.
|
||||
- When in doubt, prefer to suggest a search for comprehensiveness.
|
||||
- Today's date is: {{CURRENT_DATE}}
|
||||
- Respond **EXCLUSIVELY** with a JSON object. Any form of extra commentary, explanation, or additional text is strictly prohibited.
|
||||
- When generating search queries, respond in the format: { "queries": ["query1", "query2"] }, ensuring each query is distinct, concise, and relevant to the topic.
|
||||
- If and only if it is entirely certain that no useful results can be retrieved by a search, return: { "queries": [] }.
|
||||
- Err on the side of suggesting search queries if there is **any chance** they might provide useful or updated information.
|
||||
- Be concise and focused on composing high-quality search queries, avoiding unnecessary elaboration, commentary, or assumptions.
|
||||
- Today's date is: {{CURRENT_DATE}}.
|
||||
- Always prioritize providing actionable and broad queries that maximize informational coverage.
|
||||
|
||||
### Output:
|
||||
JSON format: {
|
||||
Strictly return in JSON format:
|
||||
{
|
||||
"queries": ["query1", "query2"]
|
||||
}
|
||||
|
||||
@@ -991,6 +1084,66 @@ JSON format: {
|
||||
</chat_history>
|
||||
"""
|
||||
|
||||
ENABLE_AUTOCOMPLETE_GENERATION = PersistentConfig(
|
||||
"ENABLE_AUTOCOMPLETE_GENERATION",
|
||||
"task.autocomplete.enable",
|
||||
os.environ.get("ENABLE_AUTOCOMPLETE_GENERATION", "True").lower() == "true",
|
||||
)
|
||||
|
||||
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = PersistentConfig(
|
||||
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH",
|
||||
"task.autocomplete.input_max_length",
|
||||
int(os.environ.get("AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH", "-1")),
|
||||
)
|
||||
|
||||
AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
|
||||
"AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE",
|
||||
"task.autocomplete.prompt_template",
|
||||
os.environ.get("AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE", ""),
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE = """### Task:
|
||||
You are an autocompletion system. Continue the text in `<text>` based on the **completion type** in `<type>` and the given language.
|
||||
|
||||
### **Instructions**:
|
||||
1. Analyze `<text>` for context and meaning.
|
||||
2. Use `<type>` to guide your output:
|
||||
- **General**: Provide a natural, concise continuation.
|
||||
- **Search Query**: Complete as if generating a realistic search query.
|
||||
3. Start as if you are directly continuing `<text>`. Do **not** repeat, paraphrase, or respond as a model. Simply complete the text.
|
||||
4. Ensure the continuation:
|
||||
- Flows naturally from `<text>`.
|
||||
- Avoids repetition, overexplaining, or unrelated ideas.
|
||||
5. If unsure, return: `{ "text": "" }`.
|
||||
|
||||
### **Output Rules**:
|
||||
- Respond only in JSON format: `{ "text": "<your_completion>" }`.
|
||||
|
||||
### **Examples**:
|
||||
#### Example 1:
|
||||
Input:
|
||||
<type>General</type>
|
||||
<text>The sun was setting over the horizon, painting the sky</text>
|
||||
Output:
|
||||
{ "text": "with vibrant shades of orange and pink." }
|
||||
|
||||
#### Example 2:
|
||||
Input:
|
||||
<type>Search Query</type>
|
||||
<text>Top-rated restaurants in</text>
|
||||
Output:
|
||||
{ "text": "New York City for Italian cuisine." }
|
||||
|
||||
---
|
||||
### Context:
|
||||
<chat_history>
|
||||
{{MESSAGES:END:6}}
|
||||
</chat_history>
|
||||
<type>{{TYPE}}</type>
|
||||
<text>{{PROMPT}}</text>
|
||||
#### Output:
|
||||
"""
|
||||
|
||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
|
||||
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
|
||||
@@ -999,6 +1152,19 @@ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = """Available Tools: {{TOOLS}}\nReturn an empty string if no tools match the query. If a function tool matches, construct and return a JSON object in the format {\"name\": \"functionName\", \"parameters\": {\"requiredFunctionParamKey\": \"requiredFunctionParamValue\"}} using the appropriate tool and its parameters. Only return the object and limit the response to the JSON object without additional text."""
|
||||
|
||||
|
||||
DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE = """Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
|
||||
|
||||
Message: ```{{prompt}}```"""
|
||||
|
||||
DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}"
|
||||
|
||||
Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability.
|
||||
|
||||
Responses from models: {{responses}}"""
|
||||
|
||||
####################################
|
||||
# Vector Database
|
||||
####################################
|
||||
@@ -1050,6 +1216,26 @@ if VECTOR_DB == "pgvector" and not PGVECTOR_DB_URL.startswith("postgres"):
|
||||
# Information Retrieval (RAG)
|
||||
####################################
|
||||
|
||||
|
||||
# If configured, Google Drive will be available as an upload option.
|
||||
ENABLE_GOOGLE_DRIVE_INTEGRATION = PersistentConfig(
|
||||
"ENABLE_GOOGLE_DRIVE_INTEGRATION",
|
||||
"google_drive.enable",
|
||||
os.getenv("ENABLE_GOOGLE_DRIVE_INTEGRATION", "False").lower() == "true",
|
||||
)
|
||||
|
||||
GOOGLE_DRIVE_CLIENT_ID = PersistentConfig(
|
||||
"GOOGLE_DRIVE_CLIENT_ID",
|
||||
"google_drive.client_id",
|
||||
os.environ.get("GOOGLE_DRIVE_CLIENT_ID", ""),
|
||||
)
|
||||
|
||||
GOOGLE_DRIVE_API_KEY = PersistentConfig(
|
||||
"GOOGLE_DRIVE_API_KEY",
|
||||
"google_drive.api_key",
|
||||
os.environ.get("GOOGLE_DRIVE_API_KEY", ""),
|
||||
)
|
||||
|
||||
# RAG Content Extraction
|
||||
CONTENT_EXTRACTION_ENGINE = PersistentConfig(
|
||||
"CONTENT_EXTRACTION_ENGINE",
|
||||
@@ -1124,7 +1310,8 @@ RAG_EMBEDDING_MODEL = PersistentConfig(
|
||||
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}")
|
||||
|
||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
|
||||
os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "True").lower() == "true"
|
||||
not OFFLINE_MODE
|
||||
and os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "True").lower() == "true"
|
||||
)
|
||||
|
||||
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
|
||||
@@ -1149,7 +1336,8 @@ if RAG_RERANKING_MODEL.value != "":
|
||||
log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}")
|
||||
|
||||
RAG_RERANKING_MODEL_AUTO_UPDATE = (
|
||||
os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "True").lower() == "true"
|
||||
not OFFLINE_MODE
|
||||
and os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "True").lower() == "true"
|
||||
)
|
||||
|
||||
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
|
||||
@@ -1252,6 +1440,12 @@ YOUTUBE_LOADER_LANGUAGE = PersistentConfig(
|
||||
os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(","),
|
||||
)
|
||||
|
||||
YOUTUBE_LOADER_PROXY_URL = PersistentConfig(
|
||||
"YOUTUBE_LOADER_PROXY_URL",
|
||||
"rag.youtube_loader_proxy_url",
|
||||
os.getenv("YOUTUBE_LOADER_PROXY_URL", ""),
|
||||
)
|
||||
|
||||
|
||||
ENABLE_RAG_WEB_SEARCH = PersistentConfig(
|
||||
"ENABLE_RAG_WEB_SEARCH",
|
||||
@@ -1277,6 +1471,7 @@ RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
SEARXNG_QUERY_URL = PersistentConfig(
|
||||
"SEARXNG_QUERY_URL",
|
||||
"rag.web.search.searxng_query_url",
|
||||
@@ -1301,6 +1496,12 @@ BRAVE_SEARCH_API_KEY = PersistentConfig(
|
||||
os.getenv("BRAVE_SEARCH_API_KEY", ""),
|
||||
)
|
||||
|
||||
KAGI_SEARCH_API_KEY = PersistentConfig(
|
||||
"KAGI_SEARCH_API_KEY",
|
||||
"rag.web.search.kagi_search_api_key",
|
||||
os.getenv("KAGI_SEARCH_API_KEY", ""),
|
||||
)
|
||||
|
||||
MOJEEK_SEARCH_API_KEY = PersistentConfig(
|
||||
"MOJEEK_SEARCH_API_KEY",
|
||||
"rag.web.search.mojeek_search_api_key",
|
||||
@@ -1446,6 +1647,12 @@ COMFYUI_BASE_URL = PersistentConfig(
|
||||
os.getenv("COMFYUI_BASE_URL", ""),
|
||||
)
|
||||
|
||||
COMFYUI_API_KEY = PersistentConfig(
|
||||
"COMFYUI_API_KEY",
|
||||
"image_generation.comfyui.api_key",
|
||||
os.getenv("COMFYUI_API_KEY", ""),
|
||||
)
|
||||
|
||||
COMFYUI_DEFAULT_WORKFLOW = """
|
||||
{
|
||||
"3": {
|
||||
@@ -1607,7 +1814,8 @@ WHISPER_MODEL = PersistentConfig(
|
||||
|
||||
WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
|
||||
WHISPER_MODEL_AUTO_UPDATE = (
|
||||
os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
|
||||
not OFFLINE_MODE
|
||||
and os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -113,5 +113,6 @@ class TASKS(str, Enum):
|
||||
TAGS_GENERATION = "tags_generation"
|
||||
EMOJI_GENERATION = "emoji_generation"
|
||||
QUERY_GENERATION = "query_generation"
|
||||
AUTOCOMPLETE_GENERATION = "autocomplete_generation"
|
||||
FUNCTION_CALLING = "function_calling"
|
||||
MOA_RESPONSE_GENERATION = "moa_response_generation"
|
||||
|
||||
@@ -103,8 +103,6 @@ WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
|
||||
if WEBUI_NAME != "Open WebUI":
|
||||
WEBUI_NAME += " (Open WebUI)"
|
||||
|
||||
WEBUI_URL = os.environ.get("WEBUI_URL", "http://localhost:3000")
|
||||
|
||||
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
|
||||
|
||||
|
||||
@@ -329,6 +327,9 @@ WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
|
||||
)
|
||||
WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None)
|
||||
|
||||
BYPASS_MODEL_ACCESS_CONTROL = (
|
||||
os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true"
|
||||
)
|
||||
|
||||
####################################
|
||||
# WEBUI_SECRET_KEY
|
||||
@@ -373,7 +374,7 @@ else:
|
||||
AIOHTTP_CLIENT_TIMEOUT = 300
|
||||
|
||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = os.environ.get(
|
||||
"AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", "3"
|
||||
"AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST", ""
|
||||
)
|
||||
|
||||
if AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST == "":
|
||||
@@ -384,7 +385,7 @@ else:
|
||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST
|
||||
)
|
||||
except Exception:
|
||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = 3
|
||||
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST = 5
|
||||
|
||||
####################################
|
||||
# OFFLINE_MODE
|
||||
|
||||
316
backend/open_webui/functions.py
Normal file
316
backend/open_webui/functions.py
Normal file
@@ -0,0 +1,316 @@
|
||||
import logging
|
||||
import sys
|
||||
import inspect
|
||||
import json
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import AsyncGenerator, Generator, Iterator
|
||||
from fastapi import (
|
||||
Depends,
|
||||
FastAPI,
|
||||
File,
|
||||
Form,
|
||||
HTTPException,
|
||||
Request,
|
||||
UploadFile,
|
||||
status,
|
||||
)
|
||||
from starlette.responses import Response, StreamingResponse
|
||||
|
||||
|
||||
from open_webui.socket.main import (
|
||||
get_event_call,
|
||||
get_event_emitter,
|
||||
)
|
||||
|
||||
|
||||
from open_webui.models.functions import Functions
|
||||
from open_webui.models.models import Models
|
||||
|
||||
from open_webui.utils.plugin import load_function_module_by_id
|
||||
from open_webui.utils.tools import get_tools
|
||||
from open_webui.utils.access_control import has_access
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
|
||||
|
||||
from open_webui.utils.misc import (
|
||||
add_or_update_system_message,
|
||||
get_last_user_message,
|
||||
prepend_to_first_user_message_content,
|
||||
openai_chat_chunk_message_template,
|
||||
openai_chat_completion_message_template,
|
||||
)
|
||||
from open_webui.utils.payload import (
|
||||
apply_model_params_to_body_openai,
|
||||
apply_model_system_prompt_to_body,
|
||||
)
|
||||
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
def get_function_module_by_id(request: Request, pipe_id: str):
|
||||
# Check if function is already loaded
|
||||
if pipe_id not in request.app.state.FUNCTIONS:
|
||||
function_module, _, _ = load_function_module_by_id(pipe_id)
|
||||
request.app.state.FUNCTIONS[pipe_id] = function_module
|
||||
else:
|
||||
function_module = request.app.state.FUNCTIONS[pipe_id]
|
||||
|
||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||
valves = Functions.get_function_valves_by_id(pipe_id)
|
||||
function_module.valves = function_module.Valves(**(valves if valves else {}))
|
||||
return function_module
|
||||
|
||||
|
||||
async def get_function_models(request):
|
||||
pipes = Functions.get_functions_by_type("pipe", active_only=True)
|
||||
pipe_models = []
|
||||
|
||||
for pipe in pipes:
|
||||
function_module = get_function_module_by_id(request, pipe.id)
|
||||
|
||||
# Check if function is a manifold
|
||||
if hasattr(function_module, "pipes"):
|
||||
sub_pipes = []
|
||||
|
||||
# Check if pipes is a function or a list
|
||||
|
||||
try:
|
||||
if callable(function_module.pipes):
|
||||
sub_pipes = function_module.pipes()
|
||||
else:
|
||||
sub_pipes = function_module.pipes
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
sub_pipes = []
|
||||
|
||||
log.debug(
|
||||
f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}"
|
||||
)
|
||||
|
||||
for p in sub_pipes:
|
||||
sub_pipe_id = f'{pipe.id}.{p["id"]}'
|
||||
sub_pipe_name = p["name"]
|
||||
|
||||
if hasattr(function_module, "name"):
|
||||
sub_pipe_name = f"{function_module.name}{sub_pipe_name}"
|
||||
|
||||
pipe_flag = {"type": pipe.type}
|
||||
|
||||
pipe_models.append(
|
||||
{
|
||||
"id": sub_pipe_id,
|
||||
"name": sub_pipe_name,
|
||||
"object": "model",
|
||||
"created": pipe.created_at,
|
||||
"owned_by": "openai",
|
||||
"pipe": pipe_flag,
|
||||
}
|
||||
)
|
||||
else:
|
||||
pipe_flag = {"type": "pipe"}
|
||||
|
||||
log.debug(
|
||||
f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}"
|
||||
)
|
||||
|
||||
pipe_models.append(
|
||||
{
|
||||
"id": pipe.id,
|
||||
"name": pipe.name,
|
||||
"object": "model",
|
||||
"created": pipe.created_at,
|
||||
"owned_by": "openai",
|
||||
"pipe": pipe_flag,
|
||||
}
|
||||
)
|
||||
|
||||
return pipe_models
|
||||
|
||||
|
||||
async def generate_function_chat_completion(
|
||||
request, form_data, user, models: dict = {}
|
||||
):
|
||||
async def execute_pipe(pipe, params):
|
||||
if inspect.iscoroutinefunction(pipe):
|
||||
return await pipe(**params)
|
||||
else:
|
||||
return pipe(**params)
|
||||
|
||||
async def get_message_content(res: str | Generator | AsyncGenerator) -> str:
|
||||
if isinstance(res, str):
|
||||
return res
|
||||
if isinstance(res, Generator):
|
||||
return "".join(map(str, res))
|
||||
if isinstance(res, AsyncGenerator):
|
||||
return "".join([str(stream) async for stream in res])
|
||||
|
||||
def process_line(form_data: dict, line):
|
||||
if isinstance(line, BaseModel):
|
||||
line = line.model_dump_json()
|
||||
line = f"data: {line}"
|
||||
if isinstance(line, dict):
|
||||
line = f"data: {json.dumps(line)}"
|
||||
|
||||
try:
|
||||
line = line.decode("utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if line.startswith("data:"):
|
||||
return f"{line}\n\n"
|
||||
else:
|
||||
line = openai_chat_chunk_message_template(form_data["model"], line)
|
||||
return f"data: {json.dumps(line)}\n\n"
|
||||
|
||||
def get_pipe_id(form_data: dict) -> str:
|
||||
pipe_id = form_data["model"]
|
||||
if "." in pipe_id:
|
||||
pipe_id, _ = pipe_id.split(".", 1)
|
||||
return pipe_id
|
||||
|
||||
def get_function_params(function_module, form_data, user, extra_params=None):
|
||||
if extra_params is None:
|
||||
extra_params = {}
|
||||
|
||||
pipe_id = get_pipe_id(form_data)
|
||||
|
||||
# Get the signature of the function
|
||||
sig = inspect.signature(function_module.pipe)
|
||||
params = {"body": form_data} | {
|
||||
k: v for k, v in extra_params.items() if k in sig.parameters
|
||||
}
|
||||
|
||||
if "__user__" in params and hasattr(function_module, "UserValves"):
|
||||
user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
|
||||
try:
|
||||
params["__user__"]["valves"] = function_module.UserValves(**user_valves)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
params["__user__"]["valves"] = function_module.UserValves()
|
||||
|
||||
return params
|
||||
|
||||
model_id = form_data.get("model")
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
metadata = form_data.pop("metadata", {})
|
||||
|
||||
files = metadata.get("files", [])
|
||||
tool_ids = metadata.get("tool_ids", [])
|
||||
# Check if tool_ids is None
|
||||
if tool_ids is None:
|
||||
tool_ids = []
|
||||
|
||||
__event_emitter__ = None
|
||||
__event_call__ = None
|
||||
__task__ = None
|
||||
__task_body__ = None
|
||||
|
||||
if metadata:
|
||||
if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
|
||||
__event_emitter__ = get_event_emitter(metadata)
|
||||
__event_call__ = get_event_call(metadata)
|
||||
__task__ = metadata.get("task", None)
|
||||
__task_body__ = metadata.get("task_body", None)
|
||||
|
||||
extra_params = {
|
||||
"__event_emitter__": __event_emitter__,
|
||||
"__event_call__": __event_call__,
|
||||
"__task__": __task__,
|
||||
"__task_body__": __task_body__,
|
||||
"__files__": files,
|
||||
"__user__": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
"__metadata__": metadata,
|
||||
"__request__": request,
|
||||
}
|
||||
extra_params["__tools__"] = get_tools(
|
||||
request,
|
||||
tool_ids,
|
||||
user,
|
||||
{
|
||||
**extra_params,
|
||||
"__model__": models.get(form_data["model"], None),
|
||||
"__messages__": form_data["messages"],
|
||||
"__files__": files,
|
||||
},
|
||||
)
|
||||
|
||||
if model_info:
|
||||
if model_info.base_model_id:
|
||||
form_data["model"] = model_info.base_model_id
|
||||
|
||||
params = model_info.params.model_dump()
|
||||
form_data = apply_model_params_to_body_openai(params, form_data)
|
||||
form_data = apply_model_system_prompt_to_body(params, form_data, user)
|
||||
|
||||
pipe_id = get_pipe_id(form_data)
|
||||
function_module = get_function_module_by_id(request, pipe_id)
|
||||
|
||||
pipe = function_module.pipe
|
||||
params = get_function_params(function_module, form_data, user, extra_params)
|
||||
|
||||
if form_data.get("stream", False):
|
||||
|
||||
async def stream_content():
|
||||
try:
|
||||
res = await execute_pipe(pipe, params)
|
||||
|
||||
# Directly return if the response is a StreamingResponse
|
||||
if isinstance(res, StreamingResponse):
|
||||
async for data in res.body_iterator:
|
||||
yield data
|
||||
return
|
||||
if isinstance(res, dict):
|
||||
yield f"data: {json.dumps(res)}\n\n"
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error: {e}")
|
||||
yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
|
||||
return
|
||||
|
||||
if isinstance(res, str):
|
||||
message = openai_chat_chunk_message_template(form_data["model"], res)
|
||||
yield f"data: {json.dumps(message)}\n\n"
|
||||
|
||||
if isinstance(res, Iterator):
|
||||
for line in res:
|
||||
yield process_line(form_data, line)
|
||||
|
||||
if isinstance(res, AsyncGenerator):
|
||||
async for line in res:
|
||||
yield process_line(form_data, line)
|
||||
|
||||
if isinstance(res, str) or isinstance(res, Generator):
|
||||
finish_message = openai_chat_chunk_message_template(
|
||||
form_data["model"], ""
|
||||
)
|
||||
finish_message["choices"][0]["finish_reason"] = "stop"
|
||||
yield f"data: {json.dumps(finish_message)}\n\n"
|
||||
yield "data: [DONE]"
|
||||
|
||||
return StreamingResponse(stream_content(), media_type="text/event-stream")
|
||||
else:
|
||||
try:
|
||||
res = await execute_pipe(pipe, params)
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error: {e}")
|
||||
return {"error": {"detail": str(e)}}
|
||||
|
||||
if isinstance(res, StreamingResponse) or isinstance(res, dict):
|
||||
return res
|
||||
if isinstance(res, BaseModel):
|
||||
return res.model_dump()
|
||||
|
||||
message = await get_message_content(res)
|
||||
return openai_chat_completion_message_template(form_data["model"], message)
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Optional
|
||||
|
||||
from open_webui.apps.webui.internal.wrappers import register_connection
|
||||
from open_webui.internal.wrappers import register_connection
|
||||
from open_webui.env import (
|
||||
OPEN_WEBUI_DIR,
|
||||
DATABASE_URL,
|
||||
@@ -54,7 +54,7 @@ def handle_peewee_migration(DATABASE_URL):
|
||||
try:
|
||||
# Replace the postgresql:// with postgres:// to handle the peewee migration
|
||||
db = register_connection(DATABASE_URL.replace("postgresql://", "postgres://"))
|
||||
migrate_dir = OPEN_WEBUI_DIR / "apps" / "webui" / "internal" / "migrations"
|
||||
migrate_dir = OPEN_WEBUI_DIR / "internal" / "migrations"
|
||||
router = Router(db, logger=log, migrate_dir=migrate_dir)
|
||||
router.run()
|
||||
db.close()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,7 @@
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from open_webui.apps.webui.models.auths import Auth
|
||||
from open_webui.models.auths import Auth
|
||||
from open_webui.env import DATABASE_URL
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import open_webui.apps.webui.internal.db
|
||||
import open_webui.internal.db
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
"""Add channel table
|
||||
|
||||
Revision ID: 57c599a3cb57
|
||||
Revises: 922e7a387820
|
||||
Create Date: 2024-12-22 03:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "57c599a3cb57"
|
||||
down_revision = "922e7a387820"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
"channel",
|
||||
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
|
||||
sa.Column("user_id", sa.Text()),
|
||||
sa.Column("name", sa.Text()),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("data", sa.JSON(), nullable=True),
|
||||
sa.Column("meta", sa.JSON(), nullable=True),
|
||||
sa.Column("access_control", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"message",
|
||||
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
|
||||
sa.Column("user_id", sa.Text()),
|
||||
sa.Column("channel_id", sa.Text(), nullable=True),
|
||||
sa.Column("content", sa.Text()),
|
||||
sa.Column("data", sa.JSON(), nullable=True),
|
||||
sa.Column("meta", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_table("channel")
|
||||
|
||||
op.drop_table("message")
|
||||
@@ -0,0 +1,26 @@
|
||||
"""Update file table
|
||||
|
||||
Revision ID: 7826ab40b532
|
||||
Revises: 57c599a3cb57
|
||||
Create Date: 2024-12-23 03:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "7826ab40b532"
|
||||
down_revision = "57c599a3cb57"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.add_column(
|
||||
"file",
|
||||
sa.Column("access_control", sa.JSON(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_column("file", "access_control")
|
||||
@@ -11,8 +11,8 @@ from typing import Sequence, Union
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
import open_webui.apps.webui.internal.db
|
||||
from open_webui.apps.webui.internal.db import JSONField
|
||||
import open_webui.internal.db
|
||||
from open_webui.internal.db import JSONField
|
||||
from open_webui.migrations.util import get_existing_tables
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
||||
@@ -2,12 +2,12 @@ import logging
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.apps.webui.internal.db import Base, get_db
|
||||
from open_webui.apps.webui.models.users import UserModel, Users
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.models.users import UserModel, Users
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Boolean, Column, String, Text
|
||||
from open_webui.utils.utils import verify_password
|
||||
from open_webui.utils.auth import verify_password
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
132
backend/open_webui/models/channels.py
Normal file
132
backend/open_webui/models/channels.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.utils.access_control import has_access
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
|
||||
from sqlalchemy import or_, func, select, and_, text
|
||||
from sqlalchemy.sql import exists
|
||||
|
||||
####################
|
||||
# Channel DB Schema
|
||||
####################
|
||||
|
||||
|
||||
class Channel(Base):
|
||||
__tablename__ = "channel"
|
||||
|
||||
id = Column(Text, primary_key=True)
|
||||
user_id = Column(Text)
|
||||
|
||||
name = Column(Text)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
data = Column(JSON, nullable=True)
|
||||
meta = Column(JSON, nullable=True)
|
||||
access_control = Column(JSON, nullable=True)
|
||||
|
||||
created_at = Column(BigInteger)
|
||||
updated_at = Column(BigInteger)
|
||||
|
||||
|
||||
class ChannelModel(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
description: Optional[str] = None
|
||||
|
||||
name: str
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
access_control: Optional[dict] = None
|
||||
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class ChannelForm(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
access_control: Optional[dict] = None
|
||||
|
||||
|
||||
class ChannelTable:
|
||||
def insert_new_channel(
|
||||
self, form_data: ChannelForm, user_id: str
|
||||
) -> Optional[ChannelModel]:
|
||||
with get_db() as db:
|
||||
channel = ChannelModel(
|
||||
**{
|
||||
**form_data.model_dump(),
|
||||
"name": form_data.name.lower(),
|
||||
"id": str(uuid.uuid4()),
|
||||
"user_id": user_id,
|
||||
"created_at": int(time.time_ns()),
|
||||
"updated_at": int(time.time_ns()),
|
||||
}
|
||||
)
|
||||
|
||||
new_channel = Channel(**channel.model_dump())
|
||||
|
||||
db.add(new_channel)
|
||||
db.commit()
|
||||
return channel
|
||||
|
||||
def get_channels(self) -> list[ChannelModel]:
|
||||
with get_db() as db:
|
||||
channels = db.query(Channel).all()
|
||||
return [ChannelModel.model_validate(channel) for channel in channels]
|
||||
|
||||
def get_channels_by_user_id(
|
||||
self, user_id: str, permission: str = "read"
|
||||
) -> list[ChannelModel]:
|
||||
channels = self.get_channels()
|
||||
return [
|
||||
channel
|
||||
for channel in channels
|
||||
if channel.user_id == user_id
|
||||
or has_access(user_id, permission, channel.access_control)
|
||||
]
|
||||
|
||||
def get_channel_by_id(self, id: str) -> Optional[ChannelModel]:
|
||||
with get_db() as db:
|
||||
channel = db.query(Channel).filter(Channel.id == id).first()
|
||||
return ChannelModel.model_validate(channel) if channel else None
|
||||
|
||||
def update_channel_by_id(
|
||||
self, id: str, form_data: ChannelForm
|
||||
) -> Optional[ChannelModel]:
|
||||
with get_db() as db:
|
||||
channel = db.query(Channel).filter(Channel.id == id).first()
|
||||
if not channel:
|
||||
return None
|
||||
|
||||
channel.name = form_data.name
|
||||
channel.data = form_data.data
|
||||
channel.meta = form_data.meta
|
||||
channel.access_control = form_data.access_control
|
||||
channel.updated_at = int(time.time_ns())
|
||||
|
||||
db.commit()
|
||||
return ChannelModel.model_validate(channel) if channel else None
|
||||
|
||||
def delete_channel_by_id(self, id: str):
|
||||
with get_db() as db:
|
||||
db.query(Channel).filter(Channel.id == id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
Channels = ChannelTable()
|
||||
@@ -3,8 +3,8 @@ import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.apps.webui.internal.db import Base, get_db
|
||||
from open_webui.apps.webui.models.tags import TagModel, Tag, Tags
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.models.tags import TagModel, Tag, Tags
|
||||
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
@@ -168,6 +168,91 @@ class ChatTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_chat_title_by_id(self, id: str, title: str) -> Optional[ChatModel]:
|
||||
chat = self.get_chat_by_id(id)
|
||||
if chat is None:
|
||||
return None
|
||||
|
||||
chat = chat.chat
|
||||
chat["title"] = title
|
||||
|
||||
return self.update_chat_by_id(id, chat)
|
||||
|
||||
def update_chat_tags_by_id(
|
||||
self, id: str, tags: list[str], user
|
||||
) -> Optional[ChatModel]:
|
||||
chat = self.get_chat_by_id(id)
|
||||
if chat is None:
|
||||
return None
|
||||
|
||||
self.delete_all_tags_by_id_and_user_id(id, user.id)
|
||||
|
||||
for tag in chat.meta.get("tags", []):
|
||||
if self.count_chats_by_tag_name_and_user_id(tag, user.id) == 0:
|
||||
Tags.delete_tag_by_name_and_user_id(tag, user.id)
|
||||
|
||||
for tag_name in tags:
|
||||
if tag_name.lower() == "none":
|
||||
continue
|
||||
|
||||
self.add_chat_tag_by_id_and_user_id_and_tag_name(id, user.id, tag_name)
|
||||
return self.get_chat_by_id(id)
|
||||
|
||||
def get_chat_title_by_id(self, id: str) -> Optional[str]:
|
||||
chat = self.get_chat_by_id(id)
|
||||
if chat is None:
|
||||
return None
|
||||
|
||||
return chat.chat.get("title", "New Chat")
|
||||
|
||||
def get_messages_by_chat_id(self, id: str) -> Optional[dict]:
|
||||
chat = self.get_chat_by_id(id)
|
||||
if chat is None:
|
||||
return None
|
||||
|
||||
return chat.chat.get("history", {}).get("messages", {}) or {}
|
||||
|
||||
def upsert_message_to_chat_by_id_and_message_id(
|
||||
self, id: str, message_id: str, message: dict
|
||||
) -> Optional[ChatModel]:
|
||||
chat = self.get_chat_by_id(id)
|
||||
if chat is None:
|
||||
return None
|
||||
|
||||
chat = chat.chat
|
||||
history = chat.get("history", {})
|
||||
|
||||
if message_id in history.get("messages", {}):
|
||||
history["messages"][message_id] = {
|
||||
**history["messages"][message_id],
|
||||
**message,
|
||||
}
|
||||
else:
|
||||
history["messages"][message_id] = message
|
||||
|
||||
history["currentId"] = message_id
|
||||
|
||||
chat["history"] = history
|
||||
return self.update_chat_by_id(id, chat)
|
||||
|
||||
def add_message_status_to_chat_by_id_and_message_id(
|
||||
self, id: str, message_id: str, status: dict
|
||||
) -> Optional[ChatModel]:
|
||||
chat = self.get_chat_by_id(id)
|
||||
if chat is None:
|
||||
return None
|
||||
|
||||
chat = chat.chat
|
||||
history = chat.get("history", {})
|
||||
|
||||
if message_id in history.get("messages", {}):
|
||||
status_history = history["messages"][message_id].get("statusHistory", [])
|
||||
status_history.append(status)
|
||||
history["messages"][message_id]["statusHistory"] = status_history
|
||||
|
||||
chat["history"] = history
|
||||
return self.update_chat_by_id(id, chat)
|
||||
|
||||
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
|
||||
with get_db() as db:
|
||||
# Get the existing chat to share
|
||||
@@ -3,8 +3,8 @@ import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.apps.webui.internal.db import Base, get_db
|
||||
from open_webui.apps.webui.models.chats import Chats
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.models.chats import Chats
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.apps.webui.internal.db import Base, JSONField, get_db
|
||||
from open_webui.internal.db import Base, JSONField, get_db
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text, JSON
|
||||
@@ -27,6 +27,8 @@ class File(Base):
|
||||
data = Column(JSON, nullable=True)
|
||||
meta = Column(JSON, nullable=True)
|
||||
|
||||
access_control = Column(JSON, nullable=True)
|
||||
|
||||
created_at = Column(BigInteger)
|
||||
updated_at = Column(BigInteger)
|
||||
|
||||
@@ -44,6 +46,8 @@ class FileModel(BaseModel):
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
|
||||
access_control: Optional[dict] = None
|
||||
|
||||
created_at: Optional[int] # timestamp in epoch
|
||||
updated_at: Optional[int] # timestamp in epoch
|
||||
|
||||
@@ -90,6 +94,7 @@ class FileForm(BaseModel):
|
||||
path: str
|
||||
data: dict = {}
|
||||
meta: dict = {}
|
||||
access_control: Optional[dict] = None
|
||||
|
||||
|
||||
class FilesTable:
|
||||
@@ -3,8 +3,8 @@ import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.apps.webui.internal.db import Base, get_db
|
||||
from open_webui.apps.webui.models.chats import Chats
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.models.chats import Chats
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
@@ -2,8 +2,8 @@ import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.apps.webui.internal.db import Base, JSONField, get_db
|
||||
from open_webui.apps.webui.models.users import Users
|
||||
from open_webui.internal.db import Base, JSONField, get_db
|
||||
from open_webui.models.users import Users
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text
|
||||
@@ -4,10 +4,10 @@ import time
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from open_webui.apps.webui.internal.db import Base, get_db
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
from open_webui.apps.webui.models.files import FileMetadataResponse
|
||||
from open_webui.models.files import FileMetadataResponse
|
||||
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
@@ -146,6 +146,13 @@ class GroupTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_group_user_ids_by_id(self, id: str) -> Optional[str]:
|
||||
group = self.get_group_by_id(id)
|
||||
if group:
|
||||
return group.user_ids
|
||||
else:
|
||||
return None
|
||||
|
||||
def update_group_by_id(
|
||||
self, id: str, form_data: GroupUpdateForm, overwrite: bool = False
|
||||
) -> Optional[GroupModel]:
|
||||
@@ -4,11 +4,11 @@ import time
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from open_webui.apps.webui.internal.db import Base, get_db
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
from open_webui.apps.webui.models.files import FileMetadataResponse
|
||||
from open_webui.apps.webui.models.users import Users, UserResponse
|
||||
from open_webui.models.files import FileMetadataResponse
|
||||
from open_webui.models.users import Users, UserResponse
|
||||
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
@@ -2,7 +2,7 @@ import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.apps.webui.internal.db import Base, get_db
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text
|
||||
|
||||
141
backend/open_webui/models/messages.py
Normal file
141
backend/open_webui/models/messages.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.models.tags import TagModel, Tag, Tags
|
||||
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
|
||||
from sqlalchemy import or_, func, select, and_, text
|
||||
from sqlalchemy.sql import exists
|
||||
|
||||
####################
|
||||
# Message DB Schema
|
||||
####################
|
||||
|
||||
|
||||
class Message(Base):
|
||||
__tablename__ = "message"
|
||||
id = Column(Text, primary_key=True)
|
||||
|
||||
user_id = Column(Text)
|
||||
channel_id = Column(Text, nullable=True)
|
||||
|
||||
content = Column(Text)
|
||||
data = Column(JSON, nullable=True)
|
||||
meta = Column(JSON, nullable=True)
|
||||
|
||||
created_at = Column(BigInteger) # time_ns
|
||||
updated_at = Column(BigInteger) # time_ns
|
||||
|
||||
|
||||
class MessageModel(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
channel_id: Optional[str] = None
|
||||
|
||||
content: str
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class MessageForm(BaseModel):
|
||||
content: str
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
|
||||
|
||||
class MessageTable:
|
||||
def insert_new_message(
|
||||
self, form_data: MessageForm, channel_id: str, user_id: str
|
||||
) -> Optional[MessageModel]:
|
||||
with get_db() as db:
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
ts = int(time.time_ns())
|
||||
message = MessageModel(
|
||||
**{
|
||||
"id": id,
|
||||
"user_id": user_id,
|
||||
"channel_id": channel_id,
|
||||
"content": form_data.content,
|
||||
"data": form_data.data,
|
||||
"meta": form_data.meta,
|
||||
"created_at": ts,
|
||||
"updated_at": ts,
|
||||
}
|
||||
)
|
||||
|
||||
result = Message(**message.model_dump())
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
return MessageModel.model_validate(result) if result else None
|
||||
|
||||
def get_message_by_id(self, id: str) -> Optional[MessageModel]:
|
||||
with get_db() as db:
|
||||
message = db.get(Message, id)
|
||||
return MessageModel.model_validate(message) if message else None
|
||||
|
||||
def get_messages_by_channel_id(
|
||||
self, channel_id: str, skip: int = 0, limit: int = 50
|
||||
) -> list[MessageModel]:
|
||||
with get_db() as db:
|
||||
all_messages = (
|
||||
db.query(Message)
|
||||
.filter_by(channel_id=channel_id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
return [MessageModel.model_validate(message) for message in all_messages]
|
||||
|
||||
def get_messages_by_user_id(
|
||||
self, user_id: str, skip: int = 0, limit: int = 50
|
||||
) -> list[MessageModel]:
|
||||
with get_db() as db:
|
||||
all_messages = (
|
||||
db.query(Message)
|
||||
.filter_by(user_id=user_id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
return [MessageModel.model_validate(message) for message in all_messages]
|
||||
|
||||
def update_message_by_id(
|
||||
self, id: str, form_data: MessageForm
|
||||
) -> Optional[MessageModel]:
|
||||
with get_db() as db:
|
||||
message = db.get(Message, id)
|
||||
message.content = form_data.content
|
||||
message.data = form_data.data
|
||||
message.meta = form_data.meta
|
||||
message.updated_at = int(time.time_ns())
|
||||
db.commit()
|
||||
db.refresh(message)
|
||||
return MessageModel.model_validate(message) if message else None
|
||||
|
||||
def delete_message_by_id(self, id: str) -> bool:
|
||||
with get_db() as db:
|
||||
db.query(Message).filter_by(id=id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
Messages = MessageTable()
|
||||
@@ -2,10 +2,10 @@ import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.apps.webui.internal.db import Base, JSONField, get_db
|
||||
from open_webui.internal.db import Base, JSONField, get_db
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
from open_webui.apps.webui.models.users import Users, UserResponse
|
||||
from open_webui.models.users import Users, UserResponse
|
||||
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
@@ -1,8 +1,8 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.apps.webui.internal.db import Base, get_db
|
||||
from open_webui.apps.webui.models.users import Users, UserResponse
|
||||
from open_webui.internal.db import Base, get_db
|
||||
from open_webui.models.users import Users, UserResponse
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text, JSON
|
||||
@@ -3,7 +3,7 @@ import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.apps.webui.internal.db import Base, get_db
|
||||
from open_webui.internal.db import Base, get_db
|
||||
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
@@ -2,8 +2,8 @@ import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.apps.webui.internal.db import Base, JSONField, get_db
|
||||
from open_webui.apps.webui.models.users import Users, UserResponse
|
||||
from open_webui.internal.db import Base, JSONField, get_db
|
||||
from open_webui.models.users import Users, UserResponse
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text, JSON
|
||||
@@ -76,6 +76,10 @@ class ToolModel(BaseModel):
|
||||
####################
|
||||
|
||||
|
||||
class ToolUserModel(ToolModel):
|
||||
user: Optional[UserResponse] = None
|
||||
|
||||
|
||||
class ToolResponse(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
@@ -138,13 +142,13 @@ class ToolsTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_tools(self) -> list[ToolUserResponse]:
|
||||
def get_tools(self) -> list[ToolUserModel]:
|
||||
with get_db() as db:
|
||||
tools = []
|
||||
for tool in db.query(Tool).order_by(Tool.updated_at.desc()).all():
|
||||
user = Users.get_user_by_id(tool.user_id)
|
||||
tools.append(
|
||||
ToolUserResponse.model_validate(
|
||||
ToolUserModel.model_validate(
|
||||
{
|
||||
**ToolModel.model_validate(tool).model_dump(),
|
||||
"user": user.model_dump() if user else None,
|
||||
@@ -155,7 +159,7 @@ class ToolsTable:
|
||||
|
||||
def get_tools_by_user_id(
|
||||
self, user_id: str, permission: str = "write"
|
||||
) -> list[ToolUserResponse]:
|
||||
) -> list[ToolUserModel]:
|
||||
tools = self.get_tools()
|
||||
|
||||
return [
|
||||
@@ -1,8 +1,8 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.apps.webui.internal.db import Base, JSONField, get_db
|
||||
from open_webui.apps.webui.models.chats import Chats
|
||||
from open_webui.internal.db import Base, JSONField, get_db
|
||||
from open_webui.models.chats import Chats
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text
|
||||
|
||||
@@ -70,6 +70,13 @@ class UserResponse(BaseModel):
|
||||
profile_image_url: str
|
||||
|
||||
|
||||
class UserNameResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
role: str
|
||||
profile_image_url: str
|
||||
|
||||
|
||||
class UserRoleUpdateForm(BaseModel):
|
||||
id: str
|
||||
role: str
|
||||
@@ -147,13 +154,25 @@ class UsersTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_users(self, skip: int = 0, limit: int = 50) -> list[UserModel]:
|
||||
def get_users(
|
||||
self, skip: Optional[int] = None, limit: Optional[int] = None
|
||||
) -> list[UserModel]:
|
||||
with get_db() as db:
|
||||
users = (
|
||||
db.query(User)
|
||||
# .offset(skip).limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
query = db.query(User).order_by(User.created_at.desc())
|
||||
|
||||
if skip:
|
||||
query = query.offset(skip)
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
users = query.all()
|
||||
|
||||
return [UserModel.model_validate(user) for user in users]
|
||||
|
||||
def get_users_by_user_ids(self, user_ids: list[str]) -> list[UserModel]:
|
||||
with get_db() as db:
|
||||
users = db.query(User).filter(User.id.in_(user_ids)).all()
|
||||
return [UserModel.model_validate(user) for user in users]
|
||||
|
||||
def get_num_users(self) -> Optional[int]:
|
||||
@@ -168,6 +187,22 @@ class UsersTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_user_webhook_url_by_id(self, id: str) -> Optional[str]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
|
||||
if user.settings is None:
|
||||
return None
|
||||
else:
|
||||
return (
|
||||
user.settings.get("ui", {})
|
||||
.get("notifications", {})
|
||||
.get("webhook_url", None)
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db() as db:
|
||||
@@ -1,6 +1,7 @@
|
||||
import requests
|
||||
import logging
|
||||
import ftfy
|
||||
import sys
|
||||
|
||||
from langchain_community.document_loaders import (
|
||||
BSHTMLLoader,
|
||||
@@ -18,8 +19,9 @@ from langchain_community.document_loaders import (
|
||||
YoutubeLoader,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
@@ -106,7 +108,7 @@ class TikaLoader:
|
||||
if "Content-Type" in raw_metadata:
|
||||
headers["Content-Type"] = raw_metadata["Content-Type"]
|
||||
|
||||
log.info("Tika extracted text: %s", text)
|
||||
log.debug("Tika extracted text: %s", text)
|
||||
|
||||
return [Document(page_content=text, metadata=headers)]
|
||||
else:
|
||||
@@ -1,7 +1,12 @@
|
||||
import logging
|
||||
|
||||
from typing import Any, Dict, Generator, List, Optional, Sequence, Union
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from langchain_core.documents import Document
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
ALLOWED_SCHEMES = {"http", "https"}
|
||||
ALLOWED_NETLOCS = {
|
||||
@@ -51,12 +56,14 @@ class YoutubeLoader:
|
||||
self,
|
||||
video_id: str,
|
||||
language: Union[str, Sequence[str]] = "en",
|
||||
proxy_url: Optional[str] = None,
|
||||
):
|
||||
"""Initialize with YouTube video ID."""
|
||||
_video_id = _parse_video_id(video_id)
|
||||
self.video_id = _video_id if _video_id is not None else video_id
|
||||
self._metadata = {"source": video_id}
|
||||
self.language = language
|
||||
self.proxy_url = proxy_url
|
||||
if isinstance(language, str):
|
||||
self.language = [language]
|
||||
else:
|
||||
@@ -76,10 +83,22 @@ class YoutubeLoader:
|
||||
"Please install it with `pip install youtube-transcript-api`."
|
||||
)
|
||||
|
||||
if self.proxy_url:
|
||||
youtube_proxies = {
|
||||
"http": self.proxy_url,
|
||||
"https": self.proxy_url,
|
||||
}
|
||||
# Don't log complete URL because it might contain secrets
|
||||
log.debug(f"Using proxy URL: {self.proxy_url[:14]}...")
|
||||
else:
|
||||
youtube_proxies = None
|
||||
|
||||
try:
|
||||
transcript_list = YouTubeTranscriptApi.list_transcripts(self.video_id)
|
||||
transcript_list = YouTubeTranscriptApi.list_transcripts(
|
||||
self.video_id, proxies=youtube_proxies
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
log.exception("Loading YouTube transcript failed")
|
||||
return []
|
||||
|
||||
try:
|
||||
@@ -11,12 +11,10 @@ from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriev
|
||||
from langchain_community.retrievers import BM25Retriever
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.utils.misc import get_last_user_message
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.config import DEFAULT_RAG_TEMPLATE
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
@@ -72,7 +70,9 @@ def query_doc(
|
||||
limit=k,
|
||||
)
|
||||
|
||||
log.info(f"query_doc:result {result.ids} {result.metadatas}")
|
||||
if result:
|
||||
log.info(f"query_doc:result {result.ids} {result.metadatas}")
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
print(e)
|
||||
@@ -238,44 +238,6 @@ def query_collection_with_hybrid_search(
|
||||
return merge_and_sort_query_results(results, k=k, reverse=True)
|
||||
|
||||
|
||||
def rag_template(template: str, context: str, query: str):
|
||||
if template == "":
|
||||
template = DEFAULT_RAG_TEMPLATE
|
||||
|
||||
if "[context]" not in template and "{{CONTEXT}}" not in template:
|
||||
log.debug(
|
||||
"WARNING: The RAG template does not contain the '[context]' or '{{CONTEXT}}' placeholder."
|
||||
)
|
||||
|
||||
if "<context>" in context and "</context>" in context:
|
||||
log.debug(
|
||||
"WARNING: Potential prompt injection attack: the RAG "
|
||||
"context contains '<context>' and '</context>'. This might be "
|
||||
"nothing, or the user might be trying to hack something."
|
||||
)
|
||||
|
||||
query_placeholders = []
|
||||
if "[query]" in context:
|
||||
query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
|
||||
template = template.replace("[query]", query_placeholder)
|
||||
query_placeholders.append(query_placeholder)
|
||||
|
||||
if "{{QUERY}}" in context:
|
||||
query_placeholder = "{{QUERY" + str(uuid.uuid4()) + "}}"
|
||||
template = template.replace("{{QUERY}}", query_placeholder)
|
||||
query_placeholders.append(query_placeholder)
|
||||
|
||||
template = template.replace("[context]", context)
|
||||
template = template.replace("{{CONTEXT}}", context)
|
||||
template = template.replace("[query]", query)
|
||||
template = template.replace("{{QUERY}}", query)
|
||||
|
||||
for query_placeholder in query_placeholders:
|
||||
template = template.replace(query_placeholder, query)
|
||||
|
||||
return template
|
||||
|
||||
|
||||
def get_embedding_function(
|
||||
embedding_engine,
|
||||
embedding_model,
|
||||
@@ -469,7 +431,7 @@ def generate_openai_batch_embeddings(
|
||||
|
||||
|
||||
def generate_ollama_batch_embeddings(
|
||||
model: str, texts: list[str], url: str, key: str
|
||||
model: str, texts: list[str], url: str, key: str = ""
|
||||
) -> Optional[list[list[float]]]:
|
||||
try:
|
||||
r = requests.post(
|
||||
22
backend/open_webui/retrieval/vector/connector.py
Normal file
22
backend/open_webui/retrieval/vector/connector.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from open_webui.config import VECTOR_DB
|
||||
|
||||
if VECTOR_DB == "milvus":
|
||||
from open_webui.retrieval.vector.dbs.milvus import MilvusClient
|
||||
|
||||
VECTOR_DB_CLIENT = MilvusClient()
|
||||
elif VECTOR_DB == "qdrant":
|
||||
from open_webui.retrieval.vector.dbs.qdrant import QdrantClient
|
||||
|
||||
VECTOR_DB_CLIENT = QdrantClient()
|
||||
elif VECTOR_DB == "opensearch":
|
||||
from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient
|
||||
|
||||
VECTOR_DB_CLIENT = OpenSearchClient()
|
||||
elif VECTOR_DB == "pgvector":
|
||||
from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient
|
||||
|
||||
VECTOR_DB_CLIENT = PgvectorClient()
|
||||
else:
|
||||
from open_webui.retrieval.vector.dbs.chroma import ChromaClient
|
||||
|
||||
VECTOR_DB_CLIENT = ChromaClient()
|
||||
@@ -4,7 +4,7 @@ from chromadb.utils.batch_utils import create_batches
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.config import (
|
||||
CHROMA_DATA_PATH,
|
||||
CHROMA_HTTP_HOST,
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.config import (
|
||||
MILVUS_URI,
|
||||
)
|
||||
@@ -1,7 +1,7 @@
|
||||
from opensearchpy import OpenSearch
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.config import (
|
||||
OPENSEARCH_URI,
|
||||
OPENSEARCH_SSL,
|
||||
@@ -18,7 +18,7 @@ from sqlalchemy.dialects.postgresql import JSONB, array
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy.ext.mutable import MutableDict
|
||||
|
||||
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.config import PGVECTOR_DB_URL
|
||||
|
||||
VECTOR_LENGTH = 1536
|
||||
@@ -40,7 +40,7 @@ class PgvectorClient:
|
||||
|
||||
# if no pgvector uri, use the existing database connection
|
||||
if not PGVECTOR_DB_URL:
|
||||
from open_webui.apps.webui.internal.db import Session
|
||||
from open_webui.internal.db import Session
|
||||
|
||||
self.session = Session
|
||||
else:
|
||||
@@ -4,7 +4,7 @@ from qdrant_client import QdrantClient as Qclient
|
||||
from qdrant_client.http.models import PointStruct
|
||||
from qdrant_client.models import models
|
||||
|
||||
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult
|
||||
from open_webui.config import QDRANT_URI, QDRANT_API_KEY
|
||||
|
||||
NO_LIMIT = 999999999
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
from pprint import pprint
|
||||
from typing import Optional
|
||||
import requests
|
||||
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
import argparse
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from duckduckgo_search import DDGS
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
|
||||
import requests
|
||||
from open_webui.apps.retrieval.web.main import SearchResult
|
||||
from open_webui.retrieval.web.main import SearchResult
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from yarl import URL
|
||||
|
||||
48
backend/open_webui/retrieval/web/kagi.py
Normal file
48
backend/open_webui/retrieval/web/kagi.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_kagi(
|
||||
api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
|
||||
) -> list[SearchResult]:
|
||||
"""Search using Kagi's Search API and return the results as a list of SearchResult objects.
|
||||
|
||||
The Search API will inherit the settings in your account, including results personalization and snippet length.
|
||||
|
||||
Args:
|
||||
api_key (str): A Kagi Search API key
|
||||
query (str): The query to search for
|
||||
count (int): The number of results to return
|
||||
"""
|
||||
url = "https://kagi.com/api/v0/search"
|
||||
headers = {
|
||||
"Authorization": f"Bot {api_key}",
|
||||
}
|
||||
params = {"q": query, "limit": count}
|
||||
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
json_response = response.json()
|
||||
search_results = json_response.get("data", [])
|
||||
|
||||
results = [
|
||||
SearchResult(
|
||||
link=result["url"], title=result["title"], snippet=result.get("snippet")
|
||||
)
|
||||
for result in search_results
|
||||
if result["t"] == 0
|
||||
]
|
||||
|
||||
print(results)
|
||||
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
|
||||
return results
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -3,7 +3,7 @@ from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import requests
|
||||
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -3,7 +3,7 @@ from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import requests
|
||||
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from open_webui.apps.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
|
||||
import requests
|
||||
from open_webui.apps.retrieval.web.main import SearchResult
|
||||
from open_webui.retrieval.web.main import SearchResult
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -82,15 +82,15 @@ class SafeWebBaseLoader(WebBaseLoader):
|
||||
|
||||
|
||||
def get_web_loader(
|
||||
url: Union[str, Sequence[str]],
|
||||
urls: Union[str, Sequence[str]],
|
||||
verify_ssl: bool = True,
|
||||
requests_per_second: int = 2,
|
||||
):
|
||||
# Check if the URL is valid
|
||||
if not validate_url(url):
|
||||
if not validate_url(urls):
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||
return SafeWebBaseLoader(
|
||||
url,
|
||||
urls,
|
||||
verify_ssl=verify_ssl,
|
||||
requests_per_second=requests_per_second,
|
||||
continue_on_failure=True,
|
||||
713
backend/open_webui/routers/audio.py
Normal file
713
backend/open_webui/routers/audio.py
Normal file
@@ -0,0 +1,713 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from pydub import AudioSegment
|
||||
from pydub.silence import split_on_silence
|
||||
|
||||
import aiohttp
|
||||
import aiofiles
|
||||
import requests
|
||||
|
||||
from fastapi import (
|
||||
Depends,
|
||||
FastAPI,
|
||||
File,
|
||||
HTTPException,
|
||||
Request,
|
||||
UploadFile,
|
||||
status,
|
||||
APIRouter,
|
||||
)
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.config import (
|
||||
WHISPER_MODEL_AUTO_UPDATE,
|
||||
WHISPER_MODEL_DIR,
|
||||
CACHE_DIR,
|
||||
)
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import (
|
||||
ENV,
|
||||
SRC_LOG_LEVELS,
|
||||
DEVICE_TYPE,
|
||||
ENABLE_FORWARD_USER_INFO_HEADERS,
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Constants
|
||||
MAX_FILE_SIZE_MB = 25
|
||||
MAX_FILE_SIZE = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["AUDIO"])
|
||||
|
||||
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
|
||||
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
##########################################
|
||||
#
|
||||
# Utility functions
|
||||
#
|
||||
##########################################
|
||||
|
||||
from pydub import AudioSegment
|
||||
from pydub.utils import mediainfo
|
||||
|
||||
|
||||
def is_mp4_audio(file_path):
|
||||
"""Check if the given file is an MP4 audio file."""
|
||||
if not os.path.isfile(file_path):
|
||||
print(f"File not found: {file_path}")
|
||||
return False
|
||||
|
||||
info = mediainfo(file_path)
|
||||
if (
|
||||
info.get("codec_name") == "aac"
|
||||
and info.get("codec_type") == "audio"
|
||||
and info.get("codec_tag_string") == "mp4a"
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def convert_mp4_to_wav(file_path, output_path):
|
||||
"""Convert MP4 audio file to WAV format."""
|
||||
audio = AudioSegment.from_file(file_path, format="mp4")
|
||||
audio.export(output_path, format="wav")
|
||||
print(f"Converted {file_path} to {output_path}")
|
||||
|
||||
|
||||
def set_faster_whisper_model(model: str, auto_update: bool = False):
|
||||
whisper_model = None
|
||||
if model:
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
faster_whisper_kwargs = {
|
||||
"model_size_or_path": model,
|
||||
"device": DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu",
|
||||
"compute_type": "int8",
|
||||
"download_root": WHISPER_MODEL_DIR,
|
||||
"local_files_only": not auto_update,
|
||||
}
|
||||
|
||||
try:
|
||||
whisper_model = WhisperModel(**faster_whisper_kwargs)
|
||||
except Exception:
|
||||
log.warning(
|
||||
"WhisperModel initialization failed, attempting download with local_files_only=False"
|
||||
)
|
||||
faster_whisper_kwargs["local_files_only"] = False
|
||||
whisper_model = WhisperModel(**faster_whisper_kwargs)
|
||||
return whisper_model
|
||||
|
||||
|
||||
##########################################
|
||||
#
|
||||
# Audio API
|
||||
#
|
||||
##########################################
|
||||
|
||||
|
||||
class TTSConfigForm(BaseModel):
|
||||
OPENAI_API_BASE_URL: str
|
||||
OPENAI_API_KEY: str
|
||||
API_KEY: str
|
||||
ENGINE: str
|
||||
MODEL: str
|
||||
VOICE: str
|
||||
SPLIT_ON: str
|
||||
AZURE_SPEECH_REGION: str
|
||||
AZURE_SPEECH_OUTPUT_FORMAT: str
|
||||
|
||||
|
||||
class STTConfigForm(BaseModel):
|
||||
OPENAI_API_BASE_URL: str
|
||||
OPENAI_API_KEY: str
|
||||
ENGINE: str
|
||||
MODEL: str
|
||||
WHISPER_MODEL: str
|
||||
|
||||
|
||||
class AudioConfigUpdateForm(BaseModel):
|
||||
tts: TTSConfigForm
|
||||
stt: STTConfigForm
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
async def get_audio_config(request: Request, user=Depends(get_admin_user)):
|
||||
return {
|
||||
"tts": {
|
||||
"OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
|
||||
"OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
|
||||
"API_KEY": request.app.state.config.TTS_API_KEY,
|
||||
"ENGINE": request.app.state.config.TTS_ENGINE,
|
||||
"MODEL": request.app.state.config.TTS_MODEL,
|
||||
"VOICE": request.app.state.config.TTS_VOICE,
|
||||
"SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
|
||||
"AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
|
||||
"AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
|
||||
},
|
||||
"stt": {
|
||||
"OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
|
||||
"OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
|
||||
"ENGINE": request.app.state.config.STT_ENGINE,
|
||||
"MODEL": request.app.state.config.STT_MODEL,
|
||||
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.post("/config/update")
|
||||
async def update_audio_config(
|
||||
request: Request, form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
|
||||
):
|
||||
request.app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
|
||||
request.app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
|
||||
request.app.state.config.TTS_API_KEY = form_data.tts.API_KEY
|
||||
request.app.state.config.TTS_ENGINE = form_data.tts.ENGINE
|
||||
request.app.state.config.TTS_MODEL = form_data.tts.MODEL
|
||||
request.app.state.config.TTS_VOICE = form_data.tts.VOICE
|
||||
request.app.state.config.TTS_SPLIT_ON = form_data.tts.SPLIT_ON
|
||||
request.app.state.config.TTS_AZURE_SPEECH_REGION = form_data.tts.AZURE_SPEECH_REGION
|
||||
request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT = (
|
||||
form_data.tts.AZURE_SPEECH_OUTPUT_FORMAT
|
||||
)
|
||||
|
||||
request.app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
|
||||
request.app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
|
||||
request.app.state.config.STT_ENGINE = form_data.stt.ENGINE
|
||||
request.app.state.config.STT_MODEL = form_data.stt.MODEL
|
||||
request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL
|
||||
|
||||
if request.app.state.config.STT_ENGINE == "":
|
||||
request.app.state.faster_whisper_model = set_faster_whisper_model(
|
||||
form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE
|
||||
)
|
||||
|
||||
return {
|
||||
"tts": {
|
||||
"OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL,
|
||||
"OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY,
|
||||
"API_KEY": request.app.state.config.TTS_API_KEY,
|
||||
"ENGINE": request.app.state.config.TTS_ENGINE,
|
||||
"MODEL": request.app.state.config.TTS_MODEL,
|
||||
"VOICE": request.app.state.config.TTS_VOICE,
|
||||
"SPLIT_ON": request.app.state.config.TTS_SPLIT_ON,
|
||||
"AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION,
|
||||
"AZURE_SPEECH_OUTPUT_FORMAT": request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT,
|
||||
},
|
||||
"stt": {
|
||||
"OPENAI_API_BASE_URL": request.app.state.config.STT_OPENAI_API_BASE_URL,
|
||||
"OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY,
|
||||
"ENGINE": request.app.state.config.STT_ENGINE,
|
||||
"MODEL": request.app.state.config.STT_MODEL,
|
||||
"WHISPER_MODEL": request.app.state.config.WHISPER_MODEL,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def load_speech_pipeline(request):
|
||||
from transformers import pipeline
|
||||
from datasets import load_dataset
|
||||
|
||||
if request.app.state.speech_synthesiser is None:
|
||||
request.app.state.speech_synthesiser = pipeline(
|
||||
"text-to-speech", "microsoft/speecht5_tts"
|
||||
)
|
||||
|
||||
if request.app.state.speech_speaker_embeddings_dataset is None:
|
||||
request.app.state.speech_speaker_embeddings_dataset = load_dataset(
|
||||
"Matthijs/cmu-arctic-xvectors", split="validation"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/speech")
|
||||
async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
body = await request.body()
|
||||
name = hashlib.sha256(
|
||||
body
|
||||
+ str(request.app.state.config.TTS_ENGINE).encode("utf-8")
|
||||
+ str(request.app.state.config.TTS_MODEL).encode("utf-8")
|
||||
).hexdigest()
|
||||
|
||||
file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
|
||||
file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
|
||||
|
||||
# Check if the file already exists in the cache
|
||||
if file_path.is_file():
|
||||
return FileResponse(file_path)
|
||||
|
||||
payload = None
|
||||
try:
|
||||
payload = json.loads(body.decode("utf-8"))
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON payload")
|
||||
|
||||
if request.app.state.config.TTS_ENGINE == "openai":
|
||||
payload["model"] = request.app.state.config.TTS_MODEL
|
||||
|
||||
try:
|
||||
# print(payload)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
|
||||
json=payload,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}",
|
||||
**(
|
||||
{
|
||||
"X-OpenWebUI-User-Name": user.name,
|
||||
"X-OpenWebUI-User-Id": user.id,
|
||||
"X-OpenWebUI-User-Email": user.email,
|
||||
"X-OpenWebUI-User-Role": user.role,
|
||||
}
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
else {}
|
||||
),
|
||||
},
|
||||
) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
async with aiofiles.open(file_path, "wb") as f:
|
||||
await f.write(await r.read())
|
||||
|
||||
async with aiofiles.open(file_body_path, "w") as f:
|
||||
await f.write(json.dumps(payload))
|
||||
|
||||
return FileResponse(file_path)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
detail = None
|
||||
|
||||
try:
|
||||
if r.status != 200:
|
||||
res = await r.json()
|
||||
|
||||
if "error" in res:
|
||||
detail = f"External: {res['error'].get('message', '')}"
|
||||
except Exception:
|
||||
detail = f"External: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=getattr(r, "status", 500),
|
||||
detail=detail if detail else "Open WebUI: Server Connection Error",
|
||||
)
|
||||
|
||||
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
|
||||
voice_id = payload.get("voice", "")
|
||||
|
||||
if voice_id not in get_available_voices(request):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid voice id",
|
||||
)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}",
|
||||
json={
|
||||
"text": payload["input"],
|
||||
"model_id": request.app.state.config.TTS_MODEL,
|
||||
"voice_settings": {"stability": 0.5, "similarity_boost": 0.5},
|
||||
},
|
||||
headers={
|
||||
"Accept": "audio/mpeg",
|
||||
"Content-Type": "application/json",
|
||||
"xi-api-key": request.app.state.config.TTS_API_KEY,
|
||||
},
|
||||
) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
async with aiofiles.open(file_path, "wb") as f:
|
||||
await f.write(await r.read())
|
||||
|
||||
async with aiofiles.open(file_body_path, "w") as f:
|
||||
await f.write(json.dumps(payload))
|
||||
|
||||
return FileResponse(file_path)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
detail = None
|
||||
|
||||
try:
|
||||
if r.status != 200:
|
||||
res = await r.json()
|
||||
if "error" in res:
|
||||
detail = f"External: {res['error'].get('message', '')}"
|
||||
except Exception:
|
||||
detail = f"External: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=getattr(r, "status", 500),
|
||||
detail=detail if detail else "Open WebUI: Server Connection Error",
|
||||
)
|
||||
|
||||
elif request.app.state.config.TTS_ENGINE == "azure":
|
||||
try:
|
||||
payload = json.loads(body.decode("utf-8"))
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON payload")
|
||||
|
||||
region = request.app.state.config.TTS_AZURE_SPEECH_REGION
|
||||
language = request.app.state.config.TTS_VOICE
|
||||
locale = "-".join(request.app.state.config.TTS_VOICE.split("-")[:1])
|
||||
output_format = request.app.state.config.TTS_AZURE_SPEECH_OUTPUT_FORMAT
|
||||
|
||||
try:
|
||||
data = f"""<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="{locale}">
|
||||
<voice name="{language}">{payload["input"]}</voice>
|
||||
</speak>"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1",
|
||||
headers={
|
||||
"Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY,
|
||||
"Content-Type": "application/ssml+xml",
|
||||
"X-Microsoft-OutputFormat": output_format,
|
||||
},
|
||||
data=data,
|
||||
) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
async with aiofiles.open(file_path, "wb") as f:
|
||||
await f.write(await r.read())
|
||||
|
||||
async with aiofiles.open(file_body_path, "w") as f:
|
||||
await f.write(json.dumps(payload))
|
||||
|
||||
return FileResponse(file_path)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
detail = None
|
||||
|
||||
try:
|
||||
if r.status != 200:
|
||||
res = await r.json()
|
||||
if "error" in res:
|
||||
detail = f"External: {res['error'].get('message', '')}"
|
||||
except Exception:
|
||||
detail = f"External: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=getattr(r, "status", 500),
|
||||
detail=detail if detail else "Open WebUI: Server Connection Error",
|
||||
)
|
||||
|
||||
elif request.app.state.config.TTS_ENGINE == "transformers":
|
||||
payload = None
|
||||
try:
|
||||
payload = json.loads(body.decode("utf-8"))
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON payload")
|
||||
|
||||
import torch
|
||||
import soundfile as sf
|
||||
|
||||
load_speech_pipeline(request)
|
||||
|
||||
embeddings_dataset = request.app.state.speech_speaker_embeddings_dataset
|
||||
|
||||
speaker_index = 6799
|
||||
try:
|
||||
speaker_index = embeddings_dataset["filename"].index(
|
||||
request.app.state.config.TTS_MODEL
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
speaker_embedding = torch.tensor(
|
||||
embeddings_dataset[speaker_index]["xvector"]
|
||||
).unsqueeze(0)
|
||||
|
||||
speech = request.app.state.speech_synthesiser(
|
||||
payload["input"],
|
||||
forward_params={"speaker_embeddings": speaker_embedding},
|
||||
)
|
||||
|
||||
sf.write(file_path, speech["audio"], samplerate=speech["sampling_rate"])
|
||||
|
||||
async with aiofiles.open(file_body_path, "w") as f:
|
||||
await f.write(json.dumps(payload))
|
||||
|
||||
return FileResponse(file_path)
|
||||
|
||||
|
||||
def transcribe(request: Request, file_path):
|
||||
print("transcribe", file_path)
|
||||
filename = os.path.basename(file_path)
|
||||
file_dir = os.path.dirname(file_path)
|
||||
id = filename.split(".")[0]
|
||||
|
||||
if request.app.state.config.STT_ENGINE == "":
|
||||
if request.app.state.faster_whisper_model is None:
|
||||
request.app.state.faster_whisper_model = set_faster_whisper_model(
|
||||
request.app.state.config.WHISPER_MODEL
|
||||
)
|
||||
|
||||
model = request.app.state.faster_whisper_model
|
||||
segments, info = model.transcribe(file_path, beam_size=5)
|
||||
log.info(
|
||||
"Detected language '%s' with probability %f"
|
||||
% (info.language, info.language_probability)
|
||||
)
|
||||
|
||||
transcript = "".join([segment.text for segment in list(segments)])
|
||||
data = {"text": transcript.strip()}
|
||||
|
||||
# save the transcript to a json file
|
||||
transcript_file = f"{file_dir}/{id}.json"
|
||||
with open(transcript_file, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
log.debug(data)
|
||||
return data
|
||||
elif request.app.state.config.STT_ENGINE == "openai":
|
||||
if is_mp4_audio(file_path):
|
||||
os.rename(file_path, file_path.replace(".wav", ".mp4"))
|
||||
# Convert MP4 audio file to WAV format
|
||||
convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
|
||||
|
||||
r = None
|
||||
try:
|
||||
r = requests.post(
|
||||
url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}"
|
||||
},
|
||||
files={"file": (filename, open(file_path, "rb"))},
|
||||
data={"model": request.app.state.config.STT_MODEL},
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
|
||||
# save the transcript to a json file
|
||||
transcript_file = f"{file_dir}/{id}.json"
|
||||
with open(transcript_file, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
return data
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
detail = None
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
detail = f"External: {res['error'].get('message', '')}"
|
||||
except Exception:
|
||||
detail = f"External: {e}"
|
||||
|
||||
raise Exception(detail if detail else "Open WebUI: Server Connection Error")
|
||||
|
||||
|
||||
def compress_audio(file_path):
|
||||
if os.path.getsize(file_path) > MAX_FILE_SIZE:
|
||||
file_dir = os.path.dirname(file_path)
|
||||
audio = AudioSegment.from_file(file_path)
|
||||
audio = audio.set_frame_rate(16000).set_channels(1) # Compress audio
|
||||
compressed_path = f"{file_dir}/{id}_compressed.opus"
|
||||
audio.export(compressed_path, format="opus", bitrate="32k")
|
||||
log.debug(f"Compressed audio to {compressed_path}")
|
||||
|
||||
if (
|
||||
os.path.getsize(compressed_path) > MAX_FILE_SIZE
|
||||
): # Still larger than MAX_FILE_SIZE after compression
|
||||
raise Exception(ERROR_MESSAGES.FILE_TOO_LARGE(size=f"{MAX_FILE_SIZE_MB}MB"))
|
||||
return compressed_path
|
||||
else:
|
||||
return file_path
|
||||
|
||||
|
||||
@router.post("/transcriptions")
|
||||
def transcription(
|
||||
request: Request,
|
||||
file: UploadFile = File(...),
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
log.info(f"file.content_type: {file.content_type}")
|
||||
|
||||
if file.content_type not in ["audio/mpeg", "audio/wav", "audio/ogg", "audio/x-m4a"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
|
||||
)
|
||||
|
||||
try:
|
||||
ext = file.filename.split(".")[-1]
|
||||
id = uuid.uuid4()
|
||||
|
||||
filename = f"{id}.{ext}"
|
||||
contents = file.file.read()
|
||||
|
||||
file_dir = f"{CACHE_DIR}/audio/transcriptions"
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
file_path = f"{file_dir}/{filename}"
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(contents)
|
||||
|
||||
try:
|
||||
try:
|
||||
file_path = compress_audio(file_path)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
|
||||
data = transcribe(request, file_path)
|
||||
file_path = file_path.split("/")[-1]
|
||||
return {**data, "filename": file_path}
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
|
||||
|
||||
def get_available_models(request: Request) -> list[dict]:
|
||||
available_models = []
|
||||
if request.app.state.config.TTS_ENGINE == "openai":
|
||||
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
|
||||
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
|
||||
try:
|
||||
response = requests.get(
|
||||
"https://api.elevenlabs.io/v1/models",
|
||||
headers={
|
||||
"xi-api-key": request.app.state.config.TTS_API_KEY,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
timeout=5,
|
||||
)
|
||||
response.raise_for_status()
|
||||
models = response.json()
|
||||
|
||||
available_models = [
|
||||
{"name": model["name"], "id": model["model_id"]} for model in models
|
||||
]
|
||||
except requests.RequestException as e:
|
||||
log.error(f"Error fetching voices: {str(e)}")
|
||||
return available_models
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
async def get_models(request: Request, user=Depends(get_verified_user)):
|
||||
return {"models": get_available_models(request)}
|
||||
|
||||
|
||||
def get_available_voices(request) -> dict:
|
||||
"""Returns {voice_id: voice_name} dict"""
|
||||
available_voices = {}
|
||||
if request.app.state.config.TTS_ENGINE == "openai":
|
||||
available_voices = {
|
||||
"alloy": "alloy",
|
||||
"echo": "echo",
|
||||
"fable": "fable",
|
||||
"onyx": "onyx",
|
||||
"nova": "nova",
|
||||
"shimmer": "shimmer",
|
||||
}
|
||||
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
|
||||
try:
|
||||
available_voices = get_elevenlabs_voices(
|
||||
api_key=request.app.state.config.TTS_API_KEY
|
||||
)
|
||||
except Exception:
|
||||
# Avoided @lru_cache with exception
|
||||
pass
|
||||
elif request.app.state.config.TTS_ENGINE == "azure":
|
||||
try:
|
||||
region = request.app.state.config.TTS_AZURE_SPEECH_REGION
|
||||
url = f"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list"
|
||||
headers = {
|
||||
"Ocp-Apim-Subscription-Key": request.app.state.config.TTS_API_KEY
|
||||
}
|
||||
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
voices = response.json()
|
||||
|
||||
for voice in voices:
|
||||
available_voices[voice["ShortName"]] = (
|
||||
f"{voice['DisplayName']} ({voice['ShortName']})"
|
||||
)
|
||||
except requests.RequestException as e:
|
||||
log.error(f"Error fetching voices: {str(e)}")
|
||||
|
||||
return available_voices
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_elevenlabs_voices(api_key: str) -> dict:
|
||||
"""
|
||||
Note, set the following in your .env file to use Elevenlabs:
|
||||
AUDIO_TTS_ENGINE=elevenlabs
|
||||
AUDIO_TTS_API_KEY=sk_... # Your Elevenlabs API key
|
||||
AUDIO_TTS_VOICE=EXAVITQu4vr4xnSDxMaL # From https://api.elevenlabs.io/v1/voices
|
||||
AUDIO_TTS_MODEL=eleven_multilingual_v2
|
||||
"""
|
||||
|
||||
try:
|
||||
# TODO: Add retries
|
||||
response = requests.get(
|
||||
"https://api.elevenlabs.io/v1/voices",
|
||||
headers={
|
||||
"xi-api-key": api_key,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
voices_data = response.json()
|
||||
|
||||
voices = {}
|
||||
for voice in voices_data.get("voices", []):
|
||||
voices[voice["voice_id"]] = voice["name"]
|
||||
except requests.RequestException as e:
|
||||
# Avoid @lru_cache with exception
|
||||
log.error(f"Error fetching voices: {str(e)}")
|
||||
raise RuntimeError(f"Error fetching voices: {str(e)}")
|
||||
|
||||
return voices
|
||||
|
||||
|
||||
@router.get("/voices")
|
||||
async def get_voices(request: Request, user=Depends(get_verified_user)):
|
||||
return {
|
||||
"voices": [
|
||||
{"id": k, "name": v} for k, v in get_available_voices(request).items()
|
||||
]
|
||||
}
|
||||
@@ -3,8 +3,9 @@ import uuid
|
||||
import time
|
||||
import datetime
|
||||
import logging
|
||||
from aiohttp import ClientSession
|
||||
|
||||
from open_webui.apps.webui.models.auths import (
|
||||
from open_webui.models.auths import (
|
||||
AddUserForm,
|
||||
ApiKey,
|
||||
Auths,
|
||||
@@ -17,7 +18,7 @@ from open_webui.apps.webui.models.auths import (
|
||||
UpdateProfileForm,
|
||||
UserResponse,
|
||||
)
|
||||
from open_webui.apps.webui.models.users import Users
|
||||
from open_webui.models.users import Users
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
||||
from open_webui.env import (
|
||||
@@ -29,10 +30,14 @@ from open_webui.env import (
|
||||
SRC_LOG_LEVELS,
|
||||
)
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import Response
|
||||
from fastapi.responses import RedirectResponse, Response
|
||||
from open_webui.config import (
|
||||
OPENID_PROVIDER_URL,
|
||||
ENABLE_OAUTH_SIGNUP,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from open_webui.utils.misc import parse_duration, validate_email_format
|
||||
from open_webui.utils.utils import (
|
||||
from open_webui.utils.auth import (
|
||||
create_api_key,
|
||||
create_token,
|
||||
get_admin_user,
|
||||
@@ -246,11 +251,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||
)
|
||||
|
||||
user = Auths.insert_new_auth(
|
||||
mail,
|
||||
str(uuid.uuid4()),
|
||||
cn,
|
||||
None,
|
||||
role,
|
||||
email=mail, password=str(uuid.uuid4()), name=cn, role=role
|
||||
)
|
||||
|
||||
if not user:
|
||||
@@ -502,8 +503,31 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||
|
||||
|
||||
@router.get("/signout")
|
||||
async def signout(response: Response):
|
||||
async def signout(request: Request, response: Response):
|
||||
response.delete_cookie("token")
|
||||
|
||||
if ENABLE_OAUTH_SIGNUP.value:
|
||||
oauth_id_token = request.cookies.get("oauth_id_token")
|
||||
if oauth_id_token:
|
||||
try:
|
||||
async with ClientSession() as session:
|
||||
async with session.get(OPENID_PROVIDER_URL.value) as resp:
|
||||
if resp.status == 200:
|
||||
openid_data = await resp.json()
|
||||
logout_url = openid_data.get("end_session_endpoint")
|
||||
if logout_url:
|
||||
response.delete_cookie("oauth_id_token")
|
||||
return RedirectResponse(
|
||||
url=f"{logout_url}?id_token_hint={oauth_id_token}"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=resp.status,
|
||||
detail="Failed to fetch OpenID configuration",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
return {"status": True}
|
||||
|
||||
|
||||
@@ -523,7 +547,6 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||
|
||||
try:
|
||||
print(form_data)
|
||||
hashed = get_password_hash(form_data.password)
|
||||
user = Auths.insert_new_auth(
|
||||
form_data.email.lower(),
|
||||
@@ -590,8 +613,12 @@ async def get_admin_details(request: Request, user=Depends(get_current_user)):
|
||||
async def get_admin_config(request: Request, user=Depends(get_admin_user)):
|
||||
return {
|
||||
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
|
||||
"WEBUI_URL": request.app.state.config.WEBUI_URL,
|
||||
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
|
||||
"ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
|
||||
"ENABLE_API_KEY_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS,
|
||||
"API_KEY_ALLOWED_ENDPOINTS": request.app.state.config.API_KEY_ALLOWED_ENDPOINTS,
|
||||
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
|
||||
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
|
||||
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
|
||||
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
|
||||
@@ -601,8 +628,12 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)):
|
||||
|
||||
class AdminConfig(BaseModel):
|
||||
SHOW_ADMIN_DETAILS: bool
|
||||
WEBUI_URL: str
|
||||
ENABLE_SIGNUP: bool
|
||||
ENABLE_API_KEY: bool
|
||||
ENABLE_API_KEY_ENDPOINT_RESTRICTIONS: bool
|
||||
API_KEY_ALLOWED_ENDPOINTS: str
|
||||
ENABLE_CHANNELS: bool
|
||||
DEFAULT_USER_ROLE: str
|
||||
JWT_EXPIRES_IN: str
|
||||
ENABLE_COMMUNITY_SHARING: bool
|
||||
@@ -614,8 +645,18 @@ async def update_admin_config(
|
||||
request: Request, form_data: AdminConfig, user=Depends(get_admin_user)
|
||||
):
|
||||
request.app.state.config.SHOW_ADMIN_DETAILS = form_data.SHOW_ADMIN_DETAILS
|
||||
request.app.state.config.WEBUI_URL = form_data.WEBUI_URL
|
||||
request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP
|
||||
|
||||
request.app.state.config.ENABLE_API_KEY = form_data.ENABLE_API_KEY
|
||||
request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS = (
|
||||
form_data.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS
|
||||
)
|
||||
request.app.state.config.API_KEY_ALLOWED_ENDPOINTS = (
|
||||
form_data.API_KEY_ALLOWED_ENDPOINTS
|
||||
)
|
||||
|
||||
request.app.state.config.ENABLE_CHANNELS = form_data.ENABLE_CHANNELS
|
||||
|
||||
if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
|
||||
request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
|
||||
@@ -633,8 +674,12 @@ async def update_admin_config(
|
||||
|
||||
return {
|
||||
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
|
||||
"WEBUI_URL": request.app.state.config.WEBUI_URL,
|
||||
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
|
||||
"ENABLE_API_KEY": request.app.state.config.ENABLE_API_KEY,
|
||||
"ENABLE_API_KEY_ENDPOINT_RESTRICTIONS": request.app.state.config.ENABLE_API_KEY_ENDPOINT_RESTRICTIONS,
|
||||
"API_KEY_ALLOWED_ENDPOINTS": request.app.state.config.API_KEY_ALLOWED_ENDPOINTS,
|
||||
"ENABLE_CHANNELS": request.app.state.config.ENABLE_CHANNELS,
|
||||
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
|
||||
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
|
||||
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
|
||||
394
backend/open_webui/routers/channels.py
Normal file
394
backend/open_webui/routers/channels.py
Normal file
@@ -0,0 +1,394 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status, BackgroundTasks
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
from open_webui.socket.main import sio, get_user_ids_from_room
|
||||
from open_webui.models.users import Users, UserNameResponse
|
||||
|
||||
from open_webui.models.channels import Channels, ChannelModel, ChannelForm
|
||||
from open_webui.models.messages import Messages, MessageModel, MessageForm
|
||||
|
||||
|
||||
from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.access_control import has_access, get_users_with_access
|
||||
from open_webui.utils.webhook import post_webhook
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
############################
|
||||
# GetChatList
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/", response_model=list[ChannelModel])
|
||||
async def get_channels(user=Depends(get_verified_user)):
|
||||
if user.role == "admin":
|
||||
return Channels.get_channels()
|
||||
else:
|
||||
return Channels.get_channels_by_user_id(user.id)
|
||||
|
||||
|
||||
############################
|
||||
# CreateNewChannel
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/create", response_model=Optional[ChannelModel])
|
||||
async def create_new_channel(form_data: ChannelForm, user=Depends(get_admin_user)):
|
||||
try:
|
||||
channel = Channels.insert_new_channel(form_data, user.id)
|
||||
return ChannelModel(**channel.model_dump())
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetChannelById
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/{id}", response_model=Optional[ChannelModel])
|
||||
async def get_channel_by_id(id: str, user=Depends(get_verified_user)):
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if user.role != "admin" and not has_access(
|
||||
user.id, type="read", access_control=channel.access_control
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
return ChannelModel(**channel.model_dump())
|
||||
|
||||
|
||||
############################
|
||||
# UpdateChannelById
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/{id}/update", response_model=Optional[ChannelModel])
|
||||
async def update_channel_by_id(
|
||||
id: str, form_data: ChannelForm, user=Depends(get_admin_user)
|
||||
):
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
try:
|
||||
channel = Channels.update_channel_by_id(id, form_data)
|
||||
return ChannelModel(**channel.model_dump())
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# DeleteChannelById
|
||||
############################
|
||||
|
||||
|
||||
@router.delete("/{id}/delete", response_model=bool)
|
||||
async def delete_channel_by_id(id: str, user=Depends(get_admin_user)):
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
try:
|
||||
Channels.delete_channel_by_id(id)
|
||||
return True
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetChannelMessages
|
||||
############################
|
||||
|
||||
|
||||
class MessageUserModel(MessageModel):
|
||||
user: UserNameResponse
|
||||
|
||||
|
||||
@router.get("/{id}/messages", response_model=list[MessageUserModel])
|
||||
async def get_channel_messages(
|
||||
id: str, skip: int = 0, limit: int = 50, user=Depends(get_verified_user)
|
||||
):
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if user.role != "admin" and not has_access(
|
||||
user.id, type="read", access_control=channel.access_control
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
message_list = Messages.get_messages_by_channel_id(id, skip, limit)
|
||||
users = {}
|
||||
|
||||
messages = []
|
||||
for message in message_list:
|
||||
if message.user_id not in users:
|
||||
user = Users.get_user_by_id(message.user_id)
|
||||
users[message.user_id] = user
|
||||
|
||||
messages.append(
|
||||
MessageUserModel(
|
||||
**{
|
||||
**message.model_dump(),
|
||||
"user": UserNameResponse(**users[message.user_id].model_dump()),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
############################
|
||||
# PostNewMessage
|
||||
############################
|
||||
|
||||
|
||||
async def send_notification(webui_url, channel, message, active_user_ids):
|
||||
users = get_users_with_access("read", channel.access_control)
|
||||
|
||||
for user in users:
|
||||
if user.id in active_user_ids:
|
||||
continue
|
||||
else:
|
||||
if user.settings:
|
||||
webhook_url = user.settings.ui.get("notifications", {}).get(
|
||||
"webhook_url", None
|
||||
)
|
||||
|
||||
if webhook_url:
|
||||
post_webhook(
|
||||
webhook_url,
|
||||
f"#{channel.name} - {webui_url}/channels/{channel.id}\n\n{message.content}",
|
||||
{
|
||||
"action": "channel",
|
||||
"message": message.content,
|
||||
"title": channel.name,
|
||||
"url": f"{webui_url}/channels/{channel.id}",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{id}/messages/post", response_model=Optional[MessageModel])
|
||||
async def post_new_message(
|
||||
request: Request,
|
||||
id: str,
|
||||
form_data: MessageForm,
|
||||
background_tasks: BackgroundTasks,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if user.role != "admin" and not has_access(
|
||||
user.id, type="read", access_control=channel.access_control
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
try:
|
||||
message = Messages.insert_new_message(form_data, channel.id, user.id)
|
||||
|
||||
if message:
|
||||
event_data = {
|
||||
"channel_id": channel.id,
|
||||
"message_id": message.id,
|
||||
"data": {
|
||||
"type": "message",
|
||||
"data": {
|
||||
**message.model_dump(),
|
||||
"user": UserNameResponse(**user.model_dump()).model_dump(),
|
||||
},
|
||||
},
|
||||
"user": UserNameResponse(**user.model_dump()).model_dump(),
|
||||
"channel": channel.model_dump(),
|
||||
}
|
||||
|
||||
await sio.emit(
|
||||
"channel-events",
|
||||
event_data,
|
||||
to=f"channel:{channel.id}",
|
||||
)
|
||||
|
||||
active_user_ids = get_user_ids_from_room(f"channel:{channel.id}")
|
||||
|
||||
background_tasks.add_task(
|
||||
send_notification,
|
||||
request.app.state.config.WEBUI_URL,
|
||||
channel,
|
||||
message,
|
||||
active_user_ids,
|
||||
)
|
||||
|
||||
return MessageModel(**message.model_dump())
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# UpdateMessageById
|
||||
############################
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{id}/messages/{message_id}/update", response_model=Optional[MessageModel]
|
||||
)
|
||||
async def update_message_by_id(
|
||||
id: str, message_id: str, form_data: MessageForm, user=Depends(get_verified_user)
|
||||
):
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if user.role != "admin" and not has_access(
|
||||
user.id, type="read", access_control=channel.access_control
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
message = Messages.get_message_by_id(message_id)
|
||||
if not message:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if message.channel_id != id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
try:
|
||||
message = Messages.update_message_by_id(message_id, form_data)
|
||||
if message:
|
||||
await sio.emit(
|
||||
"channel-events",
|
||||
{
|
||||
"channel_id": channel.id,
|
||||
"message_id": message.id,
|
||||
"data": {
|
||||
"type": "message:update",
|
||||
"data": {
|
||||
**message.model_dump(),
|
||||
"user": UserNameResponse(**user.model_dump()).model_dump(),
|
||||
},
|
||||
},
|
||||
"user": UserNameResponse(**user.model_dump()).model_dump(),
|
||||
"channel": channel.model_dump(),
|
||||
},
|
||||
to=f"channel:{channel.id}",
|
||||
)
|
||||
|
||||
return MessageModel(**message.model_dump())
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# DeleteMessageById
|
||||
############################
|
||||
|
||||
|
||||
@router.delete("/{id}/messages/{message_id}/delete", response_model=bool)
|
||||
async def delete_message_by_id(
|
||||
id: str, message_id: str, user=Depends(get_verified_user)
|
||||
):
|
||||
channel = Channels.get_channel_by_id(id)
|
||||
if not channel:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if user.role != "admin" and not has_access(
|
||||
user.id, type="read", access_control=channel.access_control
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
message = Messages.get_message_by_id(message_id)
|
||||
if not message:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
)
|
||||
|
||||
if message.channel_id != id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
try:
|
||||
Messages.delete_message_by_id(message_id)
|
||||
await sio.emit(
|
||||
"channel-events",
|
||||
{
|
||||
"channel_id": channel.id,
|
||||
"message_id": message.id,
|
||||
"data": {
|
||||
"type": "message:delete",
|
||||
"data": {
|
||||
**message.model_dump(),
|
||||
"user": UserNameResponse(**user.model_dump()).model_dump(),
|
||||
},
|
||||
},
|
||||
"user": UserNameResponse(**user.model_dump()).model_dump(),
|
||||
"channel": channel.model_dump(),
|
||||
},
|
||||
to=f"channel:{channel.id}",
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
@@ -2,15 +2,15 @@ import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.apps.webui.models.chats import (
|
||||
from open_webui.models.chats import (
|
||||
ChatForm,
|
||||
ChatImportForm,
|
||||
ChatResponse,
|
||||
Chats,
|
||||
ChatTitleIdResponse,
|
||||
)
|
||||
from open_webui.apps.webui.models.tags import TagModel, Tags
|
||||
from open_webui.apps.webui.models.folders import Folders
|
||||
from open_webui.models.tags import TagModel, Tags
|
||||
from open_webui.models.folders import Folders
|
||||
|
||||
from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
@@ -19,7 +19,7 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
from open_webui.utils.utils import get_admin_user, get_verified_user
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.access_control import has_permission
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -607,7 +607,6 @@ async def add_tag_by_id_and_tag_name(
|
||||
detail=ERROR_MESSAGES.DEFAULT("Tag name cannot be 'None'"),
|
||||
)
|
||||
|
||||
print(tags, tag_id)
|
||||
if tag_id not in tags:
|
||||
Chats.add_chat_tag_by_id_and_user_id_and_tag_name(
|
||||
id, user.id, form_data.name
|
||||
@@ -1,10 +1,12 @@
|
||||
from open_webui.config import BannerModel
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from pydantic import BaseModel
|
||||
from open_webui.utils.utils import get_admin_user, get_verified_user
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.config import get_config, save_config
|
||||
from open_webui.config import BannerModel
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -34,8 +36,32 @@ async def export_config(user=Depends(get_admin_user)):
|
||||
return get_config()
|
||||
|
||||
|
||||
class SetDefaultModelsForm(BaseModel):
|
||||
models: str
|
||||
############################
|
||||
# SetDefaultModels
|
||||
############################
|
||||
class ModelsConfigForm(BaseModel):
|
||||
DEFAULT_MODELS: Optional[str]
|
||||
MODEL_ORDER_LIST: Optional[list[str]]
|
||||
|
||||
|
||||
@router.get("/models", response_model=ModelsConfigForm)
|
||||
async def get_models_config(request: Request, user=Depends(get_admin_user)):
|
||||
return {
|
||||
"DEFAULT_MODELS": request.app.state.config.DEFAULT_MODELS,
|
||||
"MODEL_ORDER_LIST": request.app.state.config.MODEL_ORDER_LIST,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/models", response_model=ModelsConfigForm)
|
||||
async def set_models_config(
|
||||
request: Request, form_data: ModelsConfigForm, user=Depends(get_admin_user)
|
||||
):
|
||||
request.app.state.config.DEFAULT_MODELS = form_data.DEFAULT_MODELS
|
||||
request.app.state.config.MODEL_ORDER_LIST = form_data.MODEL_ORDER_LIST
|
||||
return {
|
||||
"DEFAULT_MODELS": request.app.state.config.DEFAULT_MODELS,
|
||||
"MODEL_ORDER_LIST": request.app.state.config.MODEL_ORDER_LIST,
|
||||
}
|
||||
|
||||
|
||||
class PromptSuggestion(BaseModel):
|
||||
@@ -47,21 +73,8 @@ class SetDefaultSuggestionsForm(BaseModel):
|
||||
suggestions: list[PromptSuggestion]
|
||||
|
||||
|
||||
############################
|
||||
# SetDefaultModels
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/default/models", response_model=str)
|
||||
async def set_global_default_models(
|
||||
request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user)
|
||||
):
|
||||
request.app.state.config.DEFAULT_MODELS = form_data.models
|
||||
return request.app.state.config.DEFAULT_MODELS
|
||||
|
||||
|
||||
@router.post("/default/suggestions", response_model=list[PromptSuggestion])
|
||||
async def set_global_default_suggestions(
|
||||
@router.post("/suggestions", response_model=list[PromptSuggestion])
|
||||
async def set_default_suggestions(
|
||||
request: Request,
|
||||
form_data: SetDefaultSuggestionsForm,
|
||||
user=Depends(get_admin_user),
|
||||
@@ -2,8 +2,8 @@ from typing import Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from open_webui.apps.webui.models.users import Users, UserModel
|
||||
from open_webui.apps.webui.models.feedbacks import (
|
||||
from open_webui.models.users import Users, UserModel
|
||||
from open_webui.models.feedbacks import (
|
||||
FeedbackModel,
|
||||
FeedbackResponse,
|
||||
FeedbackForm,
|
||||
@@ -11,7 +11,7 @@ from open_webui.apps.webui.models.feedbacks import (
|
||||
)
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.utils.utils import get_admin_user, get_verified_user
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -5,27 +5,28 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
import mimetypes
|
||||
from urllib.parse import quote
|
||||
|
||||
from open_webui.storage.provider import Storage
|
||||
|
||||
from open_webui.apps.webui.models.files import (
|
||||
from open_webui.models.files import (
|
||||
FileForm,
|
||||
FileModel,
|
||||
FileModelResponse,
|
||||
Files,
|
||||
)
|
||||
from open_webui.apps.retrieval.main import process_file, ProcessFileForm
|
||||
from open_webui.routers.retrieval import process_file, ProcessFileForm
|
||||
|
||||
from open_webui.config import UPLOAD_DIR
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
|
||||
|
||||
from open_webui.utils.utils import get_admin_user, get_verified_user
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
@@ -39,7 +40,9 @@ router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/", response_model=FileModelResponse)
|
||||
def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
|
||||
def upload_file(
|
||||
request: Request, file: UploadFile = File(...), user=Depends(get_verified_user)
|
||||
):
|
||||
log.info(f"file.content_type: {file.content_type}")
|
||||
try:
|
||||
unsanitized_filename = file.filename
|
||||
@@ -68,7 +71,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
|
||||
)
|
||||
|
||||
try:
|
||||
process_file(ProcessFileForm(file_id=id))
|
||||
process_file(request, ProcessFileForm(file_id=id))
|
||||
file_item = Files.get_file_by_id(id=id)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
@@ -183,13 +186,15 @@ class ContentForm(BaseModel):
|
||||
|
||||
@router.post("/{id}/data/content/update")
|
||||
async def update_file_data_content_by_id(
|
||||
id: str, form_data: ContentForm, user=Depends(get_verified_user)
|
||||
request: Request, id: str, form_data: ContentForm, user=Depends(get_verified_user)
|
||||
):
|
||||
file = Files.get_file_by_id(id)
|
||||
|
||||
if file and (file.user_id == user.id or user.role == "admin"):
|
||||
try:
|
||||
process_file(ProcessFileForm(file_id=id, content=form_data.content))
|
||||
process_file(
|
||||
request, ProcessFileForm(file_id=id, content=form_data.content)
|
||||
)
|
||||
file = Files.get_file_by_id(id=id)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
@@ -218,11 +223,22 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
|
||||
# Check if the file already exists in the cache
|
||||
if file_path.is_file():
|
||||
print(f"file_path: {file_path}")
|
||||
headers = {
|
||||
"Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"'
|
||||
}
|
||||
# Handle Unicode filenames
|
||||
filename = file.meta.get("name", file.filename)
|
||||
encoded_filename = quote(filename) # RFC5987 encoding
|
||||
|
||||
headers = {}
|
||||
if file.meta.get("content_type") not in [
|
||||
"application/pdf",
|
||||
"text/plain",
|
||||
]:
|
||||
headers = {
|
||||
**headers,
|
||||
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
|
||||
}
|
||||
|
||||
return FileResponse(file_path, headers=headers)
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
@@ -279,16 +295,20 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
|
||||
if file and (file.user_id == user.id or user.role == "admin"):
|
||||
file_path = file.path
|
||||
|
||||
# Handle Unicode filenames
|
||||
filename = file.meta.get("name", file.filename)
|
||||
encoded_filename = quote(filename) # RFC5987 encoding
|
||||
headers = {
|
||||
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
}
|
||||
|
||||
if file_path:
|
||||
file_path = Storage.get_file(file_path)
|
||||
file_path = Path(file_path)
|
||||
|
||||
# Check if the file already exists in the cache
|
||||
if file_path.is_file():
|
||||
print(f"file_path: {file_path}")
|
||||
headers = {
|
||||
"Content-Disposition": f'attachment; filename="{file.meta.get("name", file.filename)}"'
|
||||
}
|
||||
return FileResponse(file_path, headers=headers)
|
||||
else:
|
||||
raise HTTPException(
|
||||
@@ -307,7 +327,7 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
return StreamingResponse(
|
||||
generator(),
|
||||
media_type="text/plain",
|
||||
headers={"Content-Disposition": f"attachment; filename={file_name}"},
|
||||
headers=headers,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
@@ -8,12 +8,12 @@ from pydantic import BaseModel
|
||||
import mimetypes
|
||||
|
||||
|
||||
from open_webui.apps.webui.models.folders import (
|
||||
from open_webui.models.folders import (
|
||||
FolderForm,
|
||||
FolderModel,
|
||||
Folders,
|
||||
)
|
||||
from open_webui.apps.webui.models.chats import Chats
|
||||
from open_webui.models.chats import Chats
|
||||
|
||||
from open_webui.config import UPLOAD_DIR
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
@@ -24,7 +24,7 @@ from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
|
||||
|
||||
from open_webui.utils.utils import get_admin_user, get_verified_user
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
@@ -2,17 +2,17 @@ import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.apps.webui.models.functions import (
|
||||
from open_webui.models.functions import (
|
||||
FunctionForm,
|
||||
FunctionModel,
|
||||
FunctionResponse,
|
||||
Functions,
|
||||
)
|
||||
from open_webui.apps.webui.utils import load_function_module_by_id, replace_imports
|
||||
from open_webui.utils.plugin import load_function_module_by_id, replace_imports
|
||||
from open_webui.config import CACHE_DIR
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from open_webui.utils.utils import get_admin_user, get_verified_user
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.apps.webui.models.groups import (
|
||||
from open_webui.models.groups import (
|
||||
Groups,
|
||||
GroupForm,
|
||||
GroupUpdateForm,
|
||||
@@ -12,7 +12,7 @@ from open_webui.apps.webui.models.groups import (
|
||||
from open_webui.config import CACHE_DIR
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from open_webui.utils.utils import get_admin_user, get_verified_user
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -9,38 +9,24 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from open_webui.apps.images.utils.comfyui import (
|
||||
|
||||
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
from open_webui.config import CACHE_DIR
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
|
||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.utils.images.comfyui import (
|
||||
ComfyUIGenerateImageForm,
|
||||
ComfyUIWorkflow,
|
||||
comfyui_generate_image,
|
||||
)
|
||||
from open_webui.config import (
|
||||
AUTOMATIC1111_API_AUTH,
|
||||
AUTOMATIC1111_BASE_URL,
|
||||
AUTOMATIC1111_CFG_SCALE,
|
||||
AUTOMATIC1111_SAMPLER,
|
||||
AUTOMATIC1111_SCHEDULER,
|
||||
CACHE_DIR,
|
||||
COMFYUI_BASE_URL,
|
||||
COMFYUI_WORKFLOW,
|
||||
COMFYUI_WORKFLOW_NODES,
|
||||
CORS_ALLOW_ORIGIN,
|
||||
ENABLE_IMAGE_GENERATION,
|
||||
IMAGE_GENERATION_ENGINE,
|
||||
IMAGE_GENERATION_MODEL,
|
||||
IMAGE_SIZE,
|
||||
IMAGE_STEPS,
|
||||
IMAGES_OPENAI_API_BASE_URL,
|
||||
IMAGES_OPENAI_API_KEY,
|
||||
AppConfig,
|
||||
)
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import ENV, SRC_LOG_LEVELS, ENABLE_FORWARD_USER_INFO_HEADERS
|
||||
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
from open_webui.utils.utils import get_admin_user, get_verified_user
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
|
||||
@@ -48,63 +34,31 @@ log.setLevel(SRC_LOG_LEVELS["IMAGES"])
|
||||
IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
|
||||
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
app = FastAPI(
|
||||
docs_url="/docs" if ENV == "dev" else None,
|
||||
openapi_url="/openapi.json" if ENV == "dev" else None,
|
||||
redoc_url=None,
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=CORS_ALLOW_ORIGIN,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.state.config = AppConfig()
|
||||
|
||||
app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
|
||||
app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
|
||||
|
||||
app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
|
||||
app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
|
||||
|
||||
app.state.config.MODEL = IMAGE_GENERATION_MODEL
|
||||
|
||||
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
|
||||
app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
|
||||
app.state.config.AUTOMATIC1111_CFG_SCALE = AUTOMATIC1111_CFG_SCALE
|
||||
app.state.config.AUTOMATIC1111_SAMPLER = AUTOMATIC1111_SAMPLER
|
||||
app.state.config.AUTOMATIC1111_SCHEDULER = AUTOMATIC1111_SCHEDULER
|
||||
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
|
||||
app.state.config.COMFYUI_WORKFLOW = COMFYUI_WORKFLOW
|
||||
app.state.config.COMFYUI_WORKFLOW_NODES = COMFYUI_WORKFLOW_NODES
|
||||
|
||||
app.state.config.IMAGE_SIZE = IMAGE_SIZE
|
||||
app.state.config.IMAGE_STEPS = IMAGE_STEPS
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@app.get("/config")
|
||||
@router.get("/config")
|
||||
async def get_config(request: Request, user=Depends(get_admin_user)):
|
||||
return {
|
||||
"enabled": app.state.config.ENABLED,
|
||||
"engine": app.state.config.ENGINE,
|
||||
"enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
|
||||
"engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
|
||||
"openai": {
|
||||
"OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
|
||||
"OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
|
||||
"OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
|
||||
"OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
|
||||
},
|
||||
"automatic1111": {
|
||||
"AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
|
||||
"AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
|
||||
"AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE,
|
||||
"AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER,
|
||||
"AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER,
|
||||
"AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
|
||||
"AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
|
||||
"AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE,
|
||||
"AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER,
|
||||
"AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER,
|
||||
},
|
||||
"comfyui": {
|
||||
"COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
|
||||
"COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW,
|
||||
"COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES,
|
||||
"COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
|
||||
"COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY,
|
||||
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
|
||||
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -117,13 +71,14 @@ class OpenAIConfigForm(BaseModel):
|
||||
class Automatic1111ConfigForm(BaseModel):
|
||||
AUTOMATIC1111_BASE_URL: str
|
||||
AUTOMATIC1111_API_AUTH: str
|
||||
AUTOMATIC1111_CFG_SCALE: Optional[str]
|
||||
AUTOMATIC1111_CFG_SCALE: Optional[str | float | int]
|
||||
AUTOMATIC1111_SAMPLER: Optional[str]
|
||||
AUTOMATIC1111_SCHEDULER: Optional[str]
|
||||
|
||||
|
||||
class ComfyUIConfigForm(BaseModel):
|
||||
COMFYUI_BASE_URL: str
|
||||
COMFYUI_API_KEY: str
|
||||
COMFYUI_WORKFLOW: str
|
||||
COMFYUI_WORKFLOW_NODES: list[dict]
|
||||
|
||||
@@ -136,133 +91,157 @@ class ConfigForm(BaseModel):
|
||||
comfyui: ComfyUIConfigForm
|
||||
|
||||
|
||||
@app.post("/config/update")
|
||||
async def update_config(form_data: ConfigForm, user=Depends(get_admin_user)):
|
||||
app.state.config.ENGINE = form_data.engine
|
||||
app.state.config.ENABLED = form_data.enabled
|
||||
@router.post("/config/update")
|
||||
async def update_config(
|
||||
request: Request, form_data: ConfigForm, user=Depends(get_admin_user)
|
||||
):
|
||||
request.app.state.config.IMAGE_GENERATION_ENGINE = form_data.engine
|
||||
request.app.state.config.ENABLE_IMAGE_GENERATION = form_data.enabled
|
||||
|
||||
app.state.config.OPENAI_API_BASE_URL = form_data.openai.OPENAI_API_BASE_URL
|
||||
app.state.config.OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
|
||||
request.app.state.config.IMAGES_OPENAI_API_BASE_URL = (
|
||||
form_data.openai.OPENAI_API_BASE_URL
|
||||
)
|
||||
request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY
|
||||
|
||||
app.state.config.AUTOMATIC1111_BASE_URL = (
|
||||
request.app.state.config.AUTOMATIC1111_BASE_URL = (
|
||||
form_data.automatic1111.AUTOMATIC1111_BASE_URL
|
||||
)
|
||||
app.state.config.AUTOMATIC1111_API_AUTH = (
|
||||
request.app.state.config.AUTOMATIC1111_API_AUTH = (
|
||||
form_data.automatic1111.AUTOMATIC1111_API_AUTH
|
||||
)
|
||||
|
||||
app.state.config.AUTOMATIC1111_CFG_SCALE = (
|
||||
request.app.state.config.AUTOMATIC1111_CFG_SCALE = (
|
||||
float(form_data.automatic1111.AUTOMATIC1111_CFG_SCALE)
|
||||
if form_data.automatic1111.AUTOMATIC1111_CFG_SCALE
|
||||
else None
|
||||
)
|
||||
app.state.config.AUTOMATIC1111_SAMPLER = (
|
||||
request.app.state.config.AUTOMATIC1111_SAMPLER = (
|
||||
form_data.automatic1111.AUTOMATIC1111_SAMPLER
|
||||
if form_data.automatic1111.AUTOMATIC1111_SAMPLER
|
||||
else None
|
||||
)
|
||||
app.state.config.AUTOMATIC1111_SCHEDULER = (
|
||||
request.app.state.config.AUTOMATIC1111_SCHEDULER = (
|
||||
form_data.automatic1111.AUTOMATIC1111_SCHEDULER
|
||||
if form_data.automatic1111.AUTOMATIC1111_SCHEDULER
|
||||
else None
|
||||
)
|
||||
|
||||
app.state.config.COMFYUI_BASE_URL = form_data.comfyui.COMFYUI_BASE_URL.strip("/")
|
||||
app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
|
||||
app.state.config.COMFYUI_WORKFLOW_NODES = form_data.comfyui.COMFYUI_WORKFLOW_NODES
|
||||
request.app.state.config.COMFYUI_BASE_URL = (
|
||||
form_data.comfyui.COMFYUI_BASE_URL.strip("/")
|
||||
)
|
||||
request.app.state.config.COMFYUI_WORKFLOW = form_data.comfyui.COMFYUI_WORKFLOW
|
||||
request.app.state.config.COMFYUI_WORKFLOW_NODES = (
|
||||
form_data.comfyui.COMFYUI_WORKFLOW_NODES
|
||||
)
|
||||
|
||||
return {
|
||||
"enabled": app.state.config.ENABLED,
|
||||
"engine": app.state.config.ENGINE,
|
||||
"enabled": request.app.state.config.ENABLE_IMAGE_GENERATION,
|
||||
"engine": request.app.state.config.IMAGE_GENERATION_ENGINE,
|
||||
"openai": {
|
||||
"OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
|
||||
"OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
|
||||
"OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL,
|
||||
"OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY,
|
||||
},
|
||||
"automatic1111": {
|
||||
"AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
|
||||
"AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
|
||||
"AUTOMATIC1111_CFG_SCALE": app.state.config.AUTOMATIC1111_CFG_SCALE,
|
||||
"AUTOMATIC1111_SAMPLER": app.state.config.AUTOMATIC1111_SAMPLER,
|
||||
"AUTOMATIC1111_SCHEDULER": app.state.config.AUTOMATIC1111_SCHEDULER,
|
||||
"AUTOMATIC1111_BASE_URL": request.app.state.config.AUTOMATIC1111_BASE_URL,
|
||||
"AUTOMATIC1111_API_AUTH": request.app.state.config.AUTOMATIC1111_API_AUTH,
|
||||
"AUTOMATIC1111_CFG_SCALE": request.app.state.config.AUTOMATIC1111_CFG_SCALE,
|
||||
"AUTOMATIC1111_SAMPLER": request.app.state.config.AUTOMATIC1111_SAMPLER,
|
||||
"AUTOMATIC1111_SCHEDULER": request.app.state.config.AUTOMATIC1111_SCHEDULER,
|
||||
},
|
||||
"comfyui": {
|
||||
"COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
|
||||
"COMFYUI_WORKFLOW": app.state.config.COMFYUI_WORKFLOW,
|
||||
"COMFYUI_WORKFLOW_NODES": app.state.config.COMFYUI_WORKFLOW_NODES,
|
||||
"COMFYUI_BASE_URL": request.app.state.config.COMFYUI_BASE_URL,
|
||||
"COMFYUI_API_KEY": request.app.state.config.COMFYUI_API_KEY,
|
||||
"COMFYUI_WORKFLOW": request.app.state.config.COMFYUI_WORKFLOW,
|
||||
"COMFYUI_WORKFLOW_NODES": request.app.state.config.COMFYUI_WORKFLOW_NODES,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_automatic1111_api_auth():
|
||||
if app.state.config.AUTOMATIC1111_API_AUTH is None:
|
||||
def get_automatic1111_api_auth(request: Request):
|
||||
if request.app.state.config.AUTOMATIC1111_API_AUTH is None:
|
||||
return ""
|
||||
else:
|
||||
auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8")
|
||||
auth1111_byte_string = request.app.state.config.AUTOMATIC1111_API_AUTH.encode(
|
||||
"utf-8"
|
||||
)
|
||||
auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
|
||||
auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
|
||||
return f"Basic {auth1111_base64_encoded_string}"
|
||||
|
||||
|
||||
@app.get("/config/url/verify")
|
||||
async def verify_url(user=Depends(get_admin_user)):
|
||||
if app.state.config.ENGINE == "automatic1111":
|
||||
@router.get("/config/url/verify")
|
||||
async def verify_url(request: Request, user=Depends(get_admin_user)):
|
||||
if request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111":
|
||||
try:
|
||||
r = requests.get(
|
||||
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
||||
headers={"authorization": get_automatic1111_api_auth()},
|
||||
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
||||
headers={"authorization": get_automatic1111_api_auth(request)},
|
||||
)
|
||||
r.raise_for_status()
|
||||
return True
|
||||
except Exception:
|
||||
app.state.config.ENABLED = False
|
||||
request.app.state.config.ENABLE_IMAGE_GENERATION = False
|
||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
|
||||
elif app.state.config.ENGINE == "comfyui":
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
||||
try:
|
||||
r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
|
||||
r = requests.get(
|
||||
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info"
|
||||
)
|
||||
r.raise_for_status()
|
||||
return True
|
||||
except Exception:
|
||||
app.state.config.ENABLED = False
|
||||
request.app.state.config.ENABLE_IMAGE_GENERATION = False
|
||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def set_image_model(model: str):
|
||||
def set_image_model(request: Request, model: str):
|
||||
log.info(f"Setting image model to {model}")
|
||||
app.state.config.MODEL = model
|
||||
if app.state.config.ENGINE in ["", "automatic1111"]:
|
||||
api_auth = get_automatic1111_api_auth()
|
||||
request.app.state.config.IMAGE_GENERATION_MODEL = model
|
||||
if request.app.state.config.IMAGE_GENERATION_ENGINE in ["", "automatic1111"]:
|
||||
api_auth = get_automatic1111_api_auth(request)
|
||||
r = requests.get(
|
||||
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
||||
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
||||
headers={"authorization": api_auth},
|
||||
)
|
||||
options = r.json()
|
||||
if model != options["sd_model_checkpoint"]:
|
||||
options["sd_model_checkpoint"] = model
|
||||
r = requests.post(
|
||||
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
||||
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
||||
json=options,
|
||||
headers={"authorization": api_auth},
|
||||
)
|
||||
return app.state.config.MODEL
|
||||
return request.app.state.config.IMAGE_GENERATION_MODEL
|
||||
|
||||
|
||||
def get_image_model():
|
||||
if app.state.config.ENGINE == "openai":
|
||||
return app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
|
||||
elif app.state.config.ENGINE == "comfyui":
|
||||
return app.state.config.MODEL if app.state.config.MODEL else ""
|
||||
elif app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == "":
|
||||
def get_image_model(request):
|
||||
if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
|
||||
return (
|
||||
request.app.state.config.IMAGE_GENERATION_MODEL
|
||||
if request.app.state.config.IMAGE_GENERATION_MODEL
|
||||
else "dall-e-2"
|
||||
)
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
||||
return (
|
||||
request.app.state.config.IMAGE_GENERATION_MODEL
|
||||
if request.app.state.config.IMAGE_GENERATION_MODEL
|
||||
else ""
|
||||
)
|
||||
elif (
|
||||
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
|
||||
or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
|
||||
):
|
||||
try:
|
||||
r = requests.get(
|
||||
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
||||
headers={"authorization": get_automatic1111_api_auth()},
|
||||
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
|
||||
headers={"authorization": get_automatic1111_api_auth(request)},
|
||||
)
|
||||
options = r.json()
|
||||
return options["sd_model_checkpoint"]
|
||||
except Exception as e:
|
||||
app.state.config.ENABLED = False
|
||||
request.app.state.config.ENABLE_IMAGE_GENERATION = False
|
||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
|
||||
|
||||
|
||||
@@ -272,23 +251,25 @@ class ImageConfigForm(BaseModel):
|
||||
IMAGE_STEPS: int
|
||||
|
||||
|
||||
@app.get("/image/config")
|
||||
async def get_image_config(user=Depends(get_admin_user)):
|
||||
@router.get("/image/config")
|
||||
async def get_image_config(request: Request, user=Depends(get_admin_user)):
|
||||
return {
|
||||
"MODEL": app.state.config.MODEL,
|
||||
"IMAGE_SIZE": app.state.config.IMAGE_SIZE,
|
||||
"IMAGE_STEPS": app.state.config.IMAGE_STEPS,
|
||||
"MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
|
||||
"IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
|
||||
"IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/image/config/update")
|
||||
async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin_user)):
|
||||
@router.post("/image/config/update")
|
||||
async def update_image_config(
|
||||
request: Request, form_data: ImageConfigForm, user=Depends(get_admin_user)
|
||||
):
|
||||
|
||||
set_image_model(form_data.MODEL)
|
||||
set_image_model(request, form_data.MODEL)
|
||||
|
||||
pattern = r"^\d+x\d+$"
|
||||
if re.match(pattern, form_data.IMAGE_SIZE):
|
||||
app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
|
||||
request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@@ -296,7 +277,7 @@ async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin
|
||||
)
|
||||
|
||||
if form_data.IMAGE_STEPS >= 0:
|
||||
app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
|
||||
request.app.state.config.IMAGE_STEPS = form_data.IMAGE_STEPS
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@@ -304,29 +285,35 @@ async def update_image_config(form_data: ImageConfigForm, user=Depends(get_admin
|
||||
)
|
||||
|
||||
return {
|
||||
"MODEL": app.state.config.MODEL,
|
||||
"IMAGE_SIZE": app.state.config.IMAGE_SIZE,
|
||||
"IMAGE_STEPS": app.state.config.IMAGE_STEPS,
|
||||
"MODEL": request.app.state.config.IMAGE_GENERATION_MODEL,
|
||||
"IMAGE_SIZE": request.app.state.config.IMAGE_SIZE,
|
||||
"IMAGE_STEPS": request.app.state.config.IMAGE_STEPS,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/models")
|
||||
def get_models(user=Depends(get_verified_user)):
|
||||
@router.get("/models")
|
||||
def get_models(request: Request, user=Depends(get_verified_user)):
|
||||
try:
|
||||
if app.state.config.ENGINE == "openai":
|
||||
if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
|
||||
return [
|
||||
{"id": "dall-e-2", "name": "DALL·E 2"},
|
||||
{"id": "dall-e-3", "name": "DALL·E 3"},
|
||||
]
|
||||
elif app.state.config.ENGINE == "comfyui":
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
||||
# TODO - get models from comfyui
|
||||
r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
|
||||
headers = {
|
||||
"Authorization": f"Bearer {request.app.state.config.COMFYUI_API_KEY}"
|
||||
}
|
||||
r = requests.get(
|
||||
url=f"{request.app.state.config.COMFYUI_BASE_URL}/object_info",
|
||||
headers=headers,
|
||||
)
|
||||
info = r.json()
|
||||
|
||||
workflow = json.loads(app.state.config.COMFYUI_WORKFLOW)
|
||||
workflow = json.loads(request.app.state.config.COMFYUI_WORKFLOW)
|
||||
model_node_id = None
|
||||
|
||||
for node in app.state.config.COMFYUI_WORKFLOW_NODES:
|
||||
for node in request.app.state.config.COMFYUI_WORKFLOW_NODES:
|
||||
if node["type"] == "model":
|
||||
if node["node_ids"]:
|
||||
model_node_id = node["node_ids"][0]
|
||||
@@ -362,11 +349,12 @@ def get_models(user=Depends(get_verified_user)):
|
||||
)
|
||||
)
|
||||
elif (
|
||||
app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == ""
|
||||
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
|
||||
or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
|
||||
):
|
||||
r = requests.get(
|
||||
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
|
||||
headers={"authorization": get_automatic1111_api_auth()},
|
||||
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
|
||||
headers={"authorization": get_automatic1111_api_auth(request)},
|
||||
)
|
||||
models = r.json()
|
||||
return list(
|
||||
@@ -376,7 +364,7 @@ def get_models(user=Depends(get_verified_user)):
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
app.state.config.ENABLED = False
|
||||
request.app.state.config.ENABLE_IMAGE_GENERATION = False
|
||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
|
||||
|
||||
|
||||
@@ -448,18 +436,21 @@ def save_url_image(url):
|
||||
return None
|
||||
|
||||
|
||||
@app.post("/generations")
|
||||
@router.post("/generations")
|
||||
async def image_generations(
|
||||
request: Request,
|
||||
form_data: GenerateImageForm,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
|
||||
width, height = tuple(map(int, request.app.state.config.IMAGE_SIZE.split("x")))
|
||||
|
||||
r = None
|
||||
try:
|
||||
if app.state.config.ENGINE == "openai":
|
||||
if request.app.state.config.IMAGE_GENERATION_ENGINE == "openai":
|
||||
headers = {}
|
||||
headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
|
||||
headers["Authorization"] = (
|
||||
f"Bearer {request.app.state.config.IMAGES_OPENAI_API_KEY}"
|
||||
)
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS:
|
||||
@@ -470,14 +461,16 @@ async def image_generations(
|
||||
|
||||
data = {
|
||||
"model": (
|
||||
app.state.config.MODEL
|
||||
if app.state.config.MODEL != ""
|
||||
request.app.state.config.IMAGE_GENERATION_MODEL
|
||||
if request.app.state.config.IMAGE_GENERATION_MODEL != ""
|
||||
else "dall-e-2"
|
||||
),
|
||||
"prompt": form_data.prompt,
|
||||
"n": form_data.n,
|
||||
"size": (
|
||||
form_data.size if form_data.size else app.state.config.IMAGE_SIZE
|
||||
form_data.size
|
||||
if form_data.size
|
||||
else request.app.state.config.IMAGE_SIZE
|
||||
),
|
||||
"response_format": "b64_json",
|
||||
}
|
||||
@@ -485,7 +478,7 @@ async def image_generations(
|
||||
# Use asyncio.to_thread for the requests.post call
|
||||
r = await asyncio.to_thread(
|
||||
requests.post,
|
||||
url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
|
||||
url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/generations",
|
||||
json=data,
|
||||
headers=headers,
|
||||
)
|
||||
@@ -505,7 +498,7 @@ async def image_generations(
|
||||
|
||||
return images
|
||||
|
||||
elif app.state.config.ENGINE == "comfyui":
|
||||
elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui":
|
||||
data = {
|
||||
"prompt": form_data.prompt,
|
||||
"width": width,
|
||||
@@ -513,8 +506,8 @@ async def image_generations(
|
||||
"n": form_data.n,
|
||||
}
|
||||
|
||||
if app.state.config.IMAGE_STEPS is not None:
|
||||
data["steps"] = app.state.config.IMAGE_STEPS
|
||||
if request.app.state.config.IMAGE_STEPS is not None:
|
||||
data["steps"] = request.app.state.config.IMAGE_STEPS
|
||||
|
||||
if form_data.negative_prompt is not None:
|
||||
data["negative_prompt"] = form_data.negative_prompt
|
||||
@@ -523,18 +516,19 @@ async def image_generations(
|
||||
**{
|
||||
"workflow": ComfyUIWorkflow(
|
||||
**{
|
||||
"workflow": app.state.config.COMFYUI_WORKFLOW,
|
||||
"nodes": app.state.config.COMFYUI_WORKFLOW_NODES,
|
||||
"workflow": request.app.state.config.COMFYUI_WORKFLOW,
|
||||
"nodes": request.app.state.config.COMFYUI_WORKFLOW_NODES,
|
||||
}
|
||||
),
|
||||
**data,
|
||||
}
|
||||
)
|
||||
res = await comfyui_generate_image(
|
||||
app.state.config.MODEL,
|
||||
request.app.state.config.IMAGE_GENERATION_MODEL,
|
||||
form_data,
|
||||
user.id,
|
||||
app.state.config.COMFYUI_BASE_URL,
|
||||
request.app.state.config.COMFYUI_BASE_URL,
|
||||
request.app.state.config.COMFYUI_API_KEY,
|
||||
)
|
||||
log.debug(f"res: {res}")
|
||||
|
||||
@@ -551,7 +545,8 @@ async def image_generations(
|
||||
log.debug(f"images: {images}")
|
||||
return images
|
||||
elif (
|
||||
app.state.config.ENGINE == "automatic1111" or app.state.config.ENGINE == ""
|
||||
request.app.state.config.IMAGE_GENERATION_ENGINE == "automatic1111"
|
||||
or request.app.state.config.IMAGE_GENERATION_ENGINE == ""
|
||||
):
|
||||
if form_data.model:
|
||||
set_image_model(form_data.model)
|
||||
@@ -563,27 +558,27 @@ async def image_generations(
|
||||
"height": height,
|
||||
}
|
||||
|
||||
if app.state.config.IMAGE_STEPS is not None:
|
||||
data["steps"] = app.state.config.IMAGE_STEPS
|
||||
if request.app.state.config.IMAGE_STEPS is not None:
|
||||
data["steps"] = request.app.state.config.IMAGE_STEPS
|
||||
|
||||
if form_data.negative_prompt is not None:
|
||||
data["negative_prompt"] = form_data.negative_prompt
|
||||
|
||||
if app.state.config.AUTOMATIC1111_CFG_SCALE:
|
||||
data["cfg_scale"] = app.state.config.AUTOMATIC1111_CFG_SCALE
|
||||
if request.app.state.config.AUTOMATIC1111_CFG_SCALE:
|
||||
data["cfg_scale"] = request.app.state.config.AUTOMATIC1111_CFG_SCALE
|
||||
|
||||
if app.state.config.AUTOMATIC1111_SAMPLER:
|
||||
data["sampler_name"] = app.state.config.AUTOMATIC1111_SAMPLER
|
||||
if request.app.state.config.AUTOMATIC1111_SAMPLER:
|
||||
data["sampler_name"] = request.app.state.config.AUTOMATIC1111_SAMPLER
|
||||
|
||||
if app.state.config.AUTOMATIC1111_SCHEDULER:
|
||||
data["scheduler"] = app.state.config.AUTOMATIC1111_SCHEDULER
|
||||
if request.app.state.config.AUTOMATIC1111_SCHEDULER:
|
||||
data["scheduler"] = request.app.state.config.AUTOMATIC1111_SCHEDULER
|
||||
|
||||
# Use asyncio.to_thread for the requests.post call
|
||||
r = await asyncio.to_thread(
|
||||
requests.post,
|
||||
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
|
||||
url=f"{request.app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
|
||||
json=data,
|
||||
headers={"authorization": get_automatic1111_api_auth()},
|
||||
headers={"authorization": get_automatic1111_api_auth(request)},
|
||||
)
|
||||
|
||||
res = r.json()
|
||||
@@ -1,22 +1,26 @@
|
||||
import json
|
||||
from typing import Optional, Union
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
import logging
|
||||
|
||||
from open_webui.apps.webui.models.knowledge import (
|
||||
from open_webui.models.knowledge import (
|
||||
Knowledges,
|
||||
KnowledgeForm,
|
||||
KnowledgeResponse,
|
||||
KnowledgeUserResponse,
|
||||
)
|
||||
from open_webui.apps.webui.models.files import Files, FileModel
|
||||
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.apps.retrieval.main import process_file, ProcessFileForm
|
||||
from open_webui.models.files import Files, FileModel
|
||||
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.routers.retrieval import (
|
||||
process_file,
|
||||
ProcessFileForm,
|
||||
process_files_batch,
|
||||
BatchProcessFilesForm,
|
||||
)
|
||||
|
||||
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.utils.utils import get_admin_user, get_verified_user
|
||||
from open_webui.utils.auth import get_verified_user
|
||||
from open_webui.utils.access_control import has_access, has_permission
|
||||
|
||||
|
||||
@@ -242,6 +246,7 @@ class KnowledgeFileIdForm(BaseModel):
|
||||
|
||||
@router.post("/{id}/file/add", response_model=Optional[KnowledgeFilesResponse])
|
||||
def add_file_to_knowledge_by_id(
|
||||
request: Request,
|
||||
id: str,
|
||||
form_data: KnowledgeFileIdForm,
|
||||
user=Depends(get_verified_user),
|
||||
@@ -274,7 +279,9 @@ def add_file_to_knowledge_by_id(
|
||||
|
||||
# Add content to the vector database
|
||||
try:
|
||||
process_file(ProcessFileForm(file_id=form_data.file_id, collection_name=id))
|
||||
process_file(
|
||||
request, ProcessFileForm(file_id=form_data.file_id, collection_name=id)
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug(e)
|
||||
raise HTTPException(
|
||||
@@ -318,6 +325,7 @@ def add_file_to_knowledge_by_id(
|
||||
|
||||
@router.post("/{id}/file/update", response_model=Optional[KnowledgeFilesResponse])
|
||||
def update_file_from_knowledge_by_id(
|
||||
request: Request,
|
||||
id: str,
|
||||
form_data: KnowledgeFileIdForm,
|
||||
user=Depends(get_verified_user),
|
||||
@@ -349,7 +357,9 @@ def update_file_from_knowledge_by_id(
|
||||
|
||||
# Add content to the vector database
|
||||
try:
|
||||
process_file(ProcessFileForm(file_id=form_data.file_id, collection_name=id))
|
||||
process_file(
|
||||
request, ProcessFileForm(file_id=form_data.file_id, collection_name=id)
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
@@ -508,3 +518,83 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
|
||||
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data={"file_ids": []})
|
||||
|
||||
return knowledge
|
||||
|
||||
|
||||
############################
|
||||
# AddFilesToKnowledge
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/{id}/files/batch/add", response_model=Optional[KnowledgeFilesResponse])
|
||||
def add_files_to_knowledge_batch(
|
||||
id: str,
|
||||
form_data: list[KnowledgeFileIdForm],
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
"""
|
||||
Add multiple files to a knowledge base
|
||||
"""
|
||||
knowledge = Knowledges.get_knowledge_by_id(id=id)
|
||||
if not knowledge:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if knowledge.user_id != user.id and user.role != "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
# Get files content
|
||||
print(f"files/batch/add - {len(form_data)} files")
|
||||
files: List[FileModel] = []
|
||||
for form in form_data:
|
||||
file = Files.get_file_by_id(form.file_id)
|
||||
if not file:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"File {form.file_id} not found",
|
||||
)
|
||||
files.append(file)
|
||||
|
||||
# Process files
|
||||
try:
|
||||
result = process_files_batch(
|
||||
BatchProcessFilesForm(files=files, collection_name=id)
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"add_files_to_knowledge_batch: Exception occurred: {e}", exc_info=True
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
# Add successful files to knowledge base
|
||||
data = knowledge.data or {}
|
||||
existing_file_ids = data.get("file_ids", [])
|
||||
|
||||
# Only add files that were successfully processed
|
||||
successful_file_ids = [r.file_id for r in result.results if r.status == "completed"]
|
||||
for file_id in successful_file_ids:
|
||||
if file_id not in existing_file_ids:
|
||||
existing_file_ids.append(file_id)
|
||||
|
||||
data["file_ids"] = existing_file_ids
|
||||
knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
|
||||
|
||||
# If there were any errors, include them in the response
|
||||
if result.errors:
|
||||
error_details = [f"{err.file_id}: {err.error}" for err in result.errors]
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(),
|
||||
files=Files.get_files_by_ids(existing_file_ids),
|
||||
warnings={
|
||||
"message": "Some files failed to process",
|
||||
"errors": error_details,
|
||||
},
|
||||
)
|
||||
|
||||
return KnowledgeFilesResponse(
|
||||
**knowledge.model_dump(), files=Files.get_files_by_ids(existing_file_ids)
|
||||
)
|
||||
@@ -3,9 +3,9 @@ from pydantic import BaseModel
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.apps.webui.models.memories import Memories, MemoryModel
|
||||
from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.utils.utils import get_verified_user
|
||||
from open_webui.models.memories import Memories, MemoryModel
|
||||
from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
|
||||
from open_webui.utils.auth import get_verified_user
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user