mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-11 23:08:19 -05:00
Compare commits
1224 Commits
github-pag
...
SHARK-1.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bcdcbc63cc | ||
|
|
9fc0b5b58e | ||
|
|
85ae2ee57e | ||
|
|
721757013f | ||
|
|
c1efb9cfe0 | ||
|
|
18666dffc0 | ||
|
|
06bfcbc7ad | ||
|
|
f2966673e9 | ||
|
|
bb431811fb | ||
|
|
97e112091f | ||
|
|
d35288e893 | ||
|
|
e80bc9f857 | ||
|
|
773e6ebebf | ||
|
|
dda7e8a163 | ||
|
|
7fdd1952ae | ||
|
|
0a6f6fad86 | ||
|
|
6853a33728 | ||
|
|
3887d83f5d | ||
|
|
8d9b5b3afa | ||
|
|
16c03e4b44 | ||
|
|
17dab8334d | ||
|
|
f692a012e1 | ||
|
|
3cc643b2de | ||
|
|
bf70e80d20 | ||
|
|
7159698496 | ||
|
|
7e12d1782a | ||
|
|
bb5f133e1c | ||
|
|
3af0c6c658 | ||
|
|
3322b7264f | ||
|
|
eeb7bdd143 | ||
|
|
2d6f48821d | ||
|
|
c74b55f24e | ||
|
|
1a723645fb | ||
|
|
dfdd3b1f78 | ||
|
|
6384780d16 | ||
|
|
db0c53ae59 | ||
|
|
ce9ce3a7c8 | ||
|
|
d72da3801f | ||
|
|
9c50edc664 | ||
|
|
a1b7110550 | ||
|
|
ff15fd74f6 | ||
|
|
552b2c3ee3 | ||
|
|
795fc33001 | ||
|
|
2910841fe6 | ||
|
|
396a054856 | ||
|
|
5c66948d4f | ||
|
|
ed3dda94c0 | ||
|
|
d31d28b082 | ||
|
|
78c607e1d3 | ||
|
|
666e601dd9 | ||
|
|
ca58908e5b | ||
|
|
1f5b39f56e | ||
|
|
2da31c4109 | ||
|
|
da50a16242 | ||
|
|
ce38d49f05 | ||
|
|
2f780f0d38 | ||
|
|
d051c3a4a7 | ||
|
|
1b11c82c9d | ||
|
|
80a33d427f | ||
|
|
4125a26294 | ||
|
|
905d0103ff | ||
|
|
192b3b2c61 | ||
|
|
8f9adc4a2a | ||
|
|
70817bb50a | ||
|
|
dd37c26d36 | ||
|
|
a708879c6c | ||
|
|
bb1b49eb6f | ||
|
|
f6d41affd9 | ||
|
|
c2163488d8 | ||
|
|
54bff4611d | ||
|
|
11510d5111 | ||
|
|
32cab73a29 | ||
|
|
392bade0bf | ||
|
|
91df5f0613 | ||
|
|
df20cf9c8a | ||
|
|
c4a908c3ea | ||
|
|
6285430d8a | ||
|
|
51afe19e20 | ||
|
|
31005bcf73 | ||
|
|
f41ad87ef6 | ||
|
|
d811524a00 | ||
|
|
51e1bd1c5d | ||
|
|
db89b1bdc1 | ||
|
|
2754e2e257 | ||
|
|
ab0e870c43 | ||
|
|
fb30e8c226 | ||
|
|
a07d542400 | ||
|
|
ad55cb696f | ||
|
|
488a172292 | ||
|
|
500c4f2306 | ||
|
|
92b694db4d | ||
|
|
322874f7f9 | ||
|
|
5001db3415 | ||
|
|
71846344a2 | ||
|
|
72e27c96fc | ||
|
|
7963abb8ec | ||
|
|
98244232dd | ||
|
|
679a452139 | ||
|
|
72c0a8abc8 | ||
|
|
ea920f2955 | ||
|
|
486202377a | ||
|
|
0c38c33d0a | ||
|
|
841773fa32 | ||
|
|
0361db46f9 | ||
|
|
a012433ffd | ||
|
|
5061193da3 | ||
|
|
bff48924be | ||
|
|
825b36cbdd | ||
|
|
134441957d | ||
|
|
7cd14fdc47 | ||
|
|
e6cb5cef57 | ||
|
|
66abee8e5b | ||
|
|
4797bb89f5 | ||
|
|
205e57683a | ||
|
|
2866d665ee | ||
|
|
71d25ec5d8 | ||
|
|
202ffff67b | ||
|
|
0b77059628 | ||
|
|
a208302bb9 | ||
|
|
b83d32fafe | ||
|
|
0a618e1863 | ||
|
|
a731eb6ed4 | ||
|
|
2004d16945 | ||
|
|
6e409bfb77 | ||
|
|
77727d149c | ||
|
|
66f6e79d68 | ||
|
|
3b825579a7 | ||
|
|
9f0a421764 | ||
|
|
c28682110c | ||
|
|
caf6cc5d8f | ||
|
|
8614a18474 | ||
|
|
86c1c0c215 | ||
|
|
8bb364bcb8 | ||
|
|
7abddd01ec | ||
|
|
2a451fa0c7 | ||
|
|
9c4610b9da | ||
|
|
a38cc9d216 | ||
|
|
1c382449ec | ||
|
|
7cc9b3f8e8 | ||
|
|
e54517e967 | ||
|
|
326327a799 | ||
|
|
785b65c7b0 | ||
|
|
0d16c81687 | ||
|
|
8dd7850c69 | ||
|
|
e930ba85b4 | ||
|
|
cd732e7a38 | ||
|
|
8e0f8b3227 | ||
|
|
b8210ef796 | ||
|
|
94594542a9 | ||
|
|
82f833e87d | ||
|
|
c9d6870105 | ||
|
|
4fec03a6cc | ||
|
|
9a27f51378 | ||
|
|
ad1a0f35ff | ||
|
|
6773278ec2 | ||
|
|
9a0efffcca | ||
|
|
61c6f153d9 | ||
|
|
effd42e8f5 | ||
|
|
b5fbb1a8a0 | ||
|
|
ded74d09cd | ||
|
|
79267931c1 | ||
|
|
9eceba69b7 | ||
|
|
ca609afb6a | ||
|
|
11bdce9790 | ||
|
|
684943a4a6 | ||
|
|
b817bb8455 | ||
|
|
780f520f02 | ||
|
|
c61b6f8d65 | ||
|
|
c854208d49 | ||
|
|
c5dcfc1f13 | ||
|
|
bde63ee8ae | ||
|
|
9681d494eb | ||
|
|
ede6bf83e2 | ||
|
|
2c2693fb7d | ||
|
|
1d31b2b2c6 | ||
|
|
d2f64eefa3 | ||
|
|
87ae14b6ff | ||
|
|
1ccafa1fc1 | ||
|
|
4c3d8a0a7f | ||
|
|
3601dc7c3b | ||
|
|
671881cf87 | ||
|
|
4e9be6be59 | ||
|
|
9c8cbaf498 | ||
|
|
9e348a114e | ||
|
|
51f90a4d56 | ||
|
|
310d5d0a49 | ||
|
|
9697981004 | ||
|
|
450c231171 | ||
|
|
07f6f4a2f7 | ||
|
|
610813c72f | ||
|
|
8e3860c9e6 | ||
|
|
e37d6720eb | ||
|
|
16160d9a7d | ||
|
|
79075a1a07 | ||
|
|
db990826d3 | ||
|
|
7ee3e4ba5d | ||
|
|
05889a8fe1 | ||
|
|
b87efe7686 | ||
|
|
82b462de3a | ||
|
|
d8f0f7bade | ||
|
|
79bd0b84a1 | ||
|
|
8738571d1e | ||
|
|
a4c354ce54 | ||
|
|
cc53efa89f | ||
|
|
9ae8bc921e | ||
|
|
32eb78f0f9 | ||
|
|
cb509343d9 | ||
|
|
6da391c9b1 | ||
|
|
9dee7ae652 | ||
|
|
343dfd901c | ||
|
|
57260b9c37 | ||
|
|
18e7d2d061 | ||
|
|
51a1009796 | ||
|
|
045c3c3852 | ||
|
|
0139dd58d9 | ||
|
|
c96571855a | ||
|
|
4f61d69d86 | ||
|
|
531d447768 | ||
|
|
16f46f8de9 | ||
|
|
c4723f469f | ||
|
|
d804f45a61 | ||
|
|
d22177f936 | ||
|
|
75e68f02f4 | ||
|
|
4dc9c59611 | ||
|
|
18801dcabc | ||
|
|
3c577f7168 | ||
|
|
f5e4fa6ffe | ||
|
|
48de445325 | ||
|
|
8e90f1b81a | ||
|
|
e8c1203be2 | ||
|
|
e4d7abb519 | ||
|
|
96185c9dc1 | ||
|
|
bc22a81925 | ||
|
|
5203679f1f | ||
|
|
bf073f8f37 | ||
|
|
cec6eda6b4 | ||
|
|
9e37e03741 | ||
|
|
9b8c4401b5 | ||
|
|
a9f95a218b | ||
|
|
872bd72d0b | ||
|
|
fd1c4db5d0 | ||
|
|
759664bb48 | ||
|
|
14fd0cdd87 | ||
|
|
a57eccc997 | ||
|
|
a686d7d89f | ||
|
|
ed484b8253 | ||
|
|
7fe57ebaaf | ||
|
|
c287fd2be8 | ||
|
|
51ec1a1360 | ||
|
|
bd30044c0b | ||
|
|
c9de2729b2 | ||
|
|
a5b13fcc2f | ||
|
|
6bb329c4af | ||
|
|
98fb6c52df | ||
|
|
206c1b70f4 | ||
|
|
cdb037ee54 | ||
|
|
ce2fd84538 | ||
|
|
4684afad34 | ||
|
|
8d65456b7a | ||
|
|
d6759a852b | ||
|
|
ab57af43c1 | ||
|
|
4d5c55dd9f | ||
|
|
07399ad65c | ||
|
|
776a9c2293 | ||
|
|
9d399eb988 | ||
|
|
927b662aa7 | ||
|
|
47f8a79c75 | ||
|
|
289f983f41 | ||
|
|
453e46562f | ||
|
|
5497af1f56 | ||
|
|
f3cb63fc9c | ||
|
|
d7092aafaa | ||
|
|
a415f3f70e | ||
|
|
c292e5c9d7 | ||
|
|
03c4d9e171 | ||
|
|
3662224c04 | ||
|
|
db3f222933 | ||
|
|
68b3021325 | ||
|
|
336469154d | ||
|
|
41e5088908 | ||
|
|
0a8f7673f4 | ||
|
|
c482ab78da | ||
|
|
4be80f7158 | ||
|
|
536aba1424 | ||
|
|
dd738a0e02 | ||
|
|
8927cb0a2c | ||
|
|
8c317e4809 | ||
|
|
b0136593df | ||
|
|
11f62d7fac | ||
|
|
14559dd620 | ||
|
|
e503a3e8d6 | ||
|
|
22a4254adf | ||
|
|
ab01f0f048 | ||
|
|
c471d17cca | ||
|
|
a2a436eb0c | ||
|
|
1adb51b29d | ||
|
|
aab2233e25 | ||
|
|
e20cd71314 | ||
|
|
5ec91143f5 | ||
|
|
7cf19230e2 | ||
|
|
1bcf6b2c5b | ||
|
|
91027f8719 | ||
|
|
a909fc2e78 | ||
|
|
247f69cf9d | ||
|
|
3b8f7cc231 | ||
|
|
6e8dbf72bd | ||
|
|
38e5b62d80 | ||
|
|
1c7eecc981 | ||
|
|
be417f0bf4 | ||
|
|
a517e217b0 | ||
|
|
9fcae4f808 | ||
|
|
788d469c5b | ||
|
|
8a59f7cc27 | ||
|
|
1c2ec3c7a2 | ||
|
|
af0f715e20 | ||
|
|
47ec7275e6 | ||
|
|
3a24cff901 | ||
|
|
1f72907886 | ||
|
|
06c8aabd01 | ||
|
|
55a12cc0c4 | ||
|
|
7dcbbde523 | ||
|
|
1b62dc4529 | ||
|
|
c5a47887f4 | ||
|
|
c72d0eaf87 | ||
|
|
c41f58042a | ||
|
|
043e5a5c7a | ||
|
|
a1b1ce935c | ||
|
|
bc6fee1a0c | ||
|
|
91ab594744 | ||
|
|
4015793f84 | ||
|
|
d63ce76dd8 | ||
|
|
1c32915570 | ||
|
|
6d286c0609 | ||
|
|
7392b22731 | ||
|
|
534de05791 | ||
|
|
5779e8c039 | ||
|
|
d496053590 | ||
|
|
6274a813c9 | ||
|
|
1d6a1f9f8a | ||
|
|
75672c0e28 | ||
|
|
74a7202173 | ||
|
|
27a08735db | ||
|
|
eaa49cce17 | ||
|
|
10657d6fb1 | ||
|
|
e3ab844cd1 | ||
|
|
5ce6001b41 | ||
|
|
501d0ca52e | ||
|
|
b444528715 | ||
|
|
6e6c90f62b | ||
|
|
8cdb38496e | ||
|
|
726d73d6ba | ||
|
|
4d55e51d46 | ||
|
|
6ef78ee7ba | ||
|
|
4002da7161 | ||
|
|
ecb5e8e5d8 | ||
|
|
28e0919321 | ||
|
|
28f4d44a6b | ||
|
|
97f7e79391 | ||
|
|
44a8f2f8db | ||
|
|
8822b9acd7 | ||
|
|
0ca3b9fce3 | ||
|
|
045f2bb147 | ||
|
|
a811b867b9 | ||
|
|
cdd505e2dd | ||
|
|
1b0f39107c | ||
|
|
b9b8955f74 | ||
|
|
6f7a85eee3 | ||
|
|
18c8e9e51e | ||
|
|
a202bb466a | ||
|
|
07c1e1d712 | ||
|
|
18daec78c8 | ||
|
|
1a8e2024d6 | ||
|
|
d61b6641fb | ||
|
|
88cc2423cc | ||
|
|
ccf944c1bd | ||
|
|
0def74f520 | ||
|
|
3fb72e192e | ||
|
|
855435ee24 | ||
|
|
6f9f868fc0 | ||
|
|
fb865f1b99 | ||
|
|
3e5c50f07b | ||
|
|
a544f30a8f | ||
|
|
1fe56d460a | ||
|
|
fafd713141 | ||
|
|
015d0132c3 | ||
|
|
20ddd96ef7 | ||
|
|
ee33cfd2d1 | ||
|
|
a3cba21d5b | ||
|
|
a7b6ec4095 | ||
|
|
d80b087d95 | ||
|
|
297a209608 | ||
|
|
b204113563 | ||
|
|
f60ab1f4fa | ||
|
|
b203779462 | ||
|
|
38570a9bbb | ||
|
|
a5c882f296 | ||
|
|
eb6d11cfed | ||
|
|
46184a81ac | ||
|
|
149165a2f0 | ||
|
|
bec82a665f | ||
|
|
9551490341 | ||
|
|
49b3ecdbca | ||
|
|
f53e3594c3 | ||
|
|
5562d1dfda | ||
|
|
c7b0c2961e | ||
|
|
44273b0791 | ||
|
|
0a4c8fcb3e | ||
|
|
2fec3c8169 | ||
|
|
5e7d5930dd | ||
|
|
b6dbd20250 | ||
|
|
34f1295349 | ||
|
|
1980d7b2c3 | ||
|
|
2cfacc5051 | ||
|
|
436f58ddc4 | ||
|
|
6b29bd17c8 | ||
|
|
2c3485ca3e | ||
|
|
f206ecc635 | ||
|
|
a187e05ae6 | ||
|
|
8c21960486 | ||
|
|
be62fce676 | ||
|
|
f23b778a6c | ||
|
|
436edf900d | ||
|
|
ed58c2553f | ||
|
|
f2ca58e844 | ||
|
|
1dbcc736eb | ||
|
|
a83808ddc5 | ||
|
|
a07fe80530 | ||
|
|
d0ba3ef8fa | ||
|
|
8400529c2c | ||
|
|
7eaee9c242 | ||
|
|
8230eebce5 | ||
|
|
6296ea4be9 | ||
|
|
4151ec3a8f | ||
|
|
a2467e8d43 | ||
|
|
e677178bcc | ||
|
|
7ef1bea953 | ||
|
|
ad89bb1413 | ||
|
|
218ed78c40 | ||
|
|
6046f36ab6 | ||
|
|
5915bf7de3 | ||
|
|
f0a4e59758 | ||
|
|
1ddef26af5 | ||
|
|
ba8eddb12f | ||
|
|
47b346d428 | ||
|
|
1b4f4f5f4d | ||
|
|
73cd7e8320 | ||
|
|
19c0ae3702 | ||
|
|
54e57f7771 | ||
|
|
6d64b8e273 | ||
|
|
a8ea0326f5 | ||
|
|
58e9194553 | ||
|
|
eb360e255d | ||
|
|
a6f88d7f72 | ||
|
|
8e571d165f | ||
|
|
3cddd01b10 | ||
|
|
64c2b2d96b | ||
|
|
f5ce121988 | ||
|
|
991f144598 | ||
|
|
09bea17e59 | ||
|
|
aefcf80b48 | ||
|
|
512235892e | ||
|
|
6602a2f5ba | ||
|
|
20114deea0 | ||
|
|
9acf519078 | ||
|
|
bdf37b5311 | ||
|
|
8ee2ac89f8 | ||
|
|
60cb48be2e | ||
|
|
86a215b063 | ||
|
|
d6e3a9a236 | ||
|
|
a0097a1ead | ||
|
|
a9bae00606 | ||
|
|
4731c1a835 | ||
|
|
4c07e47e8c | ||
|
|
e0cc2871bb | ||
|
|
649f39408b | ||
|
|
c142297d73 | ||
|
|
9e07360b00 | ||
|
|
7b74c86e42 | ||
|
|
fa833f8366 | ||
|
|
fcb059aa38 | ||
|
|
517c670f82 | ||
|
|
59df14f18b | ||
|
|
6c95ac0f37 | ||
|
|
7a4a51ae73 | ||
|
|
d816cc015e | ||
|
|
54ce3d48ca | ||
|
|
0e4a8ca240 | ||
|
|
6ca1298675 | ||
|
|
bbef7a6464 | ||
|
|
cdf2d61d53 | ||
|
|
6c14847d1f | ||
|
|
68ecdd2a73 | ||
|
|
3f4d444d18 | ||
|
|
e473d0375b | ||
|
|
e38d96850f | ||
|
|
fed63dfd4b | ||
|
|
eba4d06405 | ||
|
|
4cfba153d2 | ||
|
|
307c05f38d | ||
|
|
696df349cb | ||
|
|
cb54cb1348 | ||
|
|
9bdb86637d | ||
|
|
fb6f26517f | ||
|
|
aa8ada9da9 | ||
|
|
1db906a373 | ||
|
|
9d1d1617d8 | ||
|
|
7112789cb8 | ||
|
|
d6b8be2849 | ||
|
|
822171277c | ||
|
|
a5ae9d9f02 | ||
|
|
09e3f63d5b | ||
|
|
d60a5a9396 | ||
|
|
90df0ee365 | ||
|
|
133c1bcadd | ||
|
|
caadbe14e9 | ||
|
|
5f5823ccd9 | ||
|
|
d2f7e03b7e | ||
|
|
0b01bbe479 | ||
|
|
25c5fc44ae | ||
|
|
7330729c92 | ||
|
|
ce16cd5431 | ||
|
|
598dc5f79d | ||
|
|
1f8e332cbe | ||
|
|
17b9632659 | ||
|
|
bda92a54ab | ||
|
|
747ed383b1 | ||
|
|
1afe07c296 | ||
|
|
b70919b38d | ||
|
|
4e513d647f | ||
|
|
94cd2a0fed | ||
|
|
606029c01c | ||
|
|
1aa85222e9 | ||
|
|
1b3f468c04 | ||
|
|
35de7e27fa | ||
|
|
467f900759 | ||
|
|
0bd9d582c7 | ||
|
|
428cfe8dae | ||
|
|
f17915bedc | ||
|
|
1b49b5149a | ||
|
|
3002793301 | ||
|
|
d25ef5529f | ||
|
|
308856a947 | ||
|
|
151b4e142f | ||
|
|
e5a69a7c36 | ||
|
|
450b6cafc4 | ||
|
|
237d26baa2 | ||
|
|
67d6ee1104 | ||
|
|
98b069488e | ||
|
|
e0f227643a | ||
|
|
a0af3bb0cb | ||
|
|
2cd61a5b96 | ||
|
|
f49d41a807 | ||
|
|
2191fc8952 | ||
|
|
aea7796e60 | ||
|
|
a376619f1e | ||
|
|
02d52bb626 | ||
|
|
3b63645f79 | ||
|
|
d6f740b998 | ||
|
|
594c6b8ea2 | ||
|
|
96b1560da5 | ||
|
|
0ef6a0e234 | ||
|
|
641d535f44 | ||
|
|
5bb7846227 | ||
|
|
8f84258fb8 | ||
|
|
7619e76bbd | ||
|
|
9267eadbfa | ||
|
|
431132b8ee | ||
|
|
fb35e13e7a | ||
|
|
17a67897d1 | ||
|
|
da449b73aa | ||
|
|
0b0526699a | ||
|
|
4fac46f7bb | ||
|
|
49925950f1 | ||
|
|
807947c0c8 | ||
|
|
593428bda4 | ||
|
|
cede9b4fec | ||
|
|
c2360303f0 | ||
|
|
420366c1b8 | ||
|
|
d31bae488c | ||
|
|
c23fcf3748 | ||
|
|
7dbbb1726a | ||
|
|
8b8cc7fd33 | ||
|
|
e3c96a2b9d | ||
|
|
5e3f50647d | ||
|
|
7899e1803a | ||
|
|
d105246b9c | ||
|
|
90c958bca2 | ||
|
|
f99903e023 | ||
|
|
c6f44ef1b3 | ||
|
|
8dcd4d5aeb | ||
|
|
d319f4684e | ||
|
|
54d7b6d83e | ||
|
|
4a622532e5 | ||
|
|
650b2ada58 | ||
|
|
f87f8949f3 | ||
|
|
7dc9bf8148 | ||
|
|
ba48ff8d25 | ||
|
|
638840925c | ||
|
|
b661656c03 | ||
|
|
0225434389 | ||
|
|
7ffe20b1c2 | ||
|
|
d8f0c4655d | ||
|
|
7e8d3ec0df | ||
|
|
9c08eec565 | ||
|
|
2d2c523ac5 | ||
|
|
f17b3128c0 | ||
|
|
7c7e630099 | ||
|
|
2dd1491ec1 | ||
|
|
236357fb61 | ||
|
|
7bc38719de | ||
|
|
bdbe992769 | ||
|
|
e6b925e012 | ||
|
|
771120b76c | ||
|
|
a8ce7680db | ||
|
|
b6dcf2401b | ||
|
|
62b5a9fd49 | ||
|
|
2f133e9d5c | ||
|
|
f898a1d332 | ||
|
|
b94266d2b9 | ||
|
|
1b08242aaa | ||
|
|
691030fbab | ||
|
|
16ad7d57a3 | ||
|
|
c561ebf43c | ||
|
|
97fdff7f19 | ||
|
|
ce6d82eab2 | ||
|
|
b8f4b18951 | ||
|
|
b23d3aa584 | ||
|
|
495670d9b6 | ||
|
|
815e23a0b8 | ||
|
|
783538fe11 | ||
|
|
996c645f6a | ||
|
|
1f7d249a62 | ||
|
|
7f6c9a2dc2 | ||
|
|
93891984f3 | ||
|
|
cc0ef54e0e | ||
|
|
812152485d | ||
|
|
0816fb403a | ||
|
|
4f171772be | ||
|
|
a52331d4aa | ||
|
|
ad821a1fc8 | ||
|
|
116b128802 | ||
|
|
b118f183d1 | ||
|
|
911dff16f1 | ||
|
|
de59a66ae4 | ||
|
|
23f1468cc6 | ||
|
|
080350d311 | ||
|
|
7f3f92b9d5 | ||
|
|
be3cdec290 | ||
|
|
f09574538c | ||
|
|
b1113ab551 | ||
|
|
ef756389e3 | ||
|
|
cb17d017df | ||
|
|
798f231792 | ||
|
|
7136890da3 | ||
|
|
d567192fd3 | ||
|
|
dcc4025c78 | ||
|
|
c6c8ec36a1 | ||
|
|
1344c0659a | ||
|
|
973f6d20f4 | ||
|
|
8b5c9c51e7 | ||
|
|
bae208bcc4 | ||
|
|
b6c14ad468 | ||
|
|
0064cc2a6e | ||
|
|
0a0567e944 | ||
|
|
694b1d43a8 | ||
|
|
e7eb116bd2 | ||
|
|
596499a08c | ||
|
|
2a2e460df2 | ||
|
|
a9039b35ed | ||
|
|
a01154a507 | ||
|
|
1d9204282d | ||
|
|
5ff40a0d2d | ||
|
|
fab6d2e4e0 | ||
|
|
abab59c25f | ||
|
|
c25840b585 | ||
|
|
1b3f9125bb | ||
|
|
b5d9f5ba49 | ||
|
|
1c22aa9c8f | ||
|
|
e1d7fb879c | ||
|
|
e912c42bf0 | ||
|
|
e6841acf36 | ||
|
|
bc4459b6f4 | ||
|
|
9b544491e0 | ||
|
|
9c5415b598 | ||
|
|
040dbc317f | ||
|
|
65775046d8 | ||
|
|
b18bc36127 | ||
|
|
f01c526efd | ||
|
|
16168ab6b3 | ||
|
|
4233218629 | ||
|
|
b63fb36dc0 | ||
|
|
4e92304b89 | ||
|
|
2ae047f1a8 | ||
|
|
6d2a485264 | ||
|
|
4f045db024 | ||
|
|
5b33597b6d | ||
|
|
962470f610 | ||
|
|
ba8c116380 | ||
|
|
ad7330eae4 | ||
|
|
cf126e4839 | ||
|
|
c96d25c3e2 | ||
|
|
006aa0dae2 | ||
|
|
5b204bee86 | ||
|
|
d98b2afbe9 | ||
|
|
681332ef32 | ||
|
|
c3a4fdcbfc | ||
|
|
aac5de5b02 | ||
|
|
13a255afad | ||
|
|
3bffda52f9 | ||
|
|
d4e62ce557 | ||
|
|
9738483b18 | ||
|
|
143492fe94 | ||
|
|
ecc5c662c4 | ||
|
|
d973ba191d | ||
|
|
0198b183a2 | ||
|
|
0d44a3527b | ||
|
|
2147b6a397 | ||
|
|
6b5b4ba27b | ||
|
|
67005bf57c | ||
|
|
0430c741c6 | ||
|
|
1ce02e365d | ||
|
|
eae862adc2 | ||
|
|
dffa89524a | ||
|
|
2af1102441 | ||
|
|
c4b472842a | ||
|
|
750a7d806f | ||
|
|
bc7333f1e5 | ||
|
|
55ae50f991 | ||
|
|
a590c331ef | ||
|
|
8c241b06cb | ||
|
|
9c072c8068 | ||
|
|
ebd8b5122a | ||
|
|
055e484a40 | ||
|
|
912c4a1d12 | ||
|
|
c203b65bf1 | ||
|
|
307f0334ee | ||
|
|
5167df08b9 | ||
|
|
dd2e482214 | ||
|
|
87fd13d8eb | ||
|
|
dd423bc6de | ||
|
|
899cb9cc1f | ||
|
|
0464c7e558 | ||
|
|
f64e1fb926 | ||
|
|
ef7d31293d | ||
|
|
6d54eb68dc | ||
|
|
30eb10c990 | ||
|
|
591bbcd058 | ||
|
|
99aa77d036 | ||
|
|
9c13f1e635 | ||
|
|
24af983cfb | ||
|
|
67842a7525 | ||
|
|
3159a6f3e1 | ||
|
|
b2f3c96835 | ||
|
|
6582475955 | ||
|
|
41ee65b377 | ||
|
|
83fe477066 | ||
|
|
4ca84ee4ee | ||
|
|
c28cc4c919 | ||
|
|
e9864cb3f7 | ||
|
|
83c69ecd49 | ||
|
|
3595b4aaff | ||
|
|
3a9cfe113a | ||
|
|
c9966127da | ||
|
|
51300d33a7 | ||
|
|
5af124c5a5 | ||
|
|
eeb20b531a | ||
|
|
9dca842c22 | ||
|
|
1eb9436836 | ||
|
|
9604d9ce81 | ||
|
|
481d0553d8 | ||
|
|
60035cd63a | ||
|
|
d35f992ace | ||
|
|
157ae64f9d | ||
|
|
ffa17f6057 | ||
|
|
d695a43e32 | ||
|
|
01f6b4e6f0 | ||
|
|
7cf31a6ae4 | ||
|
|
fbd6224b04 | ||
|
|
8115b26079 | ||
|
|
820586ac68 | ||
|
|
4a7441ed07 | ||
|
|
383741f284 | ||
|
|
2bbc4e0e9f | ||
|
|
a7237244b0 | ||
|
|
1d38d49162 | ||
|
|
a783c089a9 | ||
|
|
e7907dc532 | ||
|
|
394413679d | ||
|
|
37189f14cb | ||
|
|
0b1ee81901 | ||
|
|
00cf73f9b8 | ||
|
|
5a5f285493 | ||
|
|
7f2ea454b6 | ||
|
|
7c14002118 | ||
|
|
3e9554f0a1 | ||
|
|
e11ffec544 | ||
|
|
8a47ddbe99 | ||
|
|
821108c7bd | ||
|
|
339738f8a3 | ||
|
|
9b90672f63 | ||
|
|
ba07e94a5e | ||
|
|
b3fc0f29cc | ||
|
|
5c7deb3611 | ||
|
|
15604e374f | ||
|
|
7cfc0fa55b | ||
|
|
a90812133b | ||
|
|
e26a70aa4f | ||
|
|
6a32a4e26c | ||
|
|
e853abf98b | ||
|
|
51e81e6ef8 | ||
|
|
e355000ceb | ||
|
|
e374074013 | ||
|
|
81e3d1c2c6 | ||
|
|
ab0cbb4475 | ||
|
|
1c64e40722 | ||
|
|
8cafe56eb4 | ||
|
|
3eceeb7b23 | ||
|
|
1a37675435 | ||
|
|
198ebede8d | ||
|
|
a504903dd5 | ||
|
|
842adef29c | ||
|
|
7edcaf5a06 | ||
|
|
c124b76328 | ||
|
|
e9c744ee5d | ||
|
|
83302930d8 | ||
|
|
a4634632ba | ||
|
|
d17e8dc5ad | ||
|
|
9fe63de4d4 | ||
|
|
8111f8bf35 | ||
|
|
fcd62513cf | ||
|
|
c3c701e654 | ||
|
|
6bf991edf6 | ||
|
|
9644e78545 | ||
|
|
c911189ef0 | ||
|
|
1118b4b651 | ||
|
|
4be75d4418 | ||
|
|
fb6beae27c | ||
|
|
fee73b0b63 | ||
|
|
9bbffa519e | ||
|
|
c3a641f0ab | ||
|
|
aafe7c4701 | ||
|
|
9a0b082cf8 | ||
|
|
8265e34a29 | ||
|
|
8ef8ae097f | ||
|
|
c3d14293c0 | ||
|
|
d55d8be504 | ||
|
|
03543030d3 | ||
|
|
fc6b474b92 | ||
|
|
a5db785dd7 | ||
|
|
1c1c5cd611 | ||
|
|
6ed02f70ec | ||
|
|
cb78cd8ac0 | ||
|
|
0c4590b45a | ||
|
|
d2e2ee6efa | ||
|
|
6a380a0b48 | ||
|
|
e5d5acbf1f | ||
|
|
00e38abbf0 | ||
|
|
e3e4ea5443 | ||
|
|
a3e4ea3228 | ||
|
|
56f16d6baf | ||
|
|
7a55ab900e | ||
|
|
137643fe72 | ||
|
|
d6e59c6241 | ||
|
|
458eb5d34c | ||
|
|
8259f08864 | ||
|
|
b3ab0a1843 | ||
|
|
f09f217478 | ||
|
|
e842c8c19b | ||
|
|
f6c3112d44 | ||
|
|
7059610632 | ||
|
|
2d272930d9 | ||
|
|
6c470d8131 | ||
|
|
30b29ce8cd | ||
|
|
1a9933002f | ||
|
|
c4a9365aa1 | ||
|
|
9d3af37104 | ||
|
|
7b3d57cff7 | ||
|
|
a802270da9 | ||
|
|
dd194a8758 | ||
|
|
6de02de221 | ||
|
|
85259750bf | ||
|
|
1249f0007d | ||
|
|
db0514d3fa | ||
|
|
dce42a7fad | ||
|
|
ec0b380194 | ||
|
|
7f27b61c98 | ||
|
|
f0b3557b02 | ||
|
|
2a1d1c1001 | ||
|
|
df7eb80e5b | ||
|
|
b9d947ce6f | ||
|
|
e6589d2454 | ||
|
|
0f5ac6afcf | ||
|
|
bc1bb1d188 | ||
|
|
3af2dd10ce | ||
|
|
dd22c65855 | ||
|
|
48137ced19 | ||
|
|
6eb47c12d1 | ||
|
|
5a1fc6675a | ||
|
|
6f80825814 | ||
|
|
f0dd48ed2a | ||
|
|
15e2df0db0 | ||
|
|
4ad0109769 | ||
|
|
ee0009d4b8 | ||
|
|
9d851c3346 | ||
|
|
5d117af8ae | ||
|
|
bb41c2d15e | ||
|
|
eba138ee4a | ||
|
|
3b2bbb74f8 | ||
|
|
dbc0f81211 | ||
|
|
d0b613d22e | ||
|
|
72f29b67d5 | ||
|
|
9570045cc3 | ||
|
|
e4efdb5cbb | ||
|
|
187f0fa70c | ||
|
|
472185c3e4 | ||
|
|
f94a571773 | ||
|
|
183e447d35 | ||
|
|
12f844d93a | ||
|
|
47a119a37f | ||
|
|
ee56559b9a | ||
|
|
00e594deea | ||
|
|
6ad9b213b9 | ||
|
|
e4375e8195 | ||
|
|
487bf8e29b | ||
|
|
fea1694e74 | ||
|
|
4102c124a9 | ||
|
|
135bad3280 | ||
|
|
b604f36881 | ||
|
|
782b449c71 | ||
|
|
017dcab685 | ||
|
|
e60b4568c6 | ||
|
|
4ee3d95a5a | ||
|
|
f18725bacc | ||
|
|
f6064a2b84 | ||
|
|
2e90cb7b95 | ||
|
|
2c09d63cd9 | ||
|
|
cc6fbdb0c3 | ||
|
|
ecfdec12f3 | ||
|
|
45af40fd14 | ||
|
|
d11cf42501 | ||
|
|
c3c1e3b055 | ||
|
|
7c5e3b1d99 | ||
|
|
ed6cec71e7 | ||
|
|
d6bcdd069c | ||
|
|
a26347826d | ||
|
|
5d1c099b31 | ||
|
|
220bee1365 | ||
|
|
1261074d95 | ||
|
|
136021424c | ||
|
|
fee4ba3746 | ||
|
|
a5b70335d4 | ||
|
|
5cf4976054 | ||
|
|
1aa3255061 | ||
|
|
b01f29f10d | ||
|
|
2673abca88 | ||
|
|
7eeb7f0715 | ||
|
|
37262a2479 | ||
|
|
de6e304959 | ||
|
|
234475bbc7 | ||
|
|
abbd9f7cfc | ||
|
|
dfd6ba67b3 | ||
|
|
1595254eab | ||
|
|
6964c5eeba | ||
|
|
2befe771b3 | ||
|
|
b133a035a4 | ||
|
|
726c062327 | ||
|
|
9083672de3 | ||
|
|
cdbaf880af | ||
|
|
9434981cdc | ||
|
|
8b3706f557 | ||
|
|
0d5173833d | ||
|
|
bf1178eb79 | ||
|
|
abcd3fa94a | ||
|
|
62aa1614b6 | ||
|
|
7027356126 | ||
|
|
5ebe13a13d | ||
|
|
c3bed9a2b7 | ||
|
|
f865222882 | ||
|
|
e2fe2e4095 | ||
|
|
0532a95f08 | ||
|
|
ff536f6015 | ||
|
|
097d0f27bb | ||
|
|
2257f87edf | ||
|
|
a17800da00 | ||
|
|
059c1b3a19 | ||
|
|
9a36816d27 | ||
|
|
7986b9b20b | ||
|
|
b2b3a0a62b | ||
|
|
3173b7d1d9 | ||
|
|
9d716d70d6 | ||
|
|
e1901a8608 | ||
|
|
7d0cbd8d90 | ||
|
|
59358361f9 | ||
|
|
7fea2d3b68 | ||
|
|
b6d3ff26bd | ||
|
|
523e63f5c1 | ||
|
|
10630ab597 | ||
|
|
2bc6de650d | ||
|
|
ffef1681e3 | ||
|
|
d935006a4a | ||
|
|
660cb5946e | ||
|
|
10160a066a | ||
|
|
72976a2ece | ||
|
|
831f206cd0 | ||
|
|
72648aa9f2 | ||
|
|
35e623deaf | ||
|
|
6263636738 | ||
|
|
535d012ded | ||
|
|
c73eed2e51 | ||
|
|
30fdc99f37 | ||
|
|
acb905f0cc | ||
|
|
bba06d0142 | ||
|
|
a14a47af12 | ||
|
|
73457336bc | ||
|
|
a14c53ad31 | ||
|
|
e7e763551a | ||
|
|
2928179331 | ||
|
|
24a16a4cfe | ||
|
|
6aed4423b2 | ||
|
|
6508e3fcc9 | ||
|
|
a15cb140ae | ||
|
|
898bc9e009 | ||
|
|
e67ea31ee2 | ||
|
|
986c126a5c | ||
|
|
0eee7616b9 | ||
|
|
5ddce749b8 | ||
|
|
d946cffabc | ||
|
|
fe618811ee | ||
|
|
09c45bfb80 | ||
|
|
e9e9ccd379 | ||
|
|
a9b27c78a3 | ||
|
|
bc17c29b2e | ||
|
|
aaf60bdee6 | ||
|
|
d913453e57 | ||
|
|
08e373aef4 | ||
|
|
4cb50a3d06 | ||
|
|
b03038222d | ||
|
|
5f5e0766dd | ||
|
|
48ec11c514 | ||
|
|
8ae76d18b5 | ||
|
|
e5be1790e5 | ||
|
|
e64aa40b17 | ||
|
|
eb8114ece8 | ||
|
|
616ee9b824 | ||
|
|
57c94f8f80 | ||
|
|
2a59c4f670 | ||
|
|
192ff487c4 | ||
|
|
b62ee3fcb9 | ||
|
|
0225292a44 | ||
|
|
589a7ed02f | ||
|
|
b3a42cd0b1 | ||
|
|
e3e1ca7cc6 | ||
|
|
57e417d174 | ||
|
|
1699db79b5 | ||
|
|
dab9403b8f | ||
|
|
9a14298146 | ||
|
|
40eea21863 | ||
|
|
d2475ec169 | ||
|
|
b3bcf4bf44 | ||
|
|
6049f86bc4 | ||
|
|
ff649b52ef | ||
|
|
e9e138c757 | ||
|
|
1096936a15 | ||
|
|
29cc478525 | ||
|
|
05e9eb40b5 | ||
|
|
c4444ff695 | ||
|
|
27b34f3929 | ||
|
|
2b8d784660 | ||
|
|
18f447d8d8 | ||
|
|
d7e1078d68 | ||
|
|
6be592653f | ||
|
|
8859853b41 | ||
|
|
3c46021102 | ||
|
|
bba8646669 | ||
|
|
b0dc19a910 | ||
|
|
df79ebd0f2 | ||
|
|
e19a97f316 | ||
|
|
482ffd6275 | ||
|
|
5117e50602 | ||
|
|
83b138208d | ||
|
|
1870cb4557 | ||
|
|
42ad5b9c5c | ||
|
|
333975eb8f | ||
|
|
aa0195e4ef | ||
|
|
56109fe09b | ||
|
|
e74046478b | ||
|
|
aa5a60812f | ||
|
|
ebb60019aa | ||
|
|
6393dc5d14 | ||
|
|
8c158f2452 | ||
|
|
8c3eabdcee | ||
|
|
8aa0ce6a24 | ||
|
|
a27ee141b3 | ||
|
|
1106456651 | ||
|
|
8856878cbd | ||
|
|
a9bac0287d | ||
|
|
efbd3dc778 | ||
|
|
a0d0eaa408 | ||
|
|
e2bf734b67 | ||
|
|
a333a90441 | ||
|
|
6dc0057d3d | ||
|
|
0f9e69d48c | ||
|
|
e6a7c019ab | ||
|
|
1d32eabd14 | ||
|
|
53d03f06a6 | ||
|
|
a2d8c40455 | ||
|
|
4f7d950c8d | ||
|
|
cac54b8c26 | ||
|
|
cd0e881d7d | ||
|
|
fee406e220 | ||
|
|
128342f47f | ||
|
|
024487c5fe | ||
|
|
879ba27ccb | ||
|
|
6d6d9627e7 | ||
|
|
af4bc82543 | ||
|
|
439a18bcc3 | ||
|
|
e12a1e0444 | ||
|
|
4400b0d3c3 | ||
|
|
5dff28ff99 | ||
|
|
d5ac841a1a | ||
|
|
232ce12e9b | ||
|
|
9a8638a6d0 | ||
|
|
a5445866b8 | ||
|
|
e8ded71a7b | ||
|
|
a14c615def | ||
|
|
3903b6ff0c | ||
|
|
41bf262482 | ||
|
|
645b658da0 | ||
|
|
6ee8f61fbe | ||
|
|
3c4c4231ce | ||
|
|
d0eef19eba | ||
|
|
6ca2eb3ad7 | ||
|
|
74aeb55733 | ||
|
|
3eb7965ca0 | ||
|
|
04f20070d1 | ||
|
|
88937fcb2f | ||
|
|
f80b85f10c | ||
|
|
32a2ec432d | ||
|
|
f4821d0d39 | ||
|
|
fdf2aa54ef | ||
|
|
275c032264 | ||
|
|
d88979fe19 | ||
|
|
e67bcffea7 | ||
|
|
005ded3c6f | ||
|
|
d624940e12 | ||
|
|
7763403b0e | ||
|
|
88c58244b9 | ||
|
|
0754c6ea20 | ||
|
|
7b1f04d121 | ||
|
|
d8a9bee244 | ||
|
|
ac0ea6bd3c | ||
|
|
45677c1e23 | ||
|
|
d9f4a9954a | ||
|
|
ec461a4456 | ||
|
|
559928e93b | ||
|
|
a526f7d5b8 | ||
|
|
749a2c2dec | ||
|
|
29a317dbb6 | ||
|
|
2f36de319a | ||
|
|
2005bce419 | ||
|
|
8a02d7729d | ||
|
|
1cdf301c14 | ||
|
|
9a86e5c476 | ||
|
|
32d3f4bd5f | ||
|
|
18689afc1a | ||
|
|
64d6da75c7 | ||
|
|
1e95e4b502 | ||
|
|
c63009a6db | ||
|
|
88f8718635 | ||
|
|
a081733a42 | ||
|
|
06ccfb0533 | ||
|
|
b18d75e3f7 | ||
|
|
3e7efaa048 | ||
|
|
a3fdfc81db | ||
|
|
f4c91df1df | ||
|
|
32e1ba8c0d | ||
|
|
1939376d72 | ||
|
|
25931d48a3 | ||
|
|
024c5e153a | ||
|
|
83f34b645d | ||
|
|
3f9f450e0d | ||
|
|
fd89b06641 | ||
|
|
f8dc996004 | ||
|
|
e6a964088b | ||
|
|
e3e767c7eb | ||
|
|
239c19eb12 | ||
|
|
7f37599a60 | ||
|
|
77c9a2c5ea | ||
|
|
fd7baae548 | ||
|
|
01fdf5ee16 | ||
|
|
e52f533c16 | ||
|
|
fbd77dc936 | ||
|
|
cdc6dd19e3 | ||
|
|
fd578a48a9 | ||
|
|
9956099516 | ||
|
|
f97b8fffed | ||
|
|
7b9e309724 | ||
|
|
1d33913d48 | ||
|
|
a48eaaed20 | ||
|
|
2741b8be53 | ||
|
|
4f906a265c | ||
|
|
0dff8d7af0 | ||
|
|
4f0d0d8167 | ||
|
|
d513060b21 | ||
|
|
d1a25ce4f3 | ||
|
|
51c98695b2 | ||
|
|
b448770ec2 | ||
|
|
5fe22a7980 | ||
|
|
38ae6b5af4 | ||
|
|
0bfe30d75d | ||
|
|
7be1d7d0be | ||
|
|
0d74c873f0 | ||
|
|
139aff2938 | ||
|
|
a3f733490c | ||
|
|
8a11f138d1 | ||
|
|
3405607917 | ||
|
|
7c99a6bd33 | ||
|
|
3fba8ce0e6 | ||
|
|
f3bde3c7fc | ||
|
|
21fee8ef33 | ||
|
|
0e217d6180 | ||
|
|
00a8ce75d1 | ||
|
|
8f3f00cd99 | ||
|
|
13bae2538a |
5
.flake8
Normal file
5
.flake8
Normal file
@@ -0,0 +1,5 @@
|
||||
[flake8]
|
||||
count = 1
|
||||
show-source = 1
|
||||
select = E9,F63,F7,F82
|
||||
exclude = lit.cfg.py, apps/language_models/scripts/vicuna.py, apps/language_models/src/pipelines/minigpt4_pipeline.py, apps/language_models/langchain/h2oai_pipeline.py
|
||||
2
.github/workflows/gh-pages-releases.yml
vendored
2
.github/workflows/gh-pages-releases.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
- run: git fetch --all
|
||||
- run: git switch github-pages
|
||||
- run: git config --global user.email "none@none.com"
|
||||
- run: git config --global user.name "nod-team"
|
||||
- run: git config --global user.name "nod-ai"
|
||||
- run: mv /tmp/index.html package-index/index.html
|
||||
- run: git add package-index/index.html
|
||||
|
||||
|
||||
108
.github/workflows/nightly.yml
vendored
108
.github/workflows/nightly.yml
vendored
@@ -9,40 +9,30 @@ on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: a100
|
||||
windows-build:
|
||||
runs-on: 7950X
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
backend: [IREE, SHARK]
|
||||
python-version: ["3.11"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Setup pip cache
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pip-
|
||||
|
||||
|
||||
- name: Compute version
|
||||
shell: powershell
|
||||
run: |
|
||||
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
|
||||
tag_name="${package_version}"
|
||||
echo "package_version=${package_version}" >> $GITHUB_ENV
|
||||
echo "tag_name=${tag_name}" >> $GITHUB_ENV
|
||||
- name: Set Environment Variables
|
||||
run: |
|
||||
echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV
|
||||
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
|
||||
$package_version = $(Get-Date -UFormat "%Y%m%d")+"."+${{ github.run_number }}
|
||||
$package_version_ = $(Get-Date -UFormat "%Y%m%d")+"_"+${{ github.run_number }}
|
||||
$tag_name=$package_version
|
||||
echo "package_version=$package_version" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append
|
||||
echo "package_version_=$package_version_" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append
|
||||
echo "tag_name=$tag_name" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append
|
||||
|
||||
- name: Create Release
|
||||
id: create_release
|
||||
uses: actions/create-release@v1
|
||||
@@ -54,76 +44,30 @@ jobs:
|
||||
body: |
|
||||
Automatic snapshot release of nod.ai SHARK.
|
||||
draft: true
|
||||
prerelease: false
|
||||
- name: Find Torch-MLIR Release
|
||||
run: |
|
||||
TM_HTML_URL="$(python3 -c "import urllib.request, json, sys; u=json.loads(urllib.request.urlopen('https://api.github.com/repos/llvm/torch-mlir/releases/latest').read().decode()).get('html_url', False); print(u) if u else sys.exit(1);")"
|
||||
TM_RELEASE_DIR=${TM_HTML_URL/"tag"/"expanded_assets"}
|
||||
echo "TM_RELEASE_DIR=${TM_RELEASE_DIR}" >> $GITHUB_ENV
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
echo "Torch-MLIR Release DIR is ${{ env.TM_RELEASE_DIR }}"
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install flake8 pytest toml
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt -f ${{ env.TM_RELEASE_DIR }} -f https://github.com/nod-ai/SHARK-Runtime/releases; fi
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude shark.venv,lit.cfg.py
|
||||
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
||||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude shark.venv,lit.cfg.py
|
||||
- name: Build and validate the IREE package
|
||||
if: ${{ matrix.backend == 'IREE' }}
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
USE_IREE=1 VENV_DIR=iree.venv ./setup_venv.sh
|
||||
source iree.venv/bin/activate
|
||||
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
|
||||
SHARK_PACKAGE_VERSION=${package_version} \
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f ${{ env.TM_RELEASE_DIR }} -f https://github.com/iree-org/iree/releases
|
||||
# Install the built wheel
|
||||
pip install ./wheelhouse/nodai*
|
||||
# Validate the Models
|
||||
/bin/bash "$GITHUB_WORKSPACE/build_tools/populate_sharktank_ci.sh"
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./gen_shark_tank/" tank/test_models.py |
|
||||
tail -n 1 |
|
||||
tee -a pytest_results.txt
|
||||
if !(grep -Fxq " failed" pytest_results.txt)
|
||||
then
|
||||
export SHA=$(git log -1 --format='%h')
|
||||
gsutil -m cp -r $GITHUB_WORKSPACE/gen_shark_tank/* gs://shark_tank/$SHA
|
||||
gsutil -m cp -r gs://shark_tank/$SHA/* gs://shark_tank/latest/
|
||||
fi
|
||||
rm -rf ./wheelhouse/nodai*
|
||||
prerelease: true
|
||||
|
||||
- name: Build and validate the SHARK Runtime package
|
||||
if: ${{ matrix.backend == 'SHARK' }}
|
||||
- name: Build Package
|
||||
shell: powershell
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
|
||||
SHARK_PACKAGE_VERSION=${package_version} \
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f ${{ env.TM_RELEASE_DIR }} -f https://github.com/nod-ai/SHARK-Runtime/releases
|
||||
# Install the built wheel
|
||||
pip install ./wheelhouse/nodai*
|
||||
# Validate the Models
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./gen_shark_tank/" tank/test_models.py |
|
||||
tail -n 1 |
|
||||
tee -a pytest_results.txt
|
||||
|
||||
./setup_venv.ps1
|
||||
$env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
|
||||
pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
|
||||
python process_skipfiles.py
|
||||
pyinstaller .\apps\stable_diffusion\shark_sd.spec
|
||||
mv ./dist/nodai_shark_studio.exe ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
|
||||
signtool sign /f c:\g\shark_02152023.cer /fd certHash /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
|
||||
|
||||
- name: Upload Release Assets
|
||||
if: ${{ matrix.backend == 'SHARK' }}
|
||||
id: upload-release-assets
|
||||
uses: dwenegar/upload-release-assets@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
with:
|
||||
release_id: ${{ steps.create_release.outputs.id }}
|
||||
assets_path: ${GITHUB_WORKSPACE}/wheelhouse/nodai_*.whl
|
||||
assets_path: ./dist/nodai*
|
||||
#asset_content_type: application/vnd.microsoft.portable-executable
|
||||
|
||||
- name: Publish Release
|
||||
if: ${{ matrix.backend == 'SHARK' }}
|
||||
id: publish_release
|
||||
uses: eregon/publish-release@v1
|
||||
env:
|
||||
|
||||
77
.github/workflows/test-models.yml
vendored
77
.github/workflows/test-models.yml
vendored
@@ -6,21 +6,37 @@ name: Validate Models on Shark Runtime
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
- 'shark/examples/**'
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
- 'shark/examples/**'
|
||||
workflow_dispatch:
|
||||
|
||||
# Ensure that only a single job or workflow using the same
|
||||
# concurrency group will run at a time. This would cancel
|
||||
# any in-progress jobs in the same github workflow and github
|
||||
# ref (e.g. refs/heads/main or refs/pull/<pr_number>/merge).
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build-validate:
|
||||
strategy:
|
||||
fail-fast: true
|
||||
matrix:
|
||||
os: [icelake, a100, MacStudio, ubuntu-latest]
|
||||
os: [7950x, icelake, a100, MacStudio, ubuntu-latest]
|
||||
suite: [cpu,cuda,vulkan]
|
||||
python-version: ["3.10"]
|
||||
python-version: ["3.11"]
|
||||
include:
|
||||
- os: ubuntu-latest
|
||||
suite: lint
|
||||
- os: MacStudio
|
||||
suite: metal
|
||||
exclude:
|
||||
- os: ubuntu-latest
|
||||
suite: vulkan
|
||||
@@ -40,6 +56,10 @@ jobs:
|
||||
suite: cuda
|
||||
- os: a100
|
||||
suite: cpu
|
||||
- os: 7950x
|
||||
suite: cpu
|
||||
- os: 7950x
|
||||
suite: cuda
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
@@ -47,6 +67,7 @@ jobs:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Set Environment Variables
|
||||
if: matrix.os != '7950x'
|
||||
run: |
|
||||
echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV
|
||||
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
|
||||
@@ -78,36 +99,66 @@ jobs:
|
||||
run: |
|
||||
# black format check
|
||||
black --version
|
||||
black --line-length 79 --check .
|
||||
black --check .
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude lit.cfg.py
|
||||
flake8 . --statistics
|
||||
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
||||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude lit.cfg.py
|
||||
flake8 . --isolated --count --exit-zero --max-complexity=10 --max-line-length=127 \
|
||||
--statistics --exclude lit.cfg.py
|
||||
|
||||
- name: Validate Models on CPU
|
||||
if: matrix.suite == 'cpu'
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
|
||||
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --benchmark --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/data/anush" tank/test_models.py -k cpu
|
||||
pytest --benchmark=native --update_tank -k cpu
|
||||
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
|
||||
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
|
||||
python build_tools/vicuna_testing.py
|
||||
|
||||
- name: Validate Models on NVIDIA GPU
|
||||
if: matrix.suite == 'cuda'
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
|
||||
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --benchmark --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/data/anush" tank/test_models.py -k cuda
|
||||
pytest --benchmark=native --update_tank -k cuda
|
||||
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
|
||||
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
|
||||
# Disabled due to black image bug
|
||||
# python build_tools/stable_diffusion_testing.py --device=cuda
|
||||
|
||||
- name: Validate Vulkan Models
|
||||
if: matrix.suite == 'vulkan'
|
||||
- name: Validate Vulkan Models (MacOS)
|
||||
if: matrix.suite == 'metal' && matrix.os == 'MacStudio'
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
|
||||
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/data/anush" tank/test_models.py -k vulkan
|
||||
echo $PATH
|
||||
pip list | grep -E "torch|iree"
|
||||
# disabled due to a low-visibility memory issue with pytest on macos.
|
||||
# pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k metal
|
||||
|
||||
- name: Validate Vulkan Models (a100)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == 'a100'
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --update_tank -k vulkan
|
||||
python build_tools/stable_diffusion_testing.py --device=vulkan --no-exit_on_fail
|
||||
|
||||
- name: Validate Vulkan Models (Windows)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
|
||||
run: |
|
||||
./setup_venv.ps1
|
||||
pytest -k vulkan -s --ci
|
||||
|
||||
- name: Validate Stable Diffusion Models (Windows)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
|
||||
run: |
|
||||
./setup_venv.ps1
|
||||
python process_skipfiles.py
|
||||
pyinstaller .\apps\stable_diffusion\shark_sd.spec
|
||||
python build_tools/stable_diffusion_testing.py --device=vulkan
|
||||
|
||||
40
.gitignore
vendored
40
.gitignore
vendored
@@ -2,6 +2,8 @@
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.mlir
|
||||
*.vmfb
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
@@ -31,7 +33,6 @@ MANIFEST
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
@@ -158,12 +159,47 @@ cython_debug/
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
.idea/
|
||||
|
||||
# vscode related
|
||||
.vscode
|
||||
|
||||
# Shark related artefacts
|
||||
*venv/
|
||||
shark_tmp/
|
||||
*.vmfb
|
||||
*.safetensors
|
||||
.use-iree
|
||||
tank/dict_configs.py
|
||||
*.csv
|
||||
reproducers/
|
||||
|
||||
# ORT related artefacts
|
||||
cache_models/
|
||||
onnx_models/
|
||||
|
||||
# Generated images
|
||||
generated_imgs/
|
||||
|
||||
# Custom model related artefacts
|
||||
variants.json
|
||||
/models/
|
||||
|
||||
# models folder
|
||||
apps/stable_diffusion/web/models/
|
||||
|
||||
# Stencil annotators.
|
||||
stencil_annotator/
|
||||
|
||||
# For DocuChat
|
||||
apps/language_models/langchain/user_path/
|
||||
db_dir_UserData
|
||||
|
||||
# Embeded browser cache and other
|
||||
apps/stable_diffusion/web/EBWebView/
|
||||
|
||||
# Llama2 tokenizer configs
|
||||
llama2_tokenizer_configs/
|
||||
|
||||
# Webview2 runtime artefacts
|
||||
EBWebView/
|
||||
|
||||
2
.gitmodules
vendored
2
.gitmodules
vendored
@@ -1,4 +1,4 @@
|
||||
[submodule "inference/thirdparty/shark-runtime"]
|
||||
path = inference/thirdparty/shark-runtime
|
||||
url =https://github.com/nod-ai/SHARK-Runtime.git
|
||||
url =https://github.com/nod-ai/SRT.git
|
||||
branch = shark-06032022
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
[style]
|
||||
based_on_style = google
|
||||
column_limit = 80
|
||||
472
README.md
472
README.md
@@ -1,29 +1,161 @@
|
||||
# SHARK
|
||||
|
||||
High Performance Machine Learning and Data Analytics for CPUs, GPUs, Accelerators and Heterogeneous Clusters
|
||||
High Performance Machine Learning Distribution
|
||||
|
||||
[](https://github.com/nod-ai/SHARK/actions/workflows/nightly.yml)
|
||||
[](https://github.com/nod-ai/SHARK/actions/workflows/test-models.yml)
|
||||
|
||||
## Communication Channels
|
||||
|
||||
* [SHARK Discord server](https://discord.gg/RUqY2h2s9u): Real time discussions with the SHARK team and other users
|
||||
* [GitHub issues](https://github.com/nod-ai/SHARK/issues): Feature requests, bugs etc
|
||||
|
||||
|
||||
## Installation
|
||||
|
||||
<details>
|
||||
<summary>Installation (Linux and macOS)</summary>
|
||||
<summary>Prerequisites - Drivers </summary>
|
||||
|
||||
#### Install your Windows hardware drivers
|
||||
* [AMD RDNA Users] Download the latest driver (23.2.1 is the oldest supported) [here](https://www.amd.com/en/support).
|
||||
* [macOS Users] Download and install the 1.3.216 Vulkan SDK from [here](https://sdk.lunarg.com/sdk/download/1.3.216.0/mac/vulkansdk-macos-1.3.216.0.dmg). Newer versions of the SDK will not work.
|
||||
* [Nvidia Users] Download and install the latest CUDA / Vulkan drivers from [here](https://developer.nvidia.com/cuda-downloads)
|
||||
|
||||
#### Linux Drivers
|
||||
* MESA / RADV drivers wont work with FP16. Please use the latest AMGPU-PRO drivers (non-pro OSS drivers also wont work) or the latest NVidia Linux Drivers.
|
||||
|
||||
Other users please ensure you have your latest vendor drivers and Vulkan SDK from [here](https://vulkan.lunarg.com/sdk/home) and if you are using vulkan check `vulkaninfo` works in a terminal window
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
### Quick Start for SHARK Stable Diffusion for Windows 10/11 Users
|
||||
|
||||
Install the Driver from [Prerequisites](https://github.com/nod-ai/SHARK#install-your-hardware-drivers) above
|
||||
|
||||
Download the [stable release](https://github.com/nod-ai/shark/releases/latest)
|
||||
|
||||
Double click the .exe and you should have the [UI](http://localhost:8080/) in the browser.
|
||||
|
||||
If you have custom models put them in a `models/` directory where the .exe is.
|
||||
|
||||
Enjoy.
|
||||
|
||||
<details>
|
||||
<summary>More installation notes</summary>
|
||||
* We recommend that you download EXE in a new folder, whenever you download a new EXE version. If you download it in the same folder as a previous install, you must delete the old `*.vmfb` files with `rm *.vmfb`. You can also use `--clear_all` flag once to clean all the old files.
|
||||
* If you recently updated the driver or this binary (EXE file), we recommend you clear all the local artifacts with `--clear_all`
|
||||
|
||||
## Running
|
||||
|
||||
* Open a Command Prompt or Powershell terminal, change folder (`cd`) to the .exe folder. Then run the EXE from the command prompt. That way, if an error occurs, you'll be able to cut-and-paste it to ask for help. (if it always works for you without error, you may simply double-click the EXE)
|
||||
* The first run may take few minutes when the models are downloaded and compiled. Your patience is appreciated. The download could be about 5GB.
|
||||
* You will likely see a Windows Defender message asking you to give permission to open a web server port. Accept it.
|
||||
* Open a browser to access the Stable Diffusion web server. By default, the port is 8080, so you can go to http://localhost:8080/.
|
||||
|
||||
## Stopping
|
||||
|
||||
* Select the command prompt that's running the EXE. Press CTRL-C and wait a moment or close the terminal.
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Advanced Installation (Only for developers)</summary>
|
||||
|
||||
## Advanced Installation (Windows, Linux and macOS) for developers
|
||||
|
||||
## Check out the code
|
||||
|
||||
```shell
|
||||
git clone https://github.com/nod-ai/SHARK.git
|
||||
cd SHARK
|
||||
```
|
||||
|
||||
## Setup your Python VirtualEnvironment and Dependencies
|
||||
|
||||
### Windows 10/11 Users
|
||||
|
||||
* Install the latest Python 3.11.x version from [here](https://www.python.org/downloads/windows/)
|
||||
|
||||
* Install Git for Windows from [here](https://git-scm.com/download/win)
|
||||
|
||||
#### Allow the install script to run in Powershell
|
||||
```powershell
|
||||
set-executionpolicy remotesigned
|
||||
```
|
||||
|
||||
#### Setup venv and install necessary packages (torch-mlir, nodLabs/Shark, ...)
|
||||
```powershell
|
||||
./setup_venv.ps1 #You can re-run this script to get the latest version
|
||||
```
|
||||
|
||||
### Linux / macOS Users
|
||||
|
||||
```shell
|
||||
./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
```
|
||||
|
||||
|
||||
### Run Stable Diffusion on your device - WebUI
|
||||
|
||||
#### Windows 10/11 Users
|
||||
```powershell
|
||||
(shark.venv) PS C:\g\shark> cd .\apps\stable_diffusion\web\
|
||||
(shark.venv) PS C:\g\shark\apps\stable_diffusion\web> python .\index.py
|
||||
```
|
||||
#### Linux / macOS Users
|
||||
```shell
|
||||
(shark.venv) > cd apps/stable_diffusion/web
|
||||
(shark.venv) > python index.py
|
||||
```
|
||||
|
||||
#### Access Stable Diffusion on http://localhost:8080/?__theme=dark
|
||||
|
||||
|
||||
<img width="1607" alt="webui" src="https://user-images.githubusercontent.com/74956/204939260-b8308bc2-8dc4-47f6-9ac0-f60b66edab99.png">
|
||||
|
||||
|
||||
|
||||
### Run Stable Diffusion on your device - Commandline
|
||||
|
||||
#### Windows 10/11 Users
|
||||
```powershell
|
||||
(shark.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\main.py --app="txt2img" --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
|
||||
```
|
||||
|
||||
#### Linux / macOS Users
|
||||
```shell
|
||||
python3.11 apps/stable_diffusion/scripts/main.py --app=txt2img --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
|
||||
```
|
||||
|
||||
You can replace `vulkan` with `cpu` to run on your CPU or with `cuda` to run on CUDA devices. If you have multiple vulkan devices you can address them with `--device=vulkan://1` etc
|
||||
</details>
|
||||
|
||||
The output on a AMD 7900XTX would look something like:
|
||||
|
||||
```shell
|
||||
Average step time: 47.19188690185547ms/it
|
||||
Clip Inference time (ms) = 109.531
|
||||
VAE Inference time (ms): 78.590
|
||||
|
||||
Total image generation time: 2.5788655281066895sec
|
||||
```
|
||||
|
||||
Here are some samples generated:
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
|
||||
Find us on [SHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any trouble with running it on your hardware.
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Binary Installation</summary>
|
||||
|
||||
### Setup a new pip Virtual Environment
|
||||
|
||||
This step sets up a new VirtualEnv for Python
|
||||
|
||||
```shell
|
||||
python --version #Check you have 3.7->3.10 on Linux or 3.10 on macOS
|
||||
python --version #Check you have 3.11 on Linux, macOS or Windows Powershell
|
||||
python -m venv shark_venv
|
||||
source shark_venv/bin/activate
|
||||
source shark_venv/bin/activate # Use shark_venv/Scripts/activate on Windows
|
||||
|
||||
# If you are using conda create and activate a new conda env
|
||||
|
||||
@@ -35,12 +167,17 @@ python -m pip install --upgrade pip
|
||||
|
||||
### Install SHARK
|
||||
|
||||
This step pip installs SHARK and related packages on Linux Python 3.7, 3.8, 3.9, 3.10 and macOS Python 3.10
|
||||
This step pip installs SHARK and related packages on Linux Python 3.8, 3.10 and 3.11 and macOS / Windows Python 3.11
|
||||
|
||||
```shell
|
||||
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://github.com/nod-ai/shark-runtime/releases --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
```
|
||||
If you are on an Intel macOS machine you need this [workaround](https://github.com/nod-ai/SHARK/issues/102) for an upstream issue.
|
||||
|
||||
### Run shark tank model tests.
|
||||
```shell
|
||||
pytest tank/test_models.py
|
||||
```
|
||||
See tank/README.md for a more detailed walkthrough of our pytest suite and CLI.
|
||||
|
||||
### Download and run Resnet50 sample
|
||||
|
||||
@@ -61,33 +198,31 @@ python ./minilm_jit.py --device="cpu" #use cuda or vulkan or metal
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Source Installation</summary>
|
||||
<summary>Development, Testing and Benchmarks</summary>
|
||||
|
||||
## Check out the code
|
||||
If you want to use Python3.11 and with TF Import tools you can use the environment variables like:
|
||||
Set `USE_IREE=1` to use upstream IREE
|
||||
```
|
||||
# PYTHON=python3.11 VENV_DIR=0617_venv IMPORTER=1 ./setup_venv.sh
|
||||
```
|
||||
|
||||
### Run any of the hundreds of SHARK tank models via the test framework
|
||||
```shell
|
||||
git clone https://github.com/nod-ai/SHARK.git
|
||||
python -m shark.examples.shark_inference.resnet50_script --device="cpu" # Use gpu | vulkan
|
||||
# Or a pytest
|
||||
pytest tank/test_models.py -k "MiniLM"
|
||||
```
|
||||
|
||||
## Setup your Python VirtualEnvironment and Dependencies
|
||||
```shell
|
||||
# Setup venv and install necessary packages (torch-mlir, nodLabs/Shark, ...).
|
||||
./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
```
|
||||
For example if you want to use Python3.10 and upstream IREE with TF Import tools you can use the environment variables like:
|
||||
```
|
||||
# PYTHON=python3.10 VENV_DIR=0617_venv IMPORTER=1 USE_IREE=1 ./setup_venv.sh
|
||||
```
|
||||
|
||||
|
||||
### How to use your locally built IREE / Torch-MLIR with SHARK
|
||||
If you are a *Torch-mlir developer or an IREE developer* and want to test local changes you can uninstall
|
||||
the provided packages with `pip uninstall torch-mlir` and / or `pip uninstall iree-compiler iree-runtime` and build locally
|
||||
with Python bindings and set your PYTHONPATH as mentioned [here](https://google.github.io/iree/bindings/python/)
|
||||
with Python bindings and set your PYTHONPATH as mentioned [here](https://github.com/iree-org/iree/tree/main/docs/api_docs/python#install-iree-binaries)
|
||||
for IREE and [here](https://github.com/llvm/torch-mlir/blob/main/development.md#setup-python-environment-to-export-the-built-python-packages)
|
||||
for Torch-MLIR.
|
||||
|
||||
### How to use your locally built Torch-MLIR with SHARK
|
||||
How to use your locally built Torch-MLIR with SHARK:
|
||||
```shell
|
||||
1.) Run `./setup_venv.sh in SHARK` and activate `shark.venv` virtual env.
|
||||
2.) Run `pip uninstall torch-mlir`.
|
||||
@@ -102,82 +237,44 @@ for Torch-MLIR.
|
||||
```
|
||||
Now the SHARK will use your locally build Torch-MLIR repo.
|
||||
|
||||
### Run a demo script
|
||||
```shell
|
||||
python -m shark.examples.shark_inference.resnet50_script --device="cpu" # Use gpu | vulkan
|
||||
# Or a pytest
|
||||
pytest tank/test_models.py -k "MiniLM"
|
||||
|
||||
## Benchmarking Dispatches
|
||||
|
||||
To produce benchmarks of individual dispatches, you can add `--dispatch_benchmarks=All --dispatch_benchmarks_dir=<output_dir>` to your pytest command line argument.
|
||||
If you only want to compile specific dispatches, you can specify them with a space seperated string instead of `"All"`. E.G. `--dispatch_benchmarks="0 1 2 10"`
|
||||
|
||||
For example, to generate and run dispatch benchmarks for MiniLM on CUDA:
|
||||
```
|
||||
pytest -k "MiniLM and torch and static and cuda" --benchmark_dispatches=All -s --dispatch_benchmarks_dir=./my_dispatch_benchmarks
|
||||
```
|
||||
The given command will populate `<dispatch_benchmarks_dir>/<model_name>/` with an `ordered_dispatches.txt` that lists and orders the dispatches and their latencies, as well as folders for each dispatch that contain .mlir, .vmfb, and results of the benchmark for that dispatch.
|
||||
|
||||
if you want to instead incorporate this into a python script, you can pass the `dispatch_benchmarks` and `dispatch_benchmarks_dir` commands when initializing `SharkInference`, and the benchmarks will be generated when compiled. E.G:
|
||||
|
||||
```
|
||||
shark_module = SharkInference(
|
||||
mlir_model,
|
||||
device=args.device,
|
||||
mlir_dialect="tm_tensor",
|
||||
dispatch_benchmarks="all",
|
||||
dispatch_benchmarks_dir="results"
|
||||
)
|
||||
```
|
||||
|
||||
Output will include:
|
||||
- An ordered list ordered-dispatches.txt of all the dispatches with their runtime
|
||||
- Inside the specified directory, there will be a directory for each dispatch (there will be mlir files for all dispatches, but only compiled binaries and benchmark data for the specified dispatches)
|
||||
- An .mlir file containing the dispatch benchmark
|
||||
- A compiled .vmfb file containing the dispatch benchmark
|
||||
- An .mlir file containing just the hal executable
|
||||
- A compiled .vmfb file of the hal executable
|
||||
- A .txt file containing benchmark output
|
||||
|
||||
|
||||
See tank/README.md for further instructions on how to run model tests and benchmarks from the SHARK tank.
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Testing and Benchmarks</summary>
|
||||
|
||||
### Run all model tests on CPU/GPU/VULKAN/Metal
|
||||
```shell
|
||||
pytest tank/test_models.py
|
||||
|
||||
# If on Linux for multithreading on CPU (faster results):
|
||||
pytest tank/test_models.py -n auto
|
||||
```
|
||||
|
||||
### Running specific tests
|
||||
```shell
|
||||
|
||||
# Search for test cases by including a keyword that matches all or part of the test case's name;
|
||||
pytest tank/test_models.py -k "keyword"
|
||||
|
||||
# Test cases are named uniformly by format test_module_<model_name_underscores_only>_<torch/tf>_<static/dynamic>_<device>.
|
||||
|
||||
# Example: Test all models on nvidia gpu:
|
||||
pytest tank/test_models.py -k "cuda"
|
||||
|
||||
# Example: Test all tensorflow resnet models on Vulkan backend:
|
||||
pytest tank/test_models.py -k "resnet and tf and vulkan"
|
||||
|
||||
# Exclude a test case:
|
||||
pytest tank/test_models.py -k "not ..."
|
||||
|
||||
### Run benchmarks on SHARK tank pytests and generate bench_results.csv with results.
|
||||
|
||||
(the following requires source installation with `IMPORTER=1 ./setup_venv.sh`)
|
||||
|
||||
```shell
|
||||
pytest --benchmark tank/test_models.py
|
||||
|
||||
# Just do static GPU benchmarks for PyTorch tests:
|
||||
pytest --benchmark tank/test_models.py -k "pytorch and static and cuda"
|
||||
|
||||
```
|
||||
|
||||
### Benchmark Resnet50, MiniLM on CPU
|
||||
|
||||
(requires source installation with `IMPORTER=1 ./setup_venv.sh`)
|
||||
|
||||
```shell
|
||||
# We suggest running the following commands as root before running benchmarks on CPU:
|
||||
|
||||
cat /sys/devices/system/cpu/cpu*/topology/thread_siblings_list | awk -F, '{print $2}' | sort -n | uniq | ( while read X ; do echo $X ; echo 0 > /sys/devices/system/cpu/cpu$X/online ; done )
|
||||
echo 1 > /sys/devices/system/cpu/intel_pstate/no_turbo
|
||||
|
||||
# Benchmark canonical Resnet50 on CPU via pytest
|
||||
pytest --benchmark tank/test_models -k "resnet50 and tf_static_cpu"
|
||||
|
||||
# Benchmark canonical MiniLM on CPU via pytest
|
||||
pytest --benchmark tank/test_models -k "MiniLM and cpu"
|
||||
|
||||
# Benchmark MiniLM on CPU via transformer-benchmarks:
|
||||
git clone --recursive https://github.com/nod-ai/transformer-benchmarks.git
|
||||
cd transformer-benchmarks
|
||||
./perf-ci.sh -n
|
||||
# Check detail.csv for MLIR/IREE results.
|
||||
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>API Reference</summary>
|
||||
|
||||
@@ -199,7 +296,7 @@ torch_mlir, func_name = mlir_importer.import_mlir(tracing_required=True)
|
||||
# SharkInference accepts mlir in linalg, mhlo, and tosa dialect.
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
shark_module = SharkInference(torch_mlir, func_name, device="cpu", mlir_dialect="linalg")
|
||||
shark_module = SharkInference(torch_mlir, device="cpu", mlir_dialect="linalg")
|
||||
shark_module.compile()
|
||||
result = shark_module.forward((input))
|
||||
|
||||
@@ -222,166 +319,37 @@ mhlo_ir = r"""builtin.module {
|
||||
|
||||
arg0 = np.ones((1, 4)).astype(np.float32)
|
||||
arg1 = np.ones((4, 1)).astype(np.float32)
|
||||
shark_module = SharkInference(mhlo_ir, func_name="forward", device="cpu", mlir_dialect="mhlo")
|
||||
shark_module = SharkInference(mhlo_ir, device="cpu", mlir_dialect="mhlo")
|
||||
shark_module.compile()
|
||||
result = shark_module.forward((arg0, arg1))
|
||||
```
|
||||
</details>
|
||||
|
||||
## Examples Using the REST API
|
||||
|
||||
* [Setting up SHARK for use with Blender](./docs/shark_sd_blender.md)
|
||||
* [Setting up SHARK for use with Koboldcpp](./docs/shark_sd_koboldcpp.md)
|
||||
|
||||
## Supported and Validated Models
|
||||
|
||||
<details>
|
||||
<summary>PyTorch Models</summary>
|
||||
SHARK is maintained to support the latest innovations in ML Models:
|
||||
|
||||
### Huggingface PyTorch Models
|
||||
| TF HuggingFace Models | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------|----------|-------------|
|
||||
| BERT | :green_heart: | :green_heart: | :green_heart: |
|
||||
| DistilBERT | :green_heart: | :green_heart: | :green_heart: |
|
||||
| GPT2 | :green_heart: | :green_heart: | :green_heart: |
|
||||
| BLOOM | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Stable Diffusion | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Vision Transformer | :green_heart: | :green_heart: | :green_heart: |
|
||||
| ResNet50 | :green_heart: | :green_heart: | :green_heart: |
|
||||
|
||||
| Hugging Face Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------------------|----------|----------|-------------|
|
||||
| BERT | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Albert | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| BigBird | :green_heart: (AOT) | | | |
|
||||
| DistilBERT | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| GPT2 | :broken_heart: (AOT) | | | |
|
||||
| MobileBert | :green_heart: (JIT) | :green_heart: | :green_heart: | :green_heart: |
|
||||
For a complete list of the models supported in SHARK, please refer to [tank/README.md](https://github.com/nod-ai/SHARK/blob/main/tank/README.md).
|
||||
|
||||
### Torchvision Models
|
||||
## Communication Channels
|
||||
|
||||
| TORCHVISION Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|--------------------|----------------------|----------|----------|-------------|
|
||||
| AlexNet | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| DenseNet121 | :green_heart: (Script) | | | |
|
||||
| MNasNet1_0 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| MobileNetV2 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| MobileNetV3 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Unet | :broken_heart: (Script) | | | |
|
||||
| Resnet18 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Resnet50 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Resnet101 | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Resnext50_32x4d | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| ShuffleNet_v2 | :broken_heart: (Script) | | | |
|
||||
| SqueezeNet | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| EfficientNet | :green_heart: (Script) | | | |
|
||||
| Regnet | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Resnest | :broken_heart: (Script) | | | |
|
||||
| Vision Transformer | :green_heart: (Script) | | | |
|
||||
| VGG 16 | :green_heart: (Script) | :green_heart: | :green_heart: | |
|
||||
| Wide Resnet | :green_heart: (Script) | :green_heart: | :green_heart: | :green_heart: |
|
||||
| RAFT | :broken_heart: (JIT) | | | |
|
||||
|
||||
For more information refer to [MODEL TRACKING SHEET](https://docs.google.com/spreadsheets/d/15PcjKeHZIrB5LfDyuw7DGEEE8XnQEX2aX8lm8qbxV8A/edit#gid=0)
|
||||
|
||||
### PyTorch Training Models
|
||||
|
||||
| Models | Torch-MLIR lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------------------|----------|----------|-------------|
|
||||
| BERT | :broken_heart: | :broken_heart: | | |
|
||||
| FullyConnected | :green_heart: | :green_heart: | | |
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>JAX Models</summary>
|
||||
|
||||
|
||||
### JAX Models
|
||||
|
||||
| Models | JAX-MHLO lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------------------|----------|----------|-------------|
|
||||
| DALL-E | :broken_heart: | :broken_heart: | | |
|
||||
| FullyConnected | :green_heart: | :green_heart: | | |
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>TFLite Models</summary>
|
||||
|
||||
### TFLite Models
|
||||
|
||||
| Models | TOSA/LinAlg | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------------------|----------|----------|-------------|
|
||||
| BERT | :broken_heart: | :broken_heart: | | |
|
||||
| FullyConnected | :green_heart: | :green_heart: | | |
|
||||
| albert | :green_heart: | :green_heart: | | |
|
||||
| asr_conformer | :green_heart: | :green_heart: | | |
|
||||
| bird_classifier | :green_heart: | :green_heart: | | |
|
||||
| cartoon_gan | :green_heart: | :green_heart: | | |
|
||||
| craft_text | :green_heart: | :green_heart: | | |
|
||||
| deeplab_v3 | :green_heart: | :green_heart: | | |
|
||||
| densenet | :green_heart: | :green_heart: | | |
|
||||
| east_text_detector | :green_heart: | :green_heart: | | |
|
||||
| efficientnet_lite0_int8 | :green_heart: | :green_heart: | | |
|
||||
| efficientnet | :green_heart: | :green_heart: | | |
|
||||
| gpt2 | :green_heart: | :green_heart: | | |
|
||||
| image_stylization | :green_heart: | :green_heart: | | |
|
||||
| inception_v4 | :green_heart: | :green_heart: | | |
|
||||
| inception_v4_uint8 | :green_heart: | :green_heart: | | |
|
||||
| lightning_fp16 | :green_heart: | :green_heart: | | |
|
||||
| lightning_i8 | :green_heart: | :green_heart: | | |
|
||||
| lightning | :green_heart: | :green_heart: | | |
|
||||
| magenta | :green_heart: | :green_heart: | | |
|
||||
| midas | :green_heart: | :green_heart: | | |
|
||||
| mirnet | :green_heart: | :green_heart: | | |
|
||||
| mnasnet | :green_heart: | :green_heart: | | |
|
||||
| mobilebert_edgetpu_s_float | :green_heart: | :green_heart: | | |
|
||||
| mobilebert_edgetpu_s_quant | :green_heart: | :green_heart: | | |
|
||||
| mobilebert | :green_heart: | :green_heart: | | |
|
||||
| mobilebert_tf2_float | :green_heart: | :green_heart: | | |
|
||||
| mobilebert_tf2_quant | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_ssd_quant | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_v1 | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_v1_uint8 | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_v2_int8 | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_v2 | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_v2_uint8 | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_v3-large | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_v3-large_uint8 | :green_heart: | :green_heart: | | |
|
||||
| mobilenet_v35-int8 | :green_heart: | :green_heart: | | |
|
||||
| nasnet | :green_heart: | :green_heart: | | |
|
||||
| person_detect | :green_heart: | :green_heart: | | |
|
||||
| posenet | :green_heart: | :green_heart: | | |
|
||||
| resnet_50_int8 | :green_heart: | :green_heart: | | |
|
||||
| rosetta | :green_heart: | :green_heart: | | |
|
||||
| spice | :green_heart: | :green_heart: | | |
|
||||
| squeezenet | :green_heart: | :green_heart: | | |
|
||||
| ssd_mobilenet_v1 | :green_heart: | :green_heart: | | |
|
||||
| ssd_mobilenet_v1_uint8 | :green_heart: | :green_heart: | | |
|
||||
| ssd_mobilenet_v2_fpnlite | :green_heart: | :green_heart: | | |
|
||||
| ssd_mobilenet_v2_fpnlite_uint8 | :green_heart: | :green_heart: | | |
|
||||
| ssd_mobilenet_v2_int8 | :green_heart: | :green_heart: | | |
|
||||
| ssd_mobilenet_v2 | :green_heart: | :green_heart: | | |
|
||||
| ssd_spaghettinet_large | :green_heart: | :green_heart: | | |
|
||||
| ssd_spaghettinet_large_uint8 | :green_heart: | :green_heart: | | |
|
||||
| visual_wake_words_i8 | :green_heart: | :green_heart: | | |
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>TF Models</summary>
|
||||
|
||||
### Tensorflow Models (Inference)
|
||||
|
||||
| Hugging Face Models | tf-mhlo lowerable | SHARK-CPU | SHARK-CUDA | SHARK-METAL |
|
||||
|---------------------|----------------------|----------|----------|-------------|
|
||||
| BERT | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| albert-base-v2 | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| DistilBERT | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| CamemBert | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| ConvBert | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| Deberta | | | | |
|
||||
| electra | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| funnel | | | | |
|
||||
| layoutlm | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| longformer | | | | |
|
||||
| mobile-bert | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| remembert | | | | |
|
||||
| tapas | | | | |
|
||||
| flaubert | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| roberta | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| xlm-roberta | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
| mpnet | :green_heart: | :green_heart: | :green_heart: | :green_heart: |
|
||||
|
||||
</details>
|
||||
* [SHARK Discord server](https://discord.gg/RUqY2h2s9u): Real time discussions with the SHARK team and other users
|
||||
* [GitHub issues](https://github.com/nod-ai/SHARK/issues): Feature requests, bugs etc
|
||||
|
||||
## Related Projects
|
||||
|
||||
|
||||
16
apps/language_models/README.md
Normal file
16
apps/language_models/README.md
Normal file
@@ -0,0 +1,16 @@
|
||||
## CodeGen Setup using SHARK-server
|
||||
|
||||
### Setup Server
|
||||
- clone SHARK and setup the venv
|
||||
- host the server using `python apps/stable_diffusion/web/index.py --api --server_port=<PORT>`
|
||||
- default server address is `http://0.0.0.0:8080`
|
||||
|
||||
### Setup Client
|
||||
1. fauxpilot-vscode (VSCode Extension):
|
||||
- Code for the extension can be found [here](https://github.com/Venthe/vscode-fauxpilot)
|
||||
- PreReq: VSCode extension (will need [`nodejs` and `npm`](https://nodejs.org/en/download) to compile and run the extension)
|
||||
- Compile and Run the extension on VSCode (press F5 on VSCode), this opens a new VSCode window with the extension running
|
||||
- Open VSCode settings, search for fauxpilot in settings and modify `server : http://<IP>:<PORT>`, `Model : codegen` , `Max Lines : 30`
|
||||
|
||||
2. Others (REST API curl, OpenAI Python bindings) as shown [here](https://github.com/fauxpilot/fauxpilot/blob/main/documentation/client.md)
|
||||
- using Github Copilot VSCode extension with SHARK-server needs more work to be functional.
|
||||
18
apps/language_models/langchain/README.md
Normal file
18
apps/language_models/langchain/README.md
Normal file
@@ -0,0 +1,18 @@
|
||||
# Langchain
|
||||
|
||||
## How to run the model
|
||||
|
||||
1.) Install all the dependencies by running:
|
||||
```shell
|
||||
pip install -r apps/language_models/langchain/langchain_requirements.txt
|
||||
sudo apt-get install -y libmagic-dev poppler-utils tesseract-ocr libtesseract-dev libreoffice
|
||||
```
|
||||
|
||||
2.) Create a folder named `user_path` in `apps/language_models/langchain/` directory.
|
||||
|
||||
Now, you are ready to use the model.
|
||||
|
||||
3.) To run the model, run the following command:
|
||||
```shell
|
||||
python apps/language_models/langchain/gen.py --cli=True
|
||||
```
|
||||
186
apps/language_models/langchain/cli.py
Normal file
186
apps/language_models/langchain/cli.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import copy
|
||||
import torch
|
||||
|
||||
from evaluate_params import eval_func_param_names
|
||||
from gen import Langchain
|
||||
from prompter import non_hf_types
|
||||
from utils import clear_torch_cache, NullContext, get_kwargs
|
||||
|
||||
|
||||
def run_cli( # for local function:
|
||||
base_model=None,
|
||||
lora_weights=None,
|
||||
inference_server=None,
|
||||
debug=None,
|
||||
chat_context=None,
|
||||
examples=None,
|
||||
memory_restriction_level=None,
|
||||
# for get_model:
|
||||
score_model=None,
|
||||
load_8bit=None,
|
||||
load_4bit=None,
|
||||
load_half=None,
|
||||
load_gptq=None,
|
||||
use_safetensors=None,
|
||||
infer_devices=None,
|
||||
tokenizer_base_model=None,
|
||||
gpu_id=None,
|
||||
local_files_only=None,
|
||||
resume_download=None,
|
||||
use_auth_token=None,
|
||||
trust_remote_code=None,
|
||||
offload_folder=None,
|
||||
compile_model=None,
|
||||
# for some evaluate args
|
||||
stream_output=None,
|
||||
prompt_type=None,
|
||||
prompt_dict=None,
|
||||
temperature=None,
|
||||
top_p=None,
|
||||
top_k=None,
|
||||
num_beams=None,
|
||||
max_new_tokens=None,
|
||||
min_new_tokens=None,
|
||||
early_stopping=None,
|
||||
max_time=None,
|
||||
repetition_penalty=None,
|
||||
num_return_sequences=None,
|
||||
do_sample=None,
|
||||
chat=None,
|
||||
langchain_mode=None,
|
||||
langchain_action=None,
|
||||
document_choice=None,
|
||||
top_k_docs=None,
|
||||
chunk=None,
|
||||
chunk_size=None,
|
||||
# for evaluate kwargs
|
||||
src_lang=None,
|
||||
tgt_lang=None,
|
||||
concurrency_count=None,
|
||||
save_dir=None,
|
||||
sanitize_bot_response=None,
|
||||
model_state0=None,
|
||||
max_max_new_tokens=None,
|
||||
is_public=None,
|
||||
max_max_time=None,
|
||||
raise_generate_gpu_exceptions=None,
|
||||
load_db_if_exists=None,
|
||||
dbs=None,
|
||||
user_path=None,
|
||||
detect_user_path_changes_every_query=None,
|
||||
use_openai_embedding=None,
|
||||
use_openai_model=None,
|
||||
hf_embedding_model=None,
|
||||
db_type=None,
|
||||
n_jobs=None,
|
||||
first_para=None,
|
||||
text_limit=None,
|
||||
verbose=None,
|
||||
cli=None,
|
||||
reverse_docs=None,
|
||||
use_cache=None,
|
||||
auto_reduce_chunks=None,
|
||||
max_chunks=None,
|
||||
model_lock=None,
|
||||
force_langchain_evaluate=None,
|
||||
model_state_none=None,
|
||||
# unique to this function:
|
||||
cli_loop=None,
|
||||
):
|
||||
Langchain.check_locals(**locals())
|
||||
|
||||
score_model = "" # FIXME: For now, so user doesn't have to pass
|
||||
n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
|
||||
device = "cpu" if n_gpus == 0 else "cuda"
|
||||
context_class = NullContext if n_gpus > 1 or n_gpus == 0 else torch.device
|
||||
|
||||
with context_class(device):
|
||||
from functools import partial
|
||||
|
||||
# get score model
|
||||
smodel, stokenizer, sdevice = Langchain.get_score_model(
|
||||
reward_type=True,
|
||||
**get_kwargs(
|
||||
Langchain.get_score_model,
|
||||
exclude_names=["reward_type"],
|
||||
**locals()
|
||||
)
|
||||
)
|
||||
|
||||
model, tokenizer, device = Langchain.get_model(
|
||||
reward_type=False,
|
||||
**get_kwargs(
|
||||
Langchain.get_model, exclude_names=["reward_type"], **locals()
|
||||
)
|
||||
)
|
||||
model_dict = dict(
|
||||
base_model=base_model,
|
||||
tokenizer_base_model=tokenizer_base_model,
|
||||
lora_weights=lora_weights,
|
||||
inference_server=inference_server,
|
||||
prompt_type=prompt_type,
|
||||
prompt_dict=prompt_dict,
|
||||
)
|
||||
model_state = dict(model=model, tokenizer=tokenizer, device=device)
|
||||
model_state.update(model_dict)
|
||||
my_db_state = [None]
|
||||
fun = partial(
|
||||
Langchain.evaluate,
|
||||
model_state,
|
||||
my_db_state,
|
||||
**get_kwargs(
|
||||
Langchain.evaluate,
|
||||
exclude_names=["model_state", "my_db_state"]
|
||||
+ eval_func_param_names,
|
||||
**locals()
|
||||
)
|
||||
)
|
||||
|
||||
example1 = examples[-1] # pick reference example
|
||||
all_generations = []
|
||||
while True:
|
||||
clear_torch_cache()
|
||||
instruction = input("\nEnter an instruction: ")
|
||||
if instruction == "exit":
|
||||
break
|
||||
|
||||
eval_vars = copy.deepcopy(example1)
|
||||
eval_vars[eval_func_param_names.index("instruction")] = eval_vars[
|
||||
eval_func_param_names.index("instruction_nochat")
|
||||
] = instruction
|
||||
eval_vars[eval_func_param_names.index("iinput")] = eval_vars[
|
||||
eval_func_param_names.index("iinput_nochat")
|
||||
] = "" # no input yet
|
||||
eval_vars[eval_func_param_names.index("context")] = (
|
||||
"" # no context yet
|
||||
)
|
||||
|
||||
# grab other parameters, like langchain_mode
|
||||
for k in eval_func_param_names:
|
||||
if k in locals():
|
||||
eval_vars[eval_func_param_names.index(k)] = locals()[k]
|
||||
|
||||
gener = fun(*tuple(eval_vars))
|
||||
outr = ""
|
||||
res_old = ""
|
||||
for gen_output in gener:
|
||||
res = gen_output["response"]
|
||||
extra = gen_output["sources"]
|
||||
if base_model not in non_hf_types or base_model in ["llama"]:
|
||||
if not stream_output:
|
||||
print(res)
|
||||
else:
|
||||
# then stream output for gradio that has full output each generation, so need here to show only new chars
|
||||
diff = res[len(res_old) :]
|
||||
print(diff, end="", flush=True)
|
||||
res_old = res
|
||||
outr = res # don't accumulate
|
||||
else:
|
||||
outr += res # just is one thing
|
||||
if extra:
|
||||
# show sources at end after model itself had streamed to std rest of response
|
||||
print(extra, flush=True)
|
||||
all_generations.append(outr + "\n")
|
||||
if not cli_loop:
|
||||
break
|
||||
return all_generations
|
||||
2189
apps/language_models/langchain/create_data.py
Normal file
2189
apps/language_models/langchain/create_data.py
Normal file
File diff suppressed because it is too large
Load Diff
103
apps/language_models/langchain/enums.py
Normal file
103
apps/language_models/langchain/enums.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class PromptType(Enum):
|
||||
custom = -1
|
||||
plain = 0
|
||||
instruct = 1
|
||||
quality = 2
|
||||
human_bot = 3
|
||||
dai_faq = 4
|
||||
summarize = 5
|
||||
simple_instruct = 6
|
||||
instruct_vicuna = 7
|
||||
instruct_with_end = 8
|
||||
human_bot_orig = 9
|
||||
prompt_answer = 10
|
||||
open_assistant = 11
|
||||
wizard_lm = 12
|
||||
wizard_mega = 13
|
||||
instruct_vicuna2 = 14
|
||||
instruct_vicuna3 = 15
|
||||
wizard2 = 16
|
||||
wizard3 = 17
|
||||
instruct_simple = 18
|
||||
wizard_vicuna = 19
|
||||
openai = 20
|
||||
openai_chat = 21
|
||||
gptj = 22
|
||||
prompt_answer_openllama = 23
|
||||
vicuna11 = 24
|
||||
mptinstruct = 25
|
||||
mptchat = 26
|
||||
falcon = 27
|
||||
|
||||
|
||||
class DocumentChoices(Enum):
|
||||
All_Relevant = 0
|
||||
All_Relevant_Only_Sources = 1
|
||||
Only_All_Sources = 2
|
||||
Just_LLM = 3
|
||||
|
||||
|
||||
non_query_commands = [
|
||||
DocumentChoices.All_Relevant_Only_Sources.name,
|
||||
DocumentChoices.Only_All_Sources.name,
|
||||
]
|
||||
|
||||
|
||||
class LangChainMode(Enum):
|
||||
"""LangChain mode"""
|
||||
|
||||
DISABLED = "Disabled"
|
||||
CHAT_LLM = "ChatLLM"
|
||||
LLM = "LLM"
|
||||
ALL = "All"
|
||||
WIKI = "wiki"
|
||||
WIKI_FULL = "wiki_full"
|
||||
USER_DATA = "UserData"
|
||||
MY_DATA = "MyData"
|
||||
GITHUB_H2OGPT = "github h2oGPT"
|
||||
H2O_DAI_DOCS = "DriverlessAI docs"
|
||||
|
||||
|
||||
class LangChainAction(Enum):
|
||||
"""LangChain action"""
|
||||
|
||||
QUERY = "Query"
|
||||
# WIP:
|
||||
# SUMMARIZE_MAP = "Summarize_map_reduce"
|
||||
SUMMARIZE_MAP = "Summarize"
|
||||
SUMMARIZE_ALL = "Summarize_all"
|
||||
SUMMARIZE_REFINE = "Summarize_refine"
|
||||
|
||||
|
||||
no_server_str = no_lora_str = no_model_str = "[None/Remove]"
|
||||
|
||||
# from site-packages/langchain/llms/openai.py
|
||||
# but needed since ChatOpenAI doesn't have this information
|
||||
model_token_mapping = {
|
||||
"gpt-4": 8192,
|
||||
"gpt-4-0314": 8192,
|
||||
"gpt-4-32k": 32768,
|
||||
"gpt-4-32k-0314": 32768,
|
||||
"gpt-3.5-turbo": 4096,
|
||||
"gpt-3.5-turbo-16k": 16 * 1024,
|
||||
"gpt-3.5-turbo-0301": 4096,
|
||||
"text-ada-001": 2049,
|
||||
"ada": 2049,
|
||||
"text-babbage-001": 2040,
|
||||
"babbage": 2049,
|
||||
"text-curie-001": 2049,
|
||||
"curie": 2049,
|
||||
"davinci": 2049,
|
||||
"text-davinci-003": 4097,
|
||||
"text-davinci-002": 4097,
|
||||
"code-davinci-002": 8001,
|
||||
"code-davinci-001": 8001,
|
||||
"code-cushman-002": 2048,
|
||||
"code-cushman-001": 2048,
|
||||
}
|
||||
|
||||
source_prefix = "Sources [Score | Link]:"
|
||||
source_postfix = "End Sources<p>"
|
||||
53
apps/language_models/langchain/evaluate_params.py
Normal file
53
apps/language_models/langchain/evaluate_params.py
Normal file
@@ -0,0 +1,53 @@
|
||||
no_default_param_names = [
|
||||
"instruction",
|
||||
"iinput",
|
||||
"context",
|
||||
"instruction_nochat",
|
||||
"iinput_nochat",
|
||||
]
|
||||
|
||||
gen_hyper = [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"num_beams",
|
||||
"max_new_tokens",
|
||||
"min_new_tokens",
|
||||
"early_stopping",
|
||||
"max_time",
|
||||
"repetition_penalty",
|
||||
"num_return_sequences",
|
||||
"do_sample",
|
||||
]
|
||||
|
||||
eval_func_param_names = (
|
||||
[
|
||||
"instruction",
|
||||
"iinput",
|
||||
"context",
|
||||
"stream_output",
|
||||
"prompt_type",
|
||||
"prompt_dict",
|
||||
]
|
||||
+ gen_hyper
|
||||
+ [
|
||||
"chat",
|
||||
"instruction_nochat",
|
||||
"iinput_nochat",
|
||||
"langchain_mode",
|
||||
"langchain_action",
|
||||
"top_k_docs",
|
||||
"chunk",
|
||||
"chunk_size",
|
||||
"document_choice",
|
||||
]
|
||||
)
|
||||
|
||||
# form evaluate defaults for submit_nochat_api
|
||||
eval_func_param_names_defaults = eval_func_param_names.copy()
|
||||
for k in no_default_param_names:
|
||||
if k in eval_func_param_names_defaults:
|
||||
eval_func_param_names_defaults.remove(k)
|
||||
|
||||
|
||||
eval_extra_columns = ["prompt", "response", "score"]
|
||||
846
apps/language_models/langchain/expanded_pipelines.py
Normal file
846
apps/language_models/langchain/expanded_pipelines.py
Normal file
@@ -0,0 +1,846 @@
|
||||
from __future__ import annotations
|
||||
from typing import (
|
||||
Any,
|
||||
Mapping,
|
||||
Optional,
|
||||
Dict,
|
||||
List,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
Protocol,
|
||||
)
|
||||
import inspect
|
||||
import json
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
from abc import ABC, abstractmethod
|
||||
import langchain
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.question_answering import stuff_prompt
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManager,
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
|
||||
from langchain.input import get_colored_text
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import LLMResult, PromptValue
|
||||
from pydantic import Extra, Field, root_validator, validator
|
||||
|
||||
|
||||
def _get_verbosity() -> bool:
|
||||
return langchain.verbose
|
||||
|
||||
|
||||
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
|
||||
"""Format a document into a string based on a prompt template."""
|
||||
base_info = {"page_content": doc.page_content}
|
||||
base_info.update(doc.metadata)
|
||||
missing_metadata = set(prompt.input_variables).difference(base_info)
|
||||
if len(missing_metadata) > 0:
|
||||
required_metadata = [
|
||||
iv for iv in prompt.input_variables if iv != "page_content"
|
||||
]
|
||||
raise ValueError(
|
||||
f"Document prompt requires documents to have metadata variables: "
|
||||
f"{required_metadata}. Received document with missing metadata: "
|
||||
f"{list(missing_metadata)}."
|
||||
)
|
||||
document_info = {k: base_info[k] for k in prompt.input_variables}
|
||||
return prompt.format(**document_info)
|
||||
|
||||
|
||||
class Chain(Serializable, ABC):
|
||||
"""Base interface that all chains should implement."""
|
||||
|
||||
memory: Optional[BaseMemory] = None
|
||||
callbacks: Callbacks = Field(default=None, exclude=True)
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(
|
||||
default=None, exclude=True
|
||||
)
|
||||
verbose: bool = Field(
|
||||
default_factory=_get_verbosity
|
||||
) # Whether to print the response text
|
||||
tags: Optional[List[str]] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
raise NotImplementedError("Saving not supported for this chain type.")
|
||||
|
||||
@root_validator()
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
"""Raise deprecation warning if callback_manager is used."""
|
||||
if values.get("callback_manager") is not None:
|
||||
warnings.warn(
|
||||
"callback_manager is deprecated. Please use callbacks instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
values["callbacks"] = values.pop("callback_manager", None)
|
||||
return values
|
||||
|
||||
@validator("verbose", pre=True, always=True)
|
||||
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
||||
"""If verbose is None, set it.
|
||||
|
||||
This allows users to pass in None as verbose to access the global setting.
|
||||
"""
|
||||
if verbose is None:
|
||||
return _get_verbosity()
|
||||
else:
|
||||
return verbose
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Input keys this chain expects."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Output keys this chain expects."""
|
||||
|
||||
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||
"""Check that all inputs are present."""
|
||||
missing_keys = set(self.input_keys).difference(inputs)
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing some input keys: {missing_keys}")
|
||||
|
||||
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
|
||||
missing_keys = set(self.output_keys).difference(outputs)
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing some output keys: {missing_keys}")
|
||||
|
||||
@abstractmethod
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and return the output."""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
return_only_outputs: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
include_run_info: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and add to output if desired.
|
||||
|
||||
Args:
|
||||
inputs: Dictionary of inputs, or single input if chain expects
|
||||
only one param.
|
||||
return_only_outputs: boolean for whether to return only outputs in the
|
||||
response. If True, only new keys generated by this chain will be
|
||||
returned. If False, both input keys and new keys generated by this
|
||||
chain will be returned. Defaults to False.
|
||||
callbacks: Callbacks to use for this chain run. If not provided, will
|
||||
use the callbacks provided to the chain.
|
||||
include_run_info: Whether to include run info in the response. Defaults
|
||||
to False.
|
||||
"""
|
||||
input_docs = inputs["input_documents"]
|
||||
missing_keys = set(self.input_keys).difference(inputs)
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing some input keys: {missing_keys}")
|
||||
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose, tags, self.tags
|
||||
)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
inputs,
|
||||
)
|
||||
|
||||
if "is_first" in inputs.keys() and not inputs["is_first"]:
|
||||
run_manager_ = run_manager
|
||||
input_list = [inputs]
|
||||
stop = None
|
||||
prompts = []
|
||||
for inputs in input_list:
|
||||
selected_inputs = {
|
||||
k: inputs[k] for k in self.prompt.input_variables
|
||||
}
|
||||
prompt = self.prompt.format_prompt(**selected_inputs)
|
||||
_colored_text = get_colored_text(prompt.to_string(), "green")
|
||||
_text = "Prompt after formatting:\n" + _colored_text
|
||||
if run_manager_:
|
||||
run_manager_.on_text(_text, end="\n", verbose=self.verbose)
|
||||
if "stop" in inputs and inputs["stop"] != stop:
|
||||
raise ValueError(
|
||||
"If `stop` is present in any inputs, should be present in all."
|
||||
)
|
||||
prompts.append(prompt)
|
||||
|
||||
prompt_strings = [p.to_string() for p in prompts]
|
||||
prompts = prompt_strings
|
||||
callbacks = run_manager_.get_child() if run_manager_ else None
|
||||
tags = None
|
||||
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
# If string is passed in directly no errors will be raised but outputs will
|
||||
# not make sense.
|
||||
if not isinstance(prompts, list):
|
||||
raise ValueError(
|
||||
"Argument 'prompts' is expected to be of type List[str], received"
|
||||
f" argument of type {type(prompts)}."
|
||||
)
|
||||
params = self.llm.dict()
|
||||
params["stop"] = stop
|
||||
options = {"stop": stop}
|
||||
disregard_cache = self.llm.cache is not None and not self.llm.cache
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks,
|
||||
self.llm.callbacks,
|
||||
self.llm.verbose,
|
||||
tags,
|
||||
self.llm.tags,
|
||||
)
|
||||
if langchain.llm_cache is None or disregard_cache:
|
||||
# This happens when langchain.cache is None, but self.cache is True
|
||||
if self.llm.cache is not None and self.cache:
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
run_manager_ = callback_manager.on_llm_start(
|
||||
dumpd(self),
|
||||
prompts,
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
)
|
||||
|
||||
generations = []
|
||||
for prompt in prompts:
|
||||
inputs_ = prompt
|
||||
num_workers = None
|
||||
batch_size = None
|
||||
|
||||
if num_workers is None:
|
||||
if self.llm.pipeline._num_workers is None:
|
||||
num_workers = 0
|
||||
else:
|
||||
num_workers = self.llm.pipeline._num_workers
|
||||
if batch_size is None:
|
||||
if self.llm.pipeline._batch_size is None:
|
||||
batch_size = 1
|
||||
else:
|
||||
batch_size = self.llm.pipeline._batch_size
|
||||
|
||||
preprocess_params = {}
|
||||
generate_kwargs = {}
|
||||
preprocess_params.update(generate_kwargs)
|
||||
forward_params = generate_kwargs
|
||||
postprocess_params = {}
|
||||
# Fuse __init__ params and __call__ params without modifying the __init__ ones.
|
||||
preprocess_params = {
|
||||
**self.llm.pipeline._preprocess_params,
|
||||
**preprocess_params,
|
||||
}
|
||||
forward_params = {
|
||||
**self.llm.pipeline._forward_params,
|
||||
**forward_params,
|
||||
}
|
||||
postprocess_params = {
|
||||
**self.llm.pipeline._postprocess_params,
|
||||
**postprocess_params,
|
||||
}
|
||||
|
||||
self.llm.pipeline.call_count += 1
|
||||
if (
|
||||
self.llm.pipeline.call_count > 10
|
||||
and self.llm.pipeline.framework == "pt"
|
||||
and self.llm.pipeline.device.type == "cuda"
|
||||
):
|
||||
warnings.warn(
|
||||
"You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a"
|
||||
" dataset",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
model_inputs = self.llm.pipeline.preprocess(
|
||||
inputs_, **preprocess_params
|
||||
)
|
||||
model_outputs = self.llm.pipeline.forward(
|
||||
model_inputs, **forward_params
|
||||
)
|
||||
model_outputs["process"] = False
|
||||
return model_outputs
|
||||
output = LLMResult(generations=generations)
|
||||
run_manager_.on_llm_end(output)
|
||||
if run_manager_:
|
||||
output.run = RunInfo(run_id=run_manager_.run_id)
|
||||
response = output
|
||||
|
||||
outputs = [
|
||||
# Get the text of the top generated string.
|
||||
{self.output_key: generation[0].text}
|
||||
for generation in response.generations
|
||||
][0]
|
||||
run_manager.on_chain_end(outputs)
|
||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||
inputs, outputs, return_only_outputs
|
||||
)
|
||||
if include_run_info:
|
||||
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||
return final_outputs
|
||||
else:
|
||||
_run_manager = (
|
||||
run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
)
|
||||
docs = inputs[self.input_key]
|
||||
# Other keys are assumed to be needed for LLM prediction
|
||||
other_keys = {
|
||||
k: v for k, v in inputs.items() if k != self.input_key
|
||||
}
|
||||
doc_strings = [
|
||||
format_document(doc, self.document_prompt) for doc in docs
|
||||
]
|
||||
# Join the documents together to put them in the prompt.
|
||||
inputs = {
|
||||
k: v
|
||||
for k, v in other_keys.items()
|
||||
if k in self.llm_chain.prompt.input_variables
|
||||
}
|
||||
inputs[self.document_variable_name] = self.document_separator.join(
|
||||
doc_strings
|
||||
)
|
||||
inputs["is_first"] = False
|
||||
inputs["input_documents"] = input_docs
|
||||
|
||||
# Call predict on the LLM.
|
||||
output = self.llm_chain(inputs, callbacks=_run_manager.get_child())
|
||||
if "process" in output.keys() and not output["process"]:
|
||||
return output
|
||||
output = output[self.llm_chain.output_key]
|
||||
extra_return_dict = {}
|
||||
extra_return_dict[self.output_key] = output
|
||||
outputs = extra_return_dict
|
||||
run_manager.on_chain_end(outputs)
|
||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||
inputs, outputs, return_only_outputs
|
||||
)
|
||||
if include_run_info:
|
||||
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||
return final_outputs
|
||||
|
||||
def prep_outputs(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
outputs: Dict[str, str],
|
||||
return_only_outputs: bool = False,
|
||||
) -> Dict[str, str]:
|
||||
"""Validate and prep outputs."""
|
||||
self._validate_outputs(outputs)
|
||||
if self.memory is not None:
|
||||
self.memory.save_context(inputs, outputs)
|
||||
if return_only_outputs:
|
||||
return outputs
|
||||
else:
|
||||
return {**inputs, **outputs}
|
||||
|
||||
def prep_inputs(
|
||||
self, inputs: Union[Dict[str, Any], Any]
|
||||
) -> Dict[str, str]:
|
||||
"""Validate and prep inputs."""
|
||||
if not isinstance(inputs, dict):
|
||||
_input_keys = set(self.input_keys)
|
||||
if self.memory is not None:
|
||||
# If there are multiple input keys, but some get set by memory so that
|
||||
# only one is not set, we can still figure out which key it is.
|
||||
_input_keys = _input_keys.difference(
|
||||
self.memory.memory_variables
|
||||
)
|
||||
if len(_input_keys) != 1:
|
||||
raise ValueError(
|
||||
f"A single string input was passed in, but this chain expects "
|
||||
f"multiple inputs ({_input_keys}). When a chain expects "
|
||||
f"multiple inputs, please call it by passing in a dictionary, "
|
||||
"eg `chain({'foo': 1, 'bar': 2})`"
|
||||
)
|
||||
inputs = {list(_input_keys)[0]: inputs}
|
||||
if self.memory is not None:
|
||||
external_context = self.memory.load_memory_variables(inputs)
|
||||
inputs = dict(inputs, **external_context)
|
||||
self._validate_inputs(inputs)
|
||||
return inputs
|
||||
|
||||
def apply(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Call the chain on all inputs in the list."""
|
||||
return [self(inputs, callbacks=callbacks) for inputs in input_list]
|
||||
|
||||
def run(
|
||||
self,
|
||||
*args: Any,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run the chain as text in, text out or multiple variables, text out."""
|
||||
if len(self.output_keys) != 1:
|
||||
raise ValueError(
|
||||
f"`run` not supported when there is not exactly "
|
||||
f"one output key. Got {self.output_keys}."
|
||||
)
|
||||
|
||||
if args and not kwargs:
|
||||
if len(args) != 1:
|
||||
raise ValueError(
|
||||
"`run` supports only one positional argument."
|
||||
)
|
||||
return self(args[0], callbacks=callbacks, tags=tags)[
|
||||
self.output_keys[0]
|
||||
]
|
||||
|
||||
if kwargs and not args:
|
||||
return self(kwargs, callbacks=callbacks, tags=tags)[
|
||||
self.output_keys[0]
|
||||
]
|
||||
|
||||
if not kwargs and not args:
|
||||
raise ValueError(
|
||||
"`run` supported with either positional arguments or keyword arguments,"
|
||||
" but none were provided."
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"`run` supported with either positional arguments or keyword arguments"
|
||||
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
||||
)
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of chain."""
|
||||
if self.memory is not None:
|
||||
raise ValueError("Saving of memory is not yet supported.")
|
||||
_dict = super().dict()
|
||||
_dict["_type"] = self._chain_type
|
||||
return _dict
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
"""Save the chain.
|
||||
|
||||
Args:
|
||||
file_path: Path to file to save the chain to.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
chain.save(file_path="path/chain.yaml")
|
||||
"""
|
||||
# Convert file to Path object.
|
||||
if isinstance(file_path, str):
|
||||
save_path = Path(file_path)
|
||||
else:
|
||||
save_path = file_path
|
||||
|
||||
directory_path = save_path.parent
|
||||
directory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Fetch dictionary to save
|
||||
chain_dict = self.dict()
|
||||
|
||||
if save_path.suffix == ".json":
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(chain_dict, f, indent=4)
|
||||
elif save_path.suffix == ".yaml":
|
||||
with open(file_path, "w") as f:
|
||||
yaml.dump(chain_dict, f, default_flow_style=False)
|
||||
else:
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
|
||||
|
||||
class BaseCombineDocumentsChain(Chain, ABC):
|
||||
"""Base interface for chains combining documents."""
|
||||
|
||||
input_key: str = "input_documents" #: :meta private:
|
||||
output_key: str = "output_text" #: :meta private:
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def prompt_length(
|
||||
self, docs: List[Document], **kwargs: Any
|
||||
) -> Optional[int]:
|
||||
"""Return the prompt length given the documents passed in.
|
||||
|
||||
Returns None if the method does not depend on the prompt length.
|
||||
"""
|
||||
return None
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, List[Document]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = (
|
||||
run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
)
|
||||
docs = inputs[self.input_key]
|
||||
# Other keys are assumed to be needed for LLM prediction
|
||||
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
|
||||
doc_strings = [
|
||||
format_document(doc, self.document_prompt) for doc in docs
|
||||
]
|
||||
# Join the documents together to put them in the prompt.
|
||||
inputs = {
|
||||
k: v
|
||||
for k, v in other_keys.items()
|
||||
if k in self.llm_chain.prompt.input_variables
|
||||
}
|
||||
inputs[self.document_variable_name] = self.document_separator.join(
|
||||
doc_strings
|
||||
)
|
||||
|
||||
# Call predict on the LLM.
|
||||
output, extra_return_dict = (
|
||||
self.llm_chain(inputs, callbacks=_run_manager.get_child())[
|
||||
self.llm_chain.output_key
|
||||
],
|
||||
{},
|
||||
)
|
||||
|
||||
extra_return_dict[self.output_key] = output
|
||||
return extra_return_dict
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Generation(Serializable):
|
||||
"""Output of a single generation."""
|
||||
|
||||
text: str
|
||||
"""Generated text output."""
|
||||
|
||||
generation_info: Optional[Dict[str, Any]] = None
|
||||
"""Raw generation info response from the provider"""
|
||||
"""May include things like reason for finishing (e.g. in OpenAI)"""
|
||||
# TODO: add log probs
|
||||
|
||||
|
||||
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
|
||||
|
||||
|
||||
class LLMChain(Chain):
|
||||
"""Chain to run queries against LLMs.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import LLMChain, OpenAI, PromptTemplate
|
||||
prompt_template = "Tell me a {adjective} joke"
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["adjective"], template=prompt_template
|
||||
)
|
||||
llm = LLMChain(llm=OpenAI(), prompt=prompt)
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
prompt: BasePromptTemplate
|
||||
"""Prompt object to use."""
|
||||
llm: BaseLanguageModel
|
||||
output_key: str = "text" #: :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Will be whatever keys the prompt expects.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return self.prompt.input_variables
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Will always return text key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
prompts, stop = self.prep_prompts([inputs], run_manager=run_manager)
|
||||
response = self.llm.generate_prompt(
|
||||
prompts,
|
||||
stop,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
)
|
||||
return self.create_outputs(response)[0]
|
||||
|
||||
def prep_prompts(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Tuple[List[PromptValue], Optional[List[str]]]:
|
||||
"""Prepare prompts from inputs."""
|
||||
stop = None
|
||||
if "stop" in input_list[0]:
|
||||
stop = input_list[0]["stop"]
|
||||
prompts = []
|
||||
for inputs in input_list:
|
||||
selected_inputs = {
|
||||
k: inputs[k] for k in self.prompt.input_variables
|
||||
}
|
||||
prompt = self.prompt.format_prompt(**selected_inputs)
|
||||
_colored_text = get_colored_text(prompt.to_string(), "green")
|
||||
_text = "Prompt after formatting:\n" + _colored_text
|
||||
if run_manager:
|
||||
run_manager.on_text(_text, end="\n", verbose=self.verbose)
|
||||
if "stop" in inputs and inputs["stop"] != stop:
|
||||
raise ValueError(
|
||||
"If `stop` is present in any inputs, should be present in all."
|
||||
)
|
||||
prompts.append(prompt)
|
||||
return prompts, stop
|
||||
|
||||
def apply(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Utilize the LLM generate method for speed gains."""
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
{"input_list": input_list},
|
||||
)
|
||||
try:
|
||||
response = self.generate(input_list, run_manager=run_manager)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
outputs = self.create_outputs(response)
|
||||
run_manager.on_chain_end({"outputs": outputs})
|
||||
return outputs
|
||||
|
||||
def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]:
|
||||
"""Create outputs from response."""
|
||||
return [
|
||||
# Get the text of the top generated string.
|
||||
{self.output_key: generation[0].text}
|
||||
for generation in response.generations
|
||||
]
|
||||
|
||||
def predict_and_parse(
|
||||
self, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Union[str, List[str], Dict[str, Any]]:
|
||||
"""Call predict and then parse the results."""
|
||||
result = self.predict(callbacks=callbacks, **kwargs)
|
||||
if self.prompt.output_parser is not None:
|
||||
return self.prompt.output_parser.parse(result)
|
||||
else:
|
||||
return result
|
||||
|
||||
def apply_and_parse(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||
"""Call apply and then parse the results."""
|
||||
result = self.apply(input_list, callbacks=callbacks)
|
||||
return self._parse_result(result)
|
||||
|
||||
def _parse_result(
|
||||
self, result: List[Dict[str, str]]
|
||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||
if self.prompt.output_parser is not None:
|
||||
return [
|
||||
self.prompt.output_parser.parse(res[self.output_key])
|
||||
for res in result
|
||||
]
|
||||
else:
|
||||
return result
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "llm_chain"
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, llm: BaseLanguageModel, template: str) -> LLMChain:
|
||||
"""Create LLMChain from LLM and template."""
|
||||
prompt_template = PromptTemplate.from_template(template)
|
||||
return cls(llm=llm, prompt=prompt_template)
|
||||
|
||||
|
||||
def _get_default_document_prompt() -> PromptTemplate:
|
||||
return PromptTemplate(
|
||||
input_variables=["page_content"], template="{page_content}"
|
||||
)
|
||||
|
||||
|
||||
class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
"""Chain that combines documents by stuffing into context."""
|
||||
|
||||
llm_chain: LLMChain
|
||||
"""LLM wrapper to use after formatting documents."""
|
||||
document_prompt: BasePromptTemplate = Field(
|
||||
default_factory=_get_default_document_prompt
|
||||
)
|
||||
"""Prompt to use to format each document."""
|
||||
document_variable_name: str
|
||||
"""The variable name in the llm_chain to put the documents in.
|
||||
If only one variable in the llm_chain, this need not be provided."""
|
||||
document_separator: str = "\n\n"
|
||||
"""The string with which to join the formatted documents"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def get_default_document_variable_name(cls, values: Dict) -> Dict:
|
||||
"""Get default document variable name, if not provided."""
|
||||
llm_chain_variables = values["llm_chain"].prompt.input_variables
|
||||
if "document_variable_name" not in values:
|
||||
if len(llm_chain_variables) == 1:
|
||||
values["document_variable_name"] = llm_chain_variables[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"document_variable_name must be provided if there are "
|
||||
"multiple llm_chain_variables"
|
||||
)
|
||||
else:
|
||||
if values["document_variable_name"] not in llm_chain_variables:
|
||||
raise ValueError(
|
||||
f"document_variable_name {values['document_variable_name']} was "
|
||||
f"not found in llm_chain input_variables: {llm_chain_variables}"
|
||||
)
|
||||
return values
|
||||
|
||||
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
|
||||
# Format each document according to the prompt
|
||||
doc_strings = [
|
||||
format_document(doc, self.document_prompt) for doc in docs
|
||||
]
|
||||
# Join the documents together to put them in the prompt.
|
||||
inputs = {
|
||||
k: v
|
||||
for k, v in kwargs.items()
|
||||
if k in self.llm_chain.prompt.input_variables
|
||||
}
|
||||
inputs[self.document_variable_name] = self.document_separator.join(
|
||||
doc_strings
|
||||
)
|
||||
return inputs
|
||||
|
||||
def prompt_length(
|
||||
self, docs: List[Document], **kwargs: Any
|
||||
) -> Optional[int]:
|
||||
"""Get the prompt length by formatting the prompt."""
|
||||
inputs = self._get_inputs(docs, **kwargs)
|
||||
prompt = self.llm_chain.prompt.format(**inputs)
|
||||
return self.llm_chain.llm.get_num_tokens(prompt)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "stuff_documents_chain"
|
||||
|
||||
|
||||
class LoadingCallable(Protocol):
|
||||
"""Interface for loading the combine documents chain."""
|
||||
|
||||
def __call__(
|
||||
self, llm: BaseLanguageModel, **kwargs: Any
|
||||
) -> BaseCombineDocumentsChain:
|
||||
"""Callable to load the combine documents chain."""
|
||||
|
||||
|
||||
def _load_stuff_chain(
|
||||
llm: BaseLanguageModel,
|
||||
prompt: Optional[BasePromptTemplate] = None,
|
||||
document_variable_name: str = "context",
|
||||
verbose: Optional[bool] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> StuffDocumentsChain:
|
||||
_prompt = prompt or stuff_prompt.PROMPT_SELECTOR.get_prompt(llm)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=_prompt,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
# TODO: document prompt
|
||||
return StuffDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
document_variable_name=document_variable_name,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def load_qa_chain(
|
||||
llm: BaseLanguageModel,
|
||||
chain_type: str = "stuff",
|
||||
verbose: Optional[bool] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseCombineDocumentsChain:
|
||||
"""Load question answering chain.
|
||||
|
||||
Args:
|
||||
llm: Language Model to use in the chain.
|
||||
chain_type: Type of document combining chain to use. Should be one of "stuff",
|
||||
"map_reduce", "map_rerank", and "refine".
|
||||
verbose: Whether chains should be run in verbose mode or not. Note that this
|
||||
applies to all chains that make up the final chain.
|
||||
callback_manager: Callback manager to use for the chain.
|
||||
|
||||
Returns:
|
||||
A chain to use for question answering.
|
||||
"""
|
||||
loader_mapping: Mapping[str, LoadingCallable] = {
|
||||
"stuff": _load_stuff_chain,
|
||||
}
|
||||
if chain_type not in loader_mapping:
|
||||
raise ValueError(
|
||||
f"Got unsupported chain type: {chain_type}. "
|
||||
f"Should be one of {loader_mapping.keys()}"
|
||||
)
|
||||
return loader_mapping[chain_type](
|
||||
llm, verbose=verbose, callback_manager=callback_manager, **kwargs
|
||||
)
|
||||
1951
apps/language_models/langchain/gen.py
Normal file
1951
apps/language_models/langchain/gen.py
Normal file
File diff suppressed because it is too large
Load Diff
380
apps/language_models/langchain/gpt4all_llm.py
Normal file
380
apps/language_models/langchain/gpt4all_llm.py
Normal file
@@ -0,0 +1,380 @@
|
||||
import inspect
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Dict, Any, Optional, List
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from pydantic import root_validator
|
||||
from langchain.llms import gpt4all
|
||||
from dotenv import dotenv_values
|
||||
|
||||
from utils import FakeTokenizer
|
||||
|
||||
|
||||
def get_model_tokenizer_gpt4all(base_model, **kwargs):
|
||||
# defaults (some of these are generation parameters, so need to be passed in at generation time)
|
||||
model_kwargs = dict(
|
||||
n_threads=os.cpu_count() // 2,
|
||||
temp=kwargs.get("temperature", 0.2),
|
||||
top_p=kwargs.get("top_p", 0.75),
|
||||
top_k=kwargs.get("top_k", 40),
|
||||
n_ctx=2048 - 256,
|
||||
)
|
||||
env_gpt4all_file = ".env_gpt4all"
|
||||
model_kwargs.update(dotenv_values(env_gpt4all_file))
|
||||
# make int or float if can to satisfy types for class
|
||||
for k, v in model_kwargs.items():
|
||||
try:
|
||||
if float(v) == int(v):
|
||||
model_kwargs[k] = int(v)
|
||||
else:
|
||||
model_kwargs[k] = float(v)
|
||||
except:
|
||||
pass
|
||||
|
||||
if base_model == "llama":
|
||||
if "model_path_llama" not in model_kwargs:
|
||||
raise ValueError("No model_path_llama in %s" % env_gpt4all_file)
|
||||
model_path = model_kwargs.pop("model_path_llama")
|
||||
# FIXME: GPT4All version of llama doesn't handle new quantization, so use llama_cpp_python
|
||||
from llama_cpp import Llama
|
||||
|
||||
# llama sets some things at init model time, not generation time
|
||||
func_names = list(inspect.signature(Llama.__init__).parameters)
|
||||
model_kwargs = {
|
||||
k: v for k, v in model_kwargs.items() if k in func_names
|
||||
}
|
||||
model_kwargs["n_ctx"] = int(model_kwargs["n_ctx"])
|
||||
model = Llama(model_path=model_path, **model_kwargs)
|
||||
elif base_model in "gpt4all_llama":
|
||||
if (
|
||||
"model_name_gpt4all_llama" not in model_kwargs
|
||||
and "model_path_gpt4all_llama" not in model_kwargs
|
||||
):
|
||||
raise ValueError(
|
||||
"No model_name_gpt4all_llama or model_path_gpt4all_llama in %s"
|
||||
% env_gpt4all_file
|
||||
)
|
||||
model_name = model_kwargs.pop("model_name_gpt4all_llama")
|
||||
model_type = "llama"
|
||||
from gpt4all import GPT4All as GPT4AllModel
|
||||
|
||||
model = GPT4AllModel(model_name=model_name, model_type=model_type)
|
||||
elif base_model in "gptj":
|
||||
if (
|
||||
"model_name_gptj" not in model_kwargs
|
||||
and "model_path_gptj" not in model_kwargs
|
||||
):
|
||||
raise ValueError(
|
||||
"No model_name_gpt4j or model_path_gpt4j in %s"
|
||||
% env_gpt4all_file
|
||||
)
|
||||
model_name = model_kwargs.pop("model_name_gptj")
|
||||
model_type = "gptj"
|
||||
from gpt4all import GPT4All as GPT4AllModel
|
||||
|
||||
model = GPT4AllModel(model_name=model_name, model_type=model_type)
|
||||
else:
|
||||
raise ValueError("No such base_model %s" % base_model)
|
||||
return model, FakeTokenizer(), "cpu"
|
||||
|
||||
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
|
||||
class H2OStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
# streaming to std already occurs without this
|
||||
# sys.stdout.write(token)
|
||||
# sys.stdout.flush()
|
||||
pass
|
||||
|
||||
|
||||
def get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=[]):
|
||||
# default from class
|
||||
model_kwargs = {
|
||||
k: v.default
|
||||
for k, v in dict(inspect.signature(cls).parameters).items()
|
||||
if k not in exclude_list
|
||||
}
|
||||
# from our defaults
|
||||
model_kwargs.update(default_kwargs)
|
||||
# from user defaults
|
||||
model_kwargs.update(env_kwargs)
|
||||
# ensure only valid keys
|
||||
func_names = list(inspect.signature(cls).parameters)
|
||||
model_kwargs = {k: v for k, v in model_kwargs.items() if k in func_names}
|
||||
return model_kwargs
|
||||
|
||||
|
||||
def get_llm_gpt4all(
|
||||
model_name,
|
||||
model=None,
|
||||
max_new_tokens=256,
|
||||
temperature=0.1,
|
||||
repetition_penalty=1.0,
|
||||
top_k=40,
|
||||
top_p=0.7,
|
||||
streaming=False,
|
||||
callbacks=None,
|
||||
prompter=None,
|
||||
verbose=False,
|
||||
):
|
||||
assert prompter is not None
|
||||
env_gpt4all_file = ".env_gpt4all"
|
||||
env_kwargs = dotenv_values(env_gpt4all_file)
|
||||
n_ctx = env_kwargs.pop("n_ctx", 2048 - max_new_tokens)
|
||||
default_kwargs = dict(
|
||||
context_erase=0.5,
|
||||
n_batch=1,
|
||||
n_ctx=n_ctx,
|
||||
n_predict=max_new_tokens,
|
||||
repeat_last_n=64 if repetition_penalty != 1.0 else 0,
|
||||
repeat_penalty=repetition_penalty,
|
||||
temp=temperature,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
use_mlock=True,
|
||||
verbose=verbose,
|
||||
)
|
||||
if model_name == "llama":
|
||||
cls = H2OLlamaCpp
|
||||
model_path = (
|
||||
env_kwargs.pop("model_path_llama") if model is None else model
|
||||
)
|
||||
model_kwargs = get_model_kwargs(
|
||||
env_kwargs, default_kwargs, cls, exclude_list=["lc_kwargs"]
|
||||
)
|
||||
model_kwargs.update(
|
||||
dict(
|
||||
model_path=model_path,
|
||||
callbacks=callbacks,
|
||||
streaming=streaming,
|
||||
prompter=prompter,
|
||||
)
|
||||
)
|
||||
llm = cls(**model_kwargs)
|
||||
llm.client.verbose = verbose
|
||||
elif model_name == "gpt4all_llama":
|
||||
cls = H2OGPT4All
|
||||
model_path = (
|
||||
env_kwargs.pop("model_path_gpt4all_llama")
|
||||
if model is None
|
||||
else model
|
||||
)
|
||||
model_kwargs = get_model_kwargs(
|
||||
env_kwargs, default_kwargs, cls, exclude_list=["lc_kwargs"]
|
||||
)
|
||||
model_kwargs.update(
|
||||
dict(
|
||||
model=model_path,
|
||||
backend="llama",
|
||||
callbacks=callbacks,
|
||||
streaming=streaming,
|
||||
prompter=prompter,
|
||||
)
|
||||
)
|
||||
llm = cls(**model_kwargs)
|
||||
elif model_name == "gptj":
|
||||
cls = H2OGPT4All
|
||||
model_path = (
|
||||
env_kwargs.pop("model_path_gptj") if model is None else model
|
||||
)
|
||||
model_kwargs = get_model_kwargs(
|
||||
env_kwargs, default_kwargs, cls, exclude_list=["lc_kwargs"]
|
||||
)
|
||||
model_kwargs.update(
|
||||
dict(
|
||||
model=model_path,
|
||||
backend="gptj",
|
||||
callbacks=callbacks,
|
||||
streaming=streaming,
|
||||
prompter=prompter,
|
||||
)
|
||||
)
|
||||
llm = cls(**model_kwargs)
|
||||
else:
|
||||
raise RuntimeError("No such model_name %s" % model_name)
|
||||
return llm
|
||||
|
||||
|
||||
class H2OGPT4All(gpt4all.GPT4All):
|
||||
model: Any
|
||||
prompter: Any
|
||||
"""Path to the pre-trained GPT4All model file."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that the python package exists in the environment."""
|
||||
try:
|
||||
if isinstance(values["model"], str):
|
||||
from gpt4all import GPT4All as GPT4AllModel
|
||||
|
||||
full_path = values["model"]
|
||||
model_path, delimiter, model_name = full_path.rpartition("/")
|
||||
model_path += delimiter
|
||||
|
||||
values["client"] = GPT4AllModel(
|
||||
model_name=model_name,
|
||||
model_path=model_path or None,
|
||||
model_type=values["backend"],
|
||||
allow_download=False,
|
||||
)
|
||||
if values["n_threads"] is not None:
|
||||
# set n_threads
|
||||
values["client"].model.set_thread_count(
|
||||
values["n_threads"]
|
||||
)
|
||||
else:
|
||||
values["client"] = values["model"]
|
||||
try:
|
||||
values["backend"] = values["client"].model_type
|
||||
except AttributeError:
|
||||
# The below is for compatibility with GPT4All Python bindings <= 0.2.3.
|
||||
values["backend"] = values["client"].model.model_type
|
||||
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import gpt4all python package. "
|
||||
"Please install it with `pip install gpt4all`."
|
||||
)
|
||||
return values
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
# Roughly 4 chars per token if natural language
|
||||
prompt = prompt[-self.n_ctx * 4 :]
|
||||
|
||||
# use instruct prompting
|
||||
data_point = dict(context="", instruction=prompt, input="")
|
||||
prompt = self.prompter.generate_prompt(data_point)
|
||||
|
||||
verbose = False
|
||||
if verbose:
|
||||
print("_call prompt: %s" % prompt, flush=True)
|
||||
# FIXME: GPT4ALl doesn't support yield during generate, so cannot support streaming except via itself to stdout
|
||||
return super()._call(prompt, stop=stop, run_manager=run_manager)
|
||||
|
||||
|
||||
from langchain.llms import LlamaCpp
|
||||
|
||||
|
||||
class H2OLlamaCpp(LlamaCpp):
|
||||
model_path: Any
|
||||
prompter: Any
|
||||
"""Path to the pre-trained GPT4All model file."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that llama-cpp-python library is installed."""
|
||||
if isinstance(values["model_path"], str):
|
||||
model_path = values["model_path"]
|
||||
model_param_names = [
|
||||
"lora_path",
|
||||
"lora_base",
|
||||
"n_ctx",
|
||||
"n_parts",
|
||||
"seed",
|
||||
"f16_kv",
|
||||
"logits_all",
|
||||
"vocab_only",
|
||||
"use_mlock",
|
||||
"n_threads",
|
||||
"n_batch",
|
||||
"use_mmap",
|
||||
"last_n_tokens_size",
|
||||
]
|
||||
model_params = {k: values[k] for k in model_param_names}
|
||||
# For backwards compatibility, only include if non-null.
|
||||
if values["n_gpu_layers"] is not None:
|
||||
model_params["n_gpu_layers"] = values["n_gpu_layers"]
|
||||
|
||||
try:
|
||||
from llama_cpp import Llama
|
||||
|
||||
values["client"] = Llama(model_path, **model_params)
|
||||
except ImportError:
|
||||
raise ModuleNotFoundError(
|
||||
"Could not import llama-cpp-python library. "
|
||||
"Please install the llama-cpp-python library to "
|
||||
"use this embedding model: pip install llama-cpp-python"
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Could not load Llama model from path: {model_path}. "
|
||||
f"Received error {e}"
|
||||
)
|
||||
else:
|
||||
values["client"] = values["model_path"]
|
||||
return values
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
verbose = False
|
||||
# tokenize twice, just to count tokens, since llama cpp python wrapper has no way to truncate
|
||||
# still have to avoid crazy sizes, else hit llama_tokenize: too many tokens -- might still hit, not fatal
|
||||
prompt = prompt[-self.n_ctx * 4 :]
|
||||
prompt_tokens = self.client.tokenize(b" " + prompt.encode("utf-8"))
|
||||
num_prompt_tokens = len(prompt_tokens)
|
||||
if num_prompt_tokens > self.n_ctx:
|
||||
# conservative by using int()
|
||||
chars_per_token = int(len(prompt) / num_prompt_tokens)
|
||||
prompt = prompt[-self.n_ctx * chars_per_token :]
|
||||
if verbose:
|
||||
print(
|
||||
"reducing tokens, assuming average of %s chars/token: %s"
|
||||
% chars_per_token,
|
||||
flush=True,
|
||||
)
|
||||
prompt_tokens2 = self.client.tokenize(
|
||||
b" " + prompt.encode("utf-8")
|
||||
)
|
||||
num_prompt_tokens2 = len(prompt_tokens2)
|
||||
print(
|
||||
"reduced tokens from %d -> %d"
|
||||
% (num_prompt_tokens, num_prompt_tokens2),
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# use instruct prompting
|
||||
data_point = dict(context="", instruction=prompt, input="")
|
||||
prompt = self.prompter.generate_prompt(data_point)
|
||||
|
||||
if verbose:
|
||||
print("_call prompt: %s" % prompt, flush=True)
|
||||
|
||||
if self.streaming:
|
||||
text_callback = None
|
||||
if run_manager:
|
||||
text_callback = partial(
|
||||
run_manager.on_llm_new_token, verbose=self.verbose
|
||||
)
|
||||
# parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter
|
||||
if text_callback:
|
||||
text_callback(prompt)
|
||||
text = ""
|
||||
for token in self.stream(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager
|
||||
):
|
||||
text_chunk = token["choices"][0]["text"]
|
||||
# self.stream already calls text_callback
|
||||
# if text_callback:
|
||||
# text_callback(text_chunk)
|
||||
text += text_chunk
|
||||
return text
|
||||
else:
|
||||
params = self._get_parameters(stop)
|
||||
params = {**params, **kwargs}
|
||||
result = self.client(prompt=prompt, **params)
|
||||
return result["choices"][0]["text"]
|
||||
3137
apps/language_models/langchain/gpt_langchain.py
Normal file
3137
apps/language_models/langchain/gpt_langchain.py
Normal file
File diff suppressed because it is too large
Load Diff
93
apps/language_models/langchain/gradio_utils/grclient.py
Normal file
93
apps/language_models/langchain/gradio_utils/grclient.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import traceback
|
||||
from typing import Callable
|
||||
import os
|
||||
|
||||
from gradio_client.client import Job
|
||||
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
|
||||
from gradio_client import Client
|
||||
|
||||
|
||||
class GradioClient(Client):
|
||||
"""
|
||||
Parent class of gradio client
|
||||
To handle automatically refreshing client if detect gradio server changed
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
super().__init__(*args, **kwargs)
|
||||
self.server_hash = self.get_server_hash()
|
||||
|
||||
def get_server_hash(self):
|
||||
"""
|
||||
Get server hash using super without any refresh action triggered
|
||||
Returns: git hash of gradio server
|
||||
"""
|
||||
return super().submit(api_name="/system_hash").result()
|
||||
|
||||
def refresh_client_if_should(self):
|
||||
# get current hash in order to update api_name -> fn_index map in case gradio server changed
|
||||
# FIXME: Could add cli api as hash
|
||||
server_hash = self.get_server_hash()
|
||||
if self.server_hash != server_hash:
|
||||
self.refresh_client()
|
||||
self.server_hash = server_hash
|
||||
else:
|
||||
self.reset_session()
|
||||
|
||||
def refresh_client(self):
|
||||
"""
|
||||
Ensure every client call is independent
|
||||
Also ensure map between api_name and fn_index is updated in case server changed (e.g. restarted with new code)
|
||||
Returns:
|
||||
"""
|
||||
# need session hash to be new every time, to avoid "generator already executing"
|
||||
self.reset_session()
|
||||
|
||||
client = Client(*self.args, **self.kwargs)
|
||||
for k, v in client.__dict__.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
def submit(
|
||||
self,
|
||||
*args,
|
||||
api_name: str | None = None,
|
||||
fn_index: int | None = None,
|
||||
result_callbacks: Callable | list[Callable] | None = None,
|
||||
) -> Job:
|
||||
# Note predict calls submit
|
||||
try:
|
||||
self.refresh_client_if_should()
|
||||
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
|
||||
except Exception as e:
|
||||
print("Hit e=%s" % str(e), flush=True)
|
||||
# force reconfig in case only that
|
||||
self.refresh_client()
|
||||
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
|
||||
|
||||
# see if immediately failed
|
||||
e = job.future._exception
|
||||
if e is not None:
|
||||
print(
|
||||
"GR job failed: %s %s"
|
||||
% (str(e), "".join(traceback.format_tb(e.__traceback__))),
|
||||
flush=True,
|
||||
)
|
||||
# force reconfig in case only that
|
||||
self.refresh_client()
|
||||
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
|
||||
e2 = job.future._exception
|
||||
if e2 is not None:
|
||||
print(
|
||||
"GR job failed again: %s\n%s"
|
||||
% (
|
||||
str(e2),
|
||||
"".join(traceback.format_tb(e2.__traceback__)),
|
||||
),
|
||||
flush=True,
|
||||
)
|
||||
|
||||
return job
|
||||
765
apps/language_models/langchain/h2oai_pipeline.py
Normal file
765
apps/language_models/langchain/h2oai_pipeline.py
Normal file
@@ -0,0 +1,765 @@
|
||||
import os
|
||||
from apps.stable_diffusion.src.utils.utils import _compile_module
|
||||
from io import BytesIO
|
||||
import torch_mlir
|
||||
|
||||
from stopping import get_stopping
|
||||
from prompter import Prompter, PromptType
|
||||
|
||||
from transformers import TextGenerationPipeline
|
||||
from transformers.pipelines.text_generation import ReturnType
|
||||
from transformers.generation import (
|
||||
GenerationConfig,
|
||||
LogitsProcessorList,
|
||||
StoppingCriteriaList,
|
||||
)
|
||||
import copy
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
import gc
|
||||
from pathlib import Path
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_importer import import_with_fx, save_mlir
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
# Brevitas
|
||||
from typing import List, Tuple
|
||||
from brevitas_examples.common.generative.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
# fmt: off
|
||||
def quant〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
|
||||
if len(lhs) == 3 and len(rhs) == 2:
|
||||
return [lhs[0], lhs[1], rhs[0]]
|
||||
elif len(lhs) == 2 and len(rhs) == 2:
|
||||
return [lhs[0], rhs[0]]
|
||||
else:
|
||||
raise ValueError("Input shapes not supported.")
|
||||
|
||||
|
||||
def quant〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
|
||||
# output dtype is the dtype of the lhs float input
|
||||
lhs_rank, lhs_dtype = lhs_rank_dtype
|
||||
return lhs_dtype
|
||||
|
||||
|
||||
def quant〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
|
||||
return
|
||||
|
||||
|
||||
brevitas_matmul_rhs_group_quant_library = [
|
||||
quant〇matmul_rhs_group_quant〡shape,
|
||||
quant〇matmul_rhs_group_quant〡dtype,
|
||||
quant〇matmul_rhs_group_quant〡has_value_semantics]
|
||||
# fmt: on
|
||||
|
||||
global_device = "cuda"
|
||||
global_precision = "fp16"
|
||||
|
||||
if not args.run_docuchat_web:
|
||||
args.device = global_device
|
||||
args.precision = global_precision
|
||||
tensor_device = "cpu" if args.device == "cpu" else "cuda"
|
||||
|
||||
|
||||
class H2OGPTModel(torch.nn.Module):
|
||||
def __init__(self, device, precision):
|
||||
super().__init__()
|
||||
torch_dtype = (
|
||||
torch.float32
|
||||
if precision == "fp32" or device == "cpu"
|
||||
else torch.float16
|
||||
)
|
||||
device_map = {"": "cpu"} if device == "cpu" else {"": 0}
|
||||
model_kwargs = {
|
||||
"local_files_only": False,
|
||||
"torch_dtype": torch_dtype,
|
||||
"resume_download": True,
|
||||
"use_auth_token": False,
|
||||
"trust_remote_code": True,
|
||||
"offload_folder": "offline_folder",
|
||||
"device_map": device_map,
|
||||
}
|
||||
config = AutoConfig.from_pretrained(
|
||||
"h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
|
||||
use_auth_token=False,
|
||||
trust_remote_code=True,
|
||||
offload_folder="offline_folder",
|
||||
)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
"h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
|
||||
config=config,
|
||||
**model_kwargs,
|
||||
)
|
||||
if precision in ["int4", "int8"]:
|
||||
print("Applying weight quantization..")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
self.model.transformer.h,
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=128,
|
||||
quantize_weight_zero_point=False,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
input_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": None,
|
||||
"use_cache": True,
|
||||
}
|
||||
output = self.model(
|
||||
**input_dict,
|
||||
return_dict=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
)
|
||||
return output.logits[:, -1, :]
|
||||
|
||||
|
||||
class H2OGPTSHARKModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
model_name = "h2ogpt_falcon_7b"
|
||||
extended_model_name = (
|
||||
model_name + "_" + args.precision + "_" + args.device
|
||||
)
|
||||
vmfb_path = Path(extended_model_name + ".vmfb")
|
||||
mlir_path = Path(model_name + "_" + args.precision + ".mlir")
|
||||
shark_module = None
|
||||
|
||||
need_to_compile = False
|
||||
if not vmfb_path.exists():
|
||||
need_to_compile = True
|
||||
# Downloading VMFB from shark_tank
|
||||
print("Trying to download pre-compiled vmfb from shark tank.")
|
||||
download_public_file(
|
||||
"gs://shark_tank/langchain/" + str(vmfb_path),
|
||||
vmfb_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if vmfb_path.exists():
|
||||
print(
|
||||
"Pre-compiled vmfb downloaded from shark tank successfully."
|
||||
)
|
||||
need_to_compile = False
|
||||
|
||||
if need_to_compile:
|
||||
if not mlir_path.exists():
|
||||
print("Trying to download pre-generated mlir from shark tank.")
|
||||
# Downloading MLIR from shark_tank
|
||||
download_public_file(
|
||||
"gs://shark_tank/langchain/" + str(mlir_path),
|
||||
mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if mlir_path.exists():
|
||||
with open(mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
# Generating the mlir
|
||||
bytecode = self.get_bytecode(tensor_device, args.precision)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode,
|
||||
device=args.device,
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
print(f"[DEBUG] generating vmfb.")
|
||||
shark_module = _compile_module(
|
||||
shark_module, extended_model_name, []
|
||||
)
|
||||
print("Saved newly generated vmfb.")
|
||||
|
||||
if shark_module is None:
|
||||
if vmfb_path.exists():
|
||||
print("Compiled vmfb found. Loading it from: ", vmfb_path)
|
||||
shark_module = SharkInference(
|
||||
None, device=args.device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.load_module(str(vmfb_path))
|
||||
print("Compiled vmfb loaded successfully.")
|
||||
else:
|
||||
raise ValueError("Unable to download/generate a vmfb.")
|
||||
|
||||
self.model = shark_module
|
||||
|
||||
def get_bytecode(self, device, precision):
|
||||
h2ogpt_model = H2OGPTModel(device, precision)
|
||||
|
||||
compilation_input_ids = torch.randint(
|
||||
low=1, high=10000, size=(1, 400)
|
||||
).to(device=device)
|
||||
compilation_attention_mask = torch.ones(1, 400, dtype=torch.int64).to(
|
||||
device=device
|
||||
)
|
||||
|
||||
h2ogptCompileInput = (
|
||||
compilation_input_ids,
|
||||
compilation_attention_mask,
|
||||
)
|
||||
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
ts_graph = import_with_fx(
|
||||
h2ogpt_model,
|
||||
h2ogptCompileInput,
|
||||
is_f16=False,
|
||||
precision=precision,
|
||||
f16_input_mask=[False, False],
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
del h2ogpt_model
|
||||
del self.src_model
|
||||
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
if precision in ["int4", "int8"]:
|
||||
from torch_mlir.compiler_utils import (
|
||||
run_pipeline_with_repro_report,
|
||||
)
|
||||
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*h2ogptCompileInput],
|
||||
output_type=torch_mlir.OutputType.TORCH,
|
||||
backend_legal_ops=["quant.matmul_rhs_group_quant"],
|
||||
extra_library=brevitas_matmul_rhs_group_quant_library,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
module,
|
||||
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
else:
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*h2ogptCompileInput],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
del ts_graph
|
||||
|
||||
print(f"[DEBUG] converting to bytecode")
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
del module
|
||||
|
||||
bytecode = save_mlir(
|
||||
bytecode,
|
||||
model_name=f"h2ogpt_{precision}",
|
||||
frontend="torch",
|
||||
)
|
||||
return bytecode
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
result = torch.from_numpy(
|
||||
self.model(
|
||||
"forward",
|
||||
(input_ids.to(device="cpu"), attention_mask.to(device="cpu")),
|
||||
)
|
||||
).to(device=tensor_device)
|
||||
return result
|
||||
|
||||
|
||||
def decode_tokens(tokenizer, res_tokens):
|
||||
for i in range(len(res_tokens)):
|
||||
if type(res_tokens[i]) != int:
|
||||
res_tokens[i] = int(res_tokens[i][0])
|
||||
|
||||
res_str = tokenizer.decode(res_tokens, skip_special_tokens=True)
|
||||
return res_str
|
||||
|
||||
|
||||
def generate_token(h2ogpt_shark_model, model, tokenizer, **generate_kwargs):
|
||||
del generate_kwargs["max_time"]
|
||||
generate_kwargs["input_ids"] = generate_kwargs["input_ids"].to(
|
||||
device=tensor_device
|
||||
)
|
||||
generate_kwargs["attention_mask"] = generate_kwargs["attention_mask"].to(
|
||||
device=tensor_device
|
||||
)
|
||||
truncated_input_ids = []
|
||||
stopping_criteria = generate_kwargs["stopping_criteria"]
|
||||
|
||||
generation_config_ = GenerationConfig.from_model_config(model.config)
|
||||
generation_config = copy.deepcopy(generation_config_)
|
||||
model_kwargs = generation_config.update(**generate_kwargs)
|
||||
|
||||
logits_processor = LogitsProcessorList()
|
||||
stopping_criteria = (
|
||||
stopping_criteria
|
||||
if stopping_criteria is not None
|
||||
else StoppingCriteriaList()
|
||||
)
|
||||
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
generation_config.pad_token_id = eos_token_id
|
||||
|
||||
(
|
||||
inputs_tensor,
|
||||
model_input_name,
|
||||
model_kwargs,
|
||||
) = model._prepare_model_inputs(
|
||||
None, generation_config.bos_token_id, model_kwargs
|
||||
)
|
||||
|
||||
model_kwargs["output_attentions"] = generation_config.output_attentions
|
||||
model_kwargs["output_hidden_states"] = (
|
||||
generation_config.output_hidden_states
|
||||
)
|
||||
model_kwargs["use_cache"] = generation_config.use_cache
|
||||
|
||||
input_ids = (
|
||||
inputs_tensor
|
||||
if model_input_name == "input_ids"
|
||||
else model_kwargs.pop("input_ids")
|
||||
)
|
||||
|
||||
input_ids_seq_length = input_ids.shape[-1]
|
||||
|
||||
generation_config.max_length = (
|
||||
generation_config.max_new_tokens + input_ids_seq_length
|
||||
)
|
||||
|
||||
logits_processor = model._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids_seq_length,
|
||||
encoder_input_ids=inputs_tensor,
|
||||
prefix_allowed_tokens_fn=None,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
|
||||
stopping_criteria = model._get_stopping_criteria(
|
||||
generation_config=generation_config,
|
||||
stopping_criteria=stopping_criteria,
|
||||
)
|
||||
|
||||
logits_warper = model._get_logits_warper(generation_config)
|
||||
|
||||
(
|
||||
input_ids,
|
||||
model_kwargs,
|
||||
) = model._expand_inputs_for_generation(
|
||||
input_ids=input_ids,
|
||||
expand_size=generation_config.num_return_sequences, # 1
|
||||
is_encoder_decoder=model.config.is_encoder_decoder, # False
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
eos_token_id_tensor = (
|
||||
torch.tensor(eos_token_id).to(device=tensor_device)
|
||||
if eos_token_id is not None
|
||||
else None
|
||||
)
|
||||
|
||||
pad_token_id = generation_config.pad_token_id
|
||||
eos_token_id = eos_token_id
|
||||
|
||||
output_scores = generation_config.output_scores # False
|
||||
return_dict_in_generate = (
|
||||
generation_config.return_dict_in_generate # False
|
||||
)
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and output_scores) else None
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
unfinished_sequences = torch.ones(
|
||||
input_ids.shape[0],
|
||||
dtype=torch.long,
|
||||
device=input_ids.device,
|
||||
)
|
||||
|
||||
timesRan = 0
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
print("\n")
|
||||
|
||||
res_tokens = []
|
||||
while True:
|
||||
model_inputs = model.prepare_inputs_for_generation(
|
||||
input_ids, **model_kwargs
|
||||
)
|
||||
|
||||
outputs = h2ogpt_shark_model.forward(
|
||||
model_inputs["input_ids"], model_inputs["attention_mask"]
|
||||
)
|
||||
|
||||
if args.precision == "fp16":
|
||||
outputs = outputs.to(dtype=torch.float32)
|
||||
next_token_logits = outputs
|
||||
|
||||
# pre-process distribution
|
||||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||
|
||||
# sample
|
||||
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
|
||||
|
||||
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
|
||||
# finished sentences should have their next token be a padding token
|
||||
if eos_token_id is not None:
|
||||
if pad_token_id is None:
|
||||
raise ValueError(
|
||||
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
|
||||
)
|
||||
next_token = next_token * unfinished_sequences + pad_token_id * (
|
||||
1 - unfinished_sequences
|
||||
)
|
||||
|
||||
input_ids = torch.cat([input_ids, next_token[:, None]], dim=-1)
|
||||
|
||||
model_kwargs["past_key_values"] = None
|
||||
if "attention_mask" in model_kwargs:
|
||||
attention_mask = model_kwargs["attention_mask"]
|
||||
model_kwargs["attention_mask"] = torch.cat(
|
||||
[
|
||||
attention_mask,
|
||||
attention_mask.new_ones((attention_mask.shape[0], 1)),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
truncated_input_ids.append(input_ids[:, 0])
|
||||
input_ids = input_ids[:, 1:]
|
||||
model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, 1:]
|
||||
|
||||
new_word = tokenizer.decode(
|
||||
next_token.cpu().numpy(),
|
||||
add_special_tokens=False,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=True,
|
||||
)
|
||||
|
||||
res_tokens.append(next_token)
|
||||
if new_word == "<0x0A>":
|
||||
print("\n", end="", flush=True)
|
||||
else:
|
||||
print(f"{new_word}", end=" ", flush=True)
|
||||
|
||||
part_str = decode_tokens(tokenizer, res_tokens)
|
||||
yield part_str
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if eos_token_id_tensor is not None:
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
next_token.tile(eos_token_id_tensor.shape[0], 1)
|
||||
.ne(eos_token_id_tensor.unsqueeze(1))
|
||||
.prod(dim=0)
|
||||
)
|
||||
# stop when each sentence is finished
|
||||
if unfinished_sequences.max() == 0 or stopping_criteria(
|
||||
input_ids, scores
|
||||
):
|
||||
break
|
||||
timesRan = timesRan + 1
|
||||
|
||||
end = time.time()
|
||||
print(
|
||||
"\n\nTime taken is {:.2f} seconds/token\n".format(
|
||||
(end - start) / timesRan
|
||||
)
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
res_str = decode_tokens(tokenizer, res_tokens)
|
||||
yield res_str
|
||||
|
||||
|
||||
def pad_or_truncate_inputs(
|
||||
input_ids, attention_mask, max_padding_length=400, do_truncation=False
|
||||
):
|
||||
inp_shape = input_ids.shape
|
||||
if inp_shape[1] < max_padding_length:
|
||||
# do padding
|
||||
num_add_token = max_padding_length - inp_shape[1]
|
||||
padded_input_ids = torch.cat(
|
||||
[
|
||||
torch.tensor([[11] * num_add_token]).to(device=tensor_device),
|
||||
input_ids,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
padded_attention_mask = torch.cat(
|
||||
[
|
||||
torch.tensor([[0] * num_add_token]).to(device=tensor_device),
|
||||
attention_mask,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
return padded_input_ids, padded_attention_mask
|
||||
elif inp_shape[1] > max_padding_length or do_truncation:
|
||||
# do truncation
|
||||
num_remove_token = inp_shape[1] - max_padding_length
|
||||
truncated_input_ids = input_ids[:, num_remove_token:]
|
||||
truncated_attention_mask = attention_mask[:, num_remove_token:]
|
||||
return truncated_input_ids, truncated_attention_mask
|
||||
else:
|
||||
return input_ids, attention_mask
|
||||
|
||||
|
||||
class H2OTextGenerationPipeline(TextGenerationPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
debug=False,
|
||||
chat=False,
|
||||
stream_output=False,
|
||||
sanitize_bot_response=False,
|
||||
use_prompter=True,
|
||||
prompter=None,
|
||||
prompt_type=None,
|
||||
prompt_dict=None,
|
||||
max_input_tokens=2048 - 256,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
HF-like pipeline, but handle instruction prompting and stopping (for some models)
|
||||
:param args:
|
||||
:param debug:
|
||||
:param chat:
|
||||
:param stream_output:
|
||||
:param sanitize_bot_response:
|
||||
:param use_prompter: Whether to use prompter. If pass prompt_type, will make prompter
|
||||
:param prompter: prompter, can pass if have already
|
||||
:param prompt_type: prompt_type, e.g. human_bot. See prompt_type to model mapping in from prompter.py.
|
||||
If use_prompter, then will make prompter and use it.
|
||||
:param prompt_dict: dict of get_prompt(, return_dict=True) for prompt_type=custom
|
||||
:param max_input_tokens:
|
||||
:param kwargs:
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.prompt_text = None
|
||||
self.use_prompter = use_prompter
|
||||
self.prompt_type = prompt_type
|
||||
self.prompt_dict = prompt_dict
|
||||
self.prompter = prompter
|
||||
if self.use_prompter:
|
||||
if self.prompter is not None:
|
||||
assert self.prompter.prompt_type is not None
|
||||
else:
|
||||
self.prompter = Prompter(
|
||||
self.prompt_type,
|
||||
self.prompt_dict,
|
||||
debug=debug,
|
||||
chat=chat,
|
||||
stream_output=stream_output,
|
||||
)
|
||||
self.human = self.prompter.humanstr
|
||||
self.bot = self.prompter.botstr
|
||||
self.can_stop = True
|
||||
else:
|
||||
self.prompter = None
|
||||
self.human = None
|
||||
self.bot = None
|
||||
self.can_stop = False
|
||||
self.sanitize_bot_response = sanitize_bot_response
|
||||
self.max_input_tokens = (
|
||||
max_input_tokens # not for generate, so ok that not kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def limit_prompt(prompt_text, tokenizer, max_prompt_length=None):
|
||||
verbose = bool(int(os.getenv("VERBOSE_PIPELINE", "0")))
|
||||
|
||||
if hasattr(tokenizer, "model_max_length"):
|
||||
# model_max_length only defined for generate.py, not raw use of h2oai_pipeline.py
|
||||
model_max_length = tokenizer.model_max_length
|
||||
if max_prompt_length is not None:
|
||||
model_max_length = min(model_max_length, max_prompt_length)
|
||||
# cut at some upper likely limit to avoid excessive tokenization etc
|
||||
# upper bound of 10 chars/token, e.g. special chars sometimes are long
|
||||
if len(prompt_text) > model_max_length * 10:
|
||||
len0 = len(prompt_text)
|
||||
prompt_text = prompt_text[-model_max_length * 10 :]
|
||||
if verbose:
|
||||
print(
|
||||
"Cut of input: %s -> %s" % (len0, len(prompt_text)),
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
# unknown
|
||||
model_max_length = None
|
||||
|
||||
num_prompt_tokens = None
|
||||
if model_max_length is not None:
|
||||
# can't wait for "hole" if not plain prompt_type, since would lose prefix like <human>:
|
||||
# For https://github.com/h2oai/h2ogpt/issues/192
|
||||
for trial in range(0, 3):
|
||||
prompt_tokens = tokenizer(prompt_text)["input_ids"]
|
||||
num_prompt_tokens = len(prompt_tokens)
|
||||
if num_prompt_tokens > model_max_length:
|
||||
# conservative by using int()
|
||||
chars_per_token = int(len(prompt_text) / num_prompt_tokens)
|
||||
# keep tail, where question is if using langchain
|
||||
prompt_text = prompt_text[
|
||||
-model_max_length * chars_per_token :
|
||||
]
|
||||
if verbose:
|
||||
print(
|
||||
"reducing %s tokens, assuming average of %s chars/token for %s characters"
|
||||
% (
|
||||
num_prompt_tokens,
|
||||
chars_per_token,
|
||||
len(prompt_text),
|
||||
),
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
if verbose:
|
||||
print(
|
||||
"using %s tokens with %s chars"
|
||||
% (num_prompt_tokens, len(prompt_text)),
|
||||
flush=True,
|
||||
)
|
||||
break
|
||||
|
||||
return prompt_text, num_prompt_tokens
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
prompt_text,
|
||||
prefix="",
|
||||
handle_long_generation=None,
|
||||
**generate_kwargs,
|
||||
):
|
||||
(
|
||||
prompt_text,
|
||||
num_prompt_tokens,
|
||||
) = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer)
|
||||
|
||||
data_point = dict(context="", instruction=prompt_text, input="")
|
||||
if self.prompter is not None:
|
||||
prompt_text = self.prompter.generate_prompt(data_point)
|
||||
self.prompt_text = prompt_text
|
||||
if handle_long_generation is None:
|
||||
# forces truncation of inputs to avoid critical failure
|
||||
handle_long_generation = None # disable with new approaches
|
||||
return super().preprocess(
|
||||
prompt_text,
|
||||
prefix=prefix,
|
||||
handle_long_generation=handle_long_generation,
|
||||
**generate_kwargs,
|
||||
)
|
||||
|
||||
def postprocess(
|
||||
self,
|
||||
model_outputs,
|
||||
return_type=ReturnType.FULL_TEXT,
|
||||
clean_up_tokenization_spaces=True,
|
||||
):
|
||||
records = super().postprocess(
|
||||
model_outputs,
|
||||
return_type=return_type,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
)
|
||||
for rec in records:
|
||||
if self.use_prompter:
|
||||
outputs = rec["generated_text"]
|
||||
outputs = self.prompter.get_response(
|
||||
outputs,
|
||||
prompt=self.prompt_text,
|
||||
sanitize_bot_response=self.sanitize_bot_response,
|
||||
)
|
||||
elif self.bot and self.human:
|
||||
outputs = (
|
||||
rec["generated_text"]
|
||||
.split(self.bot)[1]
|
||||
.split(self.human)[0]
|
||||
)
|
||||
else:
|
||||
outputs = rec["generated_text"]
|
||||
rec["generated_text"] = outputs
|
||||
print(
|
||||
"prompt: %s\noutputs: %s\n\n" % (self.prompt_text, outputs),
|
||||
flush=True,
|
||||
)
|
||||
return records
|
||||
|
||||
def _forward(self, model_inputs, **generate_kwargs):
|
||||
if self.can_stop:
|
||||
stopping_criteria = get_stopping(
|
||||
self.prompt_type,
|
||||
self.prompt_dict,
|
||||
self.tokenizer,
|
||||
self.device,
|
||||
human=self.human,
|
||||
bot=self.bot,
|
||||
model_max_length=self.tokenizer.model_max_length,
|
||||
)
|
||||
generate_kwargs["stopping_criteria"] = stopping_criteria
|
||||
# return super()._forward(model_inputs, **generate_kwargs)
|
||||
return self.__forward(model_inputs, **generate_kwargs)
|
||||
|
||||
# FIXME: Copy-paste of original _forward, but removed copy.deepcopy()
|
||||
# FIXME: https://github.com/h2oai/h2ogpt/issues/172
|
||||
def __forward(self, model_inputs, **generate_kwargs):
|
||||
input_ids = model_inputs["input_ids"]
|
||||
attention_mask = model_inputs.get("attention_mask", None)
|
||||
# Allow empty prompts
|
||||
if input_ids.shape[1] == 0:
|
||||
input_ids = None
|
||||
attention_mask = None
|
||||
in_b = 1
|
||||
else:
|
||||
in_b = input_ids.shape[0]
|
||||
prompt_text = model_inputs.pop("prompt_text")
|
||||
|
||||
## If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying
|
||||
## generate_kwargs, as some of the parameterization may come from the initialization of the pipeline.
|
||||
# generate_kwargs = copy.deepcopy(generate_kwargs)
|
||||
prefix_length = generate_kwargs.pop("prefix_length", 0)
|
||||
if prefix_length > 0:
|
||||
has_max_new_tokens = "max_new_tokens" in generate_kwargs or (
|
||||
"generation_config" in generate_kwargs
|
||||
and generate_kwargs["generation_config"].max_new_tokens
|
||||
is not None
|
||||
)
|
||||
if not has_max_new_tokens:
|
||||
generate_kwargs["max_length"] = (
|
||||
generate_kwargs.get("max_length")
|
||||
or self.model.config.max_length
|
||||
)
|
||||
generate_kwargs["max_length"] += prefix_length
|
||||
has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
|
||||
"generation_config" in generate_kwargs
|
||||
and generate_kwargs["generation_config"].min_new_tokens
|
||||
is not None
|
||||
)
|
||||
if not has_min_new_tokens and "min_length" in generate_kwargs:
|
||||
generate_kwargs["min_length"] += prefix_length
|
||||
|
||||
# BS x SL
|
||||
# pad or truncate the input_ids and attention_mask
|
||||
max_padding_length = 400
|
||||
input_ids, attention_mask = pad_or_truncate_inputs(
|
||||
input_ids, attention_mask, max_padding_length=max_padding_length
|
||||
)
|
||||
|
||||
return_dict = {
|
||||
"model": self.model,
|
||||
"tokenizer": self.tokenizer,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
return_dict = {**return_dict, **generate_kwargs}
|
||||
return return_dict
|
||||
248
apps/language_models/langchain/image_captions.py
Normal file
248
apps/language_models/langchain/image_captions.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""
|
||||
Based upon ImageCaptionLoader in LangChain version: langchain/document_loaders/image_captions.py
|
||||
But accepts preloaded model to avoid slowness in use and CUDA forking issues
|
||||
|
||||
Loader that loads image captions
|
||||
By default, the loader utilizes the pre-trained BLIP image captioning model.
|
||||
https://huggingface.co/Salesforce/blip-image-captioning-base
|
||||
|
||||
"""
|
||||
|
||||
from typing import List, Union, Any, Tuple
|
||||
|
||||
import requests
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders import ImageCaptionLoader
|
||||
|
||||
from utils import get_device, NullContext
|
||||
|
||||
import pkg_resources
|
||||
|
||||
try:
|
||||
assert pkg_resources.get_distribution("bitsandbytes") is not None
|
||||
have_bitsandbytes = True
|
||||
except (pkg_resources.DistributionNotFound, AssertionError):
|
||||
have_bitsandbytes = False
|
||||
|
||||
|
||||
class H2OImageCaptionLoader(ImageCaptionLoader):
|
||||
"""Loader that loads the captions of an image"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path_images: Union[str, List[str]] = None,
|
||||
blip_processor: str = None,
|
||||
blip_model: str = None,
|
||||
caption_gpu=True,
|
||||
load_in_8bit=True,
|
||||
# True doesn't seem to work, even though https://huggingface.co/Salesforce/blip2-flan-t5-xxl#in-8-bit-precision-int8
|
||||
load_half=False,
|
||||
load_gptq="",
|
||||
use_safetensors=False,
|
||||
min_new_tokens=20,
|
||||
max_tokens=50,
|
||||
):
|
||||
if blip_model is None or blip_model is None:
|
||||
blip_processor = "Salesforce/blip-image-captioning-base"
|
||||
blip_model = "Salesforce/blip-image-captioning-base"
|
||||
|
||||
super().__init__(path_images, blip_processor, blip_model)
|
||||
self.blip_processor = blip_processor
|
||||
self.blip_model = blip_model
|
||||
self.processor = None
|
||||
self.model = None
|
||||
self.caption_gpu = caption_gpu
|
||||
self.context_class = NullContext
|
||||
self.device = "cpu"
|
||||
self.load_in_8bit = (
|
||||
load_in_8bit and have_bitsandbytes
|
||||
) # only for blip2
|
||||
self.load_half = load_half
|
||||
self.load_gptq = load_gptq
|
||||
self.use_safetensors = use_safetensors
|
||||
self.gpu_id = "auto"
|
||||
# default prompt
|
||||
self.prompt = "image of"
|
||||
self.min_new_tokens = min_new_tokens
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
def set_context(self):
|
||||
if get_device() == "cuda" and self.caption_gpu:
|
||||
import torch
|
||||
|
||||
n_gpus = (
|
||||
torch.cuda.device_count() if torch.cuda.is_available else 0
|
||||
)
|
||||
if n_gpus > 0:
|
||||
self.context_class = torch.device
|
||||
self.device = "cuda"
|
||||
|
||||
def load_model(self):
|
||||
try:
|
||||
import transformers
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"`transformers` package not found, please install with "
|
||||
"`pip install transformers`."
|
||||
)
|
||||
self.set_context()
|
||||
if self.caption_gpu:
|
||||
if self.gpu_id == "auto":
|
||||
# blip2 has issues with multi-GPU. Error says need to somehow set language model in device map
|
||||
# device_map = 'auto'
|
||||
device_map = {"": 0}
|
||||
else:
|
||||
if self.device == "cuda":
|
||||
device_map = {"": self.gpu_id}
|
||||
else:
|
||||
device_map = {"": "cpu"}
|
||||
else:
|
||||
device_map = {"": "cpu"}
|
||||
import torch
|
||||
|
||||
with torch.no_grad():
|
||||
with self.context_class(self.device):
|
||||
context_class_cast = (
|
||||
NullContext if self.device == "cpu" else torch.autocast
|
||||
)
|
||||
with context_class_cast(self.device):
|
||||
if "blip2" in self.blip_processor.lower():
|
||||
from transformers import (
|
||||
Blip2Processor,
|
||||
Blip2ForConditionalGeneration,
|
||||
)
|
||||
|
||||
if self.load_half and not self.load_in_8bit:
|
||||
self.processor = Blip2Processor.from_pretrained(
|
||||
self.blip_processor, device_map=device_map
|
||||
).half()
|
||||
self.model = (
|
||||
Blip2ForConditionalGeneration.from_pretrained(
|
||||
self.blip_model, device_map=device_map
|
||||
).half()
|
||||
)
|
||||
else:
|
||||
self.processor = Blip2Processor.from_pretrained(
|
||||
self.blip_processor,
|
||||
load_in_8bit=self.load_in_8bit,
|
||||
device_map=device_map,
|
||||
)
|
||||
self.model = (
|
||||
Blip2ForConditionalGeneration.from_pretrained(
|
||||
self.blip_model,
|
||||
load_in_8bit=self.load_in_8bit,
|
||||
device_map=device_map,
|
||||
)
|
||||
)
|
||||
else:
|
||||
from transformers import (
|
||||
BlipForConditionalGeneration,
|
||||
BlipProcessor,
|
||||
)
|
||||
|
||||
self.load_half = False # not supported
|
||||
if self.caption_gpu:
|
||||
if device_map == "auto":
|
||||
# Blip doesn't support device_map='auto'
|
||||
if self.device == "cuda":
|
||||
if self.gpu_id == "auto":
|
||||
device_map = {"": 0}
|
||||
else:
|
||||
device_map = {"": self.gpu_id}
|
||||
else:
|
||||
device_map = {"": "cpu"}
|
||||
else:
|
||||
device_map = {"": "cpu"}
|
||||
self.processor = BlipProcessor.from_pretrained(
|
||||
self.blip_processor, device_map=device_map
|
||||
)
|
||||
self.model = (
|
||||
BlipForConditionalGeneration.from_pretrained(
|
||||
self.blip_model, device_map=device_map
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
||||
def set_image_paths(self, path_images: Union[str, List[str]]):
|
||||
"""
|
||||
Load from a list of image files
|
||||
"""
|
||||
if isinstance(path_images, str):
|
||||
self.image_paths = [path_images]
|
||||
else:
|
||||
self.image_paths = path_images
|
||||
|
||||
def load(self, prompt=None) -> List[Document]:
|
||||
if self.processor is None or self.model is None:
|
||||
self.load_model()
|
||||
results = []
|
||||
for path_image in self.image_paths:
|
||||
caption, metadata = self._get_captions_and_metadata(
|
||||
model=self.model,
|
||||
processor=self.processor,
|
||||
path_image=path_image,
|
||||
prompt=prompt,
|
||||
)
|
||||
doc = Document(page_content=caption, metadata=metadata)
|
||||
results.append(doc)
|
||||
|
||||
return results
|
||||
|
||||
def _get_captions_and_metadata(
|
||||
self, model: Any, processor: Any, path_image: str, prompt=None
|
||||
) -> Tuple[str, dict]:
|
||||
"""
|
||||
Helper function for getting the captions and metadata of an image
|
||||
"""
|
||||
if prompt is None:
|
||||
prompt = self.prompt
|
||||
try:
|
||||
from PIL import Image
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"`PIL` package not found, please install with `pip install pillow`"
|
||||
)
|
||||
|
||||
try:
|
||||
if path_image.startswith("http://") or path_image.startswith(
|
||||
"https://"
|
||||
):
|
||||
image = Image.open(
|
||||
requests.get(path_image, stream=True).raw
|
||||
).convert("RGB")
|
||||
else:
|
||||
image = Image.open(path_image).convert("RGB")
|
||||
except Exception:
|
||||
raise ValueError(f"Could not get image data for {path_image}")
|
||||
|
||||
import torch
|
||||
|
||||
with torch.no_grad():
|
||||
with self.context_class(self.device):
|
||||
context_class_cast = (
|
||||
NullContext if self.device == "cpu" else torch.autocast
|
||||
)
|
||||
with context_class_cast(self.device):
|
||||
if self.load_half:
|
||||
inputs = processor(
|
||||
image, prompt, return_tensors="pt"
|
||||
).half()
|
||||
else:
|
||||
inputs = processor(image, prompt, return_tensors="pt")
|
||||
min_length = len(prompt) // 4 + self.min_new_tokens
|
||||
self.max_tokens = max(self.max_tokens, min_length)
|
||||
output = model.generate(
|
||||
**inputs,
|
||||
min_length=min_length,
|
||||
max_length=self.max_tokens,
|
||||
)
|
||||
|
||||
caption: str = processor.decode(
|
||||
output[0], skip_special_tokens=True
|
||||
)
|
||||
prompti = caption.find(prompt)
|
||||
if prompti >= 0:
|
||||
caption = caption[prompti + len(prompt) :]
|
||||
metadata: dict = {"image_path": path_image}
|
||||
|
||||
return caption, metadata
|
||||
120
apps/language_models/langchain/langchain_requirements.txt
Normal file
120
apps/language_models/langchain/langchain_requirements.txt
Normal file
@@ -0,0 +1,120 @@
|
||||
# for generate (gradio server) and finetune
|
||||
datasets==2.13.0
|
||||
sentencepiece==0.1.99
|
||||
huggingface_hub==0.16.4
|
||||
appdirs==1.4.4
|
||||
fire==0.5.0
|
||||
docutils==0.20.1
|
||||
evaluate==0.4.0
|
||||
rouge_score==0.1.2
|
||||
sacrebleu==2.3.1
|
||||
scikit-learn==1.2.2
|
||||
alt-profanity-check==1.2.2
|
||||
better-profanity==0.7.0
|
||||
numpy==1.24.3
|
||||
pandas==2.0.2
|
||||
matplotlib==3.7.1
|
||||
loralib==0.1.1
|
||||
bitsandbytes==0.39.0
|
||||
accelerate==0.20.3
|
||||
peft==0.4.0
|
||||
# 4.31.0+ breaks load_in_8bit=True (https://github.com/huggingface/transformers/issues/25026)
|
||||
transformers==4.30.2
|
||||
tokenizers==0.13.3
|
||||
APScheduler==3.10.1
|
||||
|
||||
# optional for generate
|
||||
pynvml==11.5.0
|
||||
psutil==5.9.5
|
||||
boto3==1.26.101
|
||||
botocore==1.29.101
|
||||
|
||||
# optional for finetune
|
||||
tensorboard==2.13.0
|
||||
neptune==1.2.0
|
||||
|
||||
# for gradio client
|
||||
gradio_client==0.2.10
|
||||
beautifulsoup4==4.12.2
|
||||
markdown==3.4.3
|
||||
|
||||
# data and testing
|
||||
pytest==7.2.2
|
||||
pytest-xdist==3.2.1
|
||||
nltk==3.8.1
|
||||
textstat==0.7.3
|
||||
# pandoc==2.3
|
||||
pypandoc==1.11; sys_platform == "darwin" and platform_machine == "arm64"
|
||||
pypandoc_binary==1.11; platform_machine == "x86_64"
|
||||
pypandoc_binary==1.11; sys_platform == "win32"
|
||||
openpyxl==3.1.2
|
||||
lm_dataformat==0.0.20
|
||||
bioc==2.0
|
||||
|
||||
# falcon
|
||||
einops==0.6.1
|
||||
instructorembedding==1.0.1
|
||||
|
||||
# for gpt4all .env file, but avoid worrying about imports
|
||||
python-dotenv==1.0.0
|
||||
|
||||
text-generation==0.6.0
|
||||
# for tokenization when don't have HF tokenizer
|
||||
tiktoken==0.4.0
|
||||
# optional: for OpenAI endpoint or embeddings (requires key)
|
||||
openai==0.27.8
|
||||
|
||||
# optional for chat with PDF
|
||||
langchain==0.0.329
|
||||
pypdf==3.17.0
|
||||
# avoid textract, requires old six
|
||||
#textract==1.6.5
|
||||
|
||||
# for HF embeddings
|
||||
sentence_transformers==2.2.2
|
||||
|
||||
# local vector db
|
||||
chromadb==0.3.25
|
||||
# server vector db
|
||||
#pymilvus==2.2.8
|
||||
|
||||
# weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
|
||||
# unstructured==0.8.1
|
||||
|
||||
# strong support for images
|
||||
# Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libtesseract-dev libreoffice
|
||||
unstructured[local-inference]==0.7.4
|
||||
#pdf2image==1.16.3
|
||||
#pytesseract==0.3.10
|
||||
pillow
|
||||
|
||||
pdfminer.six==20221105
|
||||
urllib3
|
||||
requests_file
|
||||
|
||||
#pdf2image==1.16.3
|
||||
#pytesseract==0.3.10
|
||||
tabulate==0.9.0
|
||||
# FYI pandoc already part of requirements.txt
|
||||
|
||||
# JSONLoader, but makes some trouble for some users
|
||||
# jq==1.4.1
|
||||
|
||||
# to check licenses
|
||||
# Run: pip-licenses|grep -v 'BSD\|Apache\|MIT'
|
||||
pip-licenses==4.3.0
|
||||
|
||||
# weaviate vector db
|
||||
weaviate-client==3.22.1
|
||||
|
||||
gpt4all==1.0.5
|
||||
llama-cpp-python==0.1.73
|
||||
|
||||
arxiv==1.4.8
|
||||
pymupdf==1.22.5 # AGPL license
|
||||
# extract-msg==0.41.1 # GPL3
|
||||
|
||||
# sometimes unstructured fails, these work in those cases. See https://github.com/h2oai/h2ogpt/issues/320
|
||||
playwright==1.36.0
|
||||
# requires Chrome binary to be in path
|
||||
selenium==4.10.0
|
||||
124
apps/language_models/langchain/llama_flash_attn_monkey_patch.py
Normal file
124
apps/language_models/langchain/llama_flash_attn_monkey_patch.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import transformers
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
||||
from flash_attn.bert_padding import unpad_input, pad_input
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[
|
||||
torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]
|
||||
]:
|
||||
"""Input shape: Batch x Time x Channel
|
||||
attention_mask: [bsz, q_len]
|
||||
"""
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = (
|
||||
self.q_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
key_states = (
|
||||
self.k_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
value_states = (
|
||||
self.v_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
# [bsz, q_len, nh, hd]
|
||||
# [bsz, nh, q_len, hd]
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
assert past_key_value is None, "past_key_value is not supported"
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
# [bsz, nh, t, hd]
|
||||
assert not output_attentions, "output_attentions is not supported"
|
||||
assert not use_cache, "use_cache is not supported"
|
||||
|
||||
# Flash attention codes from
|
||||
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
|
||||
|
||||
# transform the data into the format required by flash attention
|
||||
qkv = torch.stack(
|
||||
[query_states, key_states, value_states], dim=2
|
||||
) # [bsz, nh, 3, q_len, hd]
|
||||
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
||||
# the attention_mask should be the same as the key_padding_mask
|
||||
key_padding_mask = attention_mask
|
||||
|
||||
if key_padding_mask is None:
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
max_s = q_len
|
||||
cu_q_lens = torch.arange(
|
||||
0,
|
||||
(bsz + 1) * q_len,
|
||||
step=q_len,
|
||||
dtype=torch.int32,
|
||||
device=qkv.device,
|
||||
)
|
||||
output = flash_attn_unpadded_qkvpacked_func(
|
||||
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||
)
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
else:
|
||||
nheads = qkv.shape[-2]
|
||||
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
||||
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
||||
x_unpad = rearrange(
|
||||
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
|
||||
)
|
||||
output_unpad = flash_attn_unpadded_qkvpacked_func(
|
||||
x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||
)
|
||||
output = rearrange(
|
||||
pad_input(
|
||||
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
|
||||
indices,
|
||||
bsz,
|
||||
q_len,
|
||||
),
|
||||
"b s (h d) -> b s h d",
|
||||
h=nheads,
|
||||
)
|
||||
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||
# requires the attention mask to be the same as the key_padding_mask
|
||||
def _prepare_decoder_attention_mask(
|
||||
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
):
|
||||
# [bsz, seq_len]
|
||||
return attention_mask
|
||||
|
||||
|
||||
def replace_llama_attn_with_flash_attn():
|
||||
print(
|
||||
"Replacing original LLaMa attention with flash attention", flush=True
|
||||
)
|
||||
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
|
||||
_prepare_decoder_attention_mask
|
||||
)
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|
||||
109
apps/language_models/langchain/loaders.py
Normal file
109
apps/language_models/langchain/loaders.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import functools
|
||||
|
||||
|
||||
def get_loaders(model_name, reward_type, llama_type=None, load_gptq=""):
|
||||
# NOTE: Some models need specific new prompt_type
|
||||
# E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
|
||||
if load_gptq:
|
||||
from transformers import AutoTokenizer
|
||||
from auto_gptq import AutoGPTQForCausalLM
|
||||
|
||||
use_triton = False
|
||||
functools.partial(
|
||||
AutoGPTQForCausalLM.from_quantized,
|
||||
quantize_config=None,
|
||||
use_triton=use_triton,
|
||||
)
|
||||
return AutoGPTQForCausalLM.from_quantized, AutoTokenizer
|
||||
if llama_type is None:
|
||||
llama_type = "llama" in model_name.lower()
|
||||
if llama_type:
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
return LlamaForCausalLM.from_pretrained, LlamaTokenizer
|
||||
elif "distilgpt2" in model_name.lower():
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
return AutoModelForCausalLM.from_pretrained, AutoTokenizer
|
||||
elif "gpt2" in model_name.lower():
|
||||
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||||
|
||||
return GPT2LMHeadModel.from_pretrained, GPT2Tokenizer
|
||||
elif "mbart-" in model_name.lower():
|
||||
from transformers import (
|
||||
MBartForConditionalGeneration,
|
||||
MBart50TokenizerFast,
|
||||
)
|
||||
|
||||
return (
|
||||
MBartForConditionalGeneration.from_pretrained,
|
||||
MBart50TokenizerFast,
|
||||
)
|
||||
elif (
|
||||
"t5" == model_name.lower()
|
||||
or "t5-" in model_name.lower()
|
||||
or "flan-" in model_name.lower()
|
||||
):
|
||||
from transformers import AutoTokenizer, T5ForConditionalGeneration
|
||||
|
||||
return T5ForConditionalGeneration.from_pretrained, AutoTokenizer
|
||||
elif "bigbird" in model_name:
|
||||
from transformers import (
|
||||
BigBirdPegasusForConditionalGeneration,
|
||||
AutoTokenizer,
|
||||
)
|
||||
|
||||
return (
|
||||
BigBirdPegasusForConditionalGeneration.from_pretrained,
|
||||
AutoTokenizer,
|
||||
)
|
||||
elif (
|
||||
"bart-large-cnn-samsum" in model_name
|
||||
or "flan-t5-base-samsum" in model_name
|
||||
):
|
||||
from transformers import pipeline
|
||||
|
||||
return pipeline, "summarization"
|
||||
elif (
|
||||
reward_type
|
||||
or "OpenAssistant/reward-model".lower() in model_name.lower()
|
||||
):
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
)
|
||||
|
||||
return (
|
||||
AutoModelForSequenceClassification.from_pretrained,
|
||||
AutoTokenizer,
|
||||
)
|
||||
else:
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
model_loader = AutoModelForCausalLM
|
||||
tokenizer_loader = AutoTokenizer
|
||||
return model_loader.from_pretrained, tokenizer_loader
|
||||
|
||||
|
||||
def get_tokenizer(
|
||||
tokenizer_loader,
|
||||
tokenizer_base_model,
|
||||
local_files_only,
|
||||
resume_download,
|
||||
use_auth_token,
|
||||
):
|
||||
tokenizer = tokenizer_loader.from_pretrained(
|
||||
tokenizer_base_model,
|
||||
local_files_only=local_files_only,
|
||||
resume_download=resume_download,
|
||||
use_auth_token=use_auth_token,
|
||||
padding_side="left",
|
||||
)
|
||||
|
||||
tokenizer.pad_token_id = 0 # different from the eos token
|
||||
# when generating, we will use the logits of right-most token to predict the next token
|
||||
# so the padding should be on the left,
|
||||
# e.g. see: https://huggingface.co/transformers/v4.11.3/model_doc/t5.html#inference
|
||||
tokenizer.padding_side = "left" # Allow batched inference
|
||||
|
||||
return tokenizer
|
||||
203
apps/language_models/langchain/make_db.py
Normal file
203
apps/language_models/langchain/make_db.py
Normal file
@@ -0,0 +1,203 @@
|
||||
import os
|
||||
|
||||
from gpt_langchain import (
|
||||
path_to_docs,
|
||||
get_some_dbs_from_hf,
|
||||
all_db_zips,
|
||||
some_db_zips,
|
||||
create_or_update_db,
|
||||
)
|
||||
from utils import get_ngpus_vis
|
||||
|
||||
|
||||
def glob_to_db(
|
||||
user_path,
|
||||
chunk=True,
|
||||
chunk_size=512,
|
||||
verbose=False,
|
||||
fail_any_exception=False,
|
||||
n_jobs=-1,
|
||||
url=None,
|
||||
enable_captions=True,
|
||||
captions_model=None,
|
||||
caption_loader=None,
|
||||
enable_ocr=False,
|
||||
):
|
||||
sources1 = path_to_docs(
|
||||
user_path,
|
||||
verbose=verbose,
|
||||
fail_any_exception=fail_any_exception,
|
||||
n_jobs=n_jobs,
|
||||
chunk=chunk,
|
||||
chunk_size=chunk_size,
|
||||
url=url,
|
||||
enable_captions=enable_captions,
|
||||
captions_model=captions_model,
|
||||
caption_loader=caption_loader,
|
||||
enable_ocr=enable_ocr,
|
||||
)
|
||||
return sources1
|
||||
|
||||
|
||||
def make_db_main(
|
||||
use_openai_embedding: bool = False,
|
||||
hf_embedding_model: str = None,
|
||||
persist_directory: str = "db_dir_UserData",
|
||||
user_path: str = "user_path",
|
||||
url: str = None,
|
||||
add_if_exists: bool = True,
|
||||
collection_name: str = "UserData",
|
||||
verbose: bool = False,
|
||||
chunk: bool = True,
|
||||
chunk_size: int = 512,
|
||||
fail_any_exception: bool = False,
|
||||
download_all: bool = False,
|
||||
download_some: bool = False,
|
||||
download_one: str = None,
|
||||
download_dest: str = "./",
|
||||
n_jobs: int = -1,
|
||||
enable_captions: bool = True,
|
||||
captions_model: str = "Salesforce/blip-image-captioning-base",
|
||||
pre_load_caption_model: bool = False,
|
||||
caption_gpu: bool = True,
|
||||
enable_ocr: bool = False,
|
||||
db_type: str = "chroma",
|
||||
):
|
||||
"""
|
||||
# To make UserData db for generate.py, put pdfs, etc. into path user_path and run:
|
||||
python make_db.py
|
||||
|
||||
# once db is made, can use in generate.py like:
|
||||
|
||||
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b --langchain_mode=UserData
|
||||
|
||||
or zip-up the db_dir_UserData and share:
|
||||
|
||||
zip -r db_dir_UserData.zip db_dir_UserData
|
||||
|
||||
# To get all db files (except large wiki_full) do:
|
||||
python make_db.py --download_some=True
|
||||
|
||||
# To get a single db file from HF:
|
||||
python make_db.py --download_one=db_dir_DriverlessAI_docs.zip
|
||||
|
||||
:param use_openai_embedding: Whether to use OpenAI embedding
|
||||
:param hf_embedding_model: HF embedding model to use. Like generate.py, uses 'hkunlp/instructor-large' if have GPUs, else "sentence-transformers/all-MiniLM-L6-v2"
|
||||
:param persist_directory: where to persist db
|
||||
:param user_path: where to pull documents from (None means url is not None. If url is not None, this is ignored.)
|
||||
:param url: url to generate documents from (None means user_path is not None)
|
||||
:param add_if_exists: Add to db if already exists, but will not add duplicate sources
|
||||
:param collection_name: Collection name for new db if not adding
|
||||
:param verbose: whether to show verbose messages
|
||||
:param chunk: whether to chunk data
|
||||
:param chunk_size: chunk size for chunking
|
||||
:param fail_any_exception: whether to fail if any exception hit during ingestion of files
|
||||
:param download_all: whether to download all (including 23GB Wikipedia) example databases from h2o.ai HF
|
||||
:param download_some: whether to download some small example databases from h2o.ai HF
|
||||
:param download_one: whether to download one chosen example databases from h2o.ai HF
|
||||
:param download_dest: Destination for downloads
|
||||
:param n_jobs: Number of cores to use for ingesting multiple files
|
||||
:param enable_captions: Whether to enable captions on images
|
||||
:param captions_model: See generate.py
|
||||
:param pre_load_caption_model: See generate.py
|
||||
:param caption_gpu: Caption images on GPU if present
|
||||
:param enable_ocr: Whether to enable OCR on images
|
||||
:param db_type: Type of db to create. Currently only 'chroma' and 'weaviate' is supported.
|
||||
:return: None
|
||||
"""
|
||||
db = None
|
||||
|
||||
# match behavior of main() in generate.py for non-HF case
|
||||
n_gpus = get_ngpus_vis()
|
||||
if n_gpus == 0:
|
||||
if hf_embedding_model is None:
|
||||
# if no GPUs, use simpler embedding model to avoid cost in time
|
||||
hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
else:
|
||||
if hf_embedding_model is None:
|
||||
# if still None, then set default
|
||||
hf_embedding_model = "hkunlp/instructor-large"
|
||||
|
||||
if download_all:
|
||||
print("Downloading all (and unzipping): %s" % all_db_zips, flush=True)
|
||||
get_some_dbs_from_hf(download_dest, db_zips=all_db_zips)
|
||||
if verbose:
|
||||
print("DONE", flush=True)
|
||||
return db, collection_name
|
||||
elif download_some:
|
||||
print(
|
||||
"Downloading some (and unzipping): %s" % some_db_zips, flush=True
|
||||
)
|
||||
get_some_dbs_from_hf(download_dest, db_zips=some_db_zips)
|
||||
if verbose:
|
||||
print("DONE", flush=True)
|
||||
return db, collection_name
|
||||
elif download_one:
|
||||
print("Downloading %s (and unzipping)" % download_one, flush=True)
|
||||
get_some_dbs_from_hf(
|
||||
download_dest, db_zips=[[download_one, "", "Unknown License"]]
|
||||
)
|
||||
if verbose:
|
||||
print("DONE", flush=True)
|
||||
return db, collection_name
|
||||
|
||||
if enable_captions and pre_load_caption_model:
|
||||
# preload, else can be too slow or if on GPU have cuda context issues
|
||||
# Inside ingestion, this will disable parallel loading of multiple other kinds of docs
|
||||
# However, if have many images, all those images will be handled more quickly by preloaded model on GPU
|
||||
from image_captions import H2OImageCaptionLoader
|
||||
|
||||
caption_loader = H2OImageCaptionLoader(
|
||||
None,
|
||||
blip_model=captions_model,
|
||||
blip_processor=captions_model,
|
||||
caption_gpu=caption_gpu,
|
||||
).load_model()
|
||||
else:
|
||||
if enable_captions:
|
||||
caption_loader = "gpu" if caption_gpu else "cpu"
|
||||
else:
|
||||
caption_loader = False
|
||||
|
||||
if verbose:
|
||||
print("Getting sources", flush=True)
|
||||
assert (
|
||||
user_path is not None or url is not None
|
||||
), "Can't have both user_path and url as None"
|
||||
if not url:
|
||||
assert os.path.isdir(user_path), (
|
||||
"user_path=%s does not exist" % user_path
|
||||
)
|
||||
sources = glob_to_db(
|
||||
user_path,
|
||||
chunk=chunk,
|
||||
chunk_size=chunk_size,
|
||||
verbose=verbose,
|
||||
fail_any_exception=fail_any_exception,
|
||||
n_jobs=n_jobs,
|
||||
url=url,
|
||||
enable_captions=enable_captions,
|
||||
captions_model=captions_model,
|
||||
caption_loader=caption_loader,
|
||||
enable_ocr=enable_ocr,
|
||||
)
|
||||
exceptions = [x for x in sources if x.metadata.get("exception")]
|
||||
print("Exceptions: %s" % exceptions, flush=True)
|
||||
sources = [x for x in sources if "exception" not in x.metadata]
|
||||
|
||||
assert len(sources) > 0, "No sources found"
|
||||
db = create_or_update_db(
|
||||
db_type,
|
||||
persist_directory,
|
||||
collection_name,
|
||||
sources,
|
||||
use_openai_embedding,
|
||||
add_if_exists,
|
||||
verbose,
|
||||
hf_embedding_model,
|
||||
)
|
||||
|
||||
assert db is not None
|
||||
if verbose:
|
||||
print("DONE", flush=True)
|
||||
return db, collection_name
|
||||
1103
apps/language_models/langchain/prompter.py
Normal file
1103
apps/language_models/langchain/prompter.py
Normal file
File diff suppressed because it is too large
Load Diff
406
apps/language_models/langchain/read_wiki_full.py
Normal file
406
apps/language_models/langchain/read_wiki_full.py
Normal file
@@ -0,0 +1,406 @@
|
||||
"""Load Data from a MediaWiki dump xml."""
|
||||
|
||||
import ast
|
||||
import glob
|
||||
import pickle
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
import os
|
||||
import bz2
|
||||
import csv
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.document_loaders import MWDumpLoader
|
||||
|
||||
# path where downloaded wiki files exist, to be processed
|
||||
root_path = "/data/jon/h2o-llm"
|
||||
|
||||
|
||||
def unescape(x):
|
||||
try:
|
||||
x = ast.literal_eval(x)
|
||||
except:
|
||||
try:
|
||||
x = x.encode("ascii", "ignore").decode("unicode_escape")
|
||||
except:
|
||||
pass
|
||||
return x
|
||||
|
||||
|
||||
def get_views():
|
||||
# views = pd.read_csv('wiki_page_views_more_1000month.csv')
|
||||
views = pd.read_csv("wiki_page_views_more_5000month.csv")
|
||||
views.index = views["title"]
|
||||
views = views["views"]
|
||||
views = views.to_dict()
|
||||
views = {str(unescape(str(k))): v for k, v in views.items()}
|
||||
views2 = {k.replace("_", " "): v for k, v in views.items()}
|
||||
# views has _ but pages has " "
|
||||
views.update(views2)
|
||||
return views
|
||||
|
||||
|
||||
class MWDumpDirectLoader(MWDumpLoader):
|
||||
def __init__(
|
||||
self,
|
||||
data: str,
|
||||
encoding: Optional[str] = "utf8",
|
||||
title_words_limit=None,
|
||||
use_views=True,
|
||||
verbose=True,
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self.data = data
|
||||
self.encoding = encoding
|
||||
self.title_words_limit = title_words_limit
|
||||
self.verbose = verbose
|
||||
if use_views:
|
||||
# self.views = get_views()
|
||||
# faster to use global shared values
|
||||
self.views = global_views
|
||||
else:
|
||||
self.views = None
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load from file path."""
|
||||
import mwparserfromhell
|
||||
import mwxml
|
||||
|
||||
dump = mwxml.Dump.from_page_xml(self.data)
|
||||
|
||||
docs = []
|
||||
|
||||
for page in dump.pages:
|
||||
if self.views is not None and page.title not in self.views:
|
||||
if self.verbose:
|
||||
print("Skipped %s low views" % page.title, flush=True)
|
||||
continue
|
||||
for revision in page:
|
||||
if self.title_words_limit is not None:
|
||||
num_words = len(" ".join(page.title.split("_")).split(" "))
|
||||
if num_words > self.title_words_limit:
|
||||
if self.verbose:
|
||||
print("Skipped %s" % page.title, flush=True)
|
||||
continue
|
||||
if self.verbose:
|
||||
if self.views is not None:
|
||||
print(
|
||||
"Kept %s views: %s"
|
||||
% (page.title, self.views[page.title]),
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
print("Kept %s" % page.title, flush=True)
|
||||
|
||||
code = mwparserfromhell.parse(revision.text)
|
||||
text = code.strip_code(
|
||||
normalize=True, collapse=True, keep_template_params=False
|
||||
)
|
||||
title_url = str(page.title).replace(" ", "_")
|
||||
metadata = dict(
|
||||
title=page.title,
|
||||
source="https://en.wikipedia.org/wiki/" + title_url,
|
||||
id=page.id,
|
||||
redirect=page.redirect,
|
||||
views=(
|
||||
self.views[page.title]
|
||||
if self.views is not None
|
||||
else -1
|
||||
),
|
||||
)
|
||||
metadata = {k: v for k, v in metadata.items() if v is not None}
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
def search_index(search_term, index_filename):
|
||||
byte_flag = False
|
||||
data_length = start_byte = 0
|
||||
index_file = open(index_filename, "r")
|
||||
csv_reader = csv.reader(index_file, delimiter=":")
|
||||
for line in csv_reader:
|
||||
if not byte_flag and search_term == line[2]:
|
||||
start_byte = int(line[0])
|
||||
byte_flag = True
|
||||
elif byte_flag and int(line[0]) != start_byte:
|
||||
data_length = int(line[0]) - start_byte
|
||||
break
|
||||
index_file.close()
|
||||
return start_byte, data_length
|
||||
|
||||
|
||||
def get_start_bytes(index_filename):
|
||||
index_file = open(index_filename, "r")
|
||||
csv_reader = csv.reader(index_file, delimiter=":")
|
||||
start_bytes = set()
|
||||
for line in csv_reader:
|
||||
start_bytes.add(int(line[0]))
|
||||
index_file.close()
|
||||
return sorted(start_bytes)
|
||||
|
||||
|
||||
def get_wiki_filenames():
|
||||
# requires
|
||||
# wget http://ftp.acc.umu.se/mirror/wikimedia.org/dumps/enwiki/20230401/enwiki-20230401-pages-articles-multistream-index.txt.bz2
|
||||
base_path = os.path.join(
|
||||
root_path, "enwiki-20230401-pages-articles-multistream"
|
||||
)
|
||||
index_file = "enwiki-20230401-pages-articles-multistream-index.txt"
|
||||
index_filename = os.path.join(base_path, index_file)
|
||||
wiki_filename = os.path.join(
|
||||
base_path, "enwiki-20230401-pages-articles-multistream.xml.bz2"
|
||||
)
|
||||
return index_filename, wiki_filename
|
||||
|
||||
|
||||
def get_documents_by_search_term(search_term):
|
||||
index_filename, wiki_filename = get_wiki_filenames()
|
||||
start_byte, data_length = search_index(search_term, index_filename)
|
||||
with open(wiki_filename, "rb") as wiki_file:
|
||||
wiki_file.seek(start_byte)
|
||||
data = bz2.BZ2Decompressor().decompress(wiki_file.read(data_length))
|
||||
|
||||
loader = MWDumpDirectLoader(data.decode())
|
||||
documents = loader.load()
|
||||
return documents
|
||||
|
||||
|
||||
def get_one_chunk(
|
||||
wiki_filename,
|
||||
start_byte,
|
||||
end_byte,
|
||||
return_file=True,
|
||||
title_words_limit=None,
|
||||
use_views=True,
|
||||
):
|
||||
data_length = end_byte - start_byte
|
||||
with open(wiki_filename, "rb") as wiki_file:
|
||||
wiki_file.seek(start_byte)
|
||||
data = bz2.BZ2Decompressor().decompress(wiki_file.read(data_length))
|
||||
|
||||
loader = MWDumpDirectLoader(
|
||||
data.decode(), title_words_limit=title_words_limit, use_views=use_views
|
||||
)
|
||||
documents1 = loader.load()
|
||||
if return_file:
|
||||
base_tmp = "temp_wiki"
|
||||
if not os.path.isdir(base_tmp):
|
||||
os.makedirs(base_tmp, exist_ok=True)
|
||||
filename = os.path.join(base_tmp, str(uuid.uuid4()) + ".tmp.pickle")
|
||||
with open(filename, "wb") as f:
|
||||
pickle.dump(documents1, f)
|
||||
return filename
|
||||
return documents1
|
||||
|
||||
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
global_views = get_views()
|
||||
|
||||
|
||||
def get_all_documents(small_test=2, n_jobs=None, use_views=True):
|
||||
print("DO get all wiki docs: %s" % small_test, flush=True)
|
||||
index_filename, wiki_filename = get_wiki_filenames()
|
||||
start_bytes = get_start_bytes(index_filename)
|
||||
end_bytes = start_bytes[1:]
|
||||
start_bytes = start_bytes[:-1]
|
||||
|
||||
if small_test:
|
||||
start_bytes = start_bytes[:small_test]
|
||||
end_bytes = end_bytes[:small_test]
|
||||
if n_jobs is None:
|
||||
n_jobs = 5
|
||||
else:
|
||||
if n_jobs is None:
|
||||
n_jobs = os.cpu_count() // 4
|
||||
|
||||
# default loky backend leads to name space conflict problems
|
||||
return_file = True # large return from joblib hangs
|
||||
documents = Parallel(n_jobs=n_jobs, verbose=10, backend="multiprocessing")(
|
||||
delayed(get_one_chunk)(
|
||||
wiki_filename,
|
||||
start_byte,
|
||||
end_byte,
|
||||
return_file=return_file,
|
||||
use_views=use_views,
|
||||
)
|
||||
for start_byte, end_byte in zip(start_bytes, end_bytes)
|
||||
)
|
||||
if return_file:
|
||||
# then documents really are files
|
||||
files = documents.copy()
|
||||
documents = []
|
||||
for fil in files:
|
||||
with open(fil, "rb") as f:
|
||||
documents.extend(pickle.load(f))
|
||||
os.remove(fil)
|
||||
else:
|
||||
from functools import reduce
|
||||
from operator import concat
|
||||
|
||||
documents = reduce(concat, documents)
|
||||
assert isinstance(documents, list)
|
||||
|
||||
print("DONE get all wiki docs", flush=True)
|
||||
return documents
|
||||
|
||||
|
||||
def test_by_search_term():
|
||||
search_term = "Apollo"
|
||||
assert len(get_documents_by_search_term(search_term)) == 100
|
||||
|
||||
search_term = "Abstract (law)"
|
||||
assert len(get_documents_by_search_term(search_term)) == 100
|
||||
|
||||
search_term = "Artificial languages"
|
||||
assert len(get_documents_by_search_term(search_term)) == 100
|
||||
|
||||
|
||||
def test_start_bytes():
|
||||
index_filename, wiki_filename = get_wiki_filenames()
|
||||
assert len(get_start_bytes(index_filename)) == 227850
|
||||
|
||||
|
||||
def test_get_all_documents():
|
||||
small_test = 20 # 227850
|
||||
n_jobs = os.cpu_count() // 4
|
||||
|
||||
assert (
|
||||
len(
|
||||
get_all_documents(
|
||||
small_test=small_test, n_jobs=n_jobs, use_views=False
|
||||
)
|
||||
)
|
||||
== small_test * 100
|
||||
)
|
||||
|
||||
assert (
|
||||
len(
|
||||
get_all_documents(
|
||||
small_test=small_test, n_jobs=n_jobs, use_views=True
|
||||
)
|
||||
)
|
||||
== 429
|
||||
)
|
||||
|
||||
|
||||
def get_one_pageviews(fil):
|
||||
df1 = pd.read_csv(
|
||||
fil,
|
||||
sep=" ",
|
||||
header=None,
|
||||
names=["region", "title", "views", "foo"],
|
||||
quoting=csv.QUOTE_NONE,
|
||||
)
|
||||
df1.index = df1["title"]
|
||||
df1 = df1[df1["region"] == "en"]
|
||||
df1 = df1.drop("region", axis=1)
|
||||
df1 = df1.drop("foo", axis=1)
|
||||
df1 = df1.drop("title", axis=1) # already index
|
||||
|
||||
base_tmp = "temp_wiki_pageviews"
|
||||
if not os.path.isdir(base_tmp):
|
||||
os.makedirs(base_tmp, exist_ok=True)
|
||||
filename = os.path.join(base_tmp, str(uuid.uuid4()) + ".tmp.csv")
|
||||
df1.to_csv(filename, index=True)
|
||||
return filename
|
||||
|
||||
|
||||
def test_agg_pageviews(gen_files=False):
|
||||
if gen_files:
|
||||
path = os.path.join(
|
||||
root_path,
|
||||
"wiki_pageviews/dumps.wikimedia.org/other/pageviews/2023/2023-04",
|
||||
)
|
||||
files = glob.glob(os.path.join(path, "pageviews*.gz"))
|
||||
# files = files[:2] # test
|
||||
n_jobs = os.cpu_count() // 2
|
||||
csv_files = Parallel(
|
||||
n_jobs=n_jobs, verbose=10, backend="multiprocessing"
|
||||
)(delayed(get_one_pageviews)(fil) for fil in files)
|
||||
else:
|
||||
# to continue without redoing above
|
||||
csv_files = glob.glob(
|
||||
os.path.join(root_path, "temp_wiki_pageviews/*.csv")
|
||||
)
|
||||
|
||||
df_list = []
|
||||
for csv_file in csv_files:
|
||||
print(csv_file)
|
||||
df1 = pd.read_csv(csv_file)
|
||||
df_list.append(df1)
|
||||
df = pd.concat(df_list, axis=0)
|
||||
df = df.groupby("title")["views"].sum().reset_index()
|
||||
df.to_csv("wiki_page_views.csv", index=True)
|
||||
|
||||
|
||||
def test_reduce_pageview():
|
||||
filename = "wiki_page_views.csv"
|
||||
df = pd.read_csv(filename)
|
||||
df = df[df["views"] < 1e7]
|
||||
#
|
||||
plt.hist(df["views"], bins=100, log=True)
|
||||
views_avg = np.mean(df["views"])
|
||||
views_median = np.median(df["views"])
|
||||
plt.title("Views avg: %s median: %s" % (views_avg, views_median))
|
||||
plt.savefig(filename.replace(".csv", ".png"))
|
||||
plt.close()
|
||||
#
|
||||
views_limit = 5000
|
||||
df = df[df["views"] > views_limit]
|
||||
filename = "wiki_page_views_more_5000month.csv"
|
||||
df.to_csv(filename, index=True)
|
||||
#
|
||||
plt.hist(df["views"], bins=100, log=True)
|
||||
views_avg = np.mean(df["views"])
|
||||
views_median = np.median(df["views"])
|
||||
plt.title("Views avg: %s median: %s" % (views_avg, views_median))
|
||||
plt.savefig(filename.replace(".csv", ".png"))
|
||||
plt.close()
|
||||
|
||||
|
||||
@pytest.mark.skip("Only if doing full processing again, some manual steps")
|
||||
def test_do_wiki_full_all():
|
||||
# Install other requirements for wiki specific conversion:
|
||||
# pip install -r reqs_optional/requirements_optional_wikiprocessing.txt
|
||||
|
||||
# Use "Transmission" in Ubuntu to get wiki dump using torrent:
|
||||
# See: https://meta.wikimedia.org/wiki/Data_dump_torrents
|
||||
# E.g. magnet:?xt=urn:btih:b2c74af2b1531d0b63f1166d2011116f44a8fed0&dn=enwiki-20230401-pages-articles-multistream.xml.bz2&tr=udp%3A%2F%2Ftracker.opentrackr.org%3A1337
|
||||
|
||||
# Get index
|
||||
os.system(
|
||||
"wget http://ftp.acc.umu.se/mirror/wikimedia.org/dumps/enwiki/20230401/enwiki-20230401-pages-articles-multistream-index.txt.bz2"
|
||||
)
|
||||
|
||||
# Test that can use LangChain to get docs from subset of wiki as sampled out of full wiki directly using bzip multistream
|
||||
test_get_all_documents()
|
||||
|
||||
# Check can search wiki multistream
|
||||
test_by_search_term()
|
||||
|
||||
# Test can get all start bytes in index
|
||||
test_start_bytes()
|
||||
|
||||
# Get page views, e.g. for entire month of April 2023
|
||||
os.system(
|
||||
"wget -b -m -k -o wget.log -e robots=off https://dumps.wikimedia.org/other/pageviews/2023/2023-04/"
|
||||
)
|
||||
|
||||
# Aggregate page views from many files into single file
|
||||
test_agg_pageviews(gen_files=True)
|
||||
|
||||
# Reduce page views to some limit, so processing of full wiki is not too large
|
||||
test_reduce_pageview()
|
||||
|
||||
# Start generate.py with requesting wiki_full in prep. This will use page views as referenced in get_views.
|
||||
# Note get_views as global() function done once is required to avoid very slow processing
|
||||
# WARNING: Requires alot of memory to handle, used up to 300GB system RAM at peak
|
||||
"""
|
||||
python generate.py --langchain_mode='wiki_full' --visible_langchain_modes="['wiki_full', 'UserData', 'MyData', 'github h2oGPT', 'DriverlessAI docs']" &> lc_out.log
|
||||
"""
|
||||
121
apps/language_models/langchain/stopping.py
Normal file
121
apps/language_models/langchain/stopping.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import torch
|
||||
from transformers import StoppingCriteria, StoppingCriteriaList
|
||||
|
||||
from enums import PromptType
|
||||
|
||||
|
||||
class StoppingCriteriaSub(StoppingCriteria):
|
||||
def __init__(
|
||||
self, stops=[], encounters=[], device="cuda", model_max_length=None
|
||||
):
|
||||
super().__init__()
|
||||
assert (
|
||||
len(stops) % len(encounters) == 0
|
||||
), "Number of stops and encounters must match"
|
||||
self.encounters = encounters
|
||||
self.stops = [stop.to(device) for stop in stops]
|
||||
self.num_stops = [0] * len(stops)
|
||||
self.model_max_length = model_max_length
|
||||
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
||||
) -> bool:
|
||||
for stopi, stop in enumerate(self.stops):
|
||||
if torch.all((stop == input_ids[0][-len(stop) :])).item():
|
||||
self.num_stops[stopi] += 1
|
||||
if (
|
||||
self.num_stops[stopi]
|
||||
>= self.encounters[stopi % len(self.encounters)]
|
||||
):
|
||||
# print("Stopped", flush=True)
|
||||
return True
|
||||
if (
|
||||
self.model_max_length is not None
|
||||
and input_ids[0].shape[0] >= self.model_max_length
|
||||
):
|
||||
# critical limit
|
||||
return True
|
||||
# print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
|
||||
# print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
|
||||
return False
|
||||
|
||||
|
||||
def get_stopping(
|
||||
prompt_type,
|
||||
prompt_dict,
|
||||
tokenizer,
|
||||
device,
|
||||
human="<human>:",
|
||||
bot="<bot>:",
|
||||
model_max_length=None,
|
||||
):
|
||||
# FIXME: prompt_dict unused currently
|
||||
if prompt_type in [
|
||||
PromptType.human_bot.name,
|
||||
PromptType.instruct_vicuna.name,
|
||||
PromptType.instruct_with_end.name,
|
||||
]:
|
||||
if prompt_type == PromptType.human_bot.name:
|
||||
# encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
|
||||
# stopping only starts once output is beyond prompt
|
||||
# 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
|
||||
stop_words = [human, bot, "\n" + human, "\n" + bot]
|
||||
encounters = [1, 2]
|
||||
elif prompt_type == PromptType.instruct_vicuna.name:
|
||||
# even below is not enough, generic strings and many ways to encode
|
||||
stop_words = [
|
||||
"### Human:",
|
||||
"""
|
||||
### Human:""",
|
||||
"""
|
||||
### Human:
|
||||
""",
|
||||
"### Assistant:",
|
||||
"""
|
||||
### Assistant:""",
|
||||
"""
|
||||
### Assistant:
|
||||
""",
|
||||
]
|
||||
encounters = [1, 2]
|
||||
else:
|
||||
# some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
|
||||
stop_words = ["### End"]
|
||||
encounters = [1]
|
||||
stop_words_ids = [
|
||||
tokenizer(stop_word, return_tensors="pt")["input_ids"].squeeze()
|
||||
for stop_word in stop_words
|
||||
]
|
||||
# handle single token case
|
||||
stop_words_ids = [
|
||||
x if len(x.shape) > 0 else torch.tensor([x])
|
||||
for x in stop_words_ids
|
||||
]
|
||||
stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
|
||||
# avoid padding in front of tokens
|
||||
if (
|
||||
tokenizer._pad_token
|
||||
): # use hidden variable to avoid annoying properly logger bug
|
||||
stop_words_ids = [
|
||||
x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x
|
||||
for x in stop_words_ids
|
||||
]
|
||||
# handle fake \n added
|
||||
stop_words_ids = [
|
||||
x[1:] if y[0] == "\n" else x
|
||||
for x, y in zip(stop_words_ids, stop_words)
|
||||
]
|
||||
# build stopper
|
||||
stopping_criteria = StoppingCriteriaList(
|
||||
[
|
||||
StoppingCriteriaSub(
|
||||
stops=stop_words_ids,
|
||||
encounters=encounters,
|
||||
device=device,
|
||||
model_max_length=model_max_length,
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
stopping_criteria = StoppingCriteriaList()
|
||||
return stopping_criteria
|
||||
1070
apps/language_models/langchain/utils.py
Normal file
1070
apps/language_models/langchain/utils.py
Normal file
File diff suppressed because it is too large
Load Diff
69
apps/language_models/langchain/utils_langchain.py
Normal file
69
apps/language_models/langchain/utils_langchain.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
import time
|
||||
import queue
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
|
||||
class StreamingGradioCallbackHandler(BaseCallbackHandler):
|
||||
"""
|
||||
Similar to H2OTextIteratorStreamer that is for HF backend, but here LangChain backend
|
||||
"""
|
||||
|
||||
def __init__(self, timeout: Optional[float] = None, block=True):
|
||||
super().__init__()
|
||||
self.text_queue = queue.SimpleQueue()
|
||||
self.stop_signal = None
|
||||
self.do_stop = False
|
||||
self.timeout = timeout
|
||||
self.block = block
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM starts running. Clean the queue."""
|
||||
while not self.text_queue.empty():
|
||||
try:
|
||||
self.text_queue.get(block=False)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
self.text_queue.put(token)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
self.text_queue.put(self.stop_signal)
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.text_queue.put(self.stop_signal)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
while True:
|
||||
try:
|
||||
value = (
|
||||
self.stop_signal
|
||||
) # value looks unused in pycharm, not true
|
||||
if self.do_stop:
|
||||
print("hit stop", flush=True)
|
||||
# could raise or break, maybe best to raise and make parent see if any exception in thread
|
||||
raise StopIteration()
|
||||
# break
|
||||
value = self.text_queue.get(
|
||||
block=self.block, timeout=self.timeout
|
||||
)
|
||||
break
|
||||
except queue.Empty:
|
||||
time.sleep(0.01)
|
||||
if value == self.stop_signal:
|
||||
raise StopIteration()
|
||||
else:
|
||||
return value
|
||||
442
apps/language_models/scripts/llama_ir_conversion_utils.py
Normal file
442
apps/language_models/scripts/llama_ir_conversion_utils.py
Normal file
@@ -0,0 +1,442 @@
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
from argparse import RawTextHelpFormatter
|
||||
import re, gc
|
||||
|
||||
"""
|
||||
This script can be used as a standalone utility to convert IRs to dynamic + combine them.
|
||||
Following are the various ways this script can be used :-
|
||||
a. To convert a single Linalg IR to dynamic IR:
|
||||
--dynamic --first_ir_path=<PATH TO FIRST IR>
|
||||
b. To convert two Linalg IRs to dynamic IR:
|
||||
--dynamic --first_ir_path=<PATH TO SECOND IR> --first_ir_path=<PATH TO SECOND IR>
|
||||
c. To combine two Linalg IRs into one:
|
||||
--combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>
|
||||
d. To convert both IRs into dynamic as well as combine the IRs:
|
||||
--dynamic --combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>
|
||||
|
||||
NOTE: For dynamic you'll also need to provide the following set of flags:-
|
||||
i. For First Llama : --dynamic_input_size (DEFAULT: 19)
|
||||
ii. For Second Llama: --model_name (DEFAULT: llama2_7b)
|
||||
--precision (DEFAULT: 'int4')
|
||||
You may use --save_dynamic to also save the dynamic IR in option d above.
|
||||
Else for option a. and b. the dynamic IR(s) will get saved by default.
|
||||
"""
|
||||
|
||||
|
||||
def combine_mlir_scripts(
|
||||
first_vicuna_mlir,
|
||||
second_vicuna_mlir,
|
||||
output_name,
|
||||
return_ir=True,
|
||||
):
|
||||
print(f"[DEBUG] combining first and second mlir")
|
||||
print(f"[DEBUG] output_name = {output_name}")
|
||||
maps1 = []
|
||||
maps2 = []
|
||||
constants = set()
|
||||
f1 = []
|
||||
f2 = []
|
||||
|
||||
print(f"[DEBUG] processing first vicuna mlir")
|
||||
first_vicuna_mlir = first_vicuna_mlir.splitlines()
|
||||
while first_vicuna_mlir:
|
||||
line = first_vicuna_mlir.pop(0)
|
||||
if re.search("#map\d*\s*=", line):
|
||||
maps1.append(line)
|
||||
elif re.search("arith.constant", line):
|
||||
constants.add(line)
|
||||
elif not re.search("module", line):
|
||||
line = re.sub("forward", "first_vicuna_forward", line)
|
||||
f1.append(line)
|
||||
f1 = f1[:-1]
|
||||
del first_vicuna_mlir
|
||||
gc.collect()
|
||||
|
||||
for i, map_line in enumerate(maps1):
|
||||
map_var = map_line.split(" ")[0]
|
||||
map_line = re.sub(f"{map_var}(?!\d)", map_var + "_0", map_line)
|
||||
maps1[i] = map_line
|
||||
f1 = [
|
||||
re.sub(f"{map_var}(?!\d)", map_var + "_0", func_line)
|
||||
for func_line in f1
|
||||
]
|
||||
|
||||
print(f"[DEBUG] processing second vicuna mlir")
|
||||
second_vicuna_mlir = second_vicuna_mlir.splitlines()
|
||||
while second_vicuna_mlir:
|
||||
line = second_vicuna_mlir.pop(0)
|
||||
if re.search("#map\d*\s*=", line):
|
||||
maps2.append(line)
|
||||
elif "global_seed" in line:
|
||||
continue
|
||||
elif re.search("arith.constant", line):
|
||||
constants.add(line)
|
||||
elif not re.search("module", line):
|
||||
line = re.sub("forward", "second_vicuna_forward", line)
|
||||
f2.append(line)
|
||||
f2 = f2[:-1]
|
||||
del second_vicuna_mlir
|
||||
gc.collect()
|
||||
|
||||
for i, map_line in enumerate(maps2):
|
||||
map_var = map_line.split(" ")[0]
|
||||
map_line = re.sub(f"{map_var}(?!\d)", map_var + "_1", map_line)
|
||||
maps2[i] = map_line
|
||||
f2 = [
|
||||
re.sub(f"{map_var}(?!\d)", map_var + "_1", func_line)
|
||||
for func_line in f2
|
||||
]
|
||||
|
||||
module_start = 'module attributes {torch.debug_module_name = "_lambda"} {'
|
||||
module_end = "}"
|
||||
|
||||
global_vars = []
|
||||
vnames = []
|
||||
global_var_loading1 = []
|
||||
global_var_loading2 = []
|
||||
|
||||
print(f"[DEBUG] processing constants")
|
||||
counter = 0
|
||||
constants = list(constants)
|
||||
while constants:
|
||||
constant = constants.pop(0)
|
||||
vname, vbody = constant.split("=")
|
||||
vname = re.sub("%", "", vname)
|
||||
vname = vname.strip()
|
||||
vbody = re.sub("arith.constant", "", vbody)
|
||||
vbody = vbody.strip()
|
||||
if len(vbody.split(":")) < 2:
|
||||
print(constant)
|
||||
vdtype = vbody.split(":")[-1].strip()
|
||||
fixed_vdtype = vdtype
|
||||
if "c1_i64" in vname:
|
||||
print(constant)
|
||||
counter += 1
|
||||
if counter == 2:
|
||||
counter = 0
|
||||
print("detected duplicate")
|
||||
continue
|
||||
vnames.append(vname)
|
||||
if "true" not in vname:
|
||||
global_vars.append(
|
||||
f"ml_program.global private @{vname}({vbody}) : {fixed_vdtype}"
|
||||
)
|
||||
global_var_loading1.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
|
||||
)
|
||||
global_var_loading2.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
|
||||
)
|
||||
else:
|
||||
global_vars.append(
|
||||
f"ml_program.global private @{vname}({vbody}) : i1"
|
||||
)
|
||||
global_var_loading1.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
|
||||
)
|
||||
global_var_loading2.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
|
||||
)
|
||||
|
||||
new_f1, new_f2 = [], []
|
||||
|
||||
print(f"[DEBUG] processing f1")
|
||||
for line in f1:
|
||||
if "func.func" in line:
|
||||
new_f1.append(line)
|
||||
for global_var in global_var_loading1:
|
||||
new_f1.append(global_var)
|
||||
else:
|
||||
new_f1.append(line)
|
||||
|
||||
print(f"[DEBUG] processing f2")
|
||||
for line in f2:
|
||||
if "func.func" in line:
|
||||
new_f2.append(line)
|
||||
for global_var in global_var_loading2:
|
||||
if (
|
||||
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64"
|
||||
in global_var
|
||||
):
|
||||
print(global_var)
|
||||
new_f2.append(global_var)
|
||||
else:
|
||||
new_f2.append(line)
|
||||
|
||||
f1 = new_f1
|
||||
f2 = new_f2
|
||||
|
||||
del new_f1
|
||||
del new_f2
|
||||
gc.collect()
|
||||
|
||||
print(
|
||||
[
|
||||
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64" in x
|
||||
for x in [maps1, maps2, global_vars, f1, f2]
|
||||
]
|
||||
)
|
||||
|
||||
# doing it this way rather than assembling the whole string
|
||||
# to prevent OOM with 64GiB RAM when encoding the file.
|
||||
|
||||
print(f"[DEBUG] Saving mlir to {output_name}")
|
||||
with open(output_name, "w+") as f_:
|
||||
f_.writelines(line + "\n" for line in maps1)
|
||||
f_.writelines(line + "\n" for line in maps2)
|
||||
f_.writelines(line + "\n" for line in [module_start])
|
||||
f_.writelines(line + "\n" for line in global_vars)
|
||||
f_.writelines(line + "\n" for line in f1)
|
||||
f_.writelines(line + "\n" for line in f2)
|
||||
f_.writelines(line + "\n" for line in [module_end])
|
||||
|
||||
del maps1
|
||||
del maps2
|
||||
del module_start
|
||||
del global_vars
|
||||
del f1
|
||||
del f2
|
||||
del module_end
|
||||
gc.collect()
|
||||
|
||||
if return_ir:
|
||||
print(f"[DEBUG] Reading combined mlir back in")
|
||||
with open(output_name, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def write_in_dynamic_inputs0(module, dynamic_input_size):
|
||||
print("[DEBUG] writing dynamic inputs to first vicuna")
|
||||
# Current solution for ensuring mlir files support dynamic inputs
|
||||
# TODO: find a more elegant way to implement this
|
||||
new_lines = []
|
||||
module = module.splitlines()
|
||||
while module:
|
||||
line = module.pop(0)
|
||||
line = re.sub(f"{dynamic_input_size}x", "?x", line)
|
||||
if "?x" in line:
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
|
||||
line = re.sub(f" {dynamic_input_size},", " %dim,", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub(f"c{dynamic_input_size}", "dim", line)
|
||||
if "%0 = tensor.empty(%dim) : tensor<?xi64>" in line:
|
||||
new_lines.append("%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>")
|
||||
if "%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>" in line:
|
||||
continue
|
||||
|
||||
new_lines.append(line)
|
||||
return "\n".join(new_lines)
|
||||
|
||||
|
||||
def write_in_dynamic_inputs1(module, model_name, precision):
|
||||
print("[DEBUG] writing dynamic inputs to second vicuna")
|
||||
|
||||
def remove_constant_dim(line):
|
||||
if "c19_i64" in line:
|
||||
line = re.sub("c19_i64", "dim_i64", line)
|
||||
if "19x" in line:
|
||||
line = re.sub("19x", "?x", line)
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)",
|
||||
"tensor.empty(%dim, %dim)",
|
||||
line,
|
||||
)
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub("c19", "dim", line)
|
||||
if " 19," in line:
|
||||
line = re.sub(" 19,", " %dim,", line)
|
||||
if "x20x" in line or "<20x" in line:
|
||||
line = re.sub("20x", "?x", line)
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dimp1)", line)
|
||||
if " 20," in line:
|
||||
line = re.sub(" 20,", " %dimp1,", line)
|
||||
return line
|
||||
|
||||
module = module.splitlines()
|
||||
new_lines = []
|
||||
|
||||
# Using a while loop and the pop method to avoid creating a copy of module
|
||||
if "llama2_13b" in model_name:
|
||||
pkv_tensor_shape = "tensor<1x40x?x128x"
|
||||
elif "llama2_70b" in model_name:
|
||||
pkv_tensor_shape = "tensor<1x8x?x128x"
|
||||
else:
|
||||
pkv_tensor_shape = "tensor<1x32x?x128x"
|
||||
if precision in ["fp16", "int4", "int8"]:
|
||||
pkv_tensor_shape += "f16>"
|
||||
else:
|
||||
pkv_tensor_shape += "f32>"
|
||||
|
||||
while module:
|
||||
line = module.pop(0)
|
||||
if "%c19_i64 = arith.constant 19 : i64" in line:
|
||||
new_lines.append("%c2 = arith.constant 2 : index")
|
||||
new_lines.append(
|
||||
f"%dim_4_int = tensor.dim %arg1, %c2 : {pkv_tensor_shape}"
|
||||
)
|
||||
new_lines.append(
|
||||
"%dim_i64 = arith.index_cast %dim_4_int : index to i64"
|
||||
)
|
||||
continue
|
||||
if "%c2 = arith.constant 2 : index" in line:
|
||||
continue
|
||||
if "%c20_i64 = arith.constant 20 : i64" in line:
|
||||
new_lines.append("%c1_i64 = arith.constant 1 : i64")
|
||||
new_lines.append("%c20_i64 = arith.addi %dim_i64, %c1_i64 : i64")
|
||||
new_lines.append(
|
||||
"%dimp1 = arith.index_cast %c20_i64 : i64 to index"
|
||||
)
|
||||
continue
|
||||
line = remove_constant_dim(line)
|
||||
new_lines.append(line)
|
||||
|
||||
return "\n".join(new_lines)
|
||||
|
||||
|
||||
def save_dynamic_ir(ir_to_save, output_file):
|
||||
if not ir_to_save:
|
||||
return
|
||||
# We only get string output from the dynamic conversion utility.
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
with open(output_file, "w") as f:
|
||||
with redirect_stdout(f):
|
||||
print(ir_to_save)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="llama ir utility",
|
||||
description="\tThis script can be used as a standalone utility to convert IRs to dynamic + combine them.\n"
|
||||
+ "\tFollowing are the various ways this script can be used :-\n"
|
||||
+ "\t\ta. To convert a single Linalg IR to dynamic IR:\n"
|
||||
+ "\t\t\t--dynamic --first_ir_path=<PATH TO FIRST IR>\n"
|
||||
+ "\t\tb. To convert two Linalg IRs to dynamic IR:\n"
|
||||
+ "\t\t\t--dynamic --first_ir_path=<PATH TO SECOND IR> --first_ir_path=<PATH TO SECOND IR>\n"
|
||||
+ "\t\tc. To combine two Linalg IRs into one:\n"
|
||||
+ "\t\t\t--combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>\n"
|
||||
+ "\t\td. To convert both IRs into dynamic as well as combine the IRs:\n"
|
||||
+ "\t\t\t--dynamic --combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>\n\n"
|
||||
+ "\tNOTE: For dynamic you'll also need to provide the following set of flags:-\n"
|
||||
+ "\t\t i. For First Llama : --dynamic_input_size (DEFAULT: 19)\n"
|
||||
+ "\t\tii. For Second Llama: --model_name (DEFAULT: llama2_7b)\n"
|
||||
+ "\t\t\t--precision (DEFAULT: 'int4')\n"
|
||||
+ "\t You may use --save_dynamic to also save the dynamic IR in option d above.\n"
|
||||
+ "\t Else for option a. and b. the dynamic IR(s) will get saved by default.\n",
|
||||
formatter_class=RawTextHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
"-p",
|
||||
default="int4",
|
||||
choices=["fp32", "fp16", "int8", "int4"],
|
||||
help="Precision of the concerned IR",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="llama2_7b",
|
||||
choices=["vicuna", "llama2_7b", "llama2_13b", "llama2_70b"],
|
||||
help="Specify which model to run.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--first_ir_path",
|
||||
default=None,
|
||||
help="path to first llama mlir file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--second_ir_path",
|
||||
default=None,
|
||||
help="path to second llama mlir file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dynamic_input_size",
|
||||
type=int,
|
||||
default=19,
|
||||
help="Specify the static input size to replace with dynamic dim.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dynamic",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Converts the IR(s) to dynamic",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_dynamic",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Save the individual IR(s) after converting to dynamic",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--combine",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Converts the IR(s) to dynamic",
|
||||
)
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
dynamic = args.dynamic
|
||||
combine = args.combine
|
||||
assert (
|
||||
dynamic or combine
|
||||
), "neither `dynamic` nor `combine` flag is turned on"
|
||||
first_ir_path = args.first_ir_path
|
||||
second_ir_path = args.second_ir_path
|
||||
assert first_ir_path or second_ir_path, "no input ir has been provided"
|
||||
if combine:
|
||||
assert (
|
||||
first_ir_path and second_ir_path
|
||||
), "you will need to provide both IRs to combine"
|
||||
precision = args.precision
|
||||
model_name = args.model_name
|
||||
dynamic_input_size = args.dynamic_input_size
|
||||
save_dynamic = args.save_dynamic
|
||||
|
||||
print(f"Dynamic conversion utility is turned {'ON' if dynamic else 'OFF'}")
|
||||
print(f"Combining IR utility is turned {'ON' if combine else 'OFF'}")
|
||||
|
||||
if dynamic and not combine:
|
||||
save_dynamic = True
|
||||
|
||||
first_ir = None
|
||||
first_dynamic_ir_name = None
|
||||
second_ir = None
|
||||
second_dynamic_ir_name = None
|
||||
if first_ir_path:
|
||||
first_dynamic_ir_name = f"{Path(first_ir_path).stem}_dynamic"
|
||||
with open(first_ir_path, "r") as f:
|
||||
first_ir = f.read()
|
||||
if second_ir_path:
|
||||
second_dynamic_ir_name = f"{Path(second_ir_path).stem}_dynamic"
|
||||
with open(second_ir_path, "r") as f:
|
||||
second_ir = f.read()
|
||||
if dynamic:
|
||||
first_ir = (
|
||||
write_in_dynamic_inputs0(first_ir, dynamic_input_size)
|
||||
if first_ir
|
||||
else None
|
||||
)
|
||||
second_ir = (
|
||||
write_in_dynamic_inputs1(second_ir, model_name, precision)
|
||||
if second_ir
|
||||
else None
|
||||
)
|
||||
if save_dynamic:
|
||||
save_dynamic_ir(first_ir, f"{first_dynamic_ir_name}.mlir")
|
||||
save_dynamic_ir(second_ir, f"{second_dynamic_ir_name}.mlir")
|
||||
|
||||
if combine:
|
||||
combine_mlir_scripts(
|
||||
first_ir,
|
||||
second_ir,
|
||||
f"{model_name}_{precision}.mlir",
|
||||
return_ir=False,
|
||||
)
|
||||
211
apps/language_models/scripts/stablelm.py
Normal file
211
apps/language_models/scripts/stablelm.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import torch
|
||||
import torch_mlir
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
StoppingCriteria,
|
||||
)
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from apps.language_models.utils import (
|
||||
get_torch_mlir_module_bytecode,
|
||||
get_vmfb_from_path,
|
||||
)
|
||||
|
||||
|
||||
class StopOnTokens(StoppingCriteria):
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
||||
) -> bool:
|
||||
stop_ids = [50278, 50279, 50277, 1, 0]
|
||||
for stop_id in stop_ids:
|
||||
if input_ids[0][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def shouldStop(tokens):
|
||||
stop_ids = [50278, 50279, 50277, 1, 0]
|
||||
for stop_id in stop_ids:
|
||||
if tokens[0][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
MAX_SEQUENCE_LENGTH = 256
|
||||
|
||||
|
||||
def user(message, history):
|
||||
# Append the user's message to the conversation history
|
||||
return "", history + [[message, ""]]
|
||||
|
||||
|
||||
def compile_stableLM(
|
||||
model,
|
||||
model_inputs,
|
||||
model_name,
|
||||
model_vmfb_name,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
debug=False,
|
||||
):
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
# device = "cuda" # "cpu"
|
||||
# TODO: vmfb and mlir name should include precision and device
|
||||
vmfb_path = (
|
||||
Path(model_name + f"_{device}.vmfb")
|
||||
if model_vmfb_name is None
|
||||
else Path(model_vmfb_name)
|
||||
)
|
||||
shark_module = get_vmfb_from_path(
|
||||
vmfb_path, device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
if shark_module is not None:
|
||||
return shark_module
|
||||
|
||||
mlir_path = Path(model_name + ".mlir")
|
||||
print(
|
||||
f"[DEBUG] mlir path {mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
|
||||
)
|
||||
if mlir_path.exists():
|
||||
with open(mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
ts_graph = get_torch_mlir_module_bytecode(model, model_inputs)
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*model_inputs],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
f_ = open(model_name + ".mlir", "wb")
|
||||
f_.write(bytecode)
|
||||
print("Saved mlir")
|
||||
f_.close()
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
path = shark_module.save_module(
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem, debug=debug
|
||||
)
|
||||
print("Saved vmfb at ", str(path))
|
||||
|
||||
return shark_module
|
||||
|
||||
|
||||
class StableLMModel(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
combine_input_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
output = self.model(**combine_input_dict)
|
||||
return output.logits
|
||||
|
||||
|
||||
# Initialize a StopOnTokens object
|
||||
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
||||
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
||||
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
||||
- StableLM will refuse to participate in anything that could harm a human.
|
||||
"""
|
||||
|
||||
|
||||
def get_tokenizer():
|
||||
model_path = "stabilityai/stablelm-tuned-alpha-3b"
|
||||
tok = AutoTokenizer.from_pretrained(model_path)
|
||||
tok.add_special_tokens({"pad_token": "<PAD>"})
|
||||
print("Sucessfully loaded the tokenizer to the memory")
|
||||
return tok
|
||||
|
||||
|
||||
# sharkStableLM = compile_stableLM
|
||||
# (
|
||||
# None,
|
||||
# tuple([input_ids, attention_mask]),
|
||||
# "stableLM_linalg_f32_seqLen256",
|
||||
# "/home/shark/vivek/stableLM_shark_f32_seqLen256"
|
||||
# )
|
||||
def generate(
|
||||
new_text,
|
||||
max_new_tokens,
|
||||
sharkStableLM,
|
||||
tokenizer=None,
|
||||
):
|
||||
if tokenizer is None:
|
||||
tokenizer = get_tokenizer()
|
||||
# Construct the input message string for the model by
|
||||
# concatenating the current system message and conversation history
|
||||
# Tokenize the messages string
|
||||
# sharkStableLM = compile_stableLM
|
||||
# (
|
||||
# None,
|
||||
# tuple([input_ids, attention_mask]),
|
||||
# "stableLM_linalg_f32_seqLen256",
|
||||
# "/home/shark/vivek/stableLM_shark_f32_seqLen256"
|
||||
# )
|
||||
words_list = []
|
||||
for i in range(max_new_tokens):
|
||||
# numWords = len(new_text.split())
|
||||
# if(numWords>220):
|
||||
# break
|
||||
params = {
|
||||
"new_text": new_text,
|
||||
}
|
||||
generated_token_op = generate_new_token(
|
||||
sharkStableLM, tokenizer, params
|
||||
)
|
||||
detok = generated_token_op["detok"]
|
||||
stop_generation = generated_token_op["stop_generation"]
|
||||
if stop_generation:
|
||||
break
|
||||
print(detok, end="", flush=True)
|
||||
words_list.append(detok)
|
||||
if detok == "":
|
||||
break
|
||||
new_text = new_text + detok
|
||||
return words_list
|
||||
|
||||
|
||||
def generate_new_token(shark_model, tokenizer, params):
|
||||
new_text = params["new_text"]
|
||||
model_inputs = tokenizer(
|
||||
[new_text],
|
||||
padding="max_length",
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
sum_attentionmask = torch.sum(model_inputs.attention_mask)
|
||||
# sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256")
|
||||
output = shark_model(
|
||||
"forward", [model_inputs.input_ids, model_inputs.attention_mask]
|
||||
)
|
||||
output = torch.from_numpy(output)
|
||||
next_toks = torch.topk(output, 1)
|
||||
stop_generation = False
|
||||
if shouldStop(next_toks.indices):
|
||||
stop_generation = True
|
||||
new_token = next_toks.indices[0][int(sum_attentionmask) - 1]
|
||||
detok = tokenizer.decode(
|
||||
new_token,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
ret_dict = {
|
||||
"new_token": new_token,
|
||||
"detok": detok,
|
||||
"stop_generation": stop_generation,
|
||||
}
|
||||
return ret_dict
|
||||
2484
apps/language_models/scripts/vicuna.py
Normal file
2484
apps/language_models/scripts/vicuna.py
Normal file
File diff suppressed because it is too large
Load Diff
94
apps/language_models/shark_llama_cli.spec
Normal file
94
apps/language_models/shark_llama_cli.spec
Normal file
@@ -0,0 +1,94 @@
|
||||
# -*- mode: python ; coding: utf-8 -*-
|
||||
from PyInstaller.utils.hooks import collect_data_files
|
||||
from PyInstaller.utils.hooks import collect_submodules
|
||||
from PyInstaller.utils.hooks import copy_metadata
|
||||
|
||||
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
|
||||
|
||||
datas = []
|
||||
datas += collect_data_files('torch')
|
||||
datas += copy_metadata('torch')
|
||||
datas += copy_metadata('tqdm')
|
||||
datas += copy_metadata('regex')
|
||||
datas += copy_metadata('requests')
|
||||
datas += copy_metadata('packaging')
|
||||
datas += copy_metadata('filelock')
|
||||
datas += copy_metadata('numpy')
|
||||
datas += copy_metadata('tokenizers')
|
||||
datas += copy_metadata('importlib_metadata')
|
||||
datas += copy_metadata('torch-mlir')
|
||||
datas += copy_metadata('omegaconf')
|
||||
datas += copy_metadata('safetensors')
|
||||
datas += copy_metadata('huggingface-hub')
|
||||
datas += copy_metadata('sentencepiece')
|
||||
datas += copy_metadata("pyyaml")
|
||||
datas += collect_data_files("tokenizers")
|
||||
datas += collect_data_files("tiktoken")
|
||||
datas += collect_data_files("accelerate")
|
||||
datas += collect_data_files('diffusers')
|
||||
datas += collect_data_files('transformers')
|
||||
datas += collect_data_files('opencv-python')
|
||||
datas += collect_data_files('pytorch_lightning')
|
||||
datas += collect_data_files('skimage')
|
||||
datas += collect_data_files('gradio')
|
||||
datas += collect_data_files('gradio_client')
|
||||
datas += collect_data_files('iree')
|
||||
datas += collect_data_files('google-cloud-storage')
|
||||
datas += collect_data_files('py-cpuinfo')
|
||||
datas += collect_data_files("shark", include_py_files=True)
|
||||
datas += collect_data_files("timm", include_py_files=True)
|
||||
datas += collect_data_files("tqdm")
|
||||
datas += collect_data_files("tkinter")
|
||||
datas += collect_data_files("webview")
|
||||
datas += collect_data_files("sentencepiece")
|
||||
datas += collect_data_files("jsonschema")
|
||||
datas += collect_data_files("jsonschema_specifications")
|
||||
datas += collect_data_files("cpuinfo")
|
||||
datas += collect_data_files("langchain")
|
||||
|
||||
binaries = []
|
||||
|
||||
block_cipher = None
|
||||
|
||||
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
|
||||
|
||||
a = Analysis(
|
||||
['scripts/vicuna.py'],
|
||||
pathex=['.'],
|
||||
binaries=binaries,
|
||||
datas=datas,
|
||||
hiddenimports=hiddenimports,
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
excludes=[],
|
||||
win_no_prefer_redirects=False,
|
||||
win_private_assemblies=False,
|
||||
cipher=block_cipher,
|
||||
noarchive=False,
|
||||
)
|
||||
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||
|
||||
exe = EXE(
|
||||
pyz,
|
||||
a.scripts,
|
||||
a.binaries,
|
||||
a.zipfiles,
|
||||
a.datas,
|
||||
[],
|
||||
name='shark_llama_cli',
|
||||
debug=False,
|
||||
bootloader_ignore_signals=False,
|
||||
strip=False,
|
||||
upx=True,
|
||||
upx_exclude=[],
|
||||
runtime_tmpdir=None,
|
||||
console=True,
|
||||
disable_windowed_traceback=False,
|
||||
argv_emulation=False,
|
||||
target_arch=None,
|
||||
codesign_identity=None,
|
||||
entitlements_file=None,
|
||||
)
|
||||
22
apps/language_models/src/model_wrappers/falcon_model.py
Normal file
22
apps/language_models/src/model_wrappers/falcon_model.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import torch
|
||||
|
||||
|
||||
class FalconModel(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
input_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": None,
|
||||
"use_cache": True,
|
||||
}
|
||||
output = self.model(
|
||||
**input_dict,
|
||||
return_dict=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
)[0]
|
||||
return output[:, -1, :]
|
||||
675
apps/language_models/src/model_wrappers/falcon_sharded_model.py
Normal file
675
apps/language_models/src/model_wrappers/falcon_sharded_model.py
Normal file
@@ -0,0 +1,675 @@
|
||||
import torch
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
class WordEmbeddingsLayer(torch.nn.Module):
|
||||
def __init__(self, word_embedding_layer):
|
||||
super().__init__()
|
||||
self.model = word_embedding_layer
|
||||
|
||||
def forward(self, input_ids):
|
||||
output = self.model.forward(input=input_ids)
|
||||
return output
|
||||
|
||||
|
||||
class CompiledWordEmbeddingsLayer(torch.nn.Module):
|
||||
def __init__(self, compiled_word_embedding_layer):
|
||||
super().__init__()
|
||||
self.model = compiled_word_embedding_layer
|
||||
|
||||
def forward(self, input_ids):
|
||||
input_ids = input_ids.detach().numpy()
|
||||
new_input_ids = self.model("forward", input_ids)
|
||||
new_input_ids = new_input_ids.reshape(
|
||||
[1, new_input_ids.shape[0], new_input_ids.shape[1]]
|
||||
)
|
||||
return torch.tensor(new_input_ids)
|
||||
|
||||
|
||||
class LNFEmbeddingLayer(torch.nn.Module):
|
||||
def __init__(self, ln_f):
|
||||
super().__init__()
|
||||
self.model = ln_f
|
||||
|
||||
def forward(self, hidden_states):
|
||||
output = self.model.forward(input=hidden_states)
|
||||
return output
|
||||
|
||||
|
||||
class CompiledLNFEmbeddingLayer(torch.nn.Module):
|
||||
def __init__(self, ln_f):
|
||||
super().__init__()
|
||||
self.model = ln_f
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = hidden_states.detach().numpy()
|
||||
new_hidden_states = self.model("forward", (hidden_states,))
|
||||
|
||||
return torch.tensor(new_hidden_states)
|
||||
|
||||
|
||||
class LMHeadEmbeddingLayer(torch.nn.Module):
|
||||
def __init__(self, embedding_layer):
|
||||
super().__init__()
|
||||
self.model = embedding_layer
|
||||
|
||||
def forward(self, hidden_states):
|
||||
output = self.model.forward(input=hidden_states)
|
||||
return output
|
||||
|
||||
|
||||
class CompiledLMHeadEmbeddingLayer(torch.nn.Module):
|
||||
def __init__(self, lm_head):
|
||||
super().__init__()
|
||||
self.model = lm_head
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = hidden_states.detach().numpy()
|
||||
new_hidden_states = self.model("forward", (hidden_states,))
|
||||
return torch.tensor(new_hidden_states)
|
||||
|
||||
|
||||
class FourWayShardingDecoderLayer(torch.nn.Module):
|
||||
def __init__(self, decoder_layer_model, falcon_variant):
|
||||
super().__init__()
|
||||
self.model = decoder_layer_model
|
||||
self.falcon_variant = falcon_variant
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
new_pkvs = []
|
||||
for layer in self.model:
|
||||
outputs = layer(
|
||||
hidden_states=hidden_states,
|
||||
alibi=None,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=True,
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
new_pkvs.append(
|
||||
(
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
)
|
||||
|
||||
(
|
||||
(new_pkv00, new_pkv01),
|
||||
(new_pkv10, new_pkv11),
|
||||
(new_pkv20, new_pkv21),
|
||||
(new_pkv30, new_pkv31),
|
||||
(new_pkv40, new_pkv41),
|
||||
(new_pkv50, new_pkv51),
|
||||
(new_pkv60, new_pkv61),
|
||||
(new_pkv70, new_pkv71),
|
||||
(new_pkv80, new_pkv81),
|
||||
(new_pkv90, new_pkv91),
|
||||
(new_pkv100, new_pkv101),
|
||||
(new_pkv110, new_pkv111),
|
||||
(new_pkv120, new_pkv121),
|
||||
(new_pkv130, new_pkv131),
|
||||
(new_pkv140, new_pkv141),
|
||||
(new_pkv150, new_pkv151),
|
||||
(new_pkv160, new_pkv161),
|
||||
(new_pkv170, new_pkv171),
|
||||
(new_pkv180, new_pkv181),
|
||||
(new_pkv190, new_pkv191),
|
||||
) = new_pkvs
|
||||
result = (
|
||||
hidden_states,
|
||||
new_pkv00,
|
||||
new_pkv01,
|
||||
new_pkv10,
|
||||
new_pkv11,
|
||||
new_pkv20,
|
||||
new_pkv21,
|
||||
new_pkv30,
|
||||
new_pkv31,
|
||||
new_pkv40,
|
||||
new_pkv41,
|
||||
new_pkv50,
|
||||
new_pkv51,
|
||||
new_pkv60,
|
||||
new_pkv61,
|
||||
new_pkv70,
|
||||
new_pkv71,
|
||||
new_pkv80,
|
||||
new_pkv81,
|
||||
new_pkv90,
|
||||
new_pkv91,
|
||||
new_pkv100,
|
||||
new_pkv101,
|
||||
new_pkv110,
|
||||
new_pkv111,
|
||||
new_pkv120,
|
||||
new_pkv121,
|
||||
new_pkv130,
|
||||
new_pkv131,
|
||||
new_pkv140,
|
||||
new_pkv141,
|
||||
new_pkv150,
|
||||
new_pkv151,
|
||||
new_pkv160,
|
||||
new_pkv161,
|
||||
new_pkv170,
|
||||
new_pkv171,
|
||||
new_pkv180,
|
||||
new_pkv181,
|
||||
new_pkv190,
|
||||
new_pkv191,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class CompiledFourWayShardingDecoderLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self, layer_id, device_idx, falcon_variant, device, precision, model
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.device_index = device_idx
|
||||
self.falcon_variant = falcon_variant
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
alibi: Optional[torch.Tensor],
|
||||
attention_mask: torch.Tensor,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
import gc
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
if self.model is None:
|
||||
raise ValueError("Layer vmfb not found")
|
||||
|
||||
hidden_states = hidden_states.to(torch.float32).detach().numpy()
|
||||
attention_mask = attention_mask.to(torch.float32).detach().numpy()
|
||||
|
||||
if alibi is not None or layer_past is not None:
|
||||
raise ValueError("Past Key Values and alibi should be None")
|
||||
else:
|
||||
output = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
),
|
||||
)
|
||||
|
||||
result = (
|
||||
torch.tensor(output[0]),
|
||||
(
|
||||
torch.tensor(output[1]),
|
||||
torch.tensor(output[2]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[3]),
|
||||
torch.tensor(output[4]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[5]),
|
||||
torch.tensor(output[6]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[7]),
|
||||
torch.tensor(output[8]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[9]),
|
||||
torch.tensor(output[10]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[11]),
|
||||
torch.tensor(output[12]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[13]),
|
||||
torch.tensor(output[14]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[15]),
|
||||
torch.tensor(output[16]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[17]),
|
||||
torch.tensor(output[18]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[19]),
|
||||
torch.tensor(output[20]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[21]),
|
||||
torch.tensor(output[22]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[23]),
|
||||
torch.tensor(output[24]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[25]),
|
||||
torch.tensor(output[26]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[27]),
|
||||
torch.tensor(output[28]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[29]),
|
||||
torch.tensor(output[30]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[31]),
|
||||
torch.tensor(output[32]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[33]),
|
||||
torch.tensor(output[34]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[35]),
|
||||
torch.tensor(output[36]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[37]),
|
||||
torch.tensor(output[38]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[39]),
|
||||
torch.tensor(output[40]),
|
||||
),
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class TwoWayShardingDecoderLayer(torch.nn.Module):
|
||||
def __init__(self, decoder_layer_model, falcon_variant):
|
||||
super().__init__()
|
||||
self.model = decoder_layer_model
|
||||
self.falcon_variant = falcon_variant
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
new_pkvs = []
|
||||
for layer in self.model:
|
||||
outputs = layer(
|
||||
hidden_states=hidden_states,
|
||||
alibi=None,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=True,
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
new_pkvs.append(
|
||||
(
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
)
|
||||
|
||||
(
|
||||
(new_pkv00, new_pkv01),
|
||||
(new_pkv10, new_pkv11),
|
||||
(new_pkv20, new_pkv21),
|
||||
(new_pkv30, new_pkv31),
|
||||
(new_pkv40, new_pkv41),
|
||||
(new_pkv50, new_pkv51),
|
||||
(new_pkv60, new_pkv61),
|
||||
(new_pkv70, new_pkv71),
|
||||
(new_pkv80, new_pkv81),
|
||||
(new_pkv90, new_pkv91),
|
||||
(new_pkv100, new_pkv101),
|
||||
(new_pkv110, new_pkv111),
|
||||
(new_pkv120, new_pkv121),
|
||||
(new_pkv130, new_pkv131),
|
||||
(new_pkv140, new_pkv141),
|
||||
(new_pkv150, new_pkv151),
|
||||
(new_pkv160, new_pkv161),
|
||||
(new_pkv170, new_pkv171),
|
||||
(new_pkv180, new_pkv181),
|
||||
(new_pkv190, new_pkv191),
|
||||
(new_pkv200, new_pkv201),
|
||||
(new_pkv210, new_pkv211),
|
||||
(new_pkv220, new_pkv221),
|
||||
(new_pkv230, new_pkv231),
|
||||
(new_pkv240, new_pkv241),
|
||||
(new_pkv250, new_pkv251),
|
||||
(new_pkv260, new_pkv261),
|
||||
(new_pkv270, new_pkv271),
|
||||
(new_pkv280, new_pkv281),
|
||||
(new_pkv290, new_pkv291),
|
||||
(new_pkv300, new_pkv301),
|
||||
(new_pkv310, new_pkv311),
|
||||
(new_pkv320, new_pkv321),
|
||||
(new_pkv330, new_pkv331),
|
||||
(new_pkv340, new_pkv341),
|
||||
(new_pkv350, new_pkv351),
|
||||
(new_pkv360, new_pkv361),
|
||||
(new_pkv370, new_pkv371),
|
||||
(new_pkv380, new_pkv381),
|
||||
(new_pkv390, new_pkv391),
|
||||
) = new_pkvs
|
||||
result = (
|
||||
hidden_states,
|
||||
new_pkv00,
|
||||
new_pkv01,
|
||||
new_pkv10,
|
||||
new_pkv11,
|
||||
new_pkv20,
|
||||
new_pkv21,
|
||||
new_pkv30,
|
||||
new_pkv31,
|
||||
new_pkv40,
|
||||
new_pkv41,
|
||||
new_pkv50,
|
||||
new_pkv51,
|
||||
new_pkv60,
|
||||
new_pkv61,
|
||||
new_pkv70,
|
||||
new_pkv71,
|
||||
new_pkv80,
|
||||
new_pkv81,
|
||||
new_pkv90,
|
||||
new_pkv91,
|
||||
new_pkv100,
|
||||
new_pkv101,
|
||||
new_pkv110,
|
||||
new_pkv111,
|
||||
new_pkv120,
|
||||
new_pkv121,
|
||||
new_pkv130,
|
||||
new_pkv131,
|
||||
new_pkv140,
|
||||
new_pkv141,
|
||||
new_pkv150,
|
||||
new_pkv151,
|
||||
new_pkv160,
|
||||
new_pkv161,
|
||||
new_pkv170,
|
||||
new_pkv171,
|
||||
new_pkv180,
|
||||
new_pkv181,
|
||||
new_pkv190,
|
||||
new_pkv191,
|
||||
new_pkv200,
|
||||
new_pkv201,
|
||||
new_pkv210,
|
||||
new_pkv211,
|
||||
new_pkv220,
|
||||
new_pkv221,
|
||||
new_pkv230,
|
||||
new_pkv231,
|
||||
new_pkv240,
|
||||
new_pkv241,
|
||||
new_pkv250,
|
||||
new_pkv251,
|
||||
new_pkv260,
|
||||
new_pkv261,
|
||||
new_pkv270,
|
||||
new_pkv271,
|
||||
new_pkv280,
|
||||
new_pkv281,
|
||||
new_pkv290,
|
||||
new_pkv291,
|
||||
new_pkv300,
|
||||
new_pkv301,
|
||||
new_pkv310,
|
||||
new_pkv311,
|
||||
new_pkv320,
|
||||
new_pkv321,
|
||||
new_pkv330,
|
||||
new_pkv331,
|
||||
new_pkv340,
|
||||
new_pkv341,
|
||||
new_pkv350,
|
||||
new_pkv351,
|
||||
new_pkv360,
|
||||
new_pkv361,
|
||||
new_pkv370,
|
||||
new_pkv371,
|
||||
new_pkv380,
|
||||
new_pkv381,
|
||||
new_pkv390,
|
||||
new_pkv391,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class CompiledTwoWayShardingDecoderLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self, layer_id, device_idx, falcon_variant, device, precision, model
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.device_index = device_idx
|
||||
self.falcon_variant = falcon_variant
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
alibi: Optional[torch.Tensor],
|
||||
attention_mask: torch.Tensor,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
import gc
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
if self.model is None:
|
||||
raise ValueError("Layer vmfb not found")
|
||||
|
||||
hidden_states = hidden_states.to(torch.float32).detach().numpy()
|
||||
attention_mask = attention_mask.to(torch.float32).detach().numpy()
|
||||
|
||||
if alibi is not None or layer_past is not None:
|
||||
raise ValueError("Past Key Values and alibi should be None")
|
||||
else:
|
||||
output = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
),
|
||||
)
|
||||
|
||||
result = (
|
||||
torch.tensor(output[0]),
|
||||
(
|
||||
torch.tensor(output[1]),
|
||||
torch.tensor(output[2]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[3]),
|
||||
torch.tensor(output[4]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[5]),
|
||||
torch.tensor(output[6]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[7]),
|
||||
torch.tensor(output[8]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[9]),
|
||||
torch.tensor(output[10]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[11]),
|
||||
torch.tensor(output[12]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[13]),
|
||||
torch.tensor(output[14]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[15]),
|
||||
torch.tensor(output[16]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[17]),
|
||||
torch.tensor(output[18]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[19]),
|
||||
torch.tensor(output[20]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[21]),
|
||||
torch.tensor(output[22]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[23]),
|
||||
torch.tensor(output[24]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[25]),
|
||||
torch.tensor(output[26]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[27]),
|
||||
torch.tensor(output[28]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[29]),
|
||||
torch.tensor(output[30]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[31]),
|
||||
torch.tensor(output[32]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[33]),
|
||||
torch.tensor(output[34]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[35]),
|
||||
torch.tensor(output[36]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[37]),
|
||||
torch.tensor(output[38]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[39]),
|
||||
torch.tensor(output[40]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[41]),
|
||||
torch.tensor(output[42]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[43]),
|
||||
torch.tensor(output[44]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[45]),
|
||||
torch.tensor(output[46]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[47]),
|
||||
torch.tensor(output[48]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[49]),
|
||||
torch.tensor(output[50]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[51]),
|
||||
torch.tensor(output[52]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[53]),
|
||||
torch.tensor(output[54]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[55]),
|
||||
torch.tensor(output[56]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[57]),
|
||||
torch.tensor(output[58]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[59]),
|
||||
torch.tensor(output[60]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[61]),
|
||||
torch.tensor(output[62]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[63]),
|
||||
torch.tensor(output[64]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[65]),
|
||||
torch.tensor(output[66]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[67]),
|
||||
torch.tensor(output[68]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[69]),
|
||||
torch.tensor(output[70]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[71]),
|
||||
torch.tensor(output[72]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[73]),
|
||||
torch.tensor(output[74]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[75]),
|
||||
torch.tensor(output[76]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[77]),
|
||||
torch.tensor(output[78]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[79]),
|
||||
torch.tensor(output[80]),
|
||||
),
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class ShardedFalconModel:
|
||||
def __init__(self, model, layers, word_embeddings, ln_f, lm_head):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.model.transformer.h = torch.nn.modules.container.ModuleList(
|
||||
layers
|
||||
)
|
||||
self.model.transformer.word_embeddings = word_embeddings
|
||||
self.model.transformer.ln_f = ln_f
|
||||
self.model.lm_head = lm_head
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
):
|
||||
return self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
).logits[:, -1, :]
|
||||
503
apps/language_models/src/model_wrappers/minigpt4.py
Normal file
503
apps/language_models/src/model_wrappers/minigpt4.py
Normal file
@@ -0,0 +1,503 @@
|
||||
import torch
|
||||
import dataclasses
|
||||
from enum import auto, Enum
|
||||
from typing import List, Any
|
||||
from transformers import StoppingCriteria
|
||||
|
||||
|
||||
from brevitas_examples.common.generative.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm to handle fp16."""
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
orig_type = x.dtype
|
||||
ret = super().forward(x.type(torch.float32))
|
||||
return ret.type(orig_type)
|
||||
|
||||
|
||||
class VisionModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ln_vision,
|
||||
visual_encoder,
|
||||
precision="fp32",
|
||||
weight_group_size=128,
|
||||
):
|
||||
super().__init__()
|
||||
self.ln_vision = ln_vision
|
||||
self.visual_encoder = visual_encoder
|
||||
if precision in ["int4", "int8"]:
|
||||
print("Vision Model applying weight quantization to ln_vision")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
self.ln_vision,
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
print(
|
||||
"Vision Model applying weight quantization to visual_encoder"
|
||||
)
|
||||
quantize_model(
|
||||
self.visual_encoder,
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
|
||||
def forward(self, image):
|
||||
image_embeds = self.ln_vision(self.visual_encoder(image))
|
||||
return image_embeds
|
||||
|
||||
|
||||
class QformerBertModel(torch.nn.Module):
|
||||
def __init__(self, qformer_bert):
|
||||
super().__init__()
|
||||
self.qformer_bert = qformer_bert
|
||||
|
||||
def forward(self, query_tokens, image_embeds, image_atts):
|
||||
query_output = self.qformer_bert(
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_atts,
|
||||
return_dict=True,
|
||||
)
|
||||
return query_output.last_hidden_state
|
||||
|
||||
|
||||
class FirstLlamaModel(torch.nn.Module):
|
||||
def __init__(self, model, precision="fp32", weight_group_size=128):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
print("SHARK: Loading LLAMA Done")
|
||||
if precision in ["int4", "int8"]:
|
||||
print("First Llama applying weight quantization")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
self.model,
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
|
||||
def forward(self, inputs_embeds, position_ids, attention_mask):
|
||||
print("************************************")
|
||||
print(
|
||||
"inputs_embeds: ",
|
||||
inputs_embeds.shape,
|
||||
" dtype: ",
|
||||
inputs_embeds.dtype,
|
||||
)
|
||||
print(
|
||||
"position_ids: ",
|
||||
position_ids.shape,
|
||||
" dtype: ",
|
||||
position_ids.dtype,
|
||||
)
|
||||
print(
|
||||
"attention_mask: ",
|
||||
attention_mask.shape,
|
||||
" dtype: ",
|
||||
attention_mask.dtype,
|
||||
)
|
||||
print("************************************")
|
||||
config = {
|
||||
"inputs_embeds": inputs_embeds,
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": None,
|
||||
"use_cache": True,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
output = self.model(
|
||||
**config,
|
||||
return_dict=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
)
|
||||
return_vals = []
|
||||
return_vals.append(output.logits)
|
||||
temp_past_key_values = output.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
return_vals.append(item[0])
|
||||
return_vals.append(item[1])
|
||||
return tuple(return_vals)
|
||||
|
||||
|
||||
class SecondLlamaModel(torch.nn.Module):
|
||||
def __init__(self, model, precision="fp32", weight_group_size=128):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
print("SHARK: Loading LLAMA Done")
|
||||
if precision in ["int4", "int8"]:
|
||||
print("Second Llama applying weight quantization")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
self.model,
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
i1,
|
||||
i2,
|
||||
i3,
|
||||
i4,
|
||||
i5,
|
||||
i6,
|
||||
i7,
|
||||
i8,
|
||||
i9,
|
||||
i10,
|
||||
i11,
|
||||
i12,
|
||||
i13,
|
||||
i14,
|
||||
i15,
|
||||
i16,
|
||||
i17,
|
||||
i18,
|
||||
i19,
|
||||
i20,
|
||||
i21,
|
||||
i22,
|
||||
i23,
|
||||
i24,
|
||||
i25,
|
||||
i26,
|
||||
i27,
|
||||
i28,
|
||||
i29,
|
||||
i30,
|
||||
i31,
|
||||
i32,
|
||||
i33,
|
||||
i34,
|
||||
i35,
|
||||
i36,
|
||||
i37,
|
||||
i38,
|
||||
i39,
|
||||
i40,
|
||||
i41,
|
||||
i42,
|
||||
i43,
|
||||
i44,
|
||||
i45,
|
||||
i46,
|
||||
i47,
|
||||
i48,
|
||||
i49,
|
||||
i50,
|
||||
i51,
|
||||
i52,
|
||||
i53,
|
||||
i54,
|
||||
i55,
|
||||
i56,
|
||||
i57,
|
||||
i58,
|
||||
i59,
|
||||
i60,
|
||||
i61,
|
||||
i62,
|
||||
i63,
|
||||
i64,
|
||||
):
|
||||
print("************************************")
|
||||
print("input_ids: ", input_ids.shape, " dtype: ", input_ids.dtype)
|
||||
print(
|
||||
"position_ids: ",
|
||||
position_ids.shape,
|
||||
" dtype: ",
|
||||
position_ids.dtype,
|
||||
)
|
||||
print(
|
||||
"attention_mask: ",
|
||||
attention_mask.shape,
|
||||
" dtype: ",
|
||||
attention_mask.dtype,
|
||||
)
|
||||
print("past_key_values: ", i1.shape, i2.shape, i63.shape, i64.shape)
|
||||
print("past_key_values dtype: ", i1.dtype)
|
||||
print("************************************")
|
||||
config = {
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": (
|
||||
(i1, i2),
|
||||
(
|
||||
i3,
|
||||
i4,
|
||||
),
|
||||
(
|
||||
i5,
|
||||
i6,
|
||||
),
|
||||
(
|
||||
i7,
|
||||
i8,
|
||||
),
|
||||
(
|
||||
i9,
|
||||
i10,
|
||||
),
|
||||
(
|
||||
i11,
|
||||
i12,
|
||||
),
|
||||
(
|
||||
i13,
|
||||
i14,
|
||||
),
|
||||
(
|
||||
i15,
|
||||
i16,
|
||||
),
|
||||
(
|
||||
i17,
|
||||
i18,
|
||||
),
|
||||
(
|
||||
i19,
|
||||
i20,
|
||||
),
|
||||
(
|
||||
i21,
|
||||
i22,
|
||||
),
|
||||
(
|
||||
i23,
|
||||
i24,
|
||||
),
|
||||
(
|
||||
i25,
|
||||
i26,
|
||||
),
|
||||
(
|
||||
i27,
|
||||
i28,
|
||||
),
|
||||
(
|
||||
i29,
|
||||
i30,
|
||||
),
|
||||
(
|
||||
i31,
|
||||
i32,
|
||||
),
|
||||
(
|
||||
i33,
|
||||
i34,
|
||||
),
|
||||
(
|
||||
i35,
|
||||
i36,
|
||||
),
|
||||
(
|
||||
i37,
|
||||
i38,
|
||||
),
|
||||
(
|
||||
i39,
|
||||
i40,
|
||||
),
|
||||
(
|
||||
i41,
|
||||
i42,
|
||||
),
|
||||
(
|
||||
i43,
|
||||
i44,
|
||||
),
|
||||
(
|
||||
i45,
|
||||
i46,
|
||||
),
|
||||
(
|
||||
i47,
|
||||
i48,
|
||||
),
|
||||
(
|
||||
i49,
|
||||
i50,
|
||||
),
|
||||
(
|
||||
i51,
|
||||
i52,
|
||||
),
|
||||
(
|
||||
i53,
|
||||
i54,
|
||||
),
|
||||
(
|
||||
i55,
|
||||
i56,
|
||||
),
|
||||
(
|
||||
i57,
|
||||
i58,
|
||||
),
|
||||
(
|
||||
i59,
|
||||
i60,
|
||||
),
|
||||
(
|
||||
i61,
|
||||
i62,
|
||||
),
|
||||
(
|
||||
i63,
|
||||
i64,
|
||||
),
|
||||
),
|
||||
"use_cache": True,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
output = self.model(
|
||||
**config,
|
||||
return_dict=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
)
|
||||
return_vals = []
|
||||
return_vals.append(output.logits)
|
||||
temp_past_key_values = output.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
return_vals.append(item[0])
|
||||
return_vals.append(item[1])
|
||||
return tuple(return_vals)
|
||||
|
||||
|
||||
class SeparatorStyle(Enum):
|
||||
"""Different separator style."""
|
||||
|
||||
SINGLE = auto()
|
||||
TWO = auto()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Conversation:
|
||||
"""A class that keeps all conversation history."""
|
||||
|
||||
system: str
|
||||
roles: List[str]
|
||||
messages: List[List[str]]
|
||||
offset: int
|
||||
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
||||
sep: str = "###"
|
||||
sep2: str = None
|
||||
|
||||
skip_next: bool = False
|
||||
conv_id: Any = None
|
||||
|
||||
def get_prompt(self):
|
||||
if self.sep_style == SeparatorStyle.SINGLE:
|
||||
ret = self.system + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + ": " + message + self.sep
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.TWO:
|
||||
seps = [self.sep, self.sep2]
|
||||
ret = self.system + seps[0]
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
ret += role + ": " + message + seps[i % 2]
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||
|
||||
def append_message(self, role, message):
|
||||
self.messages.append([role, message])
|
||||
|
||||
def to_gradio_chatbot(self):
|
||||
ret = []
|
||||
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
||||
if i % 2 == 0:
|
||||
ret.append([msg, None])
|
||||
else:
|
||||
ret[-1][-1] = msg
|
||||
return ret
|
||||
|
||||
def copy(self):
|
||||
return Conversation(
|
||||
system=self.system,
|
||||
roles=self.roles,
|
||||
messages=[[x, y] for x, y in self.messages],
|
||||
offset=self.offset,
|
||||
sep_style=self.sep_style,
|
||||
sep=self.sep,
|
||||
sep2=self.sep2,
|
||||
conv_id=self.conv_id,
|
||||
)
|
||||
|
||||
def dict(self):
|
||||
return {
|
||||
"system": self.system,
|
||||
"roles": self.roles,
|
||||
"messages": self.messages,
|
||||
"offset": self.offset,
|
||||
"sep": self.sep,
|
||||
"sep2": self.sep2,
|
||||
"conv_id": self.conv_id,
|
||||
}
|
||||
|
||||
|
||||
class StoppingCriteriaSub(StoppingCriteria):
|
||||
def __init__(self, stops=[], encounters=1):
|
||||
super().__init__()
|
||||
self.stops = stops
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
||||
for stop in self.stops:
|
||||
if torch.all((stop == input_ids[0][-len(stop) :])).item():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
CONV_VISION = Conversation(
|
||||
system="Give the following image: <Img>ImageContent</Img>. "
|
||||
"You will be able to see the image once I provide it to you. Please answer my questions.",
|
||||
roles=("Human", "Assistant"),
|
||||
messages=[],
|
||||
offset=2,
|
||||
sep_style=SeparatorStyle.SINGLE,
|
||||
sep="###",
|
||||
)
|
||||
15
apps/language_models/src/model_wrappers/stablelm_model.py
Normal file
15
apps/language_models/src/model_wrappers/stablelm_model.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import torch
|
||||
|
||||
|
||||
class StableLMModel(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
combine_input_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
output = self.model(**combine_input_dict)
|
||||
return output.logits
|
||||
876
apps/language_models/src/model_wrappers/vicuna4.py
Normal file
876
apps/language_models/src/model_wrappers/vicuna4.py
Normal file
@@ -0,0 +1,876 @@
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import iree.runtime
|
||||
import itertools
|
||||
import subprocess
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
from torch_mlir import TensorPlaceholder
|
||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
LlamaPreTrainedModel,
|
||||
)
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
SequenceClassifierOutputWithPast,
|
||||
)
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
|
||||
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
|
||||
from apps.language_models.src.model_wrappers.vicuna_sharded_model import (
|
||||
FirstVicunaLayer,
|
||||
SecondVicunaLayer,
|
||||
CompiledVicunaLayer,
|
||||
ShardedVicunaModel,
|
||||
LMHead,
|
||||
LMHeadCompiled,
|
||||
VicunaEmbedding,
|
||||
VicunaEmbeddingCompiled,
|
||||
VicunaNorm,
|
||||
VicunaNormCompiled,
|
||||
)
|
||||
from apps.language_models.src.model_wrappers.vicuna_model import (
|
||||
FirstVicuna,
|
||||
SecondVicuna7B,
|
||||
)
|
||||
from apps.language_models.utils import (
|
||||
get_vmfb_from_path,
|
||||
)
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_importer import get_f16_inputs
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaDecoderLayer,
|
||||
LlamaRMSNorm,
|
||||
_make_causal_mask,
|
||||
_expand_mask,
|
||||
)
|
||||
from torch import nn
|
||||
from time import time
|
||||
|
||||
|
||||
class LlamaModel(LlamaPreTrainedModel):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
||||
|
||||
Args:
|
||||
config: LlamaConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(
|
||||
config.vocab_size, config.hidden_size, self.padding_idx
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
LlamaDecoderLayer(config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
||||
def _prepare_decoder_attention_mask(
|
||||
self,
|
||||
attention_mask,
|
||||
input_shape,
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
):
|
||||
# create causal mask
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = None
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape,
|
||||
inputs_embeds.dtype,
|
||||
device=inputs_embeds.device,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
expanded_attn_mask = _expand_mask(
|
||||
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||
).to(inputs_embeds.device)
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask
|
||||
if combined_attention_mask is None
|
||||
else expanded_attn_mask + combined_attention_mask
|
||||
)
|
||||
|
||||
return combined_attention_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
t1 = time()
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = (
|
||||
use_cache if use_cache is not None else self.config.use_cache
|
||||
)
|
||||
|
||||
return_dict = (
|
||||
return_dict
|
||||
if return_dict is not None
|
||||
else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||
)
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError(
|
||||
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||
)
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = (
|
||||
seq_length_with_past + past_key_values_length
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
device = (
|
||||
input_ids.device
|
||||
if input_ids is not None
|
||||
else inputs_embeds.device
|
||||
)
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past),
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.compressedlayers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = (
|
||||
past_key_values[8 * idx : 8 * (idx + 1)]
|
||||
if past_key_values is not None
|
||||
else None
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, None)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer.forward(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[1:],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
try:
|
||||
hidden_states = np.asarray(hidden_states, hidden_states.dtype)
|
||||
except:
|
||||
_ = 10
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
next_cache = tuple(itertools.chain.from_iterable(next_cache))
|
||||
print(f"Token generated in {time() - t1} seconds")
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_cache,
|
||||
all_hidden_states,
|
||||
all_self_attns,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class EightLayerLayerSV(torch.nn.Module):
|
||||
def __init__(self, layers):
|
||||
super().__init__()
|
||||
assert len(layers) == 8
|
||||
self.layers = layers
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
pkv00,
|
||||
pkv01,
|
||||
pkv10,
|
||||
pkv11,
|
||||
pkv20,
|
||||
pkv21,
|
||||
pkv30,
|
||||
pkv31,
|
||||
pkv40,
|
||||
pkv41,
|
||||
pkv50,
|
||||
pkv51,
|
||||
pkv60,
|
||||
pkv61,
|
||||
pkv70,
|
||||
pkv71,
|
||||
):
|
||||
pkvs = [
|
||||
(pkv00, pkv01),
|
||||
(pkv10, pkv11),
|
||||
(pkv20, pkv21),
|
||||
(pkv30, pkv31),
|
||||
(pkv40, pkv41),
|
||||
(pkv50, pkv51),
|
||||
(pkv60, pkv61),
|
||||
(pkv70, pkv71),
|
||||
]
|
||||
new_pkvs = []
|
||||
for layer, pkv in zip(self.layers, pkvs):
|
||||
outputs = layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=(
|
||||
pkv[0],
|
||||
pkv[1],
|
||||
),
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
new_pkvs.append(
|
||||
(
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
)
|
||||
(
|
||||
(new_pkv00, new_pkv01),
|
||||
(new_pkv10, new_pkv11),
|
||||
(new_pkv20, new_pkv21),
|
||||
(new_pkv30, new_pkv31),
|
||||
(new_pkv40, new_pkv41),
|
||||
(new_pkv50, new_pkv51),
|
||||
(new_pkv60, new_pkv61),
|
||||
(new_pkv70, new_pkv71),
|
||||
) = new_pkvs
|
||||
return (
|
||||
hidden_states,
|
||||
new_pkv00,
|
||||
new_pkv01,
|
||||
new_pkv10,
|
||||
new_pkv11,
|
||||
new_pkv20,
|
||||
new_pkv21,
|
||||
new_pkv30,
|
||||
new_pkv31,
|
||||
new_pkv40,
|
||||
new_pkv41,
|
||||
new_pkv50,
|
||||
new_pkv51,
|
||||
new_pkv60,
|
||||
new_pkv61,
|
||||
new_pkv70,
|
||||
new_pkv71,
|
||||
)
|
||||
|
||||
|
||||
class EightLayerLayerFV(torch.nn.Module):
|
||||
def __init__(self, layers):
|
||||
super().__init__()
|
||||
assert len(layers) == 8
|
||||
self.layers = layers
|
||||
|
||||
def forward(self, hidden_states, attention_mask, position_ids):
|
||||
new_pkvs = []
|
||||
for layer in self.layers:
|
||||
outputs = layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=None,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
new_pkvs.append(
|
||||
(
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
)
|
||||
(
|
||||
(new_pkv00, new_pkv01),
|
||||
(new_pkv10, new_pkv11),
|
||||
(new_pkv20, new_pkv21),
|
||||
(new_pkv30, new_pkv31),
|
||||
(new_pkv40, new_pkv41),
|
||||
(new_pkv50, new_pkv51),
|
||||
(new_pkv60, new_pkv61),
|
||||
(new_pkv70, new_pkv71),
|
||||
) = new_pkvs
|
||||
return (
|
||||
hidden_states,
|
||||
new_pkv00,
|
||||
new_pkv01,
|
||||
new_pkv10,
|
||||
new_pkv11,
|
||||
new_pkv20,
|
||||
new_pkv21,
|
||||
new_pkv30,
|
||||
new_pkv31,
|
||||
new_pkv40,
|
||||
new_pkv41,
|
||||
new_pkv50,
|
||||
new_pkv51,
|
||||
new_pkv60,
|
||||
new_pkv61,
|
||||
new_pkv70,
|
||||
new_pkv71,
|
||||
)
|
||||
|
||||
|
||||
class CompiledEightLayerLayerSV(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
(
|
||||
(pkv00, pkv01),
|
||||
(pkv10, pkv11),
|
||||
(pkv20, pkv21),
|
||||
(pkv30, pkv31),
|
||||
(pkv40, pkv41),
|
||||
(pkv50, pkv51),
|
||||
(pkv60, pkv61),
|
||||
(pkv70, pkv71),
|
||||
) = past_key_value
|
||||
pkv00 = pkv00.detatch()
|
||||
pkv01 = pkv01.detatch()
|
||||
pkv10 = pkv10.detatch()
|
||||
pkv11 = pkv11.detatch()
|
||||
pkv20 = pkv20.detatch()
|
||||
pkv21 = pkv21.detatch()
|
||||
pkv30 = pkv30.detatch()
|
||||
pkv31 = pkv31.detatch()
|
||||
pkv40 = pkv40.detatch()
|
||||
pkv41 = pkv41.detatch()
|
||||
pkv50 = pkv50.detatch()
|
||||
pkv51 = pkv51.detatch()
|
||||
pkv60 = pkv60.detatch()
|
||||
pkv61 = pkv61.detatch()
|
||||
pkv70 = pkv70.detatch()
|
||||
pkv71 = pkv71.detatch()
|
||||
|
||||
output = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
pkv00,
|
||||
pkv01,
|
||||
pkv10,
|
||||
pkv11,
|
||||
pkv20,
|
||||
pkv21,
|
||||
pkv30,
|
||||
pkv31,
|
||||
pkv40,
|
||||
pkv41,
|
||||
pkv50,
|
||||
pkv51,
|
||||
pkv60,
|
||||
pkv61,
|
||||
pkv70,
|
||||
pkv71,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
return (
|
||||
output[0],
|
||||
(output[1][0], output[1][1]),
|
||||
(output[2][0], output[2][1]),
|
||||
(output[3][0], output[3][1]),
|
||||
(output[4][0], output[4][1]),
|
||||
(output[5][0], output[5][1]),
|
||||
(output[6][0], output[6][1]),
|
||||
(output[7][0], output[7][1]),
|
||||
(output[8][0], output[8][1]),
|
||||
)
|
||||
|
||||
|
||||
def forward_compressed(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||
)
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError(
|
||||
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||
)
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
if position_ids is None:
|
||||
device = (
|
||||
input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
)
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past),
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.compressedlayers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = (
|
||||
past_key_values[8 * idx : 8 * (idx + 1)]
|
||||
if past_key_values is not None
|
||||
else None
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, None)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (
|
||||
layer_outputs[2 if output_attentions else 1],
|
||||
)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_cache,
|
||||
all_hidden_states,
|
||||
all_self_attns,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class CompiledEightLayerLayer(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
t2 = time()
|
||||
if past_key_value is None:
|
||||
try:
|
||||
hidden_states = np.asarray(hidden_states, hidden_states.dtype)
|
||||
except:
|
||||
pass
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
t1 = time()
|
||||
|
||||
output = self.model(
|
||||
"first_vicuna_forward",
|
||||
(hidden_states, attention_mask, position_ids),
|
||||
send_to_host=False,
|
||||
)
|
||||
output2 = (
|
||||
output[0],
|
||||
(
|
||||
output[1],
|
||||
output[2],
|
||||
),
|
||||
(
|
||||
output[3],
|
||||
output[4],
|
||||
),
|
||||
(
|
||||
output[5],
|
||||
output[6],
|
||||
),
|
||||
(
|
||||
output[7],
|
||||
output[8],
|
||||
),
|
||||
(
|
||||
output[9],
|
||||
output[10],
|
||||
),
|
||||
(
|
||||
output[11],
|
||||
output[12],
|
||||
),
|
||||
(
|
||||
output[13],
|
||||
output[14],
|
||||
),
|
||||
(
|
||||
output[15],
|
||||
output[16],
|
||||
),
|
||||
)
|
||||
return output2
|
||||
else:
|
||||
(
|
||||
(pkv00, pkv01),
|
||||
(pkv10, pkv11),
|
||||
(pkv20, pkv21),
|
||||
(pkv30, pkv31),
|
||||
(pkv40, pkv41),
|
||||
(pkv50, pkv51),
|
||||
(pkv60, pkv61),
|
||||
(pkv70, pkv71),
|
||||
) = past_key_value
|
||||
|
||||
try:
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
pkv00 = pkv00.detach()
|
||||
pkv01 = pkv01.detach()
|
||||
pkv10 = pkv10.detach()
|
||||
pkv11 = pkv11.detach()
|
||||
pkv20 = pkv20.detach()
|
||||
pkv21 = pkv21.detach()
|
||||
pkv30 = pkv30.detach()
|
||||
pkv31 = pkv31.detach()
|
||||
pkv40 = pkv40.detach()
|
||||
pkv41 = pkv41.detach()
|
||||
pkv50 = pkv50.detach()
|
||||
pkv51 = pkv51.detach()
|
||||
pkv60 = pkv60.detach()
|
||||
pkv61 = pkv61.detach()
|
||||
pkv70 = pkv70.detach()
|
||||
pkv71 = pkv71.detach()
|
||||
except:
|
||||
x = 10
|
||||
|
||||
t1 = time()
|
||||
if type(hidden_states) == iree.runtime.array_interop.DeviceArray:
|
||||
hidden_states = np.array(hidden_states, hidden_states.dtype)
|
||||
hidden_states = torch.tensor(hidden_states)
|
||||
hidden_states = hidden_states.detach()
|
||||
|
||||
output = self.model(
|
||||
"second_vicuna_forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
pkv00,
|
||||
pkv01,
|
||||
pkv10,
|
||||
pkv11,
|
||||
pkv20,
|
||||
pkv21,
|
||||
pkv30,
|
||||
pkv31,
|
||||
pkv40,
|
||||
pkv41,
|
||||
pkv50,
|
||||
pkv51,
|
||||
pkv60,
|
||||
pkv61,
|
||||
pkv70,
|
||||
pkv71,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
print(f"{time() - t1}")
|
||||
del pkv00
|
||||
del pkv01
|
||||
del pkv10
|
||||
del pkv11
|
||||
del pkv20
|
||||
del pkv21
|
||||
del pkv30
|
||||
del pkv31
|
||||
del pkv40
|
||||
del pkv41
|
||||
del pkv50
|
||||
del pkv51
|
||||
del pkv60
|
||||
del pkv61
|
||||
del pkv70
|
||||
del pkv71
|
||||
output2 = (
|
||||
output[0],
|
||||
(
|
||||
output[1],
|
||||
output[2],
|
||||
),
|
||||
(
|
||||
output[3],
|
||||
output[4],
|
||||
),
|
||||
(
|
||||
output[5],
|
||||
output[6],
|
||||
),
|
||||
(
|
||||
output[7],
|
||||
output[8],
|
||||
),
|
||||
(
|
||||
output[9],
|
||||
output[10],
|
||||
),
|
||||
(
|
||||
output[11],
|
||||
output[12],
|
||||
),
|
||||
(
|
||||
output[13],
|
||||
output[14],
|
||||
),
|
||||
(
|
||||
output[15],
|
||||
output[16],
|
||||
),
|
||||
)
|
||||
return output2
|
||||
1175
apps/language_models/src/model_wrappers/vicuna_model.py
Normal file
1175
apps/language_models/src/model_wrappers/vicuna_model.py
Normal file
File diff suppressed because it is too large
Load Diff
1173
apps/language_models/src/model_wrappers/vicuna_model_gpu.py
Normal file
1173
apps/language_models/src/model_wrappers/vicuna_model_gpu.py
Normal file
File diff suppressed because it is too large
Load Diff
247
apps/language_models/src/model_wrappers/vicuna_sharded_model.py
Normal file
247
apps/language_models/src/model_wrappers/vicuna_sharded_model.py
Normal file
@@ -0,0 +1,247 @@
|
||||
import torch
|
||||
import time
|
||||
|
||||
|
||||
class FirstVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, hidden_states, attention_mask, position_ids):
|
||||
outputs = self.model(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
use_cache=True,
|
||||
)
|
||||
next_hidden_states = outputs[0]
|
||||
past_key_value_out0, past_key_value_out1 = (
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
|
||||
return (
|
||||
next_hidden_states,
|
||||
past_key_value_out0,
|
||||
past_key_value_out1,
|
||||
)
|
||||
|
||||
|
||||
class SecondVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value0,
|
||||
past_key_value1,
|
||||
):
|
||||
outputs = self.model(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=(
|
||||
past_key_value0,
|
||||
past_key_value1,
|
||||
),
|
||||
use_cache=True,
|
||||
)
|
||||
next_hidden_states = outputs[0]
|
||||
past_key_value_out0, past_key_value_out1 = (
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
|
||||
return (
|
||||
next_hidden_states,
|
||||
past_key_value_out0,
|
||||
past_key_value_out1,
|
||||
)
|
||||
|
||||
|
||||
class ShardedVicunaModel(torch.nn.Module):
|
||||
def __init__(self, model, layers, lmhead, embedding, norm):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.model.model.config.use_cache = True
|
||||
self.model.model.config.output_attentions = False
|
||||
self.layers = layers
|
||||
self.norm = norm
|
||||
self.embedding = embedding
|
||||
self.lmhead = lmhead
|
||||
self.model.model.norm = self.norm
|
||||
self.model.model.embed_tokens = self.embedding
|
||||
self.model.lm_head = self.lmhead
|
||||
self.model.model.layers = torch.nn.modules.container.ModuleList(
|
||||
self.layers
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
is_first=True,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
):
|
||||
return self.model.forward(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
|
||||
class LMHead(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, hidden_states):
|
||||
output = self.model(hidden_states)
|
||||
return output
|
||||
|
||||
|
||||
class LMHeadCompiled(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states_sample = hidden_states.detach()
|
||||
|
||||
output = self.model("forward", (hidden_states,))
|
||||
output = torch.tensor(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class VicunaNorm(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, hidden_states):
|
||||
output = self.model(hidden_states)
|
||||
return output
|
||||
|
||||
|
||||
class VicunaNormCompiled(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
|
||||
def forward(self, hidden_states):
|
||||
try:
|
||||
hidden_states.detach()
|
||||
except:
|
||||
pass
|
||||
output = self.model("forward", (hidden_states,), send_to_host=True)
|
||||
output = torch.tensor(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class VicunaEmbedding(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, input_ids):
|
||||
output = self.model(input_ids)
|
||||
return output
|
||||
|
||||
|
||||
class VicunaEmbeddingCompiled(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
|
||||
def forward(self, input_ids):
|
||||
input_ids.detach()
|
||||
output = self.model("forward", (input_ids,), send_to_host=True)
|
||||
output = torch.tensor(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class CompiledVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, shark_module, idx, breakpoints):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
self.idx = idx
|
||||
self.breakpoints = breakpoints
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
if self.breakpoints is None:
|
||||
is_breakpoint = False
|
||||
else:
|
||||
is_breakpoint = self.idx + 1 in self.breakpoints
|
||||
if past_key_value is None:
|
||||
output = self.model(
|
||||
"first_vicuna_forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
),
|
||||
send_to_host=is_breakpoint,
|
||||
)
|
||||
|
||||
if is_breakpoint:
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
else:
|
||||
output0 = output[0]
|
||||
output1 = output[1]
|
||||
output2 = output[2]
|
||||
|
||||
return (
|
||||
output0,
|
||||
(
|
||||
output1,
|
||||
output2,
|
||||
),
|
||||
)
|
||||
else:
|
||||
pkv0 = past_key_value[0]
|
||||
pkv1 = past_key_value[1]
|
||||
output = self.model(
|
||||
"second_vicuna_forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
pkv0,
|
||||
pkv1,
|
||||
),
|
||||
send_to_host=is_breakpoint,
|
||||
)
|
||||
|
||||
if is_breakpoint:
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
else:
|
||||
output0 = output[0]
|
||||
output1 = output[1]
|
||||
output2 = output[2]
|
||||
|
||||
return (
|
||||
output0,
|
||||
(
|
||||
output1,
|
||||
output2,
|
||||
),
|
||||
)
|
||||
44
apps/language_models/src/pipelines/SharkLLMBase.py
Normal file
44
apps/language_models/src/pipelines/SharkLLMBase.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class SharkLLMBase(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path=None,
|
||||
max_num_tokens=512,
|
||||
) -> None:
|
||||
self.model_name = model_name
|
||||
self.hf_model_path = hf_model_path
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.shark_model = None
|
||||
self.device = "cpu"
|
||||
self.precision = "fp32"
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def compile(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def generate(self, prompt):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def generate_new_token(self, params):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_tokenizer(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_src_model(self):
|
||||
pass
|
||||
|
||||
def load_init_from_config(self):
|
||||
pass
|
||||
1147
apps/language_models/src/pipelines/falcon_pipeline.py
Normal file
1147
apps/language_models/src/pipelines/falcon_pipeline.py
Normal file
File diff suppressed because it is too large
Load Diff
1449
apps/language_models/src/pipelines/minigpt4_pipeline.py
Normal file
1449
apps/language_models/src/pipelines/minigpt4_pipeline.py
Normal file
File diff suppressed because it is too large
Load Diff
1297
apps/language_models/src/pipelines/minigpt4_utils/Qformer.py
Normal file
1297
apps/language_models/src/pipelines/minigpt4_utils/Qformer.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,69 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
|
||||
class BaseProcessor:
|
||||
def __init__(self):
|
||||
self.transform = lambda x: x
|
||||
return
|
||||
|
||||
def __call__(self, item):
|
||||
return self.transform(item)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg=None):
|
||||
return cls()
|
||||
|
||||
def build(self, **kwargs):
|
||||
cfg = OmegaConf.create(kwargs)
|
||||
|
||||
return self.from_config(cfg)
|
||||
|
||||
|
||||
class BlipImageBaseProcessor(BaseProcessor):
|
||||
def __init__(self, mean=None, std=None):
|
||||
if mean is None:
|
||||
mean = (0.48145466, 0.4578275, 0.40821073)
|
||||
if std is None:
|
||||
std = (0.26862954, 0.26130258, 0.27577711)
|
||||
|
||||
self.normalize = transforms.Normalize(mean, std)
|
||||
|
||||
|
||||
class Blip2ImageEvalProcessor(BlipImageBaseProcessor):
|
||||
def __init__(self, image_size=224, mean=None, std=None):
|
||||
super().__init__(mean=mean, std=std)
|
||||
|
||||
self.transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(
|
||||
(image_size, image_size),
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
)
|
||||
|
||||
def __call__(self, item):
|
||||
return self.transform(item)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg=None):
|
||||
if cfg is None:
|
||||
cfg = OmegaConf.create()
|
||||
|
||||
image_size = cfg.get("image_size", 224)
|
||||
|
||||
mean = cfg.get("mean", None)
|
||||
std = cfg.get("std", None)
|
||||
|
||||
return cls(image_size=image_size, mean=mean, std=std)
|
||||
@@ -0,0 +1,5 @@
|
||||
datasets:
|
||||
cc_sbu_align:
|
||||
data_type: images
|
||||
build_info:
|
||||
storage: /path/to/cc_sbu_align/
|
||||
@@ -0,0 +1,33 @@
|
||||
model:
|
||||
arch: mini_gpt4
|
||||
|
||||
# vit encoder
|
||||
image_size: 224
|
||||
drop_path_rate: 0
|
||||
use_grad_checkpoint: False
|
||||
vit_precision: "fp16"
|
||||
freeze_vit: True
|
||||
freeze_qformer: True
|
||||
|
||||
# Q-Former
|
||||
num_query_token: 32
|
||||
|
||||
# Vicuna
|
||||
llama_model: "lmsys/vicuna-7b-v1.3"
|
||||
|
||||
# generation configs
|
||||
prompt: ""
|
||||
|
||||
preprocess:
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_train"
|
||||
image_size: 224
|
||||
eval:
|
||||
name: "blip2_image_eval"
|
||||
image_size: 224
|
||||
text_processor:
|
||||
train:
|
||||
name: "blip_caption"
|
||||
eval:
|
||||
name: "blip_caption"
|
||||
@@ -0,0 +1,25 @@
|
||||
model:
|
||||
arch: mini_gpt4
|
||||
model_type: pretrain_vicuna
|
||||
freeze_vit: True
|
||||
freeze_qformer: True
|
||||
max_txt_len: 160
|
||||
end_sym: "###"
|
||||
low_resource: False
|
||||
prompt_path: "apps/language_models/src/pipelines/minigpt4_utils/prompts/alignment.txt"
|
||||
prompt_template: '###Human: {} ###Assistant: '
|
||||
ckpt: 'prerained_minigpt4_7b.pth'
|
||||
|
||||
|
||||
datasets:
|
||||
cc_sbu_align:
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_eval"
|
||||
image_size: 224
|
||||
text_processor:
|
||||
train:
|
||||
name: "blip_caption"
|
||||
|
||||
run:
|
||||
task: image_text_pretrain
|
||||
631
apps/language_models/src/pipelines/minigpt4_utils/eva_vit.py
Normal file
631
apps/language_models/src/pipelines/minigpt4_utils/eva_vit.py
Normal file
@@ -0,0 +1,631 @@
|
||||
# Based on EVA, BEIT, timm and DeiT code bases
|
||||
# https://github.com/baaivision/EVA
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
||||
# https://github.com/microsoft/unilm/tree/master/beit
|
||||
# https://github.com/facebookresearch/deit/
|
||||
# https://github.com/facebookresearch/dino
|
||||
# --------------------------------------------------------'
|
||||
import math
|
||||
import requests
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
||||
|
||||
|
||||
def _cfg(url="", **kwargs):
|
||||
return {
|
||||
"url": url,
|
||||
"num_classes": 1000,
|
||||
"input_size": (3, 224, 224),
|
||||
"pool_size": None,
|
||||
"crop_pct": 0.9,
|
||||
"interpolation": "bicubic",
|
||||
"mean": (0.5, 0.5, 0.5),
|
||||
"std": (0.5, 0.5, 0.5),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return "p={}".format(self.drop_prob)
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
drop=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
# x = self.drop(x)
|
||||
# commit this for the orignal BERT implement
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
window_size=None,
|
||||
attn_head_dim=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
if attn_head_dim is not None:
|
||||
head_dim = attn_head_dim
|
||||
all_head_dim = head_dim * self.num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
||||
if qkv_bias:
|
||||
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
||||
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
||||
else:
|
||||
self.q_bias = None
|
||||
self.v_bias = None
|
||||
|
||||
if window_size:
|
||||
self.window_size = window_size
|
||||
self.num_relative_distance = (2 * window_size[0] - 1) * (
|
||||
2 * window_size[1] - 1
|
||||
) + 3
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros(self.num_relative_distance, num_heads)
|
||||
) # 2*Wh-1 * 2*Ww-1, nH
|
||||
# cls to token & token 2 cls & cls to cls
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(window_size[0])
|
||||
coords_w = torch.arange(window_size[1])
|
||||
coords = torch.stack(
|
||||
torch.meshgrid([coords_h, coords_w])
|
||||
) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = (
|
||||
coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
||||
) # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(
|
||||
1, 2, 0
|
||||
).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += (
|
||||
window_size[0] - 1
|
||||
) # shift to start from 0
|
||||
relative_coords[:, :, 1] += window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
||||
relative_position_index = torch.zeros(
|
||||
size=(window_size[0] * window_size[1] + 1,) * 2,
|
||||
dtype=relative_coords.dtype,
|
||||
)
|
||||
relative_position_index[1:, 1:] = relative_coords.sum(
|
||||
-1
|
||||
) # Wh*Ww, Wh*Ww
|
||||
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
||||
relative_position_index[0, 0] = self.num_relative_distance - 1
|
||||
|
||||
self.register_buffer(
|
||||
"relative_position_index", relative_position_index
|
||||
)
|
||||
else:
|
||||
self.window_size = None
|
||||
self.relative_position_bias_table = None
|
||||
self.relative_position_index = None
|
||||
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(all_head_dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x, rel_pos_bias=None):
|
||||
B, N, C = x.shape
|
||||
qkv_bias = None
|
||||
if self.q_bias is not None:
|
||||
qkv_bias = torch.cat(
|
||||
(
|
||||
self.q_bias,
|
||||
torch.zeros_like(self.v_bias, requires_grad=False),
|
||||
self.v_bias,
|
||||
)
|
||||
)
|
||||
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
||||
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = (
|
||||
qkv[0],
|
||||
qkv[1],
|
||||
qkv[2],
|
||||
) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
|
||||
if self.relative_position_bias_table is not None:
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)
|
||||
].view(
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
-1,
|
||||
) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(
|
||||
2, 0, 1
|
||||
).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if rel_pos_bias is not None:
|
||||
attn = attn + rel_pos_bias
|
||||
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
init_values=None,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
window_size=None,
|
||||
attn_head_dim=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
window_size=window_size,
|
||||
attn_head_dim=attn_head_dim,
|
||||
)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = (
|
||||
DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
)
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
if init_values is not None and init_values > 0:
|
||||
self.gamma_1 = nn.Parameter(
|
||||
init_values * torch.ones((dim)), requires_grad=True
|
||||
)
|
||||
self.gamma_2 = nn.Parameter(
|
||||
init_values * torch.ones((dim)), requires_grad=True
|
||||
)
|
||||
else:
|
||||
self.gamma_1, self.gamma_2 = None, None
|
||||
|
||||
def forward(self, x, rel_pos_bias=None):
|
||||
if self.gamma_1 is None:
|
||||
x = x + self.drop_path(
|
||||
self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)
|
||||
)
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
else:
|
||||
x = x + self.drop_path(
|
||||
self.gamma_1
|
||||
* self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)
|
||||
)
|
||||
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""Image to Patch Embedding"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
num_patches = (img_size[1] // patch_size[1]) * (
|
||||
img_size[0] // patch_size[0]
|
||||
)
|
||||
self.patch_shape = (
|
||||
img_size[0] // patch_size[0],
|
||||
img_size[1] // patch_size[1],
|
||||
)
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.num_patches = num_patches
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
|
||||
)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
B, C, H, W = x.shape
|
||||
# FIXME look at relaxing size constraints
|
||||
assert (
|
||||
H == self.img_size[0] and W == self.img_size[1]
|
||||
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.proj(x).flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class RelativePositionBias(nn.Module):
|
||||
def __init__(self, window_size, num_heads):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.num_relative_distance = (2 * window_size[0] - 1) * (
|
||||
2 * window_size[1] - 1
|
||||
) + 3
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros(self.num_relative_distance, num_heads)
|
||||
) # 2*Wh-1 * 2*Ww-1, nH
|
||||
# cls to token & token 2 cls & cls to cls
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(window_size[0])
|
||||
coords_w = torch.arange(window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = (
|
||||
coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
||||
) # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(
|
||||
1, 2, 0
|
||||
).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
||||
relative_coords[:, :, 1] += window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
||||
relative_position_index = torch.zeros(
|
||||
size=(window_size[0] * window_size[1] + 1,) * 2,
|
||||
dtype=relative_coords.dtype,
|
||||
)
|
||||
relative_position_index[1:, 1:] = relative_coords.sum(
|
||||
-1
|
||||
) # Wh*Ww, Wh*Ww
|
||||
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
||||
relative_position_index[0, 0] = self.num_relative_distance - 1
|
||||
|
||||
self.register_buffer(
|
||||
"relative_position_index", relative_position_index
|
||||
)
|
||||
|
||||
# trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||
|
||||
def forward(self):
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)
|
||||
].view(
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
-1,
|
||||
) # Wh*Ww,Wh*Ww,nH
|
||||
return relative_position_bias.permute(
|
||||
2, 0, 1
|
||||
).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
"""Vision Transformer with support for patch or hybrid CNN input stage"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop_rate=0.0,
|
||||
attn_drop_rate=0.0,
|
||||
drop_path_rate=0.0,
|
||||
norm_layer=nn.LayerNorm,
|
||||
init_values=None,
|
||||
use_abs_pos_emb=True,
|
||||
use_rel_pos_bias=False,
|
||||
use_shared_rel_pos_bias=False,
|
||||
use_mean_pooling=True,
|
||||
init_scale=0.001,
|
||||
use_checkpoint=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.image_size = img_size
|
||||
self.num_classes = num_classes
|
||||
self.num_features = self.embed_dim = (
|
||||
embed_dim # num_features for consistency with other models
|
||||
)
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
if use_abs_pos_emb:
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, num_patches + 1, embed_dim)
|
||||
)
|
||||
else:
|
||||
self.pos_embed = None
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
if use_shared_rel_pos_bias:
|
||||
self.rel_pos_bias = RelativePositionBias(
|
||||
window_size=self.patch_embed.patch_shape, num_heads=num_heads
|
||||
)
|
||||
else:
|
||||
self.rel_pos_bias = None
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
||||
] # stochastic depth decay rule
|
||||
self.use_rel_pos_bias = use_rel_pos_bias
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Block(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
init_values=init_values,
|
||||
window_size=(
|
||||
self.patch_embed.patch_shape
|
||||
if use_rel_pos_bias
|
||||
else None
|
||||
),
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
# self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
|
||||
# self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
|
||||
# self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
if self.pos_embed is not None:
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
trunc_normal_(self.cls_token, std=0.02)
|
||||
# trunc_normal_(self.mask_token, std=.02)
|
||||
# if isinstance(self.head, nn.Linear):
|
||||
# trunc_normal_(self.head.weight, std=.02)
|
||||
self.apply(self._init_weights)
|
||||
self.fix_init_weight()
|
||||
|
||||
# if isinstance(self.head, nn.Linear):
|
||||
# self.head.weight.data.mul_(init_scale)
|
||||
# self.head.bias.data.mul_(init_scale)
|
||||
|
||||
def fix_init_weight(self):
|
||||
def rescale(param, layer_id):
|
||||
param.div_(math.sqrt(2.0 * layer_id))
|
||||
|
||||
for layer_id, layer in enumerate(self.blocks):
|
||||
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
||||
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=""):
|
||||
self.num_classes = num_classes
|
||||
self.head = (
|
||||
nn.Linear(self.embed_dim, num_classes)
|
||||
if num_classes > 0
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
batch_size, seq_len, _ = x.size()
|
||||
|
||||
cls_tokens = self.cls_token.expand(
|
||||
batch_size, -1, -1
|
||||
) # stole cls_tokens impl from Phil Wang, thanks
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
rel_pos_bias = (
|
||||
self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
||||
)
|
||||
for blk in self.blocks:
|
||||
if self.use_checkpoint:
|
||||
x = checkpoint.checkpoint(blk, x, rel_pos_bias)
|
||||
else:
|
||||
x = blk(x, rel_pos_bias)
|
||||
return x
|
||||
|
||||
# x = self.norm(x)
|
||||
|
||||
# if self.fc_norm is not None:
|
||||
# t = x[:, 1:, :]
|
||||
# return self.fc_norm(t.mean(1))
|
||||
# else:
|
||||
# return x[:, 0]
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
# x = self.head(x)
|
||||
return x
|
||||
|
||||
def get_intermediate_layers(self, x):
|
||||
x = self.patch_embed(x)
|
||||
batch_size, seq_len, _ = x.size()
|
||||
|
||||
cls_tokens = self.cls_token.expand(
|
||||
batch_size, -1, -1
|
||||
) # stole cls_tokens impl from Phil Wang, thanks
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
features = []
|
||||
rel_pos_bias = (
|
||||
self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
||||
)
|
||||
for blk in self.blocks:
|
||||
x = blk(x, rel_pos_bias)
|
||||
features.append(x)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def interpolate_pos_embed(model, checkpoint_model):
|
||||
if "pos_embed" in checkpoint_model:
|
||||
pos_embed_checkpoint = checkpoint_model["pos_embed"].float()
|
||||
embedding_size = pos_embed_checkpoint.shape[-1]
|
||||
num_patches = model.patch_embed.num_patches
|
||||
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
||||
# height (== width) for the checkpoint position embedding
|
||||
orig_size = int(
|
||||
(pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5
|
||||
)
|
||||
# height (== width) for the new position embedding
|
||||
new_size = int(num_patches**0.5)
|
||||
# class_token and dist_token are kept unchanged
|
||||
if orig_size != new_size:
|
||||
print(
|
||||
"Position interpolate from %dx%d to %dx%d"
|
||||
% (orig_size, orig_size, new_size, new_size)
|
||||
)
|
||||
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
||||
# only the position tokens are interpolated
|
||||
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
||||
pos_tokens = pos_tokens.reshape(
|
||||
-1, orig_size, orig_size, embedding_size
|
||||
).permute(0, 3, 1, 2)
|
||||
pos_tokens = torch.nn.functional.interpolate(
|
||||
pos_tokens,
|
||||
size=(new_size, new_size),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
||||
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
||||
checkpoint_model["pos_embed"] = new_pos_embed
|
||||
|
||||
|
||||
def convert_weights_to_fp16(model: nn.Module):
|
||||
"""Convert applicable model parameters to fp16"""
|
||||
|
||||
def _convert_weights_to_fp16(l):
|
||||
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
||||
# l.weight.data = l.weight.data.half()
|
||||
l.weight.data = l.weight.data
|
||||
if l.bias is not None:
|
||||
# l.bias.data = l.bias.data.half()
|
||||
l.bias.data = l.bias.data
|
||||
|
||||
# if isinstance(l, (nn.MultiheadAttention, Attention)):
|
||||
# for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
||||
# tensor = getattr(l, attr)
|
||||
# if tensor is not None:
|
||||
# tensor.data = tensor.data.half()
|
||||
|
||||
model.apply(_convert_weights_to_fp16)
|
||||
|
||||
|
||||
def create_eva_vit_g(
|
||||
img_size=224, drop_path_rate=0.4, use_checkpoint=False, precision="fp16"
|
||||
):
|
||||
model = VisionTransformer(
|
||||
img_size=img_size,
|
||||
patch_size=14,
|
||||
use_mean_pooling=False,
|
||||
embed_dim=1408,
|
||||
depth=39,
|
||||
num_heads=1408 // 88,
|
||||
mlp_ratio=4.3637,
|
||||
qkv_bias=True,
|
||||
drop_path_rate=drop_path_rate,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
use_checkpoint=use_checkpoint,
|
||||
)
|
||||
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
|
||||
|
||||
local_filename = "eva_vit_g.pth"
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
with open(local_filename, "wb") as f:
|
||||
f.write(response.content)
|
||||
print("File downloaded successfully.")
|
||||
state_dict = torch.load(local_filename, map_location="cpu")
|
||||
interpolate_pos_embed(model, state_dict)
|
||||
|
||||
incompatible_keys = model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if precision == "fp16":
|
||||
# model.to("cuda")
|
||||
convert_weights_to_fp16(model)
|
||||
return model
|
||||
@@ -0,0 +1,4 @@
|
||||
<Img><ImageHere></Img> Describe this image in detail.
|
||||
<Img><ImageHere></Img> Take a look at this image and describe what you notice.
|
||||
<Img><ImageHere></Img> Please provide a detailed description of the picture.
|
||||
<Img><ImageHere></Img> Could you describe the contents of this image for me?
|
||||
300
apps/language_models/src/pipelines/stablelm_pipeline.py
Normal file
300
apps/language_models/src/pipelines/stablelm_pipeline.py
Normal file
@@ -0,0 +1,300 @@
|
||||
import torch
|
||||
import torch_mlir
|
||||
from transformers import AutoTokenizer, StoppingCriteria, AutoModelForCausalLM
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from apps.language_models.utils import (
|
||||
get_vmfb_from_path,
|
||||
)
|
||||
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
|
||||
from apps.language_models.src.model_wrappers.stablelm_model import (
|
||||
StableLMModel,
|
||||
)
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="stablelm runner",
|
||||
description="runs a StableLM model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--precision", "-p", default="fp16", choices=["fp32", "fp16", "int4"]
|
||||
)
|
||||
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
|
||||
parser.add_argument(
|
||||
"--stablelm_vmfb_path", default=None, help="path to StableLM's vmfb"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stablelm_mlir_path",
|
||||
default=None,
|
||||
help="path to StableLM's mlir file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_precompiled_model",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="use the precompiled vmfb",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_mlir_from_shark_tank",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="download precompile mlir from shark tank",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf_auth_token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify your own huggingface authentication token for stablelm-3B model.",
|
||||
)
|
||||
|
||||
|
||||
class StopOnTokens(StoppingCriteria):
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
||||
) -> bool:
|
||||
stop_ids = [50278, 50279, 50277, 1, 0]
|
||||
for stop_id in stop_ids:
|
||||
if input_ids[0][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class SharkStableLM(SharkLLMBase):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path="stabilityai/stablelm-tuned-alpha-3b",
|
||||
max_num_tokens=256,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
debug="False",
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
self.max_sequence_len = 256
|
||||
self.device = device
|
||||
if precision != "int4" and args.hf_auth_token == None:
|
||||
raise ValueError(
|
||||
""" HF auth token required for StableLM-3B. Pass it using
|
||||
--hf_auth_token flag. You can ask for the access to the model
|
||||
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
|
||||
)
|
||||
self.hf_auth_token = args.hf_auth_token
|
||||
|
||||
self.precision = precision
|
||||
self.debug = debug
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.shark_model = self.compile()
|
||||
|
||||
def shouldStop(self, tokens):
|
||||
stop_ids = [50278, 50279, 50277, 1, 0]
|
||||
for stop_id in stop_ids:
|
||||
if tokens[0][-1] == stop_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_src_model(self):
|
||||
kwargs = {}
|
||||
if self.precision == "int4":
|
||||
self.hf_model_path = "TheBloke/stablelm-zephyr-3b-GPTQ"
|
||||
from transformers import GPTQConfig
|
||||
|
||||
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
|
||||
kwargs["quantization_config"] = quantization_config
|
||||
kwargs["device_map"] = "cpu"
|
||||
print("[DEBUG] Loading Model")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path,
|
||||
trust_remote_code=True,
|
||||
torch_dtype=torch.float32,
|
||||
use_auth_token=self.hf_auth_token,
|
||||
**kwargs,
|
||||
)
|
||||
print("[DEBUG] Model loaded successfully")
|
||||
return model
|
||||
|
||||
def get_model_inputs(self):
|
||||
input_ids = torch.randint(3, (1, self.max_sequence_len))
|
||||
attention_mask = torch.randint(3, (1, self.max_sequence_len))
|
||||
return input_ids, attention_mask
|
||||
|
||||
def compile(self):
|
||||
tmp_model_name = f"{self.model_name}_linalg_{self.precision}_seqLen{self.max_sequence_len}"
|
||||
|
||||
# device = "cuda" # "cpu"
|
||||
# TODO: vmfb and mlir name should include precision and device
|
||||
model_vmfb_name = None
|
||||
vmfb_path = (
|
||||
Path(tmp_model_name + f"_{self.device}.vmfb")
|
||||
if model_vmfb_name is None
|
||||
else Path(model_vmfb_name)
|
||||
)
|
||||
shark_module = get_vmfb_from_path(
|
||||
vmfb_path, self.device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
if shark_module is not None:
|
||||
return shark_module
|
||||
|
||||
mlir_path = Path(tmp_model_name + ".mlir")
|
||||
print(
|
||||
f"[DEBUG] mlir path {mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
|
||||
)
|
||||
if not mlir_path.exists():
|
||||
model = StableLMModel(self.get_src_model())
|
||||
model_inputs = self.get_model_inputs()
|
||||
from shark.shark_importer import import_with_fx
|
||||
|
||||
ts_graph = import_with_fx(
|
||||
model,
|
||||
model_inputs,
|
||||
is_f16=True if self.precision in ["fp16"] else False,
|
||||
precision=self.precision,
|
||||
f16_input_mask=[False, False],
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*model_inputs],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
f_ = open(mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
print("Saved mlir at: ", mlir_path)
|
||||
f_.close()
|
||||
del bytecode
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=mlir_path, device=self.device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
path = shark_module.save_module(
|
||||
vmfb_path.parent.absolute(), vmfb_path.stem, debug=self.debug
|
||||
)
|
||||
print("Saved vmfb at ", str(path))
|
||||
|
||||
return shark_module
|
||||
|
||||
def get_tokenizer(self):
|
||||
tok = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path,
|
||||
use_auth_token=self.hf_auth_token,
|
||||
)
|
||||
tok.add_special_tokens({"pad_token": "<PAD>"})
|
||||
# print("[DEBUG] Sucessfully loaded the tokenizer to the memory")
|
||||
return tok
|
||||
|
||||
def generate(self, prompt):
|
||||
words_list = []
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
count = 0
|
||||
for i in range(self.max_num_tokens):
|
||||
count = count + 1
|
||||
params = {
|
||||
"new_text": prompt,
|
||||
}
|
||||
|
||||
generated_token_op = self.generate_new_token(params)
|
||||
|
||||
detok = generated_token_op["detok"]
|
||||
stop_generation = generated_token_op["stop_generation"]
|
||||
|
||||
if stop_generation:
|
||||
break
|
||||
|
||||
print(detok, end="", flush=True) # this is for CLI and DEBUG
|
||||
words_list.append(detok)
|
||||
if detok == "":
|
||||
break
|
||||
prompt = prompt + detok
|
||||
end = time.time()
|
||||
print(
|
||||
"\n\nTime taken is {:.2f} tokens/second\n".format(
|
||||
count / (end - start)
|
||||
)
|
||||
)
|
||||
return words_list
|
||||
|
||||
def generate_new_token(self, params):
|
||||
new_text = params["new_text"]
|
||||
model_inputs = self.tokenizer(
|
||||
[new_text],
|
||||
padding="max_length",
|
||||
max_length=self.max_sequence_len,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
sum_attentionmask = torch.sum(model_inputs.attention_mask)
|
||||
output = self.shark_model(
|
||||
"forward", [model_inputs.input_ids, model_inputs.attention_mask]
|
||||
)
|
||||
output = torch.from_numpy(output)
|
||||
next_toks = torch.topk(output, 1)
|
||||
stop_generation = False
|
||||
if self.shouldStop(next_toks.indices):
|
||||
stop_generation = True
|
||||
new_token = next_toks.indices[0][int(sum_attentionmask) - 1]
|
||||
detok = self.tokenizer.decode(
|
||||
new_token,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
ret_dict = {
|
||||
"new_token": new_token,
|
||||
"detok": detok,
|
||||
"stop_generation": stop_generation,
|
||||
}
|
||||
return ret_dict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
stable_lm = SharkStableLM(
|
||||
model_name="stablelm_zephyr_3b",
|
||||
hf_model_path="stabilityai/stablelm-zephyr-3b",
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
)
|
||||
|
||||
default_prompt_text = "The weather is always wonderful"
|
||||
continue_execution = True
|
||||
|
||||
print("\n-----\nScript executing for the following config: \n")
|
||||
print("StableLM Model: ", stable_lm.hf_model_path)
|
||||
print("Precision: ", args.precision)
|
||||
print("Device: ", args.device)
|
||||
|
||||
while continue_execution:
|
||||
use_default_prompt = input(
|
||||
"\nDo you wish to use the default prompt text? Y/N ?: "
|
||||
)
|
||||
if use_default_prompt in ["Y", "y"]:
|
||||
prompt = default_prompt_text
|
||||
else:
|
||||
prompt = input("Please enter the prompt text: ")
|
||||
print("\nPrompt Text: ", prompt)
|
||||
|
||||
res_str = stable_lm.generate(prompt)
|
||||
torch.cuda.empty_cache()
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
print(
|
||||
"\n\n-----\nHere's the complete formatted result: \n\n",
|
||||
prompt + "".join(res_str),
|
||||
)
|
||||
continue_execution = input(
|
||||
"\nDo you wish to run script one more time? Y/N ?: "
|
||||
)
|
||||
continue_execution = (
|
||||
True if continue_execution in ["Y", "y"] else False
|
||||
)
|
||||
48
apps/language_models/utils.py
Normal file
48
apps/language_models/utils.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
from shark.shark_downloader import download_public_file
|
||||
|
||||
|
||||
# expects a Path / str as arg
|
||||
# returns None if path not found or SharkInference module
|
||||
def get_vmfb_from_path(vmfb_path, device, mlir_dialect, device_id=None):
|
||||
if not isinstance(vmfb_path, Path):
|
||||
vmfb_path = Path(vmfb_path)
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
if not vmfb_path.exists():
|
||||
return None
|
||||
|
||||
print("Loading vmfb from: ", vmfb_path)
|
||||
print("Device from get_vmfb_from_path - ", device)
|
||||
shark_module = SharkInference(
|
||||
None, device=device, mlir_dialect=mlir_dialect, device_idx=device_id
|
||||
)
|
||||
shark_module.load_module(vmfb_path)
|
||||
print("Successfully loaded vmfb")
|
||||
return shark_module
|
||||
|
||||
|
||||
def get_vmfb_from_config(
|
||||
shark_container,
|
||||
model,
|
||||
precision,
|
||||
device,
|
||||
vmfb_path,
|
||||
padding=None,
|
||||
device_id=None,
|
||||
):
|
||||
vmfb_url = (
|
||||
f"gs://shark_tank/{shark_container}/{model}_{precision}_{device}"
|
||||
)
|
||||
if padding:
|
||||
vmfb_url = vmfb_url + f"_{padding}"
|
||||
vmfb_url = vmfb_url + ".vmfb"
|
||||
download_public_file(vmfb_url, vmfb_path.absolute(), single_file=True)
|
||||
return get_vmfb_from_path(
|
||||
vmfb_path, device, "tm_tensor", device_id=device_id
|
||||
)
|
||||
91
apps/shark_studio/api/llm.py
Normal file
91
apps/shark_studio/api/llm.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from turbine_models.custom_models import stateless_llama
|
||||
from shark.iree_utils.compile_utils import get_iree_compiled_module
|
||||
from apps.shark_studio.api.utils import get_resource_path
|
||||
import iree.runtime as ireert
|
||||
import gc
|
||||
import torch
|
||||
|
||||
llm_model_map = {
|
||||
"llama2_7b": {
|
||||
"initializer": stateless_llama.export_transformer_model,
|
||||
"hf_model_name": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"stop_token": 2,
|
||||
"max_tokens": 4096,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class LanguageModel:
|
||||
def __init__(
|
||||
self, model_name, hf_auth_token=None, device=None, precision="fp32"
|
||||
):
|
||||
print(llm_model_map[model_name])
|
||||
self.hf_model_name = llm_model_map[model_name]["hf_model_name"]
|
||||
self.torch_ir, self.tokenizer = llm_model_map[model_name][
|
||||
"initializer"
|
||||
](self.hf_model_name, hf_auth_token, compile_to="torch")
|
||||
self.tempfile_name = get_resource_path("llm.torch.tempfile")
|
||||
with open(self.tempfile_name, "w+") as f:
|
||||
f.write(self.torch_ir)
|
||||
del self.torch_ir
|
||||
gc.collect()
|
||||
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.max_tokens = llm_model_map[model_name]["max_tokens"]
|
||||
self.iree_module_dict = None
|
||||
self.compile()
|
||||
|
||||
def compile(self) -> None:
|
||||
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
|
||||
self.iree_module_dict = get_iree_compiled_module(
|
||||
self.tempfile_name, device=self.device, frontend="torch"
|
||||
)
|
||||
# TODO: delete the temp file
|
||||
|
||||
def chat(self, prompt):
|
||||
history = []
|
||||
for iter in range(self.max_tokens):
|
||||
input_tensor = self.tokenizer(
|
||||
prompt, return_tensors="pt"
|
||||
).input_ids
|
||||
device_inputs = [
|
||||
ireert.asdevicearray(
|
||||
self.iree_module_dict["config"], input_tensor
|
||||
)
|
||||
]
|
||||
if iter == 0:
|
||||
token = torch.tensor(
|
||||
self.iree_module_dict["vmfb"]["run_initialize"](
|
||||
*device_inputs
|
||||
).to_host()[0][0]
|
||||
)
|
||||
else:
|
||||
token = torch.tensor(
|
||||
self.iree_module_dict["vmfb"]["run_forward"](
|
||||
*device_inputs
|
||||
).to_host()[0][0]
|
||||
)
|
||||
|
||||
history.append(token)
|
||||
yield self.tokenizer.decode(history)
|
||||
|
||||
if token == llm_model_map["llama2_7b"]["stop_token"]:
|
||||
break
|
||||
|
||||
for i in range(len(history)):
|
||||
if type(history[i]) != int:
|
||||
history[i] = int(history[i])
|
||||
result_output = self.tokenizer.decode(history)
|
||||
yield result_output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
lm = LanguageModel(
|
||||
"llama2_7b",
|
||||
hf_auth_token="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",
|
||||
device="cpu-task",
|
||||
)
|
||||
print("model loaded")
|
||||
for i in lm.chat("Hello, I am a robot."):
|
||||
print(i)
|
||||
14
apps/shark_studio/api/utils.py
Normal file
14
apps/shark_studio/api/utils.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def get_available_devices():
|
||||
return ["cpu-task"]
|
||||
|
||||
|
||||
def get_resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
base_path = getattr(
|
||||
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
|
||||
)
|
||||
return os.path.join(base_path, relative_path)
|
||||
428
apps/shark_studio/web/index.py
Normal file
428
apps/shark_studio/web/index.py
Normal file
@@ -0,0 +1,428 @@
|
||||
from multiprocessing import Process, freeze_support
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from ui.chat import chat_element
|
||||
|
||||
if sys.platform == "darwin":
|
||||
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
|
||||
# import before IREE to avoid MLIR library issues
|
||||
import torch_mlir
|
||||
|
||||
# import PIL, transformers, sentencepiece # ensures inclusion in pysintaller exe generation
|
||||
# from apps.stable_diffusion.src import args, clear_all
|
||||
# import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
|
||||
|
||||
def launch_app(address):
|
||||
from tkinter import Tk
|
||||
import webview
|
||||
|
||||
window = Tk()
|
||||
|
||||
# get screen width and height of display and make it more reasonably
|
||||
# sized as we aren't making it full-screen or maximized
|
||||
width = int(window.winfo_screenwidth() * 0.81)
|
||||
height = int(window.winfo_screenheight() * 0.91)
|
||||
webview.create_window(
|
||||
"SHARK AI Studio",
|
||||
url=address,
|
||||
width=width,
|
||||
height=height,
|
||||
text_select=True,
|
||||
)
|
||||
webview.start(private_mode=False, storage_path=os.getcwd())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# if args.debug:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
# required to do multiprocessing in a pyinstaller freeze
|
||||
freeze_support()
|
||||
# if args.api or "api" in args.ui.split(","):
|
||||
# from apps.stable_diffusion.web.ui import (
|
||||
# txt2img_api,
|
||||
# img2img_api,
|
||||
# upscaler_api,
|
||||
# inpaint_api,
|
||||
# outpaint_api,
|
||||
# llm_chat_api,
|
||||
# )
|
||||
#
|
||||
# from fastapi import FastAPI, APIRouter
|
||||
# import uvicorn
|
||||
#
|
||||
# # init global sd pipeline and config
|
||||
# global_obj._init()
|
||||
#
|
||||
# app = FastAPI()
|
||||
# app.add_api_route("/sdapi/v1/txt2img", txt2img_api, methods=["post"])
|
||||
# app.add_api_route("/sdapi/v1/img2img", img2img_api, methods=["post"])
|
||||
# app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"])
|
||||
# app.add_api_route("/sdapi/v1/outpaint", outpaint_api, methods=["post"])
|
||||
# app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"])
|
||||
#
|
||||
# # chat APIs needed for compatibility with multiple extensions using OpenAI API
|
||||
# app.add_api_route(
|
||||
# "/v1/chat/completions", llm_chat_api, methods=["post"]
|
||||
# )
|
||||
# app.add_api_route("/v1/completions", llm_chat_api, methods=["post"])
|
||||
# app.add_api_route("/chat/completions", llm_chat_api, methods=["post"])
|
||||
# app.add_api_route("/completions", llm_chat_api, methods=["post"])
|
||||
# app.add_api_route(
|
||||
# "/v1/engines/codegen/completions", llm_chat_api, methods=["post"]
|
||||
# )
|
||||
# app.include_router(APIRouter())
|
||||
# uvicorn.run(app, host="0.0.0.0", port=args.server_port)
|
||||
# sys.exit(0)
|
||||
#
|
||||
# Setup to use shark_tmp for gradio's temporary image files and clear any
|
||||
# existing temporary images there if they exist. Then we can import gradio.
|
||||
# It has to be in this order or gradio ignores what we've set up.
|
||||
# from apps.stable_diffusion.web.utils.gradio_configs import (
|
||||
# config_gradio_tmp_imgs_folder,
|
||||
# )
|
||||
|
||||
# config_gradio_tmp_imgs_folder()
|
||||
import gradio as gr
|
||||
|
||||
# Create custom models folders if they don't exist
|
||||
# from apps.stable_diffusion.web.ui.utils import create_custom_models_folders
|
||||
|
||||
# create_custom_models_folders()
|
||||
|
||||
def resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
base_path = getattr(
|
||||
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
|
||||
)
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
dark_theme = resource_path("ui/css/sd_dark_theme.css")
|
||||
|
||||
# from apps.stable_diffusion.web.ui import (
|
||||
# txt2img_web,
|
||||
# txt2img_custom_model,
|
||||
# txt2img_gallery,
|
||||
# txt2img_png_info_img,
|
||||
# txt2img_status,
|
||||
# txt2img_sendto_img2img,
|
||||
# txt2img_sendto_inpaint,
|
||||
# txt2img_sendto_outpaint,
|
||||
# txt2img_sendto_upscaler,
|
||||
## h2ogpt_upload,
|
||||
## h2ogpt_web,
|
||||
# img2img_web,
|
||||
# img2img_custom_model,
|
||||
# img2img_gallery,
|
||||
# img2img_init_image,
|
||||
# img2img_status,
|
||||
# img2img_sendto_inpaint,
|
||||
# img2img_sendto_outpaint,
|
||||
# img2img_sendto_upscaler,
|
||||
# inpaint_web,
|
||||
# inpaint_custom_model,
|
||||
# inpaint_gallery,
|
||||
# inpaint_init_image,
|
||||
# inpaint_status,
|
||||
# inpaint_sendto_img2img,
|
||||
# inpaint_sendto_outpaint,
|
||||
# inpaint_sendto_upscaler,
|
||||
# outpaint_web,
|
||||
# outpaint_custom_model,
|
||||
# outpaint_gallery,
|
||||
# outpaint_init_image,
|
||||
# outpaint_status,
|
||||
# outpaint_sendto_img2img,
|
||||
# outpaint_sendto_inpaint,
|
||||
# outpaint_sendto_upscaler,
|
||||
# upscaler_web,
|
||||
# upscaler_custom_model,
|
||||
# upscaler_gallery,
|
||||
# upscaler_init_image,
|
||||
# upscaler_status,
|
||||
# upscaler_sendto_img2img,
|
||||
# upscaler_sendto_inpaint,
|
||||
# upscaler_sendto_outpaint,
|
||||
## lora_train_web,
|
||||
## model_web,
|
||||
## model_config_web,
|
||||
# hf_models,
|
||||
# modelmanager_sendto_txt2img,
|
||||
# modelmanager_sendto_img2img,
|
||||
# modelmanager_sendto_inpaint,
|
||||
# modelmanager_sendto_outpaint,
|
||||
# modelmanager_sendto_upscaler,
|
||||
# stablelm_chat,
|
||||
# minigpt4_web,
|
||||
# outputgallery_web,
|
||||
# outputgallery_tab_select,
|
||||
# outputgallery_watch,
|
||||
# outputgallery_filename,
|
||||
# outputgallery_sendto_txt2img,
|
||||
# outputgallery_sendto_img2img,
|
||||
# outputgallery_sendto_inpaint,
|
||||
# outputgallery_sendto_outpaint,
|
||||
# outputgallery_sendto_upscaler,
|
||||
# )
|
||||
|
||||
# init global sd pipeline and config
|
||||
# global_obj._init()
|
||||
|
||||
def register_button_click(button, selectedid, inputs, outputs):
|
||||
button.click(
|
||||
lambda x: (
|
||||
x[0]["name"] if len(x) != 0 else None,
|
||||
gr.Tabs.update(selected=selectedid),
|
||||
),
|
||||
inputs,
|
||||
outputs,
|
||||
)
|
||||
|
||||
def register_modelmanager_button(button, selectedid, inputs, outputs):
|
||||
button.click(
|
||||
lambda x: (
|
||||
"None",
|
||||
x,
|
||||
gr.Tabs.update(selected=selectedid),
|
||||
),
|
||||
inputs,
|
||||
outputs,
|
||||
)
|
||||
|
||||
def register_outputgallery_button(button, selectedid, inputs, outputs):
|
||||
button.click(
|
||||
lambda x: (
|
||||
x,
|
||||
gr.Tabs.update(selected=selectedid),
|
||||
),
|
||||
inputs,
|
||||
outputs,
|
||||
)
|
||||
|
||||
with gr.Blocks(
|
||||
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
|
||||
) as sd_web:
|
||||
with gr.Tabs() as tabs:
|
||||
# NOTE: If adding, removing, or re-ordering tabs, make sure that they
|
||||
# have a unique id that doesn't clash with any of the other tabs,
|
||||
# and that the order in the code here is the order they should
|
||||
# appear in the ui, as the id value doesn't determine the order.
|
||||
|
||||
# Where possible, avoid changing the id of any tab that is the
|
||||
# destination of one of the 'send to' buttons. If you do have to change
|
||||
# that id, make sure you update the relevant register_button_click calls
|
||||
# further down with the new id.
|
||||
# with gr.TabItem(label="Text-to-Image", id=0):
|
||||
# txt2img_web.render()
|
||||
# with gr.TabItem(label="Image-to-Image", id=1):
|
||||
# img2img_web.render()
|
||||
# with gr.TabItem(label="Inpainting", id=2):
|
||||
# inpaint_web.render()
|
||||
# with gr.TabItem(label="Outpainting", id=3):
|
||||
# outpaint_web.render()
|
||||
# with gr.TabItem(label="Upscaler", id=4):
|
||||
# upscaler_web.render()
|
||||
# if args.output_gallery:
|
||||
# with gr.TabItem(label="Output Gallery", id=5) as og_tab:
|
||||
# outputgallery_web.render()
|
||||
|
||||
# # extra output gallery configuration
|
||||
# outputgallery_tab_select(og_tab.select)
|
||||
# outputgallery_watch(
|
||||
# [
|
||||
# txt2img_status,
|
||||
# img2img_status,
|
||||
# inpaint_status,
|
||||
# outpaint_status,
|
||||
# upscaler_status,
|
||||
# ]
|
||||
# )
|
||||
## with gr.TabItem(label="Model Manager", id=6):
|
||||
## model_web.render()
|
||||
## with gr.TabItem(label="LoRA Training (Experimental)", id=7):
|
||||
## lora_train_web.render()
|
||||
with gr.TabItem(label="Chat Bot", id=0):
|
||||
chat_element.render()
|
||||
## with gr.TabItem(
|
||||
## label="Generate Sharding Config (Experimental)", id=9
|
||||
## ):
|
||||
## model_config_web.render()
|
||||
# with gr.TabItem(label="MultiModal (Experimental)", id=10):
|
||||
# minigpt4_web.render()
|
||||
# with gr.TabItem(label="DocuChat Upload", id=11):
|
||||
# h2ogpt_upload.render()
|
||||
# with gr.TabItem(label="DocuChat(Experimental)", id=12):
|
||||
# h2ogpt_web.render()
|
||||
|
||||
# send to buttons
|
||||
# register_button_click(
|
||||
# txt2img_sendto_img2img,
|
||||
# 1,
|
||||
# [txt2img_gallery],
|
||||
# [img2img_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# txt2img_sendto_inpaint,
|
||||
# 2,
|
||||
# [txt2img_gallery],
|
||||
# [inpaint_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# txt2img_sendto_outpaint,
|
||||
# 3,
|
||||
# [txt2img_gallery],
|
||||
# [outpaint_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# txt2img_sendto_upscaler,
|
||||
# 4,
|
||||
# [txt2img_gallery],
|
||||
# [upscaler_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# img2img_sendto_inpaint,
|
||||
# 2,
|
||||
# [img2img_gallery],
|
||||
# [inpaint_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# img2img_sendto_outpaint,
|
||||
# 3,
|
||||
# [img2img_gallery],
|
||||
# [outpaint_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# img2img_sendto_upscaler,
|
||||
# 4,
|
||||
# [img2img_gallery],
|
||||
# [upscaler_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# inpaint_sendto_img2img,
|
||||
# 1,
|
||||
# [inpaint_gallery],
|
||||
# [img2img_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# inpaint_sendto_outpaint,
|
||||
# 3,
|
||||
# [inpaint_gallery],
|
||||
# [outpaint_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# inpaint_sendto_upscaler,
|
||||
# 4,
|
||||
# [inpaint_gallery],
|
||||
# [upscaler_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# outpaint_sendto_img2img,
|
||||
# 1,
|
||||
# [outpaint_gallery],
|
||||
# [img2img_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# outpaint_sendto_inpaint,
|
||||
# 2,
|
||||
# [outpaint_gallery],
|
||||
# [inpaint_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# outpaint_sendto_upscaler,
|
||||
# 4,
|
||||
# [outpaint_gallery],
|
||||
# [upscaler_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# upscaler_sendto_img2img,
|
||||
# 1,
|
||||
# [upscaler_gallery],
|
||||
# [img2img_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# upscaler_sendto_inpaint,
|
||||
# 2,
|
||||
# [upscaler_gallery],
|
||||
# [inpaint_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# upscaler_sendto_outpaint,
|
||||
# 3,
|
||||
# [upscaler_gallery],
|
||||
# [outpaint_init_image, tabs],
|
||||
# )
|
||||
# if args.output_gallery:
|
||||
# register_outputgallery_button(
|
||||
# outputgallery_sendto_txt2img,
|
||||
# 0,
|
||||
# [outputgallery_filename],
|
||||
# [txt2img_png_info_img, tabs],
|
||||
# )
|
||||
# register_outputgallery_button(
|
||||
# outputgallery_sendto_img2img,
|
||||
# 1,
|
||||
# [outputgallery_filename],
|
||||
# [img2img_init_image, tabs],
|
||||
# )
|
||||
# register_outputgallery_button(
|
||||
# outputgallery_sendto_inpaint,
|
||||
# 2,
|
||||
# [outputgallery_filename],
|
||||
# [inpaint_init_image, tabs],
|
||||
# )
|
||||
# register_outputgallery_button(
|
||||
# outputgallery_sendto_outpaint,
|
||||
# 3,
|
||||
# [outputgallery_filename],
|
||||
# [outpaint_init_image, tabs],
|
||||
# )
|
||||
# register_outputgallery_button(
|
||||
# outputgallery_sendto_upscaler,
|
||||
# 4,
|
||||
# [outputgallery_filename],
|
||||
# [upscaler_init_image, tabs],
|
||||
# )
|
||||
# register_modelmanager_button(
|
||||
# modelmanager_sendto_txt2img,
|
||||
# 0,
|
||||
# [hf_models],
|
||||
# [txt2img_custom_model, tabs],
|
||||
# )
|
||||
# register_modelmanager_button(
|
||||
# modelmanager_sendto_img2img,
|
||||
# 1,
|
||||
# [hf_models],
|
||||
# [img2img_custom_model, tabs],
|
||||
# )
|
||||
# register_modelmanager_button(
|
||||
# modelmanager_sendto_inpaint,
|
||||
# 2,
|
||||
# [hf_models],
|
||||
# [inpaint_custom_model, tabs],
|
||||
# )
|
||||
# register_modelmanager_button(
|
||||
# modelmanager_sendto_outpaint,
|
||||
# 3,
|
||||
# [hf_models],
|
||||
# [outpaint_custom_model, tabs],
|
||||
# )
|
||||
# register_modelmanager_button(
|
||||
# modelmanager_sendto_upscaler,
|
||||
# 4,
|
||||
# [hf_models],
|
||||
# [upscaler_custom_model, tabs],
|
||||
# )
|
||||
|
||||
sd_web.queue()
|
||||
# if args.ui == "app":
|
||||
# t = Process(
|
||||
# target=launch_app, args=[f"http://localhost:{args.server_port}"]
|
||||
# )
|
||||
# t.start()
|
||||
sd_web.launch(
|
||||
share=True,
|
||||
inbrowser=True,
|
||||
server_name="0.0.0.0",
|
||||
server_port=11911, # args.server_port,
|
||||
)
|
||||
517
apps/shark_studio/web/ui/chat.py
Normal file
517
apps/shark_studio/web/ui/chat.py
Normal file
@@ -0,0 +1,517 @@
|
||||
import gradio as gr
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datetime import datetime as dt
|
||||
import json
|
||||
import sys
|
||||
from apps.shark_studio.api.utils import (
|
||||
get_available_devices,
|
||||
)
|
||||
from apps.shark_studio.api.llm import (
|
||||
llm_model_map,
|
||||
LanguageModel,
|
||||
)
|
||||
|
||||
|
||||
def user(message, history):
|
||||
# Append the user's message to the conversation history
|
||||
return "", history + [[message, ""]]
|
||||
|
||||
|
||||
language_model = None
|
||||
|
||||
|
||||
# NOTE: Each `model_name` should have its own start message
|
||||
start_message = {
|
||||
"llama2_7b": (
|
||||
"You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
|
||||
"illegal content. Please ensure that your responses are socially "
|
||||
"unbiased and positive in nature. If a question does not make any "
|
||||
"sense, or is not factually coherent, explain why instead of "
|
||||
"answering something not correct. If you don't know the answer "
|
||||
"to a question, please don't share false information."
|
||||
),
|
||||
"llama2_13b": (
|
||||
"You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
|
||||
"illegal content. Please ensure that your responses are socially "
|
||||
"unbiased and positive in nature. If a question does not make any "
|
||||
"sense, or is not factually coherent, explain why instead of "
|
||||
"answering something not correct. If you don't know the answer "
|
||||
"to a question, please don't share false information."
|
||||
),
|
||||
"llama2_70b": (
|
||||
"You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
|
||||
"illegal content. Please ensure that your responses are socially "
|
||||
"unbiased and positive in nature. If a question does not make any "
|
||||
"sense, or is not factually coherent, explain why instead of "
|
||||
"answering something not correct. If you don't know the answer "
|
||||
"to a question, please don't share false information."
|
||||
),
|
||||
"vicuna": (
|
||||
"A chat between a curious user and an artificial intelligence "
|
||||
"assistant. The assistant gives helpful, detailed, and "
|
||||
"polite answers to the user's questions.\n"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def create_prompt(model_name, history, prompt_prefix):
|
||||
return ""
|
||||
system_message = ""
|
||||
if prompt_prefix:
|
||||
system_message = start_message[model_name]
|
||||
|
||||
if "llama2" in model_name:
|
||||
B_INST, E_INST = "[INST]", "[/INST]"
|
||||
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
||||
conversation = "".join(
|
||||
[f"{B_INST} {item[0]} {E_INST} {item[1]} " for item in history[1:]]
|
||||
)
|
||||
if prompt_prefix:
|
||||
msg = f"{B_INST} {B_SYS}{system_message}{E_SYS}{history[0][0]} {E_INST} {history[0][1]} {conversation}"
|
||||
else:
|
||||
msg = f"{B_INST} {history[0][0]} {E_INST} {history[0][1]} {conversation}"
|
||||
elif model_name in ["vicuna"]:
|
||||
conversation = "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
for item in history
|
||||
]
|
||||
)
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
else:
|
||||
conversation = "".join(
|
||||
["".join([item[0], item[1]]) for item in history]
|
||||
)
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
return msg
|
||||
|
||||
|
||||
def get_default_config():
|
||||
return False
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
|
||||
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
|
||||
compilation_prompt = "".join(["0" for _ in range(17)])
|
||||
compilation_input_ids = tokenizer(
|
||||
compilation_prompt,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
|
||||
[1, 19]
|
||||
)
|
||||
firstVicunaCompileInput = (compilation_input_ids,)
|
||||
from apps.language_models.src.model_wrappers.vicuna_model import (
|
||||
CombinedModel,
|
||||
)
|
||||
from shark.shark_generate_model_config import GenerateConfigFile
|
||||
|
||||
model = CombinedModel()
|
||||
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
|
||||
c.split_into_layers()
|
||||
|
||||
|
||||
# model_vmfb_key = ""
|
||||
|
||||
|
||||
def chat_fn(
|
||||
prompt_prefix,
|
||||
history,
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
cli=False,
|
||||
progress=gr.Progress(),
|
||||
):
|
||||
global language_model
|
||||
if language_model is None:
|
||||
language_model = LanguageModel(
|
||||
model, device=device, precision=precision
|
||||
)
|
||||
|
||||
language_model.chat(prompt_prefix)
|
||||
return "", ""
|
||||
global past_key_values
|
||||
global model_vmfb_key
|
||||
|
||||
device_id = None
|
||||
model_name, model_path = list(map(str.strip, model.split("=>")))
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "sync" in device:
|
||||
device = "cpu-sync"
|
||||
elif "task" in device:
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device_id = int(device.split("://")[1])
|
||||
device = "vulkan"
|
||||
elif "rocm" in device:
|
||||
device = "rocm"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
|
||||
from apps.language_models.scripts.vicuna import ShardedVicuna
|
||||
from apps.language_models.scripts.vicuna import UnshardedVicuna
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{device_id}#{precision}#{download_vmfb}"
|
||||
if vicuna_model is None or new_model_vmfb_key != model_vmfb_key:
|
||||
model_vmfb_key = new_model_vmfb_key
|
||||
max_toks = 128 if model_name == "codegen" else 512
|
||||
|
||||
# get iree flags that need to be overridden, from commandline args
|
||||
_extra_args = []
|
||||
# vulkan target triple
|
||||
vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
get_all_vulkan_devices,
|
||||
get_vulkan_target_triple,
|
||||
)
|
||||
|
||||
if device == "vulkan":
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
if vulkan_target_triple == "":
|
||||
# We already have the device_id extracted via WebUI, so we directly use
|
||||
# that to find the target triple.
|
||||
vulkan_target_triple = get_vulkan_target_triple(
|
||||
vulkaninfo_list[device_id]
|
||||
)
|
||||
_extra_args.append(
|
||||
f"-iree-vulkan-target-triple={vulkan_target_triple}"
|
||||
)
|
||||
if "rdna" in vulkan_target_triple:
|
||||
flags_to_add = [
|
||||
"--iree-spirv-index-bits=64",
|
||||
]
|
||||
_extra_args = _extra_args + flags_to_add
|
||||
|
||||
if device_id is None:
|
||||
id = 0
|
||||
for device in vulkaninfo_list:
|
||||
target_triple = get_vulkan_target_triple(
|
||||
vulkaninfo_list[id]
|
||||
)
|
||||
if target_triple == vulkan_target_triple:
|
||||
device_id = id
|
||||
break
|
||||
id += 1
|
||||
|
||||
assert (
|
||||
device_id
|
||||
), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"
|
||||
print(f"Will use vulkan target triple : {vulkan_target_triple}")
|
||||
|
||||
elif "rocm" in device:
|
||||
# add iree rocm flags
|
||||
_extra_args.append(
|
||||
f"--iree-rocm-target-chip={args.iree_rocm_target_chip}"
|
||||
)
|
||||
print(f"extra args = {_extra_args}")
|
||||
|
||||
if model_name == "vicuna4":
|
||||
vicuna_model = ShardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
compressed=True,
|
||||
extra_args_cmd=_extra_args,
|
||||
)
|
||||
else:
|
||||
# if config_file is None:
|
||||
vicuna_model = UnshardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
hf_auth_token=args.hf_auth_token,
|
||||
device=device,
|
||||
vulkan_target_triple=vulkan_target_triple,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
download_vmfb=download_vmfb,
|
||||
load_mlir_from_shark_tank=True,
|
||||
extra_args_cmd=_extra_args,
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
if vicuna_model is None:
|
||||
sys.exit("Unable to instantiate the model object, exiting.")
|
||||
|
||||
prompt = create_prompt(model_name, history, prompt_prefix)
|
||||
|
||||
partial_text = ""
|
||||
token_count = 0
|
||||
total_time_ms = 0.001 # In order to avoid divide by zero error
|
||||
prefill_time = 0
|
||||
is_first = True
|
||||
for text, msg, exec_time in progress.tqdm(
|
||||
vicuna_model.generate(prompt, cli=cli),
|
||||
desc="generating response",
|
||||
):
|
||||
if msg is None:
|
||||
if is_first:
|
||||
prefill_time = exec_time
|
||||
is_first = False
|
||||
else:
|
||||
total_time_ms += exec_time
|
||||
token_count += 1
|
||||
partial_text += text + " "
|
||||
history[-1][1] = partial_text
|
||||
yield history, f"Prefill: {prefill_time:.2f}"
|
||||
elif "formatted" in msg:
|
||||
history[-1][1] = text
|
||||
tokens_per_sec = (token_count / total_time_ms) * 1000
|
||||
yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec"
|
||||
else:
|
||||
sys.exit(
|
||||
"unexpected message from the vicuna generate call, exiting."
|
||||
)
|
||||
|
||||
return history, ""
|
||||
|
||||
|
||||
def llm_chat_api(InputData: dict):
|
||||
return None
|
||||
print(f"Input keys : {InputData.keys()}")
|
||||
# print(f"model : {InputData['model']}")
|
||||
is_chat_completion_api = (
|
||||
"messages" in InputData.keys()
|
||||
) # else it is the legacy `completion` api
|
||||
# For Debugging input data from API
|
||||
# if is_chat_completion_api:
|
||||
# print(f"message -> role : {InputData['messages'][0]['role']}")
|
||||
# print(f"message -> content : {InputData['messages'][0]['content']}")
|
||||
# else:
|
||||
# print(f"prompt : {InputData['prompt']}")
|
||||
# print(f"max_tokens : {InputData['max_tokens']}") # Default to 128 for now
|
||||
global vicuna_model
|
||||
model_name = (
|
||||
InputData["model"] if "model" in InputData.keys() else "codegen"
|
||||
)
|
||||
model_path = llm_model_map[model_name]
|
||||
device = "cpu-task"
|
||||
precision = "fp16"
|
||||
max_toks = (
|
||||
None
|
||||
if "max_tokens" not in InputData.keys()
|
||||
else InputData["max_tokens"]
|
||||
)
|
||||
if max_toks is None:
|
||||
max_toks = 128 if model_name == "codegen" else 512
|
||||
|
||||
# make it working for codegen first
|
||||
from apps.language_models.scripts.vicuna import (
|
||||
UnshardedVicuna,
|
||||
)
|
||||
|
||||
device_id = None
|
||||
if vicuna_model == 0:
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "sync" in device:
|
||||
device = "cpu-sync"
|
||||
elif "task" in device:
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device_id = int(device.split("://")[1])
|
||||
device = "vulkan"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
|
||||
vicuna_model = UnshardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
download_vmfb=True,
|
||||
load_mlir_from_shark_tank=True,
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
# TODO: add role dict for different models
|
||||
if is_chat_completion_api:
|
||||
# TODO: add funtionality for multiple messages
|
||||
prompt = create_prompt(
|
||||
model_name, [(InputData["messages"][0]["content"], "")]
|
||||
)
|
||||
else:
|
||||
prompt = InputData["prompt"]
|
||||
print("prompt = ", prompt)
|
||||
|
||||
res = vicuna_model.generate(prompt)
|
||||
res_op = None
|
||||
for op in res:
|
||||
res_op = op
|
||||
|
||||
if is_chat_completion_api:
|
||||
choices = [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": res_op, # since we are yeilding the result
|
||||
},
|
||||
"finish_reason": "stop", # or length
|
||||
}
|
||||
]
|
||||
else:
|
||||
choices = [
|
||||
{
|
||||
"text": res_op,
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop", # or length
|
||||
}
|
||||
]
|
||||
end_time = dt.now().strftime("%Y%m%d%H%M%S%f")
|
||||
return {
|
||||
"id": end_time,
|
||||
"object": (
|
||||
"chat.completion" if is_chat_completion_api else "text_completion"
|
||||
),
|
||||
"created": int(end_time),
|
||||
"choices": choices,
|
||||
}
|
||||
|
||||
|
||||
def view_json_file(file_obj):
|
||||
content = ""
|
||||
with open(file_obj.name, "r") as fopen:
|
||||
content = fopen.read()
|
||||
return content
|
||||
|
||||
|
||||
with gr.Blocks(title="Chat") as chat_element:
|
||||
with gr.Row():
|
||||
model_choices = list(llm_model_map.keys())
|
||||
model = gr.Dropdown(
|
||||
label="Select Model",
|
||||
value=model_choices[0],
|
||||
choices=model_choices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
supported_devices = get_available_devices()
|
||||
enabled = True
|
||||
if len(supported_devices) == 0:
|
||||
supported_devices = ["cpu-task"]
|
||||
supported_devices = [x for x in supported_devices if "sync" not in x]
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=supported_devices[0],
|
||||
choices=supported_devices,
|
||||
interactive=enabled,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value="int4",
|
||||
choices=[
|
||||
# "int4",
|
||||
# "int8",
|
||||
# "fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
tokens_time = gr.Textbox(label="Tokens generated per second")
|
||||
with gr.Column():
|
||||
download_vmfb = gr.Checkbox(
|
||||
label="Download vmfb from Shark tank if available",
|
||||
value=True,
|
||||
interactive=True,
|
||||
)
|
||||
prompt_prefix = gr.Checkbox(
|
||||
label="Add System Prompt",
|
||||
value=False,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
chatbot = gr.Chatbot(height=500)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
msg = gr.Textbox(
|
||||
label="Chat Message Box",
|
||||
placeholder="Chat Message Box",
|
||||
show_label=False,
|
||||
interactive=enabled,
|
||||
container=False,
|
||||
)
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
submit = gr.Button("Submit", interactive=enabled)
|
||||
stop = gr.Button("Stop", interactive=enabled)
|
||||
clear = gr.Button("Clear", interactive=enabled)
|
||||
|
||||
with gr.Row(visible=False):
|
||||
with gr.Group():
|
||||
config_file = gr.File(
|
||||
label="Upload sharding configuration", visible=False
|
||||
)
|
||||
json_view_button = gr.Button(label="View as JSON", visible=False)
|
||||
json_view = gr.JSON(interactive=True, visible=False)
|
||||
json_view_button.click(
|
||||
fn=view_json_file, inputs=[config_file], outputs=[json_view]
|
||||
)
|
||||
submit_event = msg.submit(
|
||||
fn=user,
|
||||
inputs=[msg, chatbot],
|
||||
outputs=[msg, chatbot],
|
||||
show_progress=False,
|
||||
queue=False,
|
||||
).then(
|
||||
fn=chat_fn,
|
||||
inputs=[
|
||||
prompt_prefix,
|
||||
chatbot,
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
],
|
||||
outputs=[chatbot, tokens_time],
|
||||
show_progress=False,
|
||||
queue=True,
|
||||
)
|
||||
submit_click_event = submit.click(
|
||||
fn=user,
|
||||
inputs=[msg, chatbot],
|
||||
outputs=[msg, chatbot],
|
||||
show_progress=False,
|
||||
queue=False,
|
||||
).then(
|
||||
fn=chat_fn,
|
||||
inputs=[
|
||||
prompt_prefix,
|
||||
chatbot,
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
],
|
||||
outputs=[chatbot, tokens_time],
|
||||
show_progress=False,
|
||||
queue=True,
|
||||
)
|
||||
stop.click(
|
||||
fn=None,
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
cancels=[submit_event, submit_click_event],
|
||||
queue=False,
|
||||
)
|
||||
clear.click(lambda: None, None, [chatbot], queue=False)
|
||||
87
apps/stable_diffusion/profiling_with_iree.md
Normal file
87
apps/stable_diffusion/profiling_with_iree.md
Normal file
@@ -0,0 +1,87 @@
|
||||
Compile / Run Instructions:
|
||||
|
||||
To compile .vmfb for SD (vae, unet, CLIP), run the following commands with the .mlir in your local shark_tank cache (default location for Linux users is `~/.local/shark_tank`). These will be available once the script from [this README](https://github.com/nod-ai/SHARK/blob/main/shark/examples/shark_inference/stable_diffusion/README.md) is run once.
|
||||
Running the script mentioned above with the `--save_vmfb` flag will also save the .vmfb in your SHARK base directory if you want to skip straight to benchmarks.
|
||||
|
||||
Compile Commands FP32/FP16:
|
||||
|
||||
```shell
|
||||
Vulkan AMD:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux /path/to/input/mlir -o /path/to/output/vmfb
|
||||
|
||||
# add --mlir-print-debuginfo --mlir-print-op-on-diagnostic=true for debug
|
||||
# use –iree-input-type=auto or "mhlo_legacy" or "stablehlo" for TF models
|
||||
|
||||
CUDA NVIDIA:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda /path/to/input/mlir -o /path/to/output/vmfb
|
||||
|
||||
CPU:
|
||||
iree-compile --iree-input-type=none --iree-hal-target-backends=llvm-cpu /path/to/input/mlir -o /path/to/output/vmfb
|
||||
```
|
||||
|
||||
|
||||
|
||||
Run / Benchmark Command (FP32 - NCHW):
|
||||
(NEED to use BS=2 since we do two forward passes to unet as a result of classifier free guidance.)
|
||||
|
||||
```shell
|
||||
## Vulkan AMD:
|
||||
iree-benchmark-module --module=/path/to/output/vmfb --function=forward --device=vulkan --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
|
||||
|
||||
## CUDA:
|
||||
iree-benchmark-module --module=/path/to/vmfb --function=forward --device=cuda --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
|
||||
|
||||
## CPU:
|
||||
iree-benchmark-module --module=/path/to/vmfb --function=forward --device=local-task --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
|
||||
|
||||
```
|
||||
|
||||
Run via vulkan_gui for RGP Profiling:
|
||||
|
||||
To build the vulkan app for profiling UNet follow the instructions [here](https://github.com/nod-ai/SHARK/tree/main/cpp) and then run the following command from the cpp directory with your compiled stable_diff.vmfb
|
||||
```shell
|
||||
./build/vulkan_gui/iree-vulkan-gui --module=/path/to/unet.vmfb --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
|
||||
```
|
||||
|
||||
</details>
|
||||
<details>
|
||||
<summary>Debug Commands</summary>
|
||||
|
||||
## Debug commands and other advanced usage follows.
|
||||
|
||||
```shell
|
||||
python txt2img.py --precision="fp32"|"fp16" --device="cpu"|"cuda"|"vulkan" --import_mlir|--no-import_mlir --prompt "enter the text"
|
||||
```
|
||||
|
||||
## dump all dispatch .spv and isa using amdllpc
|
||||
|
||||
```shell
|
||||
python txt2img.py --precision="fp16" --device="vulkan" --iree-vulkan-target-triple=rdna3-unknown-linux --no-load_vmfb --dispatch_benchmarks="all" --dispatch_benchmarks_dir="SD_dispatches" --dump_isa
|
||||
```
|
||||
|
||||
## Compile and save the .vmfb (using vulkan fp16 as an example):
|
||||
|
||||
```shell
|
||||
python txt2img.py --precision=fp16 --device=vulkan --steps=50 --save_vmfb
|
||||
```
|
||||
|
||||
## Capture an RGP trace
|
||||
|
||||
```shell
|
||||
python txt2img.py --precision=fp16 --device=vulkan --steps=50 --save_vmfb --enable_rgp
|
||||
```
|
||||
|
||||
## Run the vae module with iree-benchmark-module (NCHW, fp16, vulkan, for example):
|
||||
|
||||
```shell
|
||||
iree-benchmark-module --module=/path/to/output/vmfb --function=forward --device=vulkan --input=1x4x64x64xf16
|
||||
```
|
||||
|
||||
## Run the unet module with iree-benchmark-module (same config as above):
|
||||
```shell
|
||||
##if you want to use .npz inputs:
|
||||
unzip ~/.local/shark_tank/<your unet>/inputs.npz
|
||||
iree-benchmark-module --module=/path/to/output/vmfb --function=forward --input=@arr_0.npy --input=1xf16 --input=@arr_2.npy --input=@arr_3.npy --input=@arr_4.npy
|
||||
```
|
||||
|
||||
</details>
|
||||
1
apps/stable_diffusion/scripts/__init__.py
Normal file
1
apps/stable_diffusion/scripts/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from apps.stable_diffusion.scripts.train_lora_word import lora_train
|
||||
128
apps/stable_diffusion/scripts/img2img.py
Normal file
128
apps/stable_diffusion/scripts/img2img.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import sys
|
||||
import torch
|
||||
import time
|
||||
from PIL import Image
|
||||
import transformers
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
Image2ImagePipeline,
|
||||
StencilPipeline,
|
||||
resize_stencil,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
|
||||
|
||||
def main():
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
if args.img_path is None:
|
||||
print("Flag --img_path is required.")
|
||||
exit()
|
||||
|
||||
image = Image.open(args.img_path).convert("RGB")
|
||||
# When the models get uploaded, it should be default to False.
|
||||
args.import_mlir = True
|
||||
|
||||
use_stencil = args.use_stencil
|
||||
if use_stencil:
|
||||
args.scheduler = "DDIM"
|
||||
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
|
||||
image, args.width, args.height = resize_stencil(image)
|
||||
elif "Shark" in args.scheduler:
|
||||
print(
|
||||
f"Shark schedulers are not supported. Switching to EulerDiscrete scheduler"
|
||||
)
|
||||
args.scheduler = "EulerDiscrete"
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
set_init_device_flags()
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
seed = utils.sanitize_seed(args.seed)
|
||||
# Adjust for height and width based on model
|
||||
|
||||
if use_stencil:
|
||||
img2img_obj = StencilPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
use_stencil=use_stencil,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
else:
|
||||
img2img_obj = Image2ImagePipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
generated_imgs = img2img_obj.generate_images(
|
||||
args.prompts,
|
||||
args.negative_prompts,
|
||||
image,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.steps,
|
||||
args.strength,
|
||||
args.guidance_scale,
|
||||
seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
use_stencil=use_stencil,
|
||||
control_mode=args.control_mode,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += f"\nsteps={args.steps}, strength={args.strength}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
text_output += img2img_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
extra_info = {"STRENGTH": args.strength}
|
||||
save_output_img(generated_imgs[0], seed, extra_info)
|
||||
print(text_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
105
apps/stable_diffusion/scripts/inpaint.py
Normal file
105
apps/stable_diffusion/scripts/inpaint.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import torch
|
||||
import time
|
||||
from PIL import Image
|
||||
import transformers
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
InpaintPipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import get_generation_text_info
|
||||
|
||||
|
||||
def main():
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
if args.img_path is None:
|
||||
print("Flag --img_path is required.")
|
||||
exit()
|
||||
if args.mask_path is None:
|
||||
print("Flag --mask_path is required.")
|
||||
exit()
|
||||
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
set_init_device_flags()
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if "inpaint" in args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-inpainting"
|
||||
)
|
||||
schedulers = get_schedulers(model_id)
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
seed = args.seed
|
||||
image = Image.open(args.img_path)
|
||||
mask_image = Image.open(args.mask_path)
|
||||
|
||||
inpaint_obj = InpaintPipeline.from_pretrained(
|
||||
scheduler=scheduler_obj,
|
||||
import_mlir=args.import_mlir,
|
||||
model_id=args.hf_model_id,
|
||||
ckpt_loc=args.ckpt_loc,
|
||||
custom_vae=args.custom_vae,
|
||||
precision=args.precision,
|
||||
max_length=args.max_length,
|
||||
batch_size=args.batch_size,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
use_tuned=args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
|
||||
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
|
||||
for current_batch in range(args.batch_count):
|
||||
start_time = time.time()
|
||||
generated_imgs = inpaint_obj.generate_images(
|
||||
args.prompts,
|
||||
args.negative_prompts,
|
||||
image,
|
||||
mask_image,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.inpaint_full_res,
|
||||
args.inpaint_full_res_padding,
|
||||
args.steps,
|
||||
args.guidance_scale,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += (
|
||||
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
)
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += (
|
||||
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
|
||||
)
|
||||
text_output += f"seed={seed}, size={args.height}x{args.width}"
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
text_output += inpaint_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
save_output_img(generated_imgs[0], seed)
|
||||
print(text_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
19
apps/stable_diffusion/scripts/main.py
Normal file
19
apps/stable_diffusion/scripts/main.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from apps.stable_diffusion.src import args
|
||||
from apps.stable_diffusion.scripts import (
|
||||
img2img,
|
||||
txt2img,
|
||||
# inpaint,
|
||||
# outpaint,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.app == "txt2img":
|
||||
txt2img.main()
|
||||
elif args.app == "img2img":
|
||||
img2img.main()
|
||||
# elif args.app == "inpaint":
|
||||
# inpaint.main()
|
||||
# elif args.app == "outpaint":
|
||||
# outpaint.main()
|
||||
else:
|
||||
print(f"args.app value is {args.app} but this isn't supported")
|
||||
120
apps/stable_diffusion/scripts/outpaint.py
Normal file
120
apps/stable_diffusion/scripts/outpaint.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import torch
|
||||
import time
|
||||
from PIL import Image
|
||||
import transformers
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
OutpaintPipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
if args.img_path is None:
|
||||
print("Flag --img_path is required.")
|
||||
exit()
|
||||
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
set_init_device_flags()
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if "inpaint" in args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-inpainting"
|
||||
)
|
||||
schedulers = get_schedulers(model_id)
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
seed = args.seed
|
||||
image = Image.open(args.img_path)
|
||||
|
||||
outpaint_obj = OutpaintPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
use_lora=args.use_lora,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
|
||||
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
|
||||
for current_batch in range(args.batch_count):
|
||||
start_time = time.time()
|
||||
generated_imgs = outpaint_obj.generate_images(
|
||||
args.prompts,
|
||||
args.negative_prompts,
|
||||
image,
|
||||
args.pixels,
|
||||
args.mask_blur,
|
||||
args.left,
|
||||
args.right,
|
||||
args.top,
|
||||
args.bottom,
|
||||
args.noise_q,
|
||||
args.color_variation,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.steps,
|
||||
args.guidance_scale,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += (
|
||||
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
)
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += (
|
||||
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
|
||||
)
|
||||
text_output += f"seed={seed}, size={args.height}x{args.width}"
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
text_output += outpaint_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
# save this information as metadata of output generated image.
|
||||
directions = []
|
||||
if args.left:
|
||||
directions.append("left")
|
||||
if args.right:
|
||||
directions.append("right")
|
||||
if args.top:
|
||||
directions.append("up")
|
||||
if args.bottom:
|
||||
directions.append("down")
|
||||
extra_info = {
|
||||
"PIXELS": args.pixels,
|
||||
"MASK_BLUR": args.mask_blur,
|
||||
"DIRECTIONS": directions,
|
||||
"NOISE_Q": args.noise_q,
|
||||
"COLOR_VARIATION": args.color_variation,
|
||||
}
|
||||
save_output_img(generated_imgs[0], seed, extra_info)
|
||||
print(text_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
240
apps/stable_diffusion/scripts/telegram_bot.py
Normal file
240
apps/stable_diffusion/scripts/telegram_bot.py
Normal file
@@ -0,0 +1,240 @@
|
||||
import logging
|
||||
import os
|
||||
from models.stable_diffusion.main import stable_diff_inf
|
||||
from models.stable_diffusion.utils import get_available_devices
|
||||
from dotenv import load_dotenv
|
||||
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
|
||||
from telegram import BotCommand
|
||||
from telegram.ext import Application, ApplicationBuilder, CallbackQueryHandler
|
||||
from telegram.ext import ContextTypes, MessageHandler, CommandHandler, filters
|
||||
from io import BytesIO
|
||||
import random
|
||||
|
||||
log = logging.getLogger("TG.Bot")
|
||||
logging.basicConfig()
|
||||
log.warning("Start")
|
||||
load_dotenv()
|
||||
os.environ["AMD_ENABLE_LLPC"] = "0"
|
||||
TG_TOKEN = os.getenv("TG_TOKEN")
|
||||
SELECTED_MODEL = "stablediffusion"
|
||||
SELECTED_SCHEDULER = "EulerAncestralDiscrete"
|
||||
STEPS = 30
|
||||
NEGATIVE_PROMPT = (
|
||||
"Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra"
|
||||
" limbs,Gross proportions,Missing arms,Mutated hands,Long"
|
||||
" neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad"
|
||||
" anatomy,Cloned face,Malformed limbs,Missing legs,Too many"
|
||||
" fingers,blurry, lowres, text, error, cropped, worst quality, low"
|
||||
" quality, jpeg artifacts, out of frame, extra fingers, mutated hands,"
|
||||
" poorly drawn hands, poorly drawn face, bad anatomy, extra limbs, cloned"
|
||||
" face, malformed limbs, missing arms, missing legs, extra arms, extra"
|
||||
" legs, fused fingers, too many fingers"
|
||||
)
|
||||
GUIDANCE_SCALE = 6
|
||||
available_devices = get_available_devices()
|
||||
models_list = [
|
||||
"stablediffusion",
|
||||
"anythingv3",
|
||||
"analogdiffusion",
|
||||
"openjourney",
|
||||
"dreamlike",
|
||||
]
|
||||
sheds_list = [
|
||||
"DDIM",
|
||||
"PNDM",
|
||||
"LMSDiscrete",
|
||||
"DPMSolverMultistep",
|
||||
"EulerDiscrete",
|
||||
"EulerAncestralDiscrete",
|
||||
"SharkEulerDiscrete",
|
||||
]
|
||||
|
||||
|
||||
def image_to_bytes(image):
|
||||
bio = BytesIO()
|
||||
bio.name = "image.jpeg"
|
||||
image.save(bio, "JPEG")
|
||||
bio.seek(0)
|
||||
return bio
|
||||
|
||||
|
||||
def get_try_again_markup():
|
||||
keyboard = [[InlineKeyboardButton("Try again", callback_data="TRYAGAIN")]]
|
||||
reply_markup = InlineKeyboardMarkup(keyboard)
|
||||
return reply_markup
|
||||
|
||||
|
||||
def generate_image(prompt):
|
||||
seed = random.randint(1, 10000)
|
||||
log.warning(SELECTED_MODEL)
|
||||
log.warning(STEPS)
|
||||
image, text = stable_diff_inf(
|
||||
prompt=prompt,
|
||||
negative_prompt=NEGATIVE_PROMPT,
|
||||
steps=STEPS,
|
||||
guidance_scale=GUIDANCE_SCALE,
|
||||
seed=seed,
|
||||
scheduler_key=SELECTED_SCHEDULER,
|
||||
variant=SELECTED_MODEL,
|
||||
device_key=available_devices[0],
|
||||
)
|
||||
|
||||
return image, seed
|
||||
|
||||
|
||||
async def generate_and_send_photo(
|
||||
update: Update, context: ContextTypes.DEFAULT_TYPE
|
||||
) -> None:
|
||||
progress_msg = await update.message.reply_text(
|
||||
"Generating image...", reply_to_message_id=update.message.message_id
|
||||
)
|
||||
im, seed = generate_image(prompt=update.message.text)
|
||||
await context.bot.delete_message(
|
||||
chat_id=progress_msg.chat_id, message_id=progress_msg.message_id
|
||||
)
|
||||
await context.bot.send_photo(
|
||||
update.effective_user.id,
|
||||
image_to_bytes(im),
|
||||
caption=f'"{update.message.text}" (Seed: {seed})',
|
||||
reply_markup=get_try_again_markup(),
|
||||
reply_to_message_id=update.message.message_id,
|
||||
)
|
||||
|
||||
|
||||
async def button(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
query = update.callback_query
|
||||
if query.data in models_list:
|
||||
global SELECTED_MODEL
|
||||
SELECTED_MODEL = query.data
|
||||
await query.answer()
|
||||
await query.edit_message_text(text=f"Selected model: {query.data}")
|
||||
return
|
||||
if query.data in sheds_list:
|
||||
global SELECTED_SCHEDULER
|
||||
SELECTED_SCHEDULER = query.data
|
||||
await query.answer()
|
||||
await query.edit_message_text(text=f"Selected scheduler: {query.data}")
|
||||
return
|
||||
replied_message = query.message.reply_to_message
|
||||
await query.answer()
|
||||
progress_msg = await query.message.reply_text(
|
||||
"Generating image...", reply_to_message_id=replied_message.message_id
|
||||
)
|
||||
|
||||
if query.data == "TRYAGAIN":
|
||||
prompt = replied_message.text
|
||||
im, seed = generate_image(prompt)
|
||||
|
||||
await context.bot.delete_message(
|
||||
chat_id=progress_msg.chat_id, message_id=progress_msg.message_id
|
||||
)
|
||||
await context.bot.send_photo(
|
||||
update.effective_user.id,
|
||||
image_to_bytes(im),
|
||||
caption=f'"{prompt}" (Seed: {seed})',
|
||||
reply_markup=get_try_again_markup(),
|
||||
reply_to_message_id=replied_message.message_id,
|
||||
)
|
||||
|
||||
|
||||
async def select_model_handler(update, context):
|
||||
text = "Select model"
|
||||
keyboard = []
|
||||
for model in models_list:
|
||||
keyboard.append(
|
||||
[
|
||||
InlineKeyboardButton(text=model, callback_data=model),
|
||||
]
|
||||
)
|
||||
markup = InlineKeyboardMarkup(keyboard)
|
||||
await update.message.reply_text(text=text, reply_markup=markup)
|
||||
|
||||
|
||||
async def select_scheduler_handler(update, context):
|
||||
text = "Select schedule"
|
||||
keyboard = []
|
||||
for shed in sheds_list:
|
||||
keyboard.append(
|
||||
[
|
||||
InlineKeyboardButton(text=shed, callback_data=shed),
|
||||
]
|
||||
)
|
||||
markup = InlineKeyboardMarkup(keyboard)
|
||||
await update.message.reply_text(text=text, reply_markup=markup)
|
||||
|
||||
|
||||
async def set_steps_handler(update, context):
|
||||
input_mex = update.message.text
|
||||
log.warning(input_mex)
|
||||
try:
|
||||
input_args = input_mex.split("/set_steps ")[1]
|
||||
global STEPS
|
||||
STEPS = int(input_args)
|
||||
except Exception:
|
||||
input_args = (
|
||||
"Invalid parameter for command. Correct command looks like\n"
|
||||
" /set_steps 30"
|
||||
)
|
||||
await update.message.reply_text(input_args)
|
||||
|
||||
|
||||
async def set_negative_prompt_handler(update, context):
|
||||
input_mex = update.message.text
|
||||
log.warning(input_mex)
|
||||
try:
|
||||
input_args = input_mex.split("/set_negative_prompt ")[1]
|
||||
global NEGATIVE_PROMPT
|
||||
NEGATIVE_PROMPT = input_args
|
||||
except Exception:
|
||||
input_args = (
|
||||
"Invalid parameter for command. Correct command looks like\n"
|
||||
" /set_negative_prompt ugly, bad art, mutated"
|
||||
)
|
||||
await update.message.reply_text(input_args)
|
||||
|
||||
|
||||
async def set_guidance_scale_handler(update, context):
|
||||
input_mex = update.message.text
|
||||
log.warning(input_mex)
|
||||
try:
|
||||
input_args = input_mex.split("/set_guidance_scale ")[1]
|
||||
global GUIDANCE_SCALE
|
||||
GUIDANCE_SCALE = int(input_args)
|
||||
except Exception:
|
||||
input_args = (
|
||||
"Invalid parameter for command. Correct command looks like\n"
|
||||
" /set_guidance_scale 7"
|
||||
)
|
||||
await update.message.reply_text(input_args)
|
||||
|
||||
|
||||
async def setup_bot_commands(application: Application) -> None:
|
||||
await application.bot.set_my_commands(
|
||||
[
|
||||
BotCommand("select_model", "to select model"),
|
||||
BotCommand("select_scheduler", "to select scheduler"),
|
||||
BotCommand("set_steps", "to set steps"),
|
||||
BotCommand("set_guidance_scale", "to set guidance scale"),
|
||||
BotCommand("set_negative_prompt", "to set negative prompt"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
app = (
|
||||
ApplicationBuilder().token(TG_TOKEN).post_init(setup_bot_commands).build()
|
||||
)
|
||||
app.add_handler(CommandHandler("select_model", select_model_handler))
|
||||
app.add_handler(CommandHandler("select_scheduler", select_scheduler_handler))
|
||||
app.add_handler(CommandHandler("set_steps", set_steps_handler))
|
||||
app.add_handler(
|
||||
CommandHandler("set_guidance_scale", set_guidance_scale_handler)
|
||||
)
|
||||
app.add_handler(
|
||||
CommandHandler("set_negative_prompt", set_negative_prompt_handler)
|
||||
)
|
||||
app.add_handler(
|
||||
MessageHandler(filters.TEXT & ~filters.COMMAND, generate_and_send_photo)
|
||||
)
|
||||
app.add_handler(CallbackQueryHandler(button))
|
||||
log.warning("Start bot")
|
||||
app.run_polling()
|
||||
693
apps/stable_diffusion/scripts/train_lora_word.py
Normal file
693
apps/stable_diffusion/scripts/train_lora_word.py
Normal file
@@ -0,0 +1,693 @@
|
||||
# Install the required libs
|
||||
# pip install -U git+https://github.com/huggingface/diffusers.git
|
||||
# pip install accelerate transformers ftfy
|
||||
|
||||
# HuggingFace Token
|
||||
# YOUR_TOKEN = "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"
|
||||
|
||||
|
||||
# Import required libraries
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
from typing import List
|
||||
import random
|
||||
import torch_mlir
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import PIL
|
||||
import logging
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
PNDMScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.models.attention_processor import LoRAXFormersAttnProcessor
|
||||
|
||||
import torch_mlir
|
||||
from torch_mlir.dynamo import make_simple_dynamo_backend
|
||||
import torch._dynamo as dynamo
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
torch._dynamo.config.verbose = True
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
PNDMScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.stable_diffusion import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import (
|
||||
CLIPFeatureExtractor,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
)
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
from dataclasses import dataclass
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
clear_all,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import update_lora_weight
|
||||
|
||||
|
||||
# Setup the dataset
|
||||
class LoraDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
tokenizer,
|
||||
size=512,
|
||||
repeats=100,
|
||||
interpolation="bicubic",
|
||||
set="train",
|
||||
prompt="myloraprompt",
|
||||
center_crop=False,
|
||||
):
|
||||
self.data_root = data_root
|
||||
self.tokenizer = tokenizer
|
||||
self.size = size
|
||||
self.center_crop = center_crop
|
||||
self.prompt = prompt
|
||||
|
||||
self.image_paths = [
|
||||
os.path.join(self.data_root, file_path)
|
||||
for file_path in os.listdir(self.data_root)
|
||||
]
|
||||
|
||||
self.num_images = len(self.image_paths)
|
||||
self._length = self.num_images
|
||||
|
||||
if set == "train":
|
||||
self._length = self.num_images * repeats
|
||||
|
||||
self.interpolation = {
|
||||
"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
}[interpolation]
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, i):
|
||||
example = {}
|
||||
image = Image.open(self.image_paths[i % self.num_images])
|
||||
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
example["input_ids"] = self.tokenizer(
|
||||
self.prompt,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids[0]
|
||||
|
||||
# default to score-sde preprocessing
|
||||
img = np.array(image).astype(np.uint8)
|
||||
|
||||
if self.center_crop:
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
(
|
||||
h,
|
||||
w,
|
||||
) = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
img = img[
|
||||
(h - crop) // 2 : (h + crop) // 2,
|
||||
(w - crop) // 2 : (w + crop) // 2,
|
||||
]
|
||||
|
||||
image = Image.fromarray(img)
|
||||
image = image.resize(
|
||||
(self.size, self.size), resample=self.interpolation
|
||||
)
|
||||
|
||||
image = np.array(image).astype(np.uint8)
|
||||
image = (image / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
|
||||
return example
|
||||
|
||||
|
||||
def torch_device(device):
|
||||
device_tokens = device.split("=>")
|
||||
if len(device_tokens) == 1:
|
||||
device_str = device_tokens[0].strip()
|
||||
else:
|
||||
device_str = device_tokens[1].strip()
|
||||
device_type_tokens = device_str.split("://")
|
||||
if device_type_tokens[0] == "metal":
|
||||
device_type_tokens[0] = "vulkan"
|
||||
if len(device_type_tokens) > 1:
|
||||
return device_type_tokens[0] + ":" + device_type_tokens[1]
|
||||
else:
|
||||
return device_type_tokens[0]
|
||||
|
||||
|
||||
########## Setting up the model ##########
|
||||
def lora_train(
|
||||
prompt: str,
|
||||
height: int,
|
||||
width: int,
|
||||
steps: int,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
custom_model: str,
|
||||
hf_model_id: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
max_length: int,
|
||||
training_images_dir: str,
|
||||
lora_save_dir: str,
|
||||
use_lora: str,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
|
||||
print(
|
||||
"Note LoRA training is not compatible with the latest torch-mlir branch"
|
||||
)
|
||||
print(
|
||||
"To run LoRA training you'll need this to follow this guide for the torch-mlir branch: https://github.com/nod-ai/SHARK/tree/main/shark/examples/shark_training/stable_diffusion"
|
||||
)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.steps = steps
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
types = (
|
||||
".ckpt",
|
||||
".safetensors",
|
||||
) # the tuple of file types
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
if custom_model == "None":
|
||||
if not hf_model_id:
|
||||
return (
|
||||
None,
|
||||
"Please provide either custom model or huggingface model ID, both must not be "
|
||||
"empty.",
|
||||
)
|
||||
args.hf_model_id = hf_model_id
|
||||
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
|
||||
args.ckpt_loc = custom_model
|
||||
else:
|
||||
args.hf_model_id = custom_model
|
||||
|
||||
args.training_images_dir = training_images_dir
|
||||
args.lora_save_dir = lora_save_dir
|
||||
|
||||
args.precision = precision
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
args.width = width
|
||||
args.device = torch_device(device)
|
||||
args.use_lora = use_lora
|
||||
|
||||
# Load the Stable Diffusion model
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.hf_model_id, subfolder="text_encoder"
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(args.hf_model_id, subfolder="vae")
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.hf_model_id, subfolder="unet"
|
||||
)
|
||||
|
||||
def freeze_params(params):
|
||||
for param in params:
|
||||
param.requires_grad = False
|
||||
|
||||
# Freeze everything but LoRA
|
||||
freeze_params(vae.parameters())
|
||||
freeze_params(unet.parameters())
|
||||
freeze_params(text_encoder.parameters())
|
||||
|
||||
# Move vae and unet to device
|
||||
vae.to(args.device)
|
||||
unet.to(args.device)
|
||||
text_encoder.to(args.device)
|
||||
|
||||
if use_lora != "":
|
||||
update_lora_weight(unet, args.use_lora, "unet")
|
||||
else:
|
||||
lora_attn_procs = {}
|
||||
for name in unet.attn_processors.keys():
|
||||
cross_attention_dim = (
|
||||
None
|
||||
if name.endswith("attn1.processor")
|
||||
else unet.config.cross_attention_dim
|
||||
)
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = unet.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(unet.config.block_out_channels))[
|
||||
block_id
|
||||
]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
|
||||
lora_attn_procs[name] = LoRAXFormersAttnProcessor(
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
|
||||
unet.set_attn_processor(lora_attn_procs)
|
||||
lora_layers = AttnProcsLayers(unet.attn_processors)
|
||||
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae = vae
|
||||
|
||||
def forward(self, input):
|
||||
x = self.vae.encode(input, return_dict=False)[0]
|
||||
return x
|
||||
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.unet = unet
|
||||
|
||||
def forward(self, x, y, z):
|
||||
return self.unet.forward(x, y, z, return_dict=False)[0]
|
||||
|
||||
shark_vae = VaeModel()
|
||||
shark_unet = UnetModel()
|
||||
|
||||
####### Creating our training data ########
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
args.hf_model_id,
|
||||
subfolder="tokenizer",
|
||||
)
|
||||
|
||||
# Let's create the Dataset and Dataloader
|
||||
train_dataset = LoraDataset(
|
||||
data_root=args.training_images_dir,
|
||||
tokenizer=tokenizer,
|
||||
size=vae.sample_size,
|
||||
prompt=args.prompts[0],
|
||||
repeats=100,
|
||||
center_crop=False,
|
||||
set="train",
|
||||
)
|
||||
|
||||
def create_dataloader(train_batch_size=1):
|
||||
return torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True
|
||||
)
|
||||
|
||||
# Create noise_scheduler for training
|
||||
noise_scheduler = DDPMScheduler.from_config(
|
||||
args.hf_model_id, subfolder="scheduler"
|
||||
)
|
||||
|
||||
######## Training ###########
|
||||
|
||||
# Define hyperparameters for our training. If you are not happy with your results,
|
||||
# you can tune the `learning_rate` and the `max_train_steps`
|
||||
|
||||
# Setting up all training args
|
||||
hyperparameters = {
|
||||
"learning_rate": 5e-04,
|
||||
"scale_lr": True,
|
||||
"max_train_steps": steps,
|
||||
"train_batch_size": batch_size,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"gradient_checkpointing": True,
|
||||
"mixed_precision": "fp16",
|
||||
"seed": 42,
|
||||
"output_dir": "sd-concept-output",
|
||||
}
|
||||
# creating output directory
|
||||
cwd = os.getcwd()
|
||||
out_dir = os.path.join(cwd, hyperparameters["output_dir"])
|
||||
while not os.path.exists(str(out_dir)):
|
||||
try:
|
||||
os.mkdir(out_dir)
|
||||
except OSError as error:
|
||||
print("Output directory not created")
|
||||
|
||||
###### Torch-MLIR Compilation ######
|
||||
|
||||
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
|
||||
removed_indexes = []
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, (list, tuple)):
|
||||
node_arg = list(node_arg)
|
||||
node_args_len = len(node_arg)
|
||||
for i in range(node_args_len):
|
||||
curr_index = node_args_len - (i + 1)
|
||||
if node_arg[curr_index] is None:
|
||||
removed_indexes.append(curr_index)
|
||||
node_arg.pop(curr_index)
|
||||
node.args = (tuple(node_arg),)
|
||||
break
|
||||
|
||||
if len(removed_indexes) > 0:
|
||||
fx_g.graph.lint()
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
fx_g.recompile()
|
||||
removed_indexes.sort()
|
||||
return removed_indexes
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Replace tuple with tuple element in functions that return one-element tuples.
|
||||
Returns true if an unwrapping took place, and false otherwise.
|
||||
"""
|
||||
unwrapped_tuple = False
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
if len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
unwrapped_tuple = True
|
||||
break
|
||||
|
||||
if unwrapped_tuple:
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return unwrapped_tuple
|
||||
|
||||
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
return len(node_arg) == 0
|
||||
return False
|
||||
|
||||
def transform_fx(fx_g):
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
if node.target in [
|
||||
torch.ops.aten.empty,
|
||||
]:
|
||||
# aten.empty should be filled with zeros.
|
||||
if node.target in [torch.ops.aten.empty]:
|
||||
with fx_g.graph.inserting_after(node):
|
||||
new_node = fx_g.graph.call_function(
|
||||
torch.ops.aten.zero_,
|
||||
args=(node,),
|
||||
)
|
||||
node.append(new_node)
|
||||
node.replace_all_uses_with(new_node)
|
||||
new_node.args = (node,)
|
||||
|
||||
fx_g.graph.lint()
|
||||
|
||||
@make_simple_dynamo_backend
|
||||
def refbackend_torchdynamo_backend(
|
||||
fx_graph: torch.fx.GraphModule, example_inputs: List[torch.Tensor]
|
||||
):
|
||||
# handling usage of empty tensor without initializing
|
||||
transform_fx(fx_graph)
|
||||
fx_graph.recompile()
|
||||
if _returns_nothing(fx_graph):
|
||||
return fx_graph
|
||||
removed_none_indexes = _remove_nones(fx_graph)
|
||||
was_unwrapped = _unwrap_single_tuple_return(fx_graph)
|
||||
|
||||
mlir_module = torch_mlir.compile(
|
||||
fx_graph, example_inputs, output_type="linalg-on-tensors"
|
||||
)
|
||||
|
||||
bytecode_stream = BytesIO()
|
||||
mlir_module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=args.device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
def compiled_callable(*inputs):
|
||||
inputs = [x.numpy() for x in inputs]
|
||||
result = shark_module("forward", inputs)
|
||||
if was_unwrapped:
|
||||
result = [
|
||||
result,
|
||||
]
|
||||
if not isinstance(result, list):
|
||||
result = torch.from_numpy(result)
|
||||
else:
|
||||
result = tuple(torch.from_numpy(x) for x in result)
|
||||
result = list(result)
|
||||
for removed_index in removed_none_indexes:
|
||||
result.insert(removed_index, None)
|
||||
result = tuple(result)
|
||||
return result
|
||||
|
||||
return compiled_callable
|
||||
|
||||
def predictions(torch_func, jit_func, batchA, batchB):
|
||||
res = jit_func(batchA.numpy(), batchB.numpy())
|
||||
if res is not None:
|
||||
# prediction = torch.from_numpy(res)
|
||||
prediction = res
|
||||
else:
|
||||
prediction = None
|
||||
return prediction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
train_batch_size = hyperparameters["train_batch_size"]
|
||||
gradient_accumulation_steps = hyperparameters[
|
||||
"gradient_accumulation_steps"
|
||||
]
|
||||
learning_rate = hyperparameters["learning_rate"]
|
||||
if hyperparameters["scale_lr"]:
|
||||
learning_rate = (
|
||||
learning_rate
|
||||
* gradient_accumulation_steps
|
||||
* train_batch_size
|
||||
# * accelerator.num_processes
|
||||
)
|
||||
|
||||
# Initialize the optimizer
|
||||
optimizer = torch.optim.AdamW(
|
||||
lora_layers.parameters(), # only optimize the embeddings
|
||||
lr=learning_rate,
|
||||
)
|
||||
|
||||
# Training function
|
||||
def train_func(batch_pixel_values, batch_input_ids):
|
||||
# Convert images to latent space
|
||||
latents = shark_vae(batch_pixel_values).sample().detach()
|
||||
latents = latents * 0.18215
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
0,
|
||||
noise_scheduler.num_train_timesteps,
|
||||
(bsz,),
|
||||
device=latents.device,
|
||||
).long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch_input_ids)[0]
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = shark_unet(
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
encoder_hidden_states,
|
||||
)
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown prediction type {noise_scheduler.config.prediction_type}"
|
||||
)
|
||||
|
||||
loss = (
|
||||
F.mse_loss(noise_pred, target, reduction="none")
|
||||
.mean([1, 2, 3])
|
||||
.mean()
|
||||
)
|
||||
loss.backward()
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
return loss
|
||||
|
||||
def training_function():
|
||||
max_train_steps = hyperparameters["max_train_steps"]
|
||||
output_dir = hyperparameters["output_dir"]
|
||||
gradient_checkpointing = hyperparameters["gradient_checkpointing"]
|
||||
|
||||
train_dataloader = create_dataloader(train_batch_size)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(
|
||||
len(train_dataloader) / gradient_accumulation_steps
|
||||
)
|
||||
num_train_epochs = math.ceil(
|
||||
max_train_steps / num_update_steps_per_epoch
|
||||
)
|
||||
|
||||
# Train!
|
||||
total_batch_size = (
|
||||
train_batch_size
|
||||
* gradient_accumulation_steps
|
||||
# train_batch_size * accelerator.num_processes * gradient_accumulation_steps
|
||||
)
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(
|
||||
f" Instantaneous batch size per device = {train_batch_size}"
|
||||
)
|
||||
logger.info(
|
||||
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
||||
)
|
||||
logger.info(
|
||||
f" Gradient Accumulation steps = {gradient_accumulation_steps}"
|
||||
)
|
||||
logger.info(f" Total optimization steps = {max_train_steps}")
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(
|
||||
# range(max_train_steps), disable=not accelerator.is_local_main_process
|
||||
range(max_train_steps)
|
||||
)
|
||||
progress_bar.set_description("Steps")
|
||||
global_step = 0
|
||||
|
||||
params__ = [
|
||||
i for i in text_encoder.get_input_embeddings().parameters()
|
||||
]
|
||||
|
||||
for epoch in range(num_train_epochs):
|
||||
unet.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
dynamo_callable = dynamo.optimize(
|
||||
refbackend_torchdynamo_backend
|
||||
)(train_func)
|
||||
lam_func = lambda x, y: dynamo_callable(
|
||||
torch.from_numpy(x), torch.from_numpy(y)
|
||||
)
|
||||
loss = predictions(
|
||||
train_func,
|
||||
lam_func,
|
||||
batch["pixel_values"],
|
||||
batch["input_ids"],
|
||||
)
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
logs = {"loss": loss.detach().item()}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= max_train_steps:
|
||||
break
|
||||
|
||||
training_function()
|
||||
|
||||
# Save the lora weights
|
||||
unet.save_attn_procs(args.lora_save_dir)
|
||||
|
||||
for param in itertools.chain(unet.parameters(), text_encoder.parameters()):
|
||||
if param.grad is not None:
|
||||
del param.grad # free some memory
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
set_init_device_flags()
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
seed = args.seed
|
||||
if len(args.prompts) != 1:
|
||||
print("Need exactly one prompt for the LoRA word")
|
||||
lora_train(
|
||||
args.prompts[0],
|
||||
args.height,
|
||||
args.width,
|
||||
args.training_steps,
|
||||
args.guidance_scale,
|
||||
args.seed,
|
||||
args.batch_count,
|
||||
args.batch_size,
|
||||
args.scheduler,
|
||||
"None",
|
||||
args.hf_model_id,
|
||||
args.precision,
|
||||
args.device,
|
||||
args.max_length,
|
||||
args.training_images_dir,
|
||||
args.lora_save_dir,
|
||||
args.use_lora,
|
||||
)
|
||||
131
apps/stable_diffusion/scripts/tuner.py
Normal file
131
apps/stable_diffusion/scripts/tuner.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from shark_tuner.codegen_tuner import SharkCodegenTuner
|
||||
from shark_tuner.iree_utils import (
|
||||
dump_dispatches,
|
||||
create_context,
|
||||
export_module_to_mlir_file,
|
||||
)
|
||||
from shark_tuner.model_annotation import model_annotation
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
from apps.stable_diffusion.src.utils.utils import set_init_device_flags
|
||||
from apps.stable_diffusion.src.utils.sd_annotation import (
|
||||
get_device_args,
|
||||
load_winograd_configs,
|
||||
)
|
||||
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
|
||||
|
||||
|
||||
def load_mlir_module():
|
||||
if "upscaler" in args.hf_model_id:
|
||||
is_upscaler = True
|
||||
else:
|
||||
is_upscaler = False
|
||||
sd_model = SharkifyStableDiffusionModel(
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
max_len=args.max_length,
|
||||
batch_size=args.batch_size,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
is_upscaler=is_upscaler,
|
||||
use_tuned=False,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
return_mlir=True,
|
||||
)
|
||||
|
||||
if args.annotation_model == "unet":
|
||||
mlir_module = sd_model.unet()
|
||||
model_name = sd_model.model_name["unet"]
|
||||
elif args.annotation_model == "vae":
|
||||
mlir_module = sd_model.vae()
|
||||
model_name = sd_model.model_name["vae"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{args.annotation_model} is not supported for tuning."
|
||||
)
|
||||
|
||||
return mlir_module, model_name
|
||||
|
||||
|
||||
def main():
|
||||
args.use_tuned = False
|
||||
set_init_device_flags()
|
||||
mlir_module, model_name = load_mlir_module()
|
||||
|
||||
# Get device and device specific arguments
|
||||
device, device_spec_args = get_device_args()
|
||||
device_spec = ""
|
||||
vulkan_target_triple = ""
|
||||
if device_spec_args:
|
||||
device_spec = device_spec_args[-1].split("=")[-1].strip()
|
||||
if device == "vulkan":
|
||||
vulkan_target_triple = device_spec
|
||||
device_spec = device_spec.split("-")[0]
|
||||
|
||||
# Add winograd annotation for vulkan device
|
||||
use_winograd = (
|
||||
True
|
||||
if device == "vulkan" and args.annotation_model in ["unet", "vae"]
|
||||
else False
|
||||
)
|
||||
winograd_config = (
|
||||
load_winograd_configs()
|
||||
if device == "vulkan" and args.annotation_model in ["unet", "vae"]
|
||||
else ""
|
||||
)
|
||||
with create_context() as ctx:
|
||||
input_module = model_annotation(
|
||||
ctx,
|
||||
input_contents=mlir_module,
|
||||
config_path=winograd_config,
|
||||
search_op="conv",
|
||||
winograd=use_winograd,
|
||||
)
|
||||
|
||||
# Dump model dispatches
|
||||
generates_dir = Path.home() / "tmp"
|
||||
if not os.path.exists(generates_dir):
|
||||
os.makedirs(generates_dir)
|
||||
dump_mlir = generates_dir / "temp.mlir"
|
||||
dispatch_dir = generates_dir / f"{model_name}_{device_spec}_dispatches"
|
||||
export_module_to_mlir_file(input_module, dump_mlir)
|
||||
dump_dispatches(
|
||||
dump_mlir,
|
||||
device,
|
||||
dispatch_dir,
|
||||
vulkan_target_triple,
|
||||
use_winograd=use_winograd,
|
||||
)
|
||||
|
||||
# Tune each dispatch
|
||||
dtype = "f16" if args.precision == "fp16" else "f32"
|
||||
config_filename = f"{model_name}_{device_spec}_configs.json"
|
||||
|
||||
for f_path in os.listdir(dispatch_dir):
|
||||
if not f_path.endswith(".mlir"):
|
||||
continue
|
||||
|
||||
model_dir = os.path.join(dispatch_dir, f_path)
|
||||
|
||||
tuner = SharkCodegenTuner(
|
||||
model_dir,
|
||||
device,
|
||||
"random",
|
||||
args.num_iters,
|
||||
args.tuned_config_dir,
|
||||
dtype,
|
||||
args.search_op,
|
||||
batch_size=1,
|
||||
config_filename=config_filename,
|
||||
use_dispatch=True,
|
||||
vulkan_target_triple=vulkan_target_triple,
|
||||
)
|
||||
tuner.tune()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
88
apps/stable_diffusion/scripts/txt2img.py
Normal file
88
apps/stable_diffusion/scripts/txt2img.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import torch
|
||||
import transformers
|
||||
import time
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
Text2ImagePipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
set_init_device_flags()
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
seed = args.seed
|
||||
txt2img_obj = Text2ImagePipeline.from_pretrained(
|
||||
scheduler=scheduler_obj,
|
||||
import_mlir=args.import_mlir,
|
||||
model_id=args.hf_model_id,
|
||||
ckpt_loc=args.ckpt_loc,
|
||||
precision=args.precision,
|
||||
max_length=args.max_length,
|
||||
batch_size=args.batch_size,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
use_tuned=args.use_tuned,
|
||||
custom_vae=args.custom_vae,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
use_quantize=args.use_quantize,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
|
||||
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
|
||||
for current_batch in range(args.batch_count):
|
||||
start_time = time.time()
|
||||
generated_imgs = txt2img_obj.generate_images(
|
||||
args.prompts,
|
||||
args.negative_prompts,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.steps,
|
||||
args.guidance_scale,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += (
|
||||
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
)
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += (
|
||||
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
|
||||
)
|
||||
text_output += (
|
||||
f"seed={seeds[current_batch]}, size={args.height}x{args.width}"
|
||||
)
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
# TODO: if using --batch_count=x txt2img_obj.log will output on each display every iteration infos from the start
|
||||
text_output += txt2img_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
save_output_img(generated_imgs[0], seed)
|
||||
print(text_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
96
apps/stable_diffusion/scripts/txt2img_sdxl.py
Normal file
96
apps/stable_diffusion/scripts/txt2img_sdxl.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import torch
|
||||
import time
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
Text2ImageSDXLPipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
# TODO: prompt_embeds and text_embeds form base_model.json requires fixing
|
||||
args.precision = "fp16"
|
||||
args.height = 1024
|
||||
args.width = 1024
|
||||
args.max_length = 77
|
||||
args.scheduler = "DDIM"
|
||||
print(
|
||||
"Using default supported configuration for SDXL :-\nprecision=fp16, width*height= 1024*1024, max_length=77 and scheduler=DDIM"
|
||||
)
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
set_init_device_flags()
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
seed = args.seed
|
||||
txt2img_obj = Text2ImageSDXLPipeline.from_pretrained(
|
||||
scheduler=scheduler_obj,
|
||||
import_mlir=args.import_mlir,
|
||||
model_id=args.hf_model_id,
|
||||
ckpt_loc=args.ckpt_loc,
|
||||
precision=args.precision,
|
||||
max_length=args.max_length,
|
||||
batch_size=args.batch_size,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
use_tuned=args.use_tuned,
|
||||
custom_vae=args.custom_vae,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
use_quantize=args.use_quantize,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
|
||||
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
|
||||
for current_batch in range(args.batch_count):
|
||||
start_time = time.time()
|
||||
generated_imgs = txt2img_obj.generate_images(
|
||||
args.prompts,
|
||||
args.negative_prompts,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.steps,
|
||||
args.guidance_scale,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += (
|
||||
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
)
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += (
|
||||
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
|
||||
)
|
||||
text_output += (
|
||||
f"seed={seeds[current_batch]}, size={args.height}x{args.width}"
|
||||
)
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
# TODO: if using --batch_count=x txt2img_obj.log will output on each display every iteration infos from the start
|
||||
text_output += txt2img_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
save_output_img(generated_imgs[0], seed)
|
||||
print(text_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
92
apps/stable_diffusion/scripts/upscaler.py
Normal file
92
apps/stable_diffusion/scripts/upscaler.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import torch
|
||||
import time
|
||||
from PIL import Image
|
||||
import transformers
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
UpscalerPipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
if args.img_path is None:
|
||||
print("Flag --img_path is required.")
|
||||
exit()
|
||||
|
||||
# When the models get uploaded, it should be defaulted to False.
|
||||
args.import_mlir = True
|
||||
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
set_init_device_flags()
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
image = (
|
||||
Image.open(args.img_path)
|
||||
.convert("RGB")
|
||||
.resize((args.height, args.width))
|
||||
)
|
||||
seed = utils.sanitize_seed(args.seed)
|
||||
# Adjust for height and width based on model
|
||||
|
||||
upscaler_obj = UpscalerPipeline.from_pretrained(
|
||||
scheduler_obj,
|
||||
args.import_mlir,
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
args.precision,
|
||||
args.max_length,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
use_lora=args.use_lora,
|
||||
ddpm_scheduler=schedulers["DDPM"],
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
generated_imgs = upscaler_obj.generate_images(
|
||||
args.prompts,
|
||||
args.negative_prompts,
|
||||
image,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.steps,
|
||||
args.noise_level,
|
||||
args.guidance_scale,
|
||||
seed,
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += f"\nsteps={args.steps}, noise_level={args.noise_level}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
text_output += upscaler_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
extra_info = {"NOISE LEVEL": args.noise_level}
|
||||
save_output_img(generated_imgs[0], seed, extra_info)
|
||||
print(text_output)
|
||||
48
apps/stable_diffusion/shark_sd.spec
Normal file
48
apps/stable_diffusion/shark_sd.spec
Normal file
@@ -0,0 +1,48 @@
|
||||
# -*- mode: python ; coding: utf-8 -*-
|
||||
from apps.stable_diffusion.shark_studio_imports import pathex, datas, hiddenimports
|
||||
|
||||
binaries = []
|
||||
|
||||
block_cipher = None
|
||||
|
||||
a = Analysis(
|
||||
['web/index.py'],
|
||||
pathex=pathex,
|
||||
binaries=binaries,
|
||||
datas=datas,
|
||||
hiddenimports=hiddenimports,
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
excludes=[],
|
||||
win_no_prefer_redirects=False,
|
||||
win_private_assemblies=False,
|
||||
cipher=block_cipher,
|
||||
noarchive=False,
|
||||
module_collection_mode={
|
||||
'gradio': 'py', # Collect gradio package as source .py files
|
||||
},
|
||||
)
|
||||
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||
|
||||
exe = EXE(
|
||||
pyz,
|
||||
a.scripts,
|
||||
a.binaries,
|
||||
a.zipfiles,
|
||||
a.datas,
|
||||
[],
|
||||
name='nodai_shark_studio',
|
||||
debug=False,
|
||||
bootloader_ignore_signals=False,
|
||||
strip=False,
|
||||
upx=False,
|
||||
upx_exclude=[],
|
||||
runtime_tmpdir=None,
|
||||
console=True,
|
||||
disable_windowed_traceback=False,
|
||||
argv_emulation=False,
|
||||
target_arch=None,
|
||||
codesign_identity=None,
|
||||
entitlements_file=None,
|
||||
)
|
||||
85
apps/stable_diffusion/shark_sd_cli.spec
Normal file
85
apps/stable_diffusion/shark_sd_cli.spec
Normal file
@@ -0,0 +1,85 @@
|
||||
# -*- mode: python ; coding: utf-8 -*-
|
||||
from PyInstaller.utils.hooks import collect_data_files
|
||||
from PyInstaller.utils.hooks import collect_submodules
|
||||
from PyInstaller.utils.hooks import copy_metadata
|
||||
|
||||
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
|
||||
|
||||
datas = []
|
||||
datas += collect_data_files('torch')
|
||||
datas += copy_metadata('torch')
|
||||
datas += copy_metadata('tqdm')
|
||||
datas += copy_metadata('regex')
|
||||
datas += copy_metadata('requests')
|
||||
datas += copy_metadata('packaging')
|
||||
datas += copy_metadata('filelock')
|
||||
datas += copy_metadata('numpy')
|
||||
datas += copy_metadata('tokenizers')
|
||||
datas += copy_metadata('importlib_metadata')
|
||||
datas += copy_metadata('torch-mlir')
|
||||
datas += copy_metadata('omegaconf')
|
||||
datas += copy_metadata('safetensors')
|
||||
datas += collect_data_files('diffusers')
|
||||
datas += collect_data_files('transformers')
|
||||
datas += collect_data_files('opencv-python')
|
||||
datas += collect_data_files('pytorch_lightning')
|
||||
datas += collect_data_files('skimage')
|
||||
datas += collect_data_files('gradio')
|
||||
datas += collect_data_files('gradio_client')
|
||||
datas += collect_data_files('iree')
|
||||
datas += collect_data_files('google-cloud-storage')
|
||||
datas += collect_data_files('shark')
|
||||
datas += collect_data_files('py-cpuinfo')
|
||||
datas += [
|
||||
( 'src/utils/resources/prompts.json', 'resources' ),
|
||||
( 'src/utils/resources/model_db.json', 'resources' ),
|
||||
( 'src/utils/resources/opt_flags.json', 'resources' ),
|
||||
( 'src/utils/resources/base_model.json', 'resources' ),
|
||||
]
|
||||
|
||||
binaries = []
|
||||
|
||||
block_cipher = None
|
||||
|
||||
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
|
||||
|
||||
a = Analysis(
|
||||
['scripts/main.py'],
|
||||
pathex=['.'],
|
||||
binaries=binaries,
|
||||
datas=datas,
|
||||
hiddenimports=hiddenimports,
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
excludes=[],
|
||||
win_no_prefer_redirects=False,
|
||||
win_private_assemblies=False,
|
||||
cipher=block_cipher,
|
||||
noarchive=False,
|
||||
)
|
||||
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||
|
||||
exe = EXE(
|
||||
pyz,
|
||||
a.scripts,
|
||||
a.binaries,
|
||||
a.zipfiles,
|
||||
a.datas,
|
||||
[],
|
||||
name='shark_sd_cli',
|
||||
debug=False,
|
||||
bootloader_ignore_signals=False,
|
||||
strip=False,
|
||||
upx=True,
|
||||
upx_exclude=[],
|
||||
runtime_tmpdir=None,
|
||||
console=True,
|
||||
disable_windowed_traceback=False,
|
||||
argv_emulation=False,
|
||||
target_arch=None,
|
||||
codesign_identity=None,
|
||||
entitlements_file=None,
|
||||
)
|
||||
91
apps/stable_diffusion/shark_studio_imports.py
Normal file
91
apps/stable_diffusion/shark_studio_imports.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from PyInstaller.utils.hooks import collect_data_files
|
||||
from PyInstaller.utils.hooks import copy_metadata
|
||||
from PyInstaller.utils.hooks import collect_submodules
|
||||
|
||||
import sys
|
||||
|
||||
sys.setrecursionlimit(sys.getrecursionlimit() * 5)
|
||||
|
||||
# python path for pyinstaller
|
||||
pathex = [
|
||||
".",
|
||||
"./apps/language_models/langchain",
|
||||
"./apps/language_models/src/pipelines/minigpt4_utils",
|
||||
]
|
||||
|
||||
# datafiles for pyinstaller
|
||||
datas = []
|
||||
datas += copy_metadata("torch")
|
||||
datas += copy_metadata("tokenizers")
|
||||
datas += copy_metadata("tqdm")
|
||||
datas += copy_metadata("regex")
|
||||
datas += copy_metadata("requests")
|
||||
datas += copy_metadata("packaging")
|
||||
datas += copy_metadata("filelock")
|
||||
datas += copy_metadata("numpy")
|
||||
datas += copy_metadata("importlib_metadata")
|
||||
datas += copy_metadata("torch-mlir")
|
||||
datas += copy_metadata("omegaconf")
|
||||
datas += copy_metadata("safetensors")
|
||||
datas += copy_metadata("Pillow")
|
||||
datas += copy_metadata("sentencepiece")
|
||||
datas += copy_metadata("pyyaml")
|
||||
datas += copy_metadata("huggingface-hub")
|
||||
datas += copy_metadata("gradio")
|
||||
datas += collect_data_files("torch")
|
||||
datas += collect_data_files("tokenizers")
|
||||
datas += collect_data_files("tiktoken")
|
||||
datas += collect_data_files("accelerate")
|
||||
datas += collect_data_files("diffusers")
|
||||
datas += collect_data_files("transformers")
|
||||
datas += collect_data_files("pytorch_lightning")
|
||||
datas += collect_data_files("skimage")
|
||||
datas += collect_data_files("gradio")
|
||||
datas += collect_data_files("gradio_client")
|
||||
datas += collect_data_files("iree")
|
||||
datas += collect_data_files("shark", include_py_files=True)
|
||||
datas += collect_data_files("timm", include_py_files=True)
|
||||
datas += collect_data_files("tqdm")
|
||||
datas += collect_data_files("tkinter")
|
||||
datas += collect_data_files("webview")
|
||||
datas += collect_data_files("sentencepiece")
|
||||
datas += collect_data_files("jsonschema")
|
||||
datas += collect_data_files("jsonschema_specifications")
|
||||
datas += collect_data_files("cpuinfo")
|
||||
datas += collect_data_files("langchain")
|
||||
datas += collect_data_files("cv2")
|
||||
datas += collect_data_files("einops")
|
||||
datas += [
|
||||
("src/utils/resources/prompts.json", "resources"),
|
||||
("src/utils/resources/model_db.json", "resources"),
|
||||
("src/utils/resources/opt_flags.json", "resources"),
|
||||
("src/utils/resources/base_model.json", "resources"),
|
||||
("web/ui/css/*", "ui/css"),
|
||||
("web/ui/js/*", "ui/js"),
|
||||
("web/ui/logos/*", "logos"),
|
||||
(
|
||||
"../language_models/src/pipelines/minigpt4_utils/configs/*",
|
||||
"minigpt4_utils/configs",
|
||||
),
|
||||
(
|
||||
"../language_models/src/pipelines/minigpt4_utils/prompts/*",
|
||||
"minigpt4_utils/prompts",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# hidden imports for pyinstaller
|
||||
hiddenimports = ["shark", "shark.shark_inference", "apps"]
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
hiddenimports += [x for x in collect_submodules("gradio") if "tests" not in x]
|
||||
hiddenimports += [
|
||||
x for x in collect_submodules("diffusers") if "tests" not in x
|
||||
]
|
||||
blacklist = ["tests", "convert"]
|
||||
hiddenimports += [
|
||||
x
|
||||
for x in collect_submodules("transformers")
|
||||
if not any(kw in x for kw in blacklist)
|
||||
]
|
||||
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
|
||||
hiddenimports += ["iree._runtime"]
|
||||
19
apps/stable_diffusion/src/__init__.py
Normal file
19
apps/stable_diffusion/src/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
args,
|
||||
set_init_device_flags,
|
||||
prompt_examples,
|
||||
get_available_devices,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
resize_stencil,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines import (
|
||||
Text2ImagePipeline,
|
||||
Text2ImageSDXLPipeline,
|
||||
Image2ImagePipeline,
|
||||
InpaintPipeline,
|
||||
OutpaintPipeline,
|
||||
StencilPipeline,
|
||||
UpscalerPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import get_schedulers
|
||||
12
apps/stable_diffusion/src/models/__init__.py
Normal file
12
apps/stable_diffusion/src/models/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from apps.stable_diffusion.src.models.model_wrappers import (
|
||||
SharkifyStableDiffusionModel,
|
||||
)
|
||||
from apps.stable_diffusion.src.models.opt_params import (
|
||||
get_vae_encode,
|
||||
get_vae,
|
||||
get_unet,
|
||||
get_clip,
|
||||
get_tokenizer,
|
||||
get_params,
|
||||
get_variant_version,
|
||||
)
|
||||
1388
apps/stable_diffusion/src/models/model_wrappers.py
Normal file
1388
apps/stable_diffusion/src/models/model_wrappers.py
Normal file
File diff suppressed because it is too large
Load Diff
133
apps/stable_diffusion/src/models/opt_params.py
Normal file
133
apps/stable_diffusion/src/models/opt_params.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import sys
|
||||
from transformers import CLIPTokenizer
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
models_db,
|
||||
args,
|
||||
get_shark_model,
|
||||
get_opt_flags,
|
||||
)
|
||||
|
||||
|
||||
hf_model_variant_map = {
|
||||
"Linaqruf/anything-v3.0": ["anythingv3", "v1_4"],
|
||||
"dreamlike-art/dreamlike-diffusion-1.0": ["dreamlike", "v1_4"],
|
||||
"prompthero/openjourney": ["openjourney", "v1_4"],
|
||||
"wavymulder/Analog-Diffusion": ["analogdiffusion", "v1_4"],
|
||||
"stabilityai/stable-diffusion-2-1": ["stablediffusion", "v2_1base"],
|
||||
"stabilityai/stable-diffusion-2-1-base": ["stablediffusion", "v2_1base"],
|
||||
"CompVis/stable-diffusion-v1-4": ["stablediffusion", "v1_4"],
|
||||
"runwayml/stable-diffusion-inpainting": ["stablediffusion", "inpaint_v1"],
|
||||
"stabilityai/stable-diffusion-2-inpainting": [
|
||||
"stablediffusion",
|
||||
"inpaint_v2",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# TODO: Add the quantized model as a part model_db.json.
|
||||
# This is currently in experimental phase.
|
||||
def get_quantize_model():
|
||||
bucket_key = "gs://shark_tank/prashant_nod"
|
||||
model_key = "unet_int8"
|
||||
iree_flags = get_opt_flags("unet", precision="fp16")
|
||||
if args.height != 512 and args.width != 512 and args.max_length != 77:
|
||||
sys.exit(
|
||||
"The int8 quantized model currently requires the height and width to be 512, and max_length to be 77"
|
||||
)
|
||||
return bucket_key, model_key, iree_flags
|
||||
|
||||
|
||||
def get_variant_version(hf_model_id):
|
||||
return hf_model_variant_map[hf_model_id]
|
||||
|
||||
|
||||
def get_params(bucket_key, model_key, model, is_tuned, precision):
|
||||
try:
|
||||
bucket = models_db[0][bucket_key]
|
||||
model_name = models_db[1][model_key]
|
||||
except KeyError:
|
||||
raise Exception(
|
||||
f"{bucket_key}/{model_key} is not present in the models database"
|
||||
)
|
||||
iree_flags = get_opt_flags(model, precision="fp16")
|
||||
return bucket, model_name, iree_flags
|
||||
|
||||
|
||||
def get_unet():
|
||||
variant, version = get_variant_version(args.hf_model_id)
|
||||
# Tuned model is present only for `fp16` precision.
|
||||
is_tuned = "tuned" if args.use_tuned else "untuned"
|
||||
|
||||
# TODO: Get the quantize model from model_db.json
|
||||
if args.use_quantize == "int8":
|
||||
bk, mk, flags = get_quantize_model()
|
||||
return get_shark_model(bk, mk, flags)
|
||||
|
||||
if "vulkan" not in args.device and args.use_tuned:
|
||||
bucket_key = f"{variant}/{is_tuned}/{args.device}"
|
||||
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}/{args.device}"
|
||||
else:
|
||||
bucket_key = f"{variant}/{is_tuned}"
|
||||
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}"
|
||||
|
||||
bucket, model_name, iree_flags = get_params(
|
||||
bucket_key, model_key, "unet", is_tuned, args.precision
|
||||
)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
|
||||
def get_vae_encode():
|
||||
variant, version = get_variant_version(args.hf_model_id)
|
||||
# Tuned model is present only for `fp16` precision.
|
||||
is_tuned = "tuned" if args.use_tuned else "untuned"
|
||||
if "vulkan" not in args.device and args.use_tuned:
|
||||
bucket_key = f"{variant}/{is_tuned}/{args.device}"
|
||||
model_key = f"{variant}/{version}/vae_encode/{args.precision}/length_77/{is_tuned}/{args.device}"
|
||||
else:
|
||||
bucket_key = f"{variant}/{is_tuned}"
|
||||
model_key = f"{variant}/{version}/vae_encode/{args.precision}/length_77/{is_tuned}"
|
||||
|
||||
bucket, model_name, iree_flags = get_params(
|
||||
bucket_key, model_key, "vae", is_tuned, args.precision
|
||||
)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
|
||||
def get_vae():
|
||||
variant, version = get_variant_version(args.hf_model_id)
|
||||
# Tuned model is present only for `fp16` precision.
|
||||
is_tuned = "tuned" if args.use_tuned else "untuned"
|
||||
is_base = "/base" if args.use_base_vae else ""
|
||||
if "vulkan" not in args.device and args.use_tuned:
|
||||
bucket_key = f"{variant}/{is_tuned}/{args.device}"
|
||||
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/{is_tuned}{is_base}/{args.device}"
|
||||
else:
|
||||
bucket_key = f"{variant}/{is_tuned}"
|
||||
model_key = f"{variant}/{version}/vae/{args.precision}/length_77/{is_tuned}{is_base}"
|
||||
|
||||
bucket, model_name, iree_flags = get_params(
|
||||
bucket_key, model_key, "vae", is_tuned, args.precision
|
||||
)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
|
||||
def get_clip():
|
||||
variant, version = get_variant_version(args.hf_model_id)
|
||||
bucket_key = f"{variant}/untuned"
|
||||
model_key = (
|
||||
f"{variant}/{version}/clip/fp32/length_{args.max_length}/untuned"
|
||||
)
|
||||
bucket, model_name, iree_flags = get_params(
|
||||
bucket_key, model_key, "clip", "untuned", "fp32"
|
||||
)
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
|
||||
def get_tokenizer(subfolder="tokenizer", hf_model_id=None):
|
||||
if hf_model_id is not None:
|
||||
args.hf_model_id = hf_model_id
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
args.hf_model_id, subfolder=subfolder
|
||||
)
|
||||
return tokenizer
|
||||
21
apps/stable_diffusion/src/pipelines/__init__.py
Normal file
21
apps/stable_diffusion/src/pipelines/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img import (
|
||||
Text2ImagePipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img_sdxl import (
|
||||
Text2ImageSDXLPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_img2img import (
|
||||
Image2ImagePipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_inpaint import (
|
||||
InpaintPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_outpaint import (
|
||||
OutpaintPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_stencil import (
|
||||
StencilPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_upscaler import (
|
||||
UpscalerPipeline,
|
||||
)
|
||||
@@ -0,0 +1,234 @@
|
||||
import torch
|
||||
import time
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
from random import randint
|
||||
from PIL import Image
|
||||
from transformers import CLIPTokenizer
|
||||
from typing import Union
|
||||
from shark.shark_inference import SharkInference
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
DDPMScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.models import (
|
||||
SharkifyStableDiffusionModel,
|
||||
get_vae_encode,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
resamplers,
|
||||
resampler_list,
|
||||
)
|
||||
|
||||
|
||||
class Image2ImagePipeline(StableDiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
DDPMScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
],
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
):
|
||||
super().__init__(
|
||||
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
|
||||
)
|
||||
self.vae_encode = None
|
||||
|
||||
def load_vae_encode(self):
|
||||
if self.vae_encode is not None:
|
||||
return
|
||||
|
||||
if self.import_mlir or self.use_lora:
|
||||
self.vae_encode = self.sd_model.vae_encode()
|
||||
else:
|
||||
try:
|
||||
self.vae_encode = get_vae_encode()
|
||||
except:
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
self.vae_encode = self.sd_model.vae_encode()
|
||||
|
||||
def unload_vae_encode(self):
|
||||
del self.vae_encode
|
||||
self.vae_encode = None
|
||||
|
||||
def prepare_image_latents(
|
||||
self,
|
||||
image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
strength,
|
||||
dtype,
|
||||
resample_type,
|
||||
):
|
||||
# Pre process image -> get image encoded -> process latents
|
||||
|
||||
# TODO: process with variable HxW combos
|
||||
|
||||
# Pre-process image
|
||||
resample_type = (
|
||||
resamplers[resample_type]
|
||||
if resample_type in resampler_list
|
||||
# Fallback to Lanczos
|
||||
else Image.Resampling.LANCZOS
|
||||
)
|
||||
|
||||
image = image.resize((width, height), resample=resample_type)
|
||||
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
|
||||
image_arr = image_arr / 255.0
|
||||
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(dtype)
|
||||
image_arr = 2 * (image_arr - 0.5)
|
||||
|
||||
# set scheduler steps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
init_timestep = min(
|
||||
int(num_inference_steps * strength), num_inference_steps
|
||||
)
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
# timesteps reduced as per strength
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
# new number of steps to be used as per strength will be
|
||||
# num_inference_steps = num_inference_steps - t_start
|
||||
|
||||
# image encode
|
||||
latents = self.encode_image((image_arr,))
|
||||
latents = torch.from_numpy(latents).to(dtype)
|
||||
# add noise to data
|
||||
noise = torch.randn(latents.shape, generator=generator, dtype=dtype)
|
||||
latents = self.scheduler.add_noise(
|
||||
latents, noise, timesteps[0].repeat(1)
|
||||
)
|
||||
|
||||
return latents, timesteps
|
||||
|
||||
def encode_image(self, input_image):
|
||||
self.load_vae_encode()
|
||||
vae_encode_start = time.time()
|
||||
latents = self.vae_encode("forward", input_image)
|
||||
vae_inf_time = (time.time() - vae_encode_start) * 1000
|
||||
if self.ondemand:
|
||||
self.unload_vae_encode()
|
||||
self.log += f"\nVAE Encode Inference time (ms): {vae_inf_time:.3f}"
|
||||
|
||||
return latents
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompts,
|
||||
neg_prompts,
|
||||
image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
num_inference_steps,
|
||||
strength,
|
||||
guidance_scale,
|
||||
seed,
|
||||
max_length,
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
stencils,
|
||||
images,
|
||||
resample_type,
|
||||
control_mode,
|
||||
preprocessed_hints=[],
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(neg_prompts, str):
|
||||
neg_prompts = [neg_prompts]
|
||||
|
||||
prompts = prompts * batch_size
|
||||
neg_prompts = neg_prompts * batch_size
|
||||
|
||||
# seed generator to create the inital latent noise. Also handle out of range seeds.
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
# Get text embeddings with weight emphasis from prompts
|
||||
text_embeddings = self.encode_prompts_weight(
|
||||
prompts,
|
||||
neg_prompts,
|
||||
max_length,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
|
||||
|
||||
# Prepare input image latent
|
||||
image_latents, final_timesteps = self.prepare_image_latents(
|
||||
image=image,
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
strength=strength,
|
||||
dtype=dtype,
|
||||
resample_type=resample_type,
|
||||
)
|
||||
|
||||
# Get Image latents
|
||||
latents = self.produce_img_latents(
|
||||
latents=image_latents,
|
||||
text_embeddings=text_embeddings,
|
||||
guidance_scale=guidance_scale,
|
||||
total_timesteps=final_timesteps,
|
||||
dtype=dtype,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
)
|
||||
|
||||
# Img latents -> PIL images
|
||||
all_imgs = []
|
||||
self.load_vae()
|
||||
for i in tqdm(range(0, latents.shape[0], batch_size)):
|
||||
imgs = self.decode_latents(
|
||||
latents=latents[i : i + batch_size],
|
||||
use_base_vae=use_base_vae,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
if self.ondemand:
|
||||
self.unload_vae()
|
||||
|
||||
return all_imgs
|
||||
@@ -0,0 +1,490 @@
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
import numpy as np
|
||||
from random import randint
|
||||
from PIL import Image, ImageOps
|
||||
from transformers import CLIPTokenizer
|
||||
from typing import Union
|
||||
from shark.shark_inference import SharkInference
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
DDPMScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.models import (
|
||||
SharkifyStableDiffusionModel,
|
||||
get_vae_encode,
|
||||
)
|
||||
|
||||
|
||||
class InpaintPipeline(StableDiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
DDPMScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
],
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
):
|
||||
super().__init__(
|
||||
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
|
||||
)
|
||||
self.vae_encode = None
|
||||
|
||||
def load_vae_encode(self):
|
||||
if self.vae_encode is not None:
|
||||
return
|
||||
|
||||
if self.import_mlir or self.use_lora:
|
||||
self.vae_encode = self.sd_model.vae_encode()
|
||||
else:
|
||||
try:
|
||||
self.vae_encode = get_vae_encode()
|
||||
except:
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
self.vae_encode = self.sd_model.vae_encode()
|
||||
|
||||
def unload_vae_encode(self):
|
||||
del self.vae_encode
|
||||
self.vae_encode = None
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
dtype,
|
||||
):
|
||||
latents = torch.randn(
|
||||
(
|
||||
batch_size,
|
||||
4,
|
||||
height // 8,
|
||||
width // 8,
|
||||
),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def get_crop_region(self, mask, pad=0):
|
||||
h, w = mask.shape
|
||||
|
||||
crop_left = 0
|
||||
for i in range(w):
|
||||
if not (mask[:, i] == 0).all():
|
||||
break
|
||||
crop_left += 1
|
||||
|
||||
crop_right = 0
|
||||
for i in reversed(range(w)):
|
||||
if not (mask[:, i] == 0).all():
|
||||
break
|
||||
crop_right += 1
|
||||
|
||||
crop_top = 0
|
||||
for i in range(h):
|
||||
if not (mask[i] == 0).all():
|
||||
break
|
||||
crop_top += 1
|
||||
|
||||
crop_bottom = 0
|
||||
for i in reversed(range(h)):
|
||||
if not (mask[i] == 0).all():
|
||||
break
|
||||
crop_bottom += 1
|
||||
|
||||
return (
|
||||
int(max(crop_left - pad, 0)),
|
||||
int(max(crop_top - pad, 0)),
|
||||
int(min(w - crop_right + pad, w)),
|
||||
int(min(h - crop_bottom + pad, h)),
|
||||
)
|
||||
|
||||
def expand_crop_region(
|
||||
self,
|
||||
crop_region,
|
||||
processing_width,
|
||||
processing_height,
|
||||
image_width,
|
||||
image_height,
|
||||
):
|
||||
x1, y1, x2, y2 = crop_region
|
||||
|
||||
ratio_crop_region = (x2 - x1) / (y2 - y1)
|
||||
ratio_processing = processing_width / processing_height
|
||||
|
||||
if ratio_crop_region > ratio_processing:
|
||||
desired_height = (x2 - x1) / ratio_processing
|
||||
desired_height_diff = int(desired_height - (y2 - y1))
|
||||
y1 -= desired_height_diff // 2
|
||||
y2 += desired_height_diff - desired_height_diff // 2
|
||||
if y2 >= image_height:
|
||||
diff = y2 - image_height
|
||||
y2 -= diff
|
||||
y1 -= diff
|
||||
if y1 < 0:
|
||||
y2 -= y1
|
||||
y1 -= y1
|
||||
if y2 >= image_height:
|
||||
y2 = image_height
|
||||
else:
|
||||
desired_width = (y2 - y1) * ratio_processing
|
||||
desired_width_diff = int(desired_width - (x2 - x1))
|
||||
x1 -= desired_width_diff // 2
|
||||
x2 += desired_width_diff - desired_width_diff // 2
|
||||
if x2 >= image_width:
|
||||
diff = x2 - image_width
|
||||
x2 -= diff
|
||||
x1 -= diff
|
||||
if x1 < 0:
|
||||
x2 -= x1
|
||||
x1 -= x1
|
||||
if x2 >= image_width:
|
||||
x2 = image_width
|
||||
|
||||
return x1, y1, x2, y2
|
||||
|
||||
def resize_image(self, resize_mode, im, width, height):
|
||||
"""
|
||||
resize_mode:
|
||||
0: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
|
||||
1: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
|
||||
"""
|
||||
|
||||
if resize_mode == 0:
|
||||
ratio = width / height
|
||||
src_ratio = im.width / im.height
|
||||
|
||||
src_w = (
|
||||
width if ratio > src_ratio else im.width * height // im.height
|
||||
)
|
||||
src_h = (
|
||||
height if ratio <= src_ratio else im.height * width // im.width
|
||||
)
|
||||
|
||||
resized = im.resize((src_w, src_h), resample=Image.LANCZOS)
|
||||
res = Image.new("RGB", (width, height))
|
||||
res.paste(
|
||||
resized,
|
||||
box=(width // 2 - src_w // 2, height // 2 - src_h // 2),
|
||||
)
|
||||
|
||||
else:
|
||||
ratio = width / height
|
||||
src_ratio = im.width / im.height
|
||||
|
||||
src_w = (
|
||||
width if ratio < src_ratio else im.width * height // im.height
|
||||
)
|
||||
src_h = (
|
||||
height if ratio >= src_ratio else im.height * width // im.width
|
||||
)
|
||||
|
||||
resized = im.resize((src_w, src_h), resample=Image.LANCZOS)
|
||||
res = Image.new("RGB", (width, height))
|
||||
res.paste(
|
||||
resized,
|
||||
box=(width // 2 - src_w // 2, height // 2 - src_h // 2),
|
||||
)
|
||||
|
||||
if ratio < src_ratio:
|
||||
fill_height = height // 2 - src_h // 2
|
||||
res.paste(
|
||||
resized.resize((width, fill_height), box=(0, 0, width, 0)),
|
||||
box=(0, 0),
|
||||
)
|
||||
res.paste(
|
||||
resized.resize(
|
||||
(width, fill_height),
|
||||
box=(0, resized.height, width, resized.height),
|
||||
),
|
||||
box=(0, fill_height + src_h),
|
||||
)
|
||||
elif ratio > src_ratio:
|
||||
fill_width = width // 2 - src_w // 2
|
||||
res.paste(
|
||||
resized.resize(
|
||||
(fill_width, height), box=(0, 0, 0, height)
|
||||
),
|
||||
box=(0, 0),
|
||||
)
|
||||
res.paste(
|
||||
resized.resize(
|
||||
(fill_width, height),
|
||||
box=(resized.width, 0, resized.width, height),
|
||||
),
|
||||
box=(fill_width + src_w, 0),
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
def prepare_mask_and_masked_image(
|
||||
self,
|
||||
image,
|
||||
mask,
|
||||
height,
|
||||
width,
|
||||
inpaint_full_res,
|
||||
inpaint_full_res_padding,
|
||||
):
|
||||
# preprocess image
|
||||
image = image.resize((width, height))
|
||||
mask = mask.resize((width, height))
|
||||
|
||||
paste_to = ()
|
||||
overlay_image = None
|
||||
if inpaint_full_res:
|
||||
# prepare overlay image
|
||||
overlay_image = Image.new("RGB", (image.width, image.height))
|
||||
overlay_image.paste(
|
||||
image.convert("RGB"),
|
||||
mask=ImageOps.invert(mask.convert("L")),
|
||||
)
|
||||
|
||||
# prepare mask
|
||||
mask = mask.convert("L")
|
||||
crop_region = self.get_crop_region(
|
||||
np.array(mask), inpaint_full_res_padding
|
||||
)
|
||||
crop_region = self.expand_crop_region(
|
||||
crop_region, width, height, mask.width, mask.height
|
||||
)
|
||||
x1, y1, x2, y2 = crop_region
|
||||
mask = mask.crop(crop_region)
|
||||
mask = self.resize_image(1, mask, width, height)
|
||||
paste_to = (x1, y1, x2 - x1, y2 - y1)
|
||||
|
||||
# prepare image
|
||||
image = image.crop(crop_region)
|
||||
image = self.resize_image(1, image, width, height)
|
||||
|
||||
if isinstance(image, (Image.Image, np.ndarray)):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image, list) and isinstance(image[0], Image.Image):
|
||||
image = [np.array(i.convert("RGB"))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
|
||||
image = np.concatenate([i[None, :] for i in image], axis=0)
|
||||
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
|
||||
# preprocess mask
|
||||
if isinstance(mask, (Image.Image, np.ndarray)):
|
||||
mask = [mask]
|
||||
|
||||
if isinstance(mask, list) and isinstance(mask[0], Image.Image):
|
||||
mask = np.concatenate(
|
||||
[np.array(m.convert("L"))[None, None, :] for m in mask], axis=0
|
||||
)
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
|
||||
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
|
||||
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
mask = torch.from_numpy(mask)
|
||||
|
||||
masked_image = image * (mask < 0.5)
|
||||
|
||||
return mask, masked_image, paste_to, overlay_image
|
||||
|
||||
def prepare_mask_latents(
|
||||
self,
|
||||
mask,
|
||||
masked_image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
):
|
||||
mask = torch.nn.functional.interpolate(
|
||||
mask, size=(height // 8, width // 8)
|
||||
)
|
||||
mask = mask.to(dtype)
|
||||
|
||||
self.load_vae_encode()
|
||||
masked_image = masked_image.to(dtype)
|
||||
masked_image_latents = self.vae_encode("forward", (masked_image,))
|
||||
masked_image_latents = torch.from_numpy(masked_image_latents)
|
||||
if self.ondemand:
|
||||
self.unload_vae_encode()
|
||||
|
||||
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
||||
if mask.shape[0] < batch_size:
|
||||
if not batch_size % mask.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
|
||||
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
|
||||
" of masks that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
|
||||
if masked_image_latents.shape[0] < batch_size:
|
||||
if not batch_size % masked_image_latents.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
||||
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
|
||||
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
masked_image_latents = masked_image_latents.repeat(
|
||||
batch_size // masked_image_latents.shape[0], 1, 1, 1
|
||||
)
|
||||
return mask, masked_image_latents
|
||||
|
||||
def apply_overlay(self, image, paste_loc, overlay):
|
||||
x, y, w, h = paste_loc
|
||||
image = self.resize_image(0, image, w, h)
|
||||
overlay.paste(image, (x, y))
|
||||
|
||||
return overlay
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompts,
|
||||
neg_prompts,
|
||||
image,
|
||||
mask_image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
inpaint_full_res,
|
||||
inpaint_full_res_padding,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
max_length,
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(neg_prompts, str):
|
||||
neg_prompts = [neg_prompts]
|
||||
|
||||
prompts = prompts * batch_size
|
||||
neg_prompts = neg_prompts * batch_size
|
||||
|
||||
# seed generator to create the inital latent noise. Also handle out of range seeds.
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
# Get initial latents
|
||||
init_latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get text embeddings with weight emphasis from prompts
|
||||
text_embeddings = self.encode_prompts_weight(
|
||||
prompts,
|
||||
neg_prompts,
|
||||
max_length,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
|
||||
|
||||
# Preprocess mask and image
|
||||
(
|
||||
mask,
|
||||
masked_image,
|
||||
paste_to,
|
||||
overlay_image,
|
||||
) = self.prepare_mask_and_masked_image(
|
||||
image,
|
||||
mask_image,
|
||||
height,
|
||||
width,
|
||||
inpaint_full_res,
|
||||
inpaint_full_res_padding,
|
||||
)
|
||||
|
||||
# Prepare mask latent variables
|
||||
mask, masked_image_latents = self.prepare_mask_latents(
|
||||
mask=mask,
|
||||
masked_image=masked_image,
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get Image latents
|
||||
latents = self.produce_img_latents(
|
||||
latents=init_latents,
|
||||
text_embeddings=text_embeddings,
|
||||
guidance_scale=guidance_scale,
|
||||
total_timesteps=self.scheduler.timesteps,
|
||||
dtype=dtype,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
mask=mask,
|
||||
masked_image_latents=masked_image_latents,
|
||||
)
|
||||
|
||||
# Img latents -> PIL images
|
||||
all_imgs = []
|
||||
self.load_vae()
|
||||
for i in tqdm(range(0, latents.shape[0], batch_size)):
|
||||
imgs = self.decode_latents(
|
||||
latents=latents[i : i + batch_size],
|
||||
use_base_vae=use_base_vae,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
if self.ondemand:
|
||||
self.unload_vae()
|
||||
|
||||
if inpaint_full_res:
|
||||
output_image = self.apply_overlay(
|
||||
all_imgs[0], paste_to, overlay_image
|
||||
)
|
||||
return [output_image]
|
||||
|
||||
return all_imgs
|
||||
@@ -0,0 +1,588 @@
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
import numpy as np
|
||||
from random import randint
|
||||
from PIL import Image, ImageDraw, ImageFilter
|
||||
from transformers import CLIPTokenizer
|
||||
from typing import Union
|
||||
from shark.shark_inference import SharkInference
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
DDPMScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
import math
|
||||
from apps.stable_diffusion.src.models import (
|
||||
SharkifyStableDiffusionModel,
|
||||
get_vae_encode,
|
||||
)
|
||||
|
||||
|
||||
class OutpaintPipeline(StableDiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
DDPMScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
],
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
):
|
||||
super().__init__(
|
||||
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
|
||||
)
|
||||
self.vae_encode = None
|
||||
|
||||
def load_vae_encode(self):
|
||||
if self.vae_encode is not None:
|
||||
return
|
||||
|
||||
if self.import_mlir or self.use_lora:
|
||||
self.vae_encode = self.sd_model.vae_encode()
|
||||
else:
|
||||
try:
|
||||
self.vae_encode = get_vae_encode()
|
||||
except:
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
self.vae_encode = self.sd_model.vae_encode()
|
||||
|
||||
def unload_vae_encode(self):
|
||||
del self.vae_encode
|
||||
self.vae_encode = None
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
dtype,
|
||||
):
|
||||
latents = torch.randn(
|
||||
(
|
||||
batch_size,
|
||||
4,
|
||||
height // 8,
|
||||
width // 8,
|
||||
),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def prepare_mask_and_masked_image(
|
||||
self, image, mask, mask_blur, width, height
|
||||
):
|
||||
if mask_blur > 0:
|
||||
mask = mask.filter(ImageFilter.GaussianBlur(mask_blur))
|
||||
image = image.resize((width, height))
|
||||
mask = mask.resize((width, height))
|
||||
|
||||
# preprocess image
|
||||
if isinstance(image, (Image.Image, np.ndarray)):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image, list) and isinstance(image[0], Image.Image):
|
||||
image = [np.array(i.convert("RGB"))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
|
||||
image = np.concatenate([i[None, :] for i in image], axis=0)
|
||||
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
|
||||
# preprocess mask
|
||||
if isinstance(mask, (Image.Image, np.ndarray)):
|
||||
mask = [mask]
|
||||
|
||||
if isinstance(mask, list) and isinstance(mask[0], Image.Image):
|
||||
mask = np.concatenate(
|
||||
[np.array(m.convert("L"))[None, None, :] for m in mask], axis=0
|
||||
)
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
|
||||
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
|
||||
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
mask = torch.from_numpy(mask)
|
||||
|
||||
masked_image = image * (mask < 0.5)
|
||||
|
||||
return mask, masked_image
|
||||
|
||||
def prepare_mask_latents(
|
||||
self,
|
||||
mask,
|
||||
masked_image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
):
|
||||
mask = torch.nn.functional.interpolate(
|
||||
mask, size=(height // 8, width // 8)
|
||||
)
|
||||
mask = mask.to(dtype)
|
||||
|
||||
self.load_vae_encode()
|
||||
masked_image = masked_image.to(dtype)
|
||||
masked_image_latents = self.vae_encode("forward", (masked_image,))
|
||||
masked_image_latents = torch.from_numpy(masked_image_latents)
|
||||
if self.ondemand:
|
||||
self.unload_vae_encode()
|
||||
|
||||
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
||||
if mask.shape[0] < batch_size:
|
||||
if not batch_size % mask.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
|
||||
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
|
||||
" of masks that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
|
||||
if masked_image_latents.shape[0] < batch_size:
|
||||
if not batch_size % masked_image_latents.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
||||
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
|
||||
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
masked_image_latents = masked_image_latents.repeat(
|
||||
batch_size // masked_image_latents.shape[0], 1, 1, 1
|
||||
)
|
||||
return mask, masked_image_latents
|
||||
|
||||
def get_matched_noise(
|
||||
self, _np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05
|
||||
):
|
||||
# helper fft routines that keep ortho normalization and auto-shift before and after fft
|
||||
def _fft2(data):
|
||||
if data.ndim > 2: # has channels
|
||||
out_fft = np.zeros(
|
||||
(data.shape[0], data.shape[1], data.shape[2]),
|
||||
dtype=np.complex128,
|
||||
)
|
||||
for c in range(data.shape[2]):
|
||||
c_data = data[:, :, c]
|
||||
out_fft[:, :, c] = np.fft.fft2(
|
||||
np.fft.fftshift(c_data), norm="ortho"
|
||||
)
|
||||
out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c])
|
||||
else: # one channel
|
||||
out_fft = np.zeros(
|
||||
(data.shape[0], data.shape[1]), dtype=np.complex128
|
||||
)
|
||||
out_fft[:, :] = np.fft.fft2(
|
||||
np.fft.fftshift(data), norm="ortho"
|
||||
)
|
||||
out_fft[:, :] = np.fft.ifftshift(out_fft[:, :])
|
||||
|
||||
return out_fft
|
||||
|
||||
def _ifft2(data):
|
||||
if data.ndim > 2: # has channels
|
||||
out_ifft = np.zeros(
|
||||
(data.shape[0], data.shape[1], data.shape[2]),
|
||||
dtype=np.complex128,
|
||||
)
|
||||
for c in range(data.shape[2]):
|
||||
c_data = data[:, :, c]
|
||||
out_ifft[:, :, c] = np.fft.ifft2(
|
||||
np.fft.fftshift(c_data), norm="ortho"
|
||||
)
|
||||
out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c])
|
||||
else: # one channel
|
||||
out_ifft = np.zeros(
|
||||
(data.shape[0], data.shape[1]), dtype=np.complex128
|
||||
)
|
||||
out_ifft[:, :] = np.fft.ifft2(
|
||||
np.fft.fftshift(data), norm="ortho"
|
||||
)
|
||||
out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :])
|
||||
|
||||
return out_ifft
|
||||
|
||||
def _get_gaussian_window(width, height, std=3.14, mode=0):
|
||||
window_scale_x = float(width / min(width, height))
|
||||
window_scale_y = float(height / min(width, height))
|
||||
|
||||
window = np.zeros((width, height))
|
||||
x = (np.arange(width) / width * 2.0 - 1.0) * window_scale_x
|
||||
for y in range(height):
|
||||
fy = (y / height * 2.0 - 1.0) * window_scale_y
|
||||
if mode == 0:
|
||||
window[:, y] = np.exp(-(x**2 + fy**2) * std)
|
||||
else:
|
||||
window[:, y] = (
|
||||
1 / ((x**2 + 1.0) * (fy**2 + 1.0))
|
||||
) ** (std / 3.14)
|
||||
|
||||
return window
|
||||
|
||||
def _get_masked_window_rgb(np_mask_grey, hardness=1.0):
|
||||
np_mask_rgb = np.zeros(
|
||||
(np_mask_grey.shape[0], np_mask_grey.shape[1], 3)
|
||||
)
|
||||
if hardness != 1.0:
|
||||
hardened = np_mask_grey[:] ** hardness
|
||||
else:
|
||||
hardened = np_mask_grey[:]
|
||||
for c in range(3):
|
||||
np_mask_rgb[:, :, c] = hardened[:]
|
||||
return np_mask_rgb
|
||||
|
||||
def _match_cumulative_cdf(source, template):
|
||||
src_values, src_unique_indices, src_counts = np.unique(
|
||||
source.ravel(), return_inverse=True, return_counts=True
|
||||
)
|
||||
tmpl_values, tmpl_counts = np.unique(
|
||||
template.ravel(), return_counts=True
|
||||
)
|
||||
|
||||
# calculate normalized quantiles for each array
|
||||
src_quantiles = np.cumsum(src_counts) / source.size
|
||||
tmpl_quantiles = np.cumsum(tmpl_counts) / template.size
|
||||
|
||||
interp_a_values = np.interp(
|
||||
src_quantiles, tmpl_quantiles, tmpl_values
|
||||
)
|
||||
return interp_a_values[src_unique_indices].reshape(source.shape)
|
||||
|
||||
def _match_histograms(image, reference):
|
||||
if image.ndim != reference.ndim:
|
||||
raise ValueError(
|
||||
"Image and reference must have the same number of channels."
|
||||
)
|
||||
|
||||
if image.shape[-1] != reference.shape[-1]:
|
||||
raise ValueError(
|
||||
"Number of channels in the input image and reference image must match!"
|
||||
)
|
||||
|
||||
matched = np.empty(image.shape, dtype=image.dtype)
|
||||
for channel in range(image.shape[-1]):
|
||||
matched_channel = _match_cumulative_cdf(
|
||||
image[..., channel], reference[..., channel]
|
||||
)
|
||||
matched[..., channel] = matched_channel
|
||||
|
||||
matched = matched.astype(np.float64, copy=False)
|
||||
return matched
|
||||
|
||||
width = _np_src_image.shape[0]
|
||||
height = _np_src_image.shape[1]
|
||||
num_channels = _np_src_image.shape[2]
|
||||
|
||||
np_src_image = _np_src_image[:] * (1.0 - np_mask_rgb)
|
||||
np_mask_grey = np.sum(np_mask_rgb, axis=2) / 3.0
|
||||
img_mask = np_mask_grey > 1e-6
|
||||
ref_mask = np_mask_grey < 1e-3
|
||||
|
||||
# rather than leave the masked area black, we get better results from fft by filling the average unmasked color
|
||||
windowed_image = _np_src_image * (
|
||||
1.0 - _get_masked_window_rgb(np_mask_grey)
|
||||
)
|
||||
windowed_image /= np.max(windowed_image)
|
||||
windowed_image += np.average(_np_src_image) * np_mask_rgb
|
||||
|
||||
src_fft = _fft2(
|
||||
windowed_image
|
||||
) # get feature statistics from masked src img
|
||||
src_dist = np.absolute(src_fft)
|
||||
src_phase = src_fft / src_dist
|
||||
|
||||
# create a generator with a static seed to make outpainting deterministic / only follow global seed
|
||||
rng = np.random.default_rng(0)
|
||||
|
||||
noise_window = _get_gaussian_window(
|
||||
width, height, mode=1
|
||||
) # start with simple gaussian noise
|
||||
noise_rgb = rng.random((width, height, num_channels))
|
||||
noise_grey = np.sum(noise_rgb, axis=2) / 3.0
|
||||
# the colorfulness of the starting noise is blended to greyscale with a parameter
|
||||
noise_rgb *= color_variation
|
||||
for c in range(num_channels):
|
||||
noise_rgb[:, :, c] += (1.0 - color_variation) * noise_grey
|
||||
|
||||
noise_fft = _fft2(noise_rgb)
|
||||
for c in range(num_channels):
|
||||
noise_fft[:, :, c] *= noise_window
|
||||
noise_rgb = np.real(_ifft2(noise_fft))
|
||||
shaped_noise_fft = _fft2(noise_rgb)
|
||||
shaped_noise_fft[:, :, :] = (
|
||||
np.absolute(shaped_noise_fft[:, :, :]) ** 2
|
||||
* (src_dist**noise_q)
|
||||
* src_phase
|
||||
) # perform the actual shaping
|
||||
|
||||
# color_variation
|
||||
brightness_variation = 0.0
|
||||
contrast_adjusted_np_src = (
|
||||
_np_src_image[:] * (brightness_variation + 1.0)
|
||||
- brightness_variation * 2.0
|
||||
)
|
||||
|
||||
shaped_noise = np.real(_ifft2(shaped_noise_fft))
|
||||
shaped_noise -= np.min(shaped_noise)
|
||||
shaped_noise /= np.max(shaped_noise)
|
||||
shaped_noise[img_mask, :] = _match_histograms(
|
||||
shaped_noise[img_mask, :] ** 1.0,
|
||||
contrast_adjusted_np_src[ref_mask, :],
|
||||
)
|
||||
shaped_noise = (
|
||||
_np_src_image[:] * (1.0 - np_mask_rgb) + shaped_noise * np_mask_rgb
|
||||
)
|
||||
|
||||
matched_noise = shaped_noise[:]
|
||||
|
||||
return np.clip(matched_noise, 0.0, 1.0)
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompts,
|
||||
neg_prompts,
|
||||
image,
|
||||
pixels,
|
||||
mask_blur,
|
||||
is_left,
|
||||
is_right,
|
||||
is_top,
|
||||
is_bottom,
|
||||
noise_q,
|
||||
color_variation,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
max_length,
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(neg_prompts, str):
|
||||
neg_prompts = [neg_prompts]
|
||||
|
||||
prompts = prompts * batch_size
|
||||
neg_prompts = neg_prompts * batch_size
|
||||
|
||||
# seed generator to create the inital latent noise. Also handle out of range seeds.
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
# Get initial latents
|
||||
init_latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get text embeddings with weight emphasis from prompts
|
||||
text_embeddings = self.encode_prompts_weight(
|
||||
prompts,
|
||||
neg_prompts,
|
||||
max_length,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
|
||||
|
||||
process_width = width
|
||||
process_height = height
|
||||
left = pixels if is_left else 0
|
||||
right = pixels if is_right else 0
|
||||
up = pixels if is_top else 0
|
||||
down = pixels if is_bottom else 0
|
||||
target_w = math.ceil((image.width + left + right) / 64) * 64
|
||||
target_h = math.ceil((image.height + up + down) / 64) * 64
|
||||
|
||||
if left > 0:
|
||||
left = left * (target_w - image.width) // (left + right)
|
||||
if right > 0:
|
||||
right = target_w - image.width - left
|
||||
if up > 0:
|
||||
up = up * (target_h - image.height) // (up + down)
|
||||
if down > 0:
|
||||
down = target_h - image.height - up
|
||||
|
||||
def expand(
|
||||
init_img,
|
||||
expand_pixels,
|
||||
is_left=False,
|
||||
is_right=False,
|
||||
is_top=False,
|
||||
is_bottom=False,
|
||||
):
|
||||
is_horiz = is_left or is_right
|
||||
is_vert = is_top or is_bottom
|
||||
pixels_horiz = expand_pixels if is_horiz else 0
|
||||
pixels_vert = expand_pixels if is_vert else 0
|
||||
|
||||
res_w = init_img.width + pixels_horiz
|
||||
res_h = init_img.height + pixels_vert
|
||||
process_res_w = math.ceil(res_w / 64) * 64
|
||||
process_res_h = math.ceil(res_h / 64) * 64
|
||||
|
||||
img = Image.new("RGB", (process_res_w, process_res_h))
|
||||
img.paste(
|
||||
init_img,
|
||||
(pixels_horiz if is_left else 0, pixels_vert if is_top else 0),
|
||||
)
|
||||
|
||||
msk = Image.new("RGB", (process_res_w, process_res_h), "white")
|
||||
draw = ImageDraw.Draw(msk)
|
||||
draw.rectangle(
|
||||
(
|
||||
expand_pixels + mask_blur if is_left else 0,
|
||||
expand_pixels + mask_blur if is_top else 0,
|
||||
(
|
||||
msk.width - expand_pixels - mask_blur
|
||||
if is_right
|
||||
else res_w
|
||||
),
|
||||
(
|
||||
msk.height - expand_pixels - mask_blur
|
||||
if is_bottom
|
||||
else res_h
|
||||
),
|
||||
),
|
||||
fill="black",
|
||||
)
|
||||
|
||||
np_image = (np.asarray(img) / 255.0).astype(np.float64)
|
||||
np_mask = (np.asarray(msk) / 255.0).astype(np.float64)
|
||||
noised = self.get_matched_noise(
|
||||
np_image, np_mask, noise_q, color_variation
|
||||
)
|
||||
output_image = Image.fromarray(
|
||||
np.clip(noised * 255.0, 0.0, 255.0).astype(np.uint8),
|
||||
mode="RGB",
|
||||
)
|
||||
|
||||
target_width = (
|
||||
min(width, init_img.width + pixels_horiz)
|
||||
if is_horiz
|
||||
else img.width
|
||||
)
|
||||
target_height = (
|
||||
min(height, init_img.height + pixels_vert)
|
||||
if is_vert
|
||||
else img.height
|
||||
)
|
||||
crop_region = (
|
||||
0 if is_left else output_image.width - target_width,
|
||||
0 if is_top else output_image.height - target_height,
|
||||
target_width if is_left else output_image.width,
|
||||
target_height if is_top else output_image.height,
|
||||
)
|
||||
mask_to_process = msk.crop(crop_region)
|
||||
image_to_process = output_image.crop(crop_region)
|
||||
|
||||
# Preprocess mask and image
|
||||
mask, masked_image = self.prepare_mask_and_masked_image(
|
||||
image_to_process, mask_to_process, mask_blur, width, height
|
||||
)
|
||||
|
||||
# Prepare mask latent variables
|
||||
mask, masked_image_latents = self.prepare_mask_latents(
|
||||
mask=mask,
|
||||
masked_image=masked_image,
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get Image latents
|
||||
latents = self.produce_img_latents(
|
||||
latents=init_latents,
|
||||
text_embeddings=text_embeddings,
|
||||
guidance_scale=guidance_scale,
|
||||
total_timesteps=self.scheduler.timesteps,
|
||||
dtype=dtype,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
mask=mask,
|
||||
masked_image_latents=masked_image_latents,
|
||||
)
|
||||
|
||||
# Img latents -> PIL images
|
||||
all_imgs = []
|
||||
self.load_vae()
|
||||
for i in tqdm(range(0, latents.shape[0], batch_size)):
|
||||
imgs = self.decode_latents(
|
||||
latents=latents[i : i + batch_size],
|
||||
use_base_vae=use_base_vae,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
|
||||
res_img = all_imgs[0].resize(
|
||||
(image_to_process.width, image_to_process.height)
|
||||
)
|
||||
output_image.paste(
|
||||
res_img,
|
||||
(
|
||||
0 if is_left else output_image.width - res_img.width,
|
||||
0 if is_top else output_image.height - res_img.height,
|
||||
),
|
||||
)
|
||||
output_image = output_image.crop((0, 0, res_w, res_h))
|
||||
|
||||
return output_image
|
||||
|
||||
img = image.resize((width, height))
|
||||
if left > 0:
|
||||
img = expand(img, left, is_left=True)
|
||||
if right > 0:
|
||||
img = expand(img, right, is_right=True)
|
||||
if up > 0:
|
||||
img = expand(img, up, is_top=True)
|
||||
if down > 0:
|
||||
img = expand(img, down, is_bottom=True)
|
||||
|
||||
return [img]
|
||||
@@ -0,0 +1,606 @@
|
||||
import torch
|
||||
import time
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
from random import randint
|
||||
from PIL import Image
|
||||
from transformers import CLIPTokenizer
|
||||
from typing import Union
|
||||
from shark.shark_inference import SharkInference
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
DDPMScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
controlnet_hint_conversion,
|
||||
controlnet_hint_reshaping,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
start_profiling,
|
||||
end_profiling,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
resamplers,
|
||||
resampler_list,
|
||||
)
|
||||
from apps.stable_diffusion.src.models import (
|
||||
SharkifyStableDiffusionModel,
|
||||
get_vae_encode,
|
||||
)
|
||||
|
||||
|
||||
class StencilPipeline(StableDiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
DDPMScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
],
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
controlnet_names: list[str],
|
||||
):
|
||||
super().__init__(
|
||||
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
|
||||
)
|
||||
self.controlnet = [None] * len(controlnet_names)
|
||||
self.controlnet_512 = [None] * len(controlnet_names)
|
||||
self.controlnet_id = [str] * len(controlnet_names)
|
||||
self.controlnet_512_id = [str] * len(controlnet_names)
|
||||
self.controlnet_names = controlnet_names
|
||||
self.vae_encode = None
|
||||
|
||||
def load_vae_encode(self):
|
||||
if self.vae_encode is not None:
|
||||
return
|
||||
|
||||
if self.import_mlir or self.use_lora:
|
||||
self.vae_encode = self.sd_model.vae_encode()
|
||||
else:
|
||||
try:
|
||||
self.vae_encode = get_vae_encode()
|
||||
except:
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
self.vae_encode = self.sd_model.vae_encode()
|
||||
|
||||
def unload_vae_encode(self):
|
||||
del self.vae_encode
|
||||
self.vae_encode = None
|
||||
|
||||
def load_controlnet(self, index, model_name):
|
||||
if model_name is None:
|
||||
return
|
||||
if (
|
||||
self.controlnet[index] is not None
|
||||
and self.controlnet_id[index] is not None
|
||||
and self.controlnet_id[index] == model_name
|
||||
):
|
||||
return
|
||||
self.controlnet_id[index] = model_name
|
||||
self.controlnet[index] = self.sd_model.controlnet(model_name)
|
||||
|
||||
def unload_controlnet(self, index):
|
||||
del self.controlnet[index]
|
||||
self.controlnet_id[index] = None
|
||||
self.controlnet[index] = None
|
||||
|
||||
def load_controlnet_512(self, index, model_name):
|
||||
if (
|
||||
self.controlnet_512[index] is not None
|
||||
and self.controlnet_512_id[index] == model_name
|
||||
):
|
||||
return
|
||||
self.controlnet_512_id[index] = model_name
|
||||
self.controlnet_512[index] = self.sd_model.controlnet(
|
||||
model_name, use_large=True
|
||||
)
|
||||
|
||||
def unload_controlnet_512(self, index):
|
||||
del self.controlnet_512[index]
|
||||
self.controlnet_512_id[index] = None
|
||||
self.controlnet_512[index] = None
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
dtype,
|
||||
):
|
||||
latents = torch.randn(
|
||||
(
|
||||
batch_size,
|
||||
4,
|
||||
height // 8,
|
||||
width // 8,
|
||||
),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
self.scheduler.is_scale_input_called = True
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def prepare_image_latents(
|
||||
self,
|
||||
image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
strength,
|
||||
dtype,
|
||||
resample_type,
|
||||
):
|
||||
# Pre process image -> get image encoded -> process latents
|
||||
|
||||
# TODO: process with variable HxW combos
|
||||
|
||||
# Pre-process image
|
||||
resample_type = (
|
||||
resamplers[resample_type]
|
||||
if resample_type in resampler_list
|
||||
# Fallback to Lanczos
|
||||
else Image.Resampling.LANCZOS
|
||||
)
|
||||
|
||||
image = image.resize((width, height), resample=resample_type)
|
||||
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
|
||||
image_arr = image_arr / 255.0
|
||||
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(dtype)
|
||||
image_arr = 2 * (image_arr - 0.5)
|
||||
|
||||
# set scheduler steps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
init_timestep = min(
|
||||
int(num_inference_steps * strength), num_inference_steps
|
||||
)
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
# timesteps reduced as per strength
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
# new number of steps to be used as per strength will be
|
||||
# num_inference_steps = num_inference_steps - t_start
|
||||
|
||||
# image encode
|
||||
latents = self.encode_image((image_arr,))
|
||||
latents = torch.from_numpy(latents).to(dtype)
|
||||
# add noise to data
|
||||
noise = torch.randn(latents.shape, generator=generator, dtype=dtype)
|
||||
latents = self.scheduler.add_noise(
|
||||
latents, noise, timesteps[0].repeat(1)
|
||||
)
|
||||
|
||||
return latents, timesteps
|
||||
|
||||
def produce_stencil_latents(
|
||||
self,
|
||||
latents,
|
||||
text_embeddings,
|
||||
guidance_scale,
|
||||
total_timesteps,
|
||||
dtype,
|
||||
cpu_scheduling,
|
||||
stencil_hints=[None],
|
||||
controlnet_conditioning_scale: float = 1.0,
|
||||
control_mode="Balanced", # Prompt, Balanced, or Controlnet
|
||||
mask=None,
|
||||
masked_image_latents=None,
|
||||
return_all_latents=False,
|
||||
):
|
||||
step_time_sum = 0
|
||||
latent_history = [latents]
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
assert control_mode in ["Prompt", "Balanced", "Controlnet"]
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
self.load_unet()
|
||||
else:
|
||||
self.load_unet_512()
|
||||
|
||||
for i, name in enumerate(self.controlnet_names):
|
||||
use_names = []
|
||||
if name is not None:
|
||||
use_names.append(name)
|
||||
else:
|
||||
continue
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
self.load_controlnet(i, name)
|
||||
else:
|
||||
self.load_controlnet_512(i, name)
|
||||
self.controlnet_names = use_names
|
||||
|
||||
for i, t in tqdm(enumerate(total_timesteps)):
|
||||
step_start_time = time.time()
|
||||
timestep = torch.tensor([t]).to(dtype)
|
||||
latent_model_input = self.scheduler.scale_model_input(latents, t)
|
||||
if mask is not None and masked_image_latents is not None:
|
||||
latent_model_input = torch.cat(
|
||||
[
|
||||
torch.from_numpy(np.asarray(latent_model_input)),
|
||||
mask,
|
||||
masked_image_latents,
|
||||
],
|
||||
dim=1,
|
||||
).to(dtype)
|
||||
if cpu_scheduling:
|
||||
latent_model_input = latent_model_input.detach().numpy()
|
||||
|
||||
if not torch.is_tensor(latent_model_input):
|
||||
latent_model_input_1 = torch.from_numpy(
|
||||
np.asarray(latent_model_input)
|
||||
).to(dtype)
|
||||
else:
|
||||
latent_model_input_1 = latent_model_input
|
||||
|
||||
# Multicontrolnet
|
||||
height = latent_model_input_1.shape[2]
|
||||
width = latent_model_input_1.shape[3]
|
||||
dtype = latent_model_input_1.dtype
|
||||
control_acc = (
|
||||
[torch.zeros((2, 320, height, width), dtype=dtype)] * 3
|
||||
+ [
|
||||
torch.zeros(
|
||||
(2, 320, int(height / 2), int(width / 2)), dtype=dtype
|
||||
)
|
||||
]
|
||||
+ [
|
||||
torch.zeros(
|
||||
(2, 640, int(height / 2), int(width / 2)), dtype=dtype
|
||||
)
|
||||
]
|
||||
* 2
|
||||
+ [
|
||||
torch.zeros(
|
||||
(2, 640, int(height / 4), int(width / 4)), dtype=dtype
|
||||
)
|
||||
]
|
||||
+ [
|
||||
torch.zeros(
|
||||
(2, 1280, int(height / 4), int(width / 4)), dtype=dtype
|
||||
)
|
||||
]
|
||||
* 2
|
||||
+ [
|
||||
torch.zeros(
|
||||
(2, 1280, int(height / 8), int(width / 8)), dtype=dtype
|
||||
)
|
||||
]
|
||||
* 4
|
||||
)
|
||||
for i, controlnet_hint in enumerate(stencil_hints):
|
||||
if controlnet_hint is None:
|
||||
pass
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
control = self.controlnet[i](
|
||||
"forward",
|
||||
(
|
||||
latent_model_input_1,
|
||||
timestep,
|
||||
text_embeddings,
|
||||
controlnet_hint,
|
||||
*control_acc,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
else:
|
||||
control = self.controlnet_512[i](
|
||||
"forward",
|
||||
(
|
||||
latent_model_input_1,
|
||||
timestep,
|
||||
text_embeddings,
|
||||
controlnet_hint,
|
||||
*control_acc,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
control_acc = control[13:]
|
||||
control = control[:13]
|
||||
|
||||
timestep = timestep.detach().numpy()
|
||||
# Profiling Unet.
|
||||
profile_device = start_profiling(file_path="unet.rdc")
|
||||
# TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py.
|
||||
|
||||
dtype = latents.dtype
|
||||
if control_mode == "Balanced":
|
||||
control_scale = [
|
||||
torch.tensor(1.0, dtype=dtype) for _ in range(len(control))
|
||||
]
|
||||
elif control_mode == "Prompt":
|
||||
control_scale = [
|
||||
torch.tensor(0.825**x, dtype=dtype)
|
||||
for x in range(len(control))
|
||||
]
|
||||
elif control_mode == "Controlnet":
|
||||
control_scale = [
|
||||
torch.tensor(float(guidance_scale), dtype=dtype)
|
||||
for _ in range(len(control))
|
||||
]
|
||||
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
control[0],
|
||||
control[1],
|
||||
control[2],
|
||||
control[3],
|
||||
control[4],
|
||||
control[5],
|
||||
control[6],
|
||||
control[7],
|
||||
control[8],
|
||||
control[9],
|
||||
control[10],
|
||||
control[11],
|
||||
control[12],
|
||||
control_scale[0],
|
||||
control_scale[1],
|
||||
control_scale[2],
|
||||
control_scale[3],
|
||||
control_scale[4],
|
||||
control_scale[5],
|
||||
control_scale[6],
|
||||
control_scale[7],
|
||||
control_scale[8],
|
||||
control_scale[9],
|
||||
control_scale[10],
|
||||
control_scale[11],
|
||||
control_scale[12],
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
else:
|
||||
noise_pred = self.unet_512(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
guidance_scale,
|
||||
control[0],
|
||||
control[1],
|
||||
control[2],
|
||||
control[3],
|
||||
control[4],
|
||||
control[5],
|
||||
control[6],
|
||||
control[7],
|
||||
control[8],
|
||||
control[9],
|
||||
control[10],
|
||||
control[11],
|
||||
control[12],
|
||||
control_scale[0],
|
||||
control_scale[1],
|
||||
control_scale[2],
|
||||
control_scale[3],
|
||||
control_scale[4],
|
||||
control_scale[5],
|
||||
control_scale[6],
|
||||
control_scale[7],
|
||||
control_scale[8],
|
||||
control_scale[9],
|
||||
control_scale[10],
|
||||
control_scale[11],
|
||||
control_scale[12],
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
end_profiling(profile_device)
|
||||
|
||||
if cpu_scheduling:
|
||||
noise_pred = torch.from_numpy(noise_pred.to_host())
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents
|
||||
).prev_sample
|
||||
else:
|
||||
latents = self.scheduler.step(noise_pred, t, latents)
|
||||
|
||||
latent_history.append(latents)
|
||||
step_time = (time.time() - step_start_time) * 1000
|
||||
# self.log += (
|
||||
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
|
||||
# )
|
||||
step_time_sum += step_time
|
||||
|
||||
if self.ondemand:
|
||||
self.unload_unet()
|
||||
self.unload_unet_512()
|
||||
for i in range(len(self.controlnet_names)):
|
||||
self.unload_controlnet(i)
|
||||
self.unload_controlnet_512(i)
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
self.log += f"\nAverage step time: {avg_step_time}ms/it"
|
||||
|
||||
if not return_all_latents:
|
||||
return latents
|
||||
all_latents = torch.cat(latent_history, dim=0)
|
||||
return all_latents
|
||||
|
||||
def encode_image(self, input_image):
|
||||
self.load_vae_encode()
|
||||
vae_encode_start = time.time()
|
||||
latents = self.vae_encode("forward", input_image)
|
||||
vae_inf_time = (time.time() - vae_encode_start) * 1000
|
||||
if self.ondemand:
|
||||
self.unload_vae_encode()
|
||||
self.log += f"\nVAE Encode Inference time (ms): {vae_inf_time:.3f}"
|
||||
|
||||
return latents
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompts,
|
||||
neg_prompts,
|
||||
image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
num_inference_steps,
|
||||
strength,
|
||||
guidance_scale,
|
||||
seed,
|
||||
max_length,
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
stencils,
|
||||
stencil_images,
|
||||
resample_type,
|
||||
control_mode,
|
||||
preprocessed_hints,
|
||||
):
|
||||
# Control Embedding check & conversion
|
||||
# controlnet_hint = controlnet_hint_conversion(
|
||||
# image, use_stencil, height, width, dtype, num_images_per_prompt=1
|
||||
# )
|
||||
stencil_hints = []
|
||||
self.sd_model.stencils = stencils
|
||||
for i, hint in enumerate(preprocessed_hints):
|
||||
if hint is not None:
|
||||
hint = controlnet_hint_reshaping(
|
||||
hint,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
num_images_per_prompt=1,
|
||||
)
|
||||
stencil_hints.append(hint)
|
||||
|
||||
for i, stencil in enumerate(stencils):
|
||||
if stencil == None:
|
||||
continue
|
||||
if len(stencil_hints) > i:
|
||||
if stencil_hints[i] is not None:
|
||||
print(f"Using preprocessed controlnet hint for {stencil}")
|
||||
continue
|
||||
image = stencil_images[i]
|
||||
stencil_hints.append(
|
||||
controlnet_hint_conversion(
|
||||
image,
|
||||
stencil,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
num_images_per_prompt=1,
|
||||
)
|
||||
)
|
||||
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(neg_prompts, str):
|
||||
neg_prompts = [neg_prompts]
|
||||
|
||||
prompts = prompts * batch_size
|
||||
neg_prompts = neg_prompts * batch_size
|
||||
|
||||
# seed generator to create the inital latent noise. Also handle out of range seeds.
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
# Get text embeddings with weight emphasis from prompts
|
||||
text_embeddings = self.encode_prompts_weight(
|
||||
prompts,
|
||||
neg_prompts,
|
||||
max_length,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
|
||||
if image is not None:
|
||||
# Prepare input image latent
|
||||
init_latents, final_timesteps = self.prepare_image_latents(
|
||||
image=image,
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
strength=strength,
|
||||
dtype=dtype,
|
||||
resample_type=resample_type,
|
||||
)
|
||||
else:
|
||||
# Prepare initial latent.
|
||||
init_latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
dtype=dtype,
|
||||
)
|
||||
final_timesteps = self.scheduler.timesteps
|
||||
|
||||
# Get Image latents
|
||||
latents = self.produce_stencil_latents(
|
||||
latents=init_latents,
|
||||
text_embeddings=text_embeddings,
|
||||
guidance_scale=guidance_scale,
|
||||
total_timesteps=final_timesteps,
|
||||
dtype=dtype,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
control_mode=control_mode,
|
||||
stencil_hints=stencil_hints,
|
||||
)
|
||||
|
||||
# Img latents -> PIL images
|
||||
all_imgs = []
|
||||
self.load_vae()
|
||||
for i in tqdm(range(0, latents.shape[0], batch_size)):
|
||||
imgs = self.decode_latents(
|
||||
latents=latents[i : i + batch_size],
|
||||
use_base_vae=use_base_vae,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
if self.ondemand:
|
||||
self.unload_vae()
|
||||
|
||||
return all_imgs
|
||||
@@ -0,0 +1,169 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from random import randint
|
||||
from transformers import CLIPTokenizer
|
||||
from typing import Union
|
||||
from shark.shark_inference import SharkInference
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
|
||||
|
||||
|
||||
class Text2ImagePipeline(StableDiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
],
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
):
|
||||
super().__init__(
|
||||
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def favored_base_models(cls, model_id):
|
||||
return [
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
]
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
dtype,
|
||||
):
|
||||
latents = torch.randn(
|
||||
(
|
||||
batch_size,
|
||||
4,
|
||||
height // 8,
|
||||
width // 8,
|
||||
),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
self.scheduler.is_scale_input_called = True
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompts,
|
||||
neg_prompts,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
max_length,
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(neg_prompts, str):
|
||||
neg_prompts = [neg_prompts]
|
||||
|
||||
prompts = prompts * batch_size
|
||||
neg_prompts = neg_prompts * batch_size
|
||||
|
||||
# seed generator to create the inital latent noise. Also handle out of range seeds.
|
||||
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
# Get initial latents
|
||||
init_latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get text embeddings with weight emphasis from prompts
|
||||
text_embeddings = self.encode_prompts_weight(
|
||||
prompts,
|
||||
neg_prompts,
|
||||
max_length,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
|
||||
|
||||
# Get Image latents
|
||||
latents = self.produce_img_latents(
|
||||
latents=init_latents,
|
||||
text_embeddings=text_embeddings,
|
||||
guidance_scale=guidance_scale,
|
||||
total_timesteps=self.scheduler.timesteps,
|
||||
dtype=dtype,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
)
|
||||
|
||||
# Img latents -> PIL images
|
||||
all_imgs = []
|
||||
self.load_vae()
|
||||
for i in range(0, latents.shape[0], batch_size):
|
||||
imgs = self.decode_latents(
|
||||
latents=latents[i : i + batch_size],
|
||||
use_base_vae=use_base_vae,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
if self.ondemand:
|
||||
self.unload_vae()
|
||||
|
||||
return all_imgs
|
||||
@@ -0,0 +1,236 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from random import randint
|
||||
from typing import Union
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Text2ImageSDXLPipeline(StableDiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
],
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
is_fp32_vae: bool,
|
||||
):
|
||||
super().__init__(
|
||||
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
|
||||
)
|
||||
self.is_fp32_vae = is_fp32_vae
|
||||
|
||||
@classmethod
|
||||
def favored_base_models(cls, model_id):
|
||||
if "turbo" in model_id:
|
||||
return [
|
||||
"stabilityai/sdxl-turbo",
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
]
|
||||
else:
|
||||
return [
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"stabilityai/sdxl-turbo",
|
||||
]
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
dtype,
|
||||
):
|
||||
latents = torch.randn(
|
||||
(
|
||||
batch_size,
|
||||
4,
|
||||
height // 8,
|
||||
width // 8,
|
||||
),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
self.scheduler.is_scale_input_called = True
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def _get_add_time_ids(
|
||||
self, original_size, crops_coords_top_left, target_size, dtype
|
||||
):
|
||||
add_time_ids = list(
|
||||
original_size + crops_coords_top_left + target_size
|
||||
)
|
||||
|
||||
# self.unet.config.addition_time_embed_dim IS 256.
|
||||
# self.text_encoder_2.config.projection_dim IS 1280.
|
||||
passed_add_embed_dim = 256 * len(add_time_ids) + 1280
|
||||
expected_add_embed_dim = 2816
|
||||
# self.unet.add_embedding.linear_1.in_features IS 2816.
|
||||
|
||||
if expected_add_embed_dim != passed_add_embed_dim:
|
||||
raise ValueError(
|
||||
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
||||
)
|
||||
|
||||
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
||||
return add_time_ids
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompts,
|
||||
neg_prompts,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
max_length,
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(neg_prompts, str):
|
||||
neg_prompts = [neg_prompts]
|
||||
|
||||
prompts = prompts * batch_size
|
||||
neg_prompts = neg_prompts * batch_size
|
||||
|
||||
# seed generator to create the inital latent noise. Also handle out of range seeds.
|
||||
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
# Get initial latents.
|
||||
init_latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get text embeddings.
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt_sdxl(
|
||||
prompt=prompts,
|
||||
num_images_per_prompt=1,
|
||||
do_classifier_free_guidance=True,
|
||||
negative_prompt=neg_prompts,
|
||||
)
|
||||
|
||||
# Prepare timesteps.
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# Prepare added time ids & embeddings.
|
||||
original_size = (height, width)
|
||||
target_size = (height, width)
|
||||
crops_coords_top_left = (0, 0)
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
)
|
||||
|
||||
prompt_embeds = torch.cat(
|
||||
[negative_prompt_embeds, prompt_embeds], dim=0
|
||||
)
|
||||
add_text_embeds = torch.cat(
|
||||
[negative_pooled_prompt_embeds, add_text_embeds], dim=0
|
||||
)
|
||||
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds
|
||||
add_text_embeds = add_text_embeds.to(dtype)
|
||||
add_time_ids = add_time_ids.repeat(batch_size * 1, 1)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
guidance_scale = torch.tensor(guidance_scale).to(dtype)
|
||||
prompt_embeds = prompt_embeds.to(dtype)
|
||||
add_time_ids = add_time_ids.to(dtype)
|
||||
|
||||
# Get Image latents.
|
||||
latents = self.produce_img_latents_sdxl(
|
||||
init_latents,
|
||||
timesteps,
|
||||
add_text_embeds,
|
||||
add_time_ids,
|
||||
prompt_embeds,
|
||||
cpu_scheduling,
|
||||
guidance_scale,
|
||||
dtype,
|
||||
)
|
||||
|
||||
# Img latents -> PIL images.
|
||||
all_imgs = []
|
||||
self.load_vae()
|
||||
for i in range(0, latents.shape[0], batch_size):
|
||||
imgs = self.decode_latents_sdxl(
|
||||
latents[i : i + batch_size], is_fp32_vae=self.is_fp32_vae
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
if self.ondemand:
|
||||
self.unload_vae()
|
||||
|
||||
return all_imgs
|
||||
@@ -0,0 +1,360 @@
|
||||
import inspect
|
||||
import torch
|
||||
import time
|
||||
from tqdm.auto import tqdm
|
||||
import numpy as np
|
||||
from random import randint
|
||||
from transformers import CLIPTokenizer
|
||||
from typing import Union
|
||||
from shark.shark_inference import SharkInference
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_IDLE,
|
||||
SD_STATE_CANCEL,
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
start_profiling,
|
||||
end_profiling,
|
||||
)
|
||||
from PIL import Image
|
||||
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
if isinstance(image, torch.Tensor):
|
||||
return image
|
||||
elif isinstance(image, Image.Image):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], Image.Image):
|
||||
w, h = image[0].size
|
||||
w, h = map(
|
||||
lambda x: x - x % 64, (w, h)
|
||||
) # resize to integer multiple of 64
|
||||
|
||||
image = [np.array(i.resize((w, h)))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = 2.0 * image - 1.0
|
||||
image = torch.from_numpy(image)
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, dim=0)
|
||||
return image
|
||||
|
||||
|
||||
class UpscalerPipeline(StableDiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
],
|
||||
low_res_scheduler: Union[
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
],
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
):
|
||||
super().__init__(
|
||||
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
|
||||
)
|
||||
self.low_res_scheduler = low_res_scheduler
|
||||
self.status = SD_STATE_IDLE
|
||||
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
accepts_eta = "eta" in set(
|
||||
inspect.signature(self.scheduler.step).parameters.keys()
|
||||
)
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(
|
||||
inspect.signature(self.scheduler.step).parameters.keys()
|
||||
)
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def decode_latents(self, latents, use_base_vae, cpu_scheduling):
|
||||
latents = 1 / 0.08333 * (latents.float())
|
||||
latents_numpy = latents
|
||||
if cpu_scheduling:
|
||||
latents_numpy = latents.detach().numpy()
|
||||
|
||||
profile_device = start_profiling(file_path="vae.rdc")
|
||||
vae_start = time.time()
|
||||
images = self.vae("forward", (latents_numpy,))
|
||||
vae_inf_time = (time.time() - vae_start) * 1000
|
||||
end_profiling(profile_device)
|
||||
self.log += f"\nVAE Inference time (ms): {vae_inf_time:.3f}"
|
||||
|
||||
images = torch.from_numpy(images)
|
||||
images = (images.detach().cpu() * 255.0).numpy()
|
||||
images = images.round()
|
||||
|
||||
images = torch.from_numpy(images).to(torch.uint8).permute(0, 2, 3, 1)
|
||||
pil_images = [Image.fromarray(image) for image in images.numpy()]
|
||||
return pil_images
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
dtype,
|
||||
):
|
||||
latents = torch.randn(
|
||||
(
|
||||
batch_size,
|
||||
4,
|
||||
height,
|
||||
width,
|
||||
),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
self.scheduler.is_scale_input_called = True
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def produce_img_latents(
|
||||
self,
|
||||
latents,
|
||||
image,
|
||||
text_embeddings,
|
||||
guidance_scale,
|
||||
noise_level,
|
||||
total_timesteps,
|
||||
dtype,
|
||||
cpu_scheduling,
|
||||
extra_step_kwargs,
|
||||
return_all_latents=False,
|
||||
):
|
||||
step_time_sum = 0
|
||||
latent_history = [latents]
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
self.status = SD_STATE_IDLE
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
self.load_unet()
|
||||
else:
|
||||
self.load_unet_512()
|
||||
for i, t in tqdm(enumerate(total_timesteps)):
|
||||
step_start_time = time.time()
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
latent_model_input = self.scheduler.scale_model_input(
|
||||
latent_model_input, t
|
||||
)
|
||||
latent_model_input = torch.cat([latent_model_input, image], dim=1)
|
||||
timestep = torch.tensor([t]).to(dtype).detach().numpy()
|
||||
if cpu_scheduling:
|
||||
latent_model_input = latent_model_input.detach().numpy()
|
||||
|
||||
# Profiling Unet.
|
||||
profile_device = start_profiling(file_path="unet.rdc")
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
noise_level,
|
||||
),
|
||||
)
|
||||
else:
|
||||
noise_pred = self.unet_512(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
text_embeddings_numpy,
|
||||
noise_level,
|
||||
),
|
||||
)
|
||||
end_profiling(profile_device)
|
||||
noise_pred = torch.from_numpy(noise_pred)
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
if cpu_scheduling:
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, **extra_step_kwargs
|
||||
).prev_sample
|
||||
else:
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, **extra_step_kwargs
|
||||
)
|
||||
|
||||
latent_history.append(latents)
|
||||
step_time = (time.time() - step_start_time) * 1000
|
||||
# self.log += (
|
||||
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
|
||||
# )
|
||||
step_time_sum += step_time
|
||||
|
||||
if self.status == SD_STATE_CANCEL:
|
||||
break
|
||||
|
||||
if self.ondemand:
|
||||
self.unload_unet()
|
||||
self.unload_unet_512()
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
self.log += f"\nAverage step time: {avg_step_time}ms/it"
|
||||
|
||||
if not return_all_latents:
|
||||
return latents
|
||||
all_latents = torch.cat(latent_history, dim=0)
|
||||
return all_latents
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompts,
|
||||
neg_prompts,
|
||||
image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
num_inference_steps,
|
||||
noise_level,
|
||||
guidance_scale,
|
||||
seed,
|
||||
max_length,
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(neg_prompts, str):
|
||||
neg_prompts = [neg_prompts]
|
||||
|
||||
prompts = prompts * batch_size
|
||||
neg_prompts = neg_prompts * batch_size
|
||||
|
||||
# seed generator to create the inital latent noise. Also handle out of range seeds.
|
||||
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
# Get text embeddings with weight emphasis from prompts
|
||||
text_embeddings = self.encode_prompts_weight(
|
||||
prompts,
|
||||
neg_prompts,
|
||||
max_length,
|
||||
max_embeddings_multiples=max_embeddings_multiples,
|
||||
)
|
||||
|
||||
# 4. Preprocess image
|
||||
image = preprocess(image).to(dtype)
|
||||
|
||||
# 5. Add noise to image
|
||||
noise_level = torch.tensor([noise_level], dtype=torch.long)
|
||||
noise = torch.randn(
|
||||
image.shape,
|
||||
generator=generator,
|
||||
).to(dtype)
|
||||
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
|
||||
image = torch.cat([image] * 2)
|
||||
noise_level = torch.cat([noise_level] * image.shape[0])
|
||||
|
||||
height, width = image.shape[2:]
|
||||
# Get initial latents
|
||||
init_latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
eta = 0.0
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
# guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
|
||||
|
||||
# Get Image latents
|
||||
latents = self.produce_img_latents(
|
||||
latents=init_latents,
|
||||
image=image,
|
||||
text_embeddings=text_embeddings,
|
||||
guidance_scale=guidance_scale,
|
||||
noise_level=noise_level,
|
||||
total_timesteps=self.scheduler.timesteps,
|
||||
dtype=dtype,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
extra_step_kwargs=extra_step_kwargs,
|
||||
)
|
||||
|
||||
# Img latents -> PIL images
|
||||
all_imgs = []
|
||||
self.load_vae()
|
||||
for i in tqdm(range(0, latents.shape[0], batch_size)):
|
||||
imgs = self.decode_latents(
|
||||
latents=latents[i : i + batch_size],
|
||||
use_base_vae=use_base_vae,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
if self.ondemand:
|
||||
self.unload_vae()
|
||||
|
||||
return all_imgs
|
||||
File diff suppressed because it is too large
Load Diff
7
apps/stable_diffusion/src/schedulers/__init__.py
Normal file
7
apps/stable_diffusion/src/schedulers/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers.shark_eulerancestraldiscrete import (
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers.sd_schedulers import get_schedulers
|
||||
141
apps/stable_diffusion/src/schedulers/sd_schedulers.py
Normal file
141
apps/stable_diffusion/src/schedulers/sd_schedulers.py
Normal file
@@ -0,0 +1,141 @@
|
||||
from diffusers import (
|
||||
LCMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
DDPMScheduler,
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers.shark_eulerancestraldiscrete import (
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
)
|
||||
|
||||
|
||||
def get_schedulers(model_id):
|
||||
# TODO: Robust scheduler setup on pipeline creation -- if we don't
|
||||
# set batch_size here, the SHARK schedulers will
|
||||
# compile with batch size = 1 regardless of whether the model
|
||||
# outputs latents of a larger batch size, e.g. SDXL.
|
||||
# However, obviously, searching for whether the base model ID
|
||||
# contains "xl" is not very robust.
|
||||
|
||||
batch_size = 2 if "xl" in model_id.lower() else 1
|
||||
|
||||
schedulers = dict()
|
||||
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["DDPM"] = DDPMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["KDPM2Discrete"] = KDPM2DiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["DDIM"] = DDIMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["LCMScheduler"] = LCMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["DPMSolverMultistep"] = (
|
||||
DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id, subfolder="scheduler", algorithm_type="dpmsolver"
|
||||
)
|
||||
)
|
||||
schedulers["DPMSolverMultistep++"] = (
|
||||
DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id, subfolder="scheduler", algorithm_type="dpmsolver++"
|
||||
)
|
||||
)
|
||||
schedulers["DPMSolverMultistepKarras"] = (
|
||||
DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
use_karras_sigmas=True,
|
||||
)
|
||||
)
|
||||
schedulers["DPMSolverMultistepKarras++"] = (
|
||||
DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
algorithm_type="dpmsolver++",
|
||||
use_karras_sigmas=True,
|
||||
)
|
||||
)
|
||||
schedulers["DPMSolverSDEKarras++"] = (
|
||||
DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
algorithm_type="sde-dpmsolver++",
|
||||
use_karras_sigmas=True,
|
||||
)
|
||||
)
|
||||
schedulers["DPMSolverSDE++"] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
algorithm_type="sde-dpmsolver++",
|
||||
)
|
||||
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["EulerAncestralDiscrete"] = (
|
||||
EulerAncestralDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
)
|
||||
schedulers["DEISMultistep"] = DEISMultistepScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["SharkEulerDiscrete"] = (
|
||||
SharkEulerDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
)
|
||||
schedulers["SharkEulerAncestralDiscrete"] = (
|
||||
SharkEulerAncestralDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
)
|
||||
schedulers["DPMSolverSinglestep"] = (
|
||||
DPMSolverSinglestepScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
)
|
||||
schedulers["KDPM2AncestralDiscrete"] = (
|
||||
KDPM2AncestralDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
)
|
||||
schedulers["HeunDiscrete"] = HeunDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["SharkEulerDiscrete"].compile(batch_size)
|
||||
schedulers["SharkEulerAncestralDiscrete"].compile(batch_size)
|
||||
return schedulers
|
||||
@@ -0,0 +1,247 @@
|
||||
import sys
|
||||
import numpy as np
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from diffusers import (
|
||||
EulerAncestralDiscreteScheduler,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
compile_through_fx,
|
||||
get_shark_model,
|
||||
args,
|
||||
)
|
||||
import torch
|
||||
|
||||
|
||||
class SharkEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
super().__init__(
|
||||
num_train_timesteps,
|
||||
beta_start,
|
||||
beta_end,
|
||||
beta_schedule,
|
||||
trained_betas,
|
||||
prediction_type,
|
||||
timestep_spacing,
|
||||
steps_offset,
|
||||
)
|
||||
# TODO: make it dynamic so we dont have to worry about batch size
|
||||
self.batch_size = None
|
||||
self.init_input_shape = None
|
||||
|
||||
def compile(self, batch_size=1):
|
||||
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
|
||||
device = args.device.split(":", 1)[0].strip()
|
||||
self.batch_size = batch_size
|
||||
|
||||
model_input = {
|
||||
"eulera": {
|
||||
"output": torch.randn(
|
||||
batch_size, 4, args.height // 8, args.width // 8
|
||||
),
|
||||
"latent": torch.randn(
|
||||
batch_size, 4, args.height // 8, args.width // 8
|
||||
),
|
||||
"sigma": torch.tensor(1).to(torch.float32),
|
||||
"sigma_from": torch.tensor(1).to(torch.float32),
|
||||
"sigma_to": torch.tensor(1).to(torch.float32),
|
||||
"noise": torch.randn(
|
||||
batch_size, 4, args.height // 8, args.width // 8
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
example_latent = model_input["eulera"]["latent"]
|
||||
example_output = model_input["eulera"]["output"]
|
||||
example_noise = model_input["eulera"]["noise"]
|
||||
if args.precision == "fp16":
|
||||
example_latent = example_latent.half()
|
||||
example_output = example_output.half()
|
||||
example_noise = example_noise.half()
|
||||
example_sigma = model_input["eulera"]["sigma"]
|
||||
example_sigma_from = model_input["eulera"]["sigma_from"]
|
||||
example_sigma_to = model_input["eulera"]["sigma_to"]
|
||||
|
||||
class ScalingModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, latent, sigma):
|
||||
return latent / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
class SchedulerStepEpsilonModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self, noise_pred, latent, sigma, sigma_from, sigma_to, noise
|
||||
):
|
||||
sigma_up = (
|
||||
sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2
|
||||
) ** 0.5
|
||||
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
||||
dt = sigma_down - sigma
|
||||
pred_original_sample = latent - sigma * noise_pred
|
||||
derivative = (latent - pred_original_sample) / sigma
|
||||
prev_sample = latent + derivative * dt
|
||||
return prev_sample + noise * sigma_up
|
||||
|
||||
class SchedulerStepVPredictionModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self, noise_pred, sigma, sigma_from, sigma_to, latent, noise
|
||||
):
|
||||
sigma_up = (
|
||||
sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2
|
||||
) ** 0.5
|
||||
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
||||
dt = sigma_down - sigma
|
||||
pred_original_sample = noise_pred * (
|
||||
-sigma / (sigma**2 + 1) ** 0.5
|
||||
) + (latent / (sigma**2 + 1))
|
||||
derivative = (latent - pred_original_sample) / sigma
|
||||
prev_sample = latent + derivative * dt
|
||||
return prev_sample + noise * sigma_up
|
||||
|
||||
iree_flags = []
|
||||
if len(args.iree_vulkan_target_triple) > 0:
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
|
||||
def _import(self):
|
||||
scaling_model = ScalingModel()
|
||||
self.scaling_model, _ = compile_through_fx(
|
||||
model=scaling_model,
|
||||
inputs=(example_latent, example_sigma),
|
||||
extended_model_name=f"euler_a_scale_model_input_{self.batch_size}_{args.height}_{args.width}_{device}_"
|
||||
+ args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
|
||||
pred_type_model_dict = {
|
||||
"epsilon": SchedulerStepEpsilonModel(),
|
||||
"v_prediction": SchedulerStepVPredictionModel(),
|
||||
}
|
||||
step_model = pred_type_model_dict[self.config.prediction_type]
|
||||
self.step_model, _ = compile_through_fx(
|
||||
step_model,
|
||||
(
|
||||
example_output,
|
||||
example_latent,
|
||||
example_sigma,
|
||||
example_sigma_from,
|
||||
example_sigma_to,
|
||||
example_noise,
|
||||
),
|
||||
extended_model_name=f"euler_a_step_{self.config.prediction_type}_{self.batch_size}_{args.height}_{args.width}_{device}_"
|
||||
+ args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
|
||||
if args.import_mlir:
|
||||
_import(self)
|
||||
|
||||
else:
|
||||
try:
|
||||
self.scaling_model = get_shark_model(
|
||||
SCHEDULER_BUCKET,
|
||||
"euler_a_scale_model_input_" + args.precision,
|
||||
iree_flags,
|
||||
)
|
||||
self.step_model = get_shark_model(
|
||||
SCHEDULER_BUCKET,
|
||||
"euler_a_step_"
|
||||
+ self.config.prediction_type
|
||||
+ args.precision,
|
||||
iree_flags,
|
||||
)
|
||||
except:
|
||||
print(
|
||||
"failed to download model, falling back and using import_mlir"
|
||||
)
|
||||
args.import_mlir = True
|
||||
_import(self)
|
||||
|
||||
def scale_model_input(self, sample, timestep):
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
sigma = self.sigmas[self.step_index]
|
||||
return self.scaling_model(
|
||||
"forward",
|
||||
(
|
||||
sample,
|
||||
sigma,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
|
||||
def step(
|
||||
self,
|
||||
noise_pred,
|
||||
timestep,
|
||||
latent,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: Optional[bool] = False,
|
||||
):
|
||||
step_inputs = []
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
|
||||
sigma_from = self.sigmas[self.step_index]
|
||||
sigma_to = self.sigmas[self.step_index + 1]
|
||||
noise = randn_tensor(
|
||||
torch.Size(noise_pred.shape),
|
||||
dtype=torch.float16,
|
||||
device="cpu",
|
||||
generator=generator,
|
||||
)
|
||||
step_inputs = [
|
||||
noise_pred,
|
||||
latent,
|
||||
sigma,
|
||||
sigma_from,
|
||||
sigma_to,
|
||||
noise,
|
||||
]
|
||||
# TODO: deal with dynamic inputs in turbine flow.
|
||||
# update step index since we're done with the variable and will return with compiled module output.
|
||||
self._step_index += 1
|
||||
|
||||
if noise_pred.shape[0] < self.batch_size:
|
||||
for i in [0, 1, 5]:
|
||||
try:
|
||||
step_inputs[i] = torch.tensor(step_inputs[i])
|
||||
except:
|
||||
step_inputs[i] = torch.tensor(step_inputs[i].to_host())
|
||||
step_inputs[i] = torch.cat(
|
||||
(step_inputs[i], step_inputs[i]), axis=0
|
||||
)
|
||||
return self.step_model(
|
||||
"forward",
|
||||
tuple(step_inputs),
|
||||
send_to_host=True,
|
||||
)
|
||||
|
||||
return self.step_model(
|
||||
"forward",
|
||||
tuple(step_inputs),
|
||||
send_to_host=False,
|
||||
)
|
||||
245
apps/stable_diffusion/src/schedulers/shark_eulerdiscrete.py
Normal file
245
apps/stable_diffusion/src/schedulers/shark_eulerdiscrete.py
Normal file
@@ -0,0 +1,245 @@
|
||||
import sys
|
||||
import numpy as np
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from diffusers import (
|
||||
EulerDiscreteScheduler,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
compile_through_fx,
|
||||
get_shark_model,
|
||||
args,
|
||||
)
|
||||
import torch
|
||||
|
||||
|
||||
class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
interpolation_type: str = "linear",
|
||||
use_karras_sigmas: bool = False,
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
timestep_spacing: str = "linspace",
|
||||
timestep_type: str = "discrete",
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
super().__init__(
|
||||
num_train_timesteps,
|
||||
beta_start,
|
||||
beta_end,
|
||||
beta_schedule,
|
||||
trained_betas,
|
||||
prediction_type,
|
||||
interpolation_type,
|
||||
use_karras_sigmas,
|
||||
sigma_min,
|
||||
sigma_max,
|
||||
timestep_spacing,
|
||||
timestep_type,
|
||||
steps_offset,
|
||||
)
|
||||
# TODO: make it dynamic so we dont have to worry about batch size
|
||||
self.batch_size = 1
|
||||
|
||||
def compile(self, batch_size=1):
|
||||
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
|
||||
device = args.device.split(":", 1)[0].strip()
|
||||
self.batch_size = batch_size
|
||||
|
||||
model_input = {
|
||||
"euler": {
|
||||
"latent": torch.randn(
|
||||
batch_size, 4, args.height // 8, args.width // 8
|
||||
),
|
||||
"output": torch.randn(
|
||||
batch_size, 4, args.height // 8, args.width // 8
|
||||
),
|
||||
"sigma": torch.tensor(1).to(torch.float32),
|
||||
"dt": torch.tensor(1).to(torch.float32),
|
||||
},
|
||||
}
|
||||
|
||||
example_latent = model_input["euler"]["latent"]
|
||||
example_output = model_input["euler"]["output"]
|
||||
if args.precision == "fp16":
|
||||
example_latent = example_latent.half()
|
||||
example_output = example_output.half()
|
||||
example_sigma = model_input["euler"]["sigma"]
|
||||
example_dt = model_input["euler"]["dt"]
|
||||
|
||||
class ScalingModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, latent, sigma):
|
||||
return latent / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
class SchedulerStepEpsilonModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, noise_pred, sigma_hat, latent, dt):
|
||||
pred_original_sample = latent - sigma_hat * noise_pred
|
||||
derivative = (latent - pred_original_sample) / sigma_hat
|
||||
return latent + derivative * dt
|
||||
|
||||
class SchedulerStepSampleModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, noise_pred, sigma_hat, latent, dt):
|
||||
pred_original_sample = noise_pred
|
||||
derivative = (latent - pred_original_sample) / sigma_hat
|
||||
return latent + derivative * dt
|
||||
|
||||
class SchedulerStepVPredictionModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, noise_pred, sigma, latent, dt):
|
||||
pred_original_sample = noise_pred * (
|
||||
-sigma / (sigma**2 + 1) ** 0.5
|
||||
) + (latent / (sigma**2 + 1))
|
||||
derivative = (latent - pred_original_sample) / sigma
|
||||
return latent + derivative * dt
|
||||
|
||||
iree_flags = []
|
||||
if len(args.iree_vulkan_target_triple) > 0:
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
|
||||
def _import(self):
|
||||
scaling_model = ScalingModel()
|
||||
self.scaling_model, _ = compile_through_fx(
|
||||
model=scaling_model,
|
||||
inputs=(example_latent, example_sigma),
|
||||
extended_model_name=f"euler_scale_model_input_{self.batch_size}_{args.height}_{args.width}_{device}_"
|
||||
+ args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
|
||||
pred_type_model_dict = {
|
||||
"epsilon": SchedulerStepEpsilonModel(),
|
||||
"v_prediction": SchedulerStepVPredictionModel(),
|
||||
"sample": SchedulerStepSampleModel(),
|
||||
"original_sample": SchedulerStepSampleModel(),
|
||||
}
|
||||
step_model = pred_type_model_dict[self.config.prediction_type]
|
||||
self.step_model, _ = compile_through_fx(
|
||||
step_model,
|
||||
(example_output, example_sigma, example_latent, example_dt),
|
||||
extended_model_name=f"euler_step_{self.config.prediction_type}_{self.batch_size}_{args.height}_{args.width}_{device}_"
|
||||
+ args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
|
||||
if args.import_mlir:
|
||||
_import(self)
|
||||
|
||||
else:
|
||||
try:
|
||||
step_model_type = (
|
||||
"sample"
|
||||
if "sample" in self.config.prediction_type
|
||||
else self.config.prediction_type
|
||||
)
|
||||
self.scaling_model = get_shark_model(
|
||||
SCHEDULER_BUCKET,
|
||||
"euler_scale_model_input_" + args.precision,
|
||||
iree_flags,
|
||||
)
|
||||
self.step_model = get_shark_model(
|
||||
SCHEDULER_BUCKET,
|
||||
"euler_step_" + step_model_type + args.precision,
|
||||
iree_flags,
|
||||
)
|
||||
except:
|
||||
print(
|
||||
"failed to download model, falling back and using import_mlir"
|
||||
)
|
||||
args.import_mlir = True
|
||||
_import(self)
|
||||
|
||||
def scale_model_input(self, sample, timestep):
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
sigma = self.sigmas[self.step_index]
|
||||
return self.scaling_model(
|
||||
"forward",
|
||||
(
|
||||
sample,
|
||||
sigma,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
|
||||
def step(
|
||||
self,
|
||||
noise_pred,
|
||||
timestep,
|
||||
latent,
|
||||
s_churn: float = 0.0,
|
||||
s_tmin: float = 0.0,
|
||||
s_tmax: float = float("inf"),
|
||||
s_noise: float = 1.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: Optional[bool] = False,
|
||||
):
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
|
||||
gamma = (
|
||||
min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1)
|
||||
if s_tmin <= sigma <= s_tmax
|
||||
else 0.0
|
||||
)
|
||||
|
||||
sigma_hat = sigma * (gamma + 1)
|
||||
|
||||
noise_pred = (
|
||||
torch.from_numpy(noise_pred)
|
||||
if isinstance(noise_pred, np.ndarray)
|
||||
else noise_pred
|
||||
)
|
||||
|
||||
noise = randn_tensor(
|
||||
torch.Size(noise_pred.shape),
|
||||
dtype=torch.float16,
|
||||
device="cpu",
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
eps = noise * s_noise
|
||||
|
||||
if gamma > 0:
|
||||
latent = latent + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
||||
|
||||
if self.config.prediction_type == "v_prediction":
|
||||
sigma_hat = sigma
|
||||
|
||||
dt = self.sigmas[self.step_index + 1] - sigma_hat
|
||||
|
||||
self._step_index += 1
|
||||
|
||||
return self.step_model(
|
||||
"forward",
|
||||
(
|
||||
noise_pred,
|
||||
sigma_hat,
|
||||
latent,
|
||||
dt,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
49
apps/stable_diffusion/src/utils/__init__.py
Normal file
49
apps/stable_diffusion/src/utils/__init__.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from apps.stable_diffusion.src.utils.profiler import (
|
||||
start_profiling,
|
||||
end_profiling,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils.resources import (
|
||||
prompt_examples,
|
||||
models_db,
|
||||
base_models,
|
||||
opt_flags,
|
||||
resource_path,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
from apps.stable_diffusion.src.utils.stencils.stencil_utils import (
|
||||
controlnet_hint_conversion,
|
||||
controlnet_hint_reshaping,
|
||||
get_stencil_model_id,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils.utils import (
|
||||
get_shark_model,
|
||||
compile_through_fx,
|
||||
set_iree_runtime_flags,
|
||||
map_device_to_name_path,
|
||||
set_init_device_flags,
|
||||
get_available_devices,
|
||||
get_opt_flags,
|
||||
preprocessCKPT,
|
||||
convert_original_vae,
|
||||
fetch_and_update_base_model_id,
|
||||
get_path_to_diffusers_checkpoint,
|
||||
sanitize_seed,
|
||||
parse_seed_input,
|
||||
batch_seeds,
|
||||
get_path_stem,
|
||||
get_extended_name,
|
||||
get_generated_imgs_path,
|
||||
get_generated_imgs_todays_subdir,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
get_generation_text_info,
|
||||
update_lora_weight,
|
||||
resize_stencil,
|
||||
_compile_module,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils.civitai import get_civitai_checkpoint
|
||||
from apps.stable_diffusion.src.utils.resamplers import (
|
||||
resamplers,
|
||||
resampler_list,
|
||||
)
|
||||
42
apps/stable_diffusion/src/utils/civitai.py
Normal file
42
apps/stable_diffusion/src/utils/civitai.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import re
|
||||
import requests
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def get_civitai_checkpoint(url: str):
|
||||
with requests.get(url, allow_redirects=True, stream=True) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
# civitai api returns the filename in the content disposition
|
||||
base_filename = re.findall(
|
||||
'"([^"]*)"', response.headers["Content-Disposition"]
|
||||
)[0]
|
||||
destination_path = (
|
||||
Path.cwd() / (args.ckpt_dir or "models") / base_filename
|
||||
)
|
||||
|
||||
# we don't have this model downloaded yet
|
||||
if not destination_path.is_file():
|
||||
print(
|
||||
f"downloading civitai model from {url} to {destination_path}"
|
||||
)
|
||||
|
||||
size = int(response.headers["content-length"], 0)
|
||||
progress_bar = tqdm(total=size, unit="iB", unit_scale=True)
|
||||
|
||||
with open(destination_path, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=65536):
|
||||
f.write(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
|
||||
progress_bar.close()
|
||||
|
||||
# we already have this model downloaded
|
||||
else:
|
||||
print(f"civitai model already downloaded to {destination_path}")
|
||||
|
||||
response.close()
|
||||
return destination_path.as_posix()
|
||||
20
apps/stable_diffusion/src/utils/profiler.py
Normal file
20
apps/stable_diffusion/src/utils/profiler.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
|
||||
|
||||
# Helper function to profile the vulkan device.
|
||||
def start_profiling(file_path="foo.rdc", profiling_mode="queue"):
|
||||
from shark.parser import shark_args
|
||||
|
||||
if shark_args.vulkan_debug_utils and "vulkan" in args.device:
|
||||
import iree
|
||||
|
||||
print(f"Profiling and saving to {file_path}.")
|
||||
vulkan_device = iree.runtime.get_device(args.device)
|
||||
vulkan_device.begin_profiling(mode=profiling_mode, file_path=file_path)
|
||||
return vulkan_device
|
||||
return None
|
||||
|
||||
|
||||
def end_profiling(device):
|
||||
if device:
|
||||
return device.end_profiling()
|
||||
12
apps/stable_diffusion/src/utils/resamplers.py
Normal file
12
apps/stable_diffusion/src/utils/resamplers.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import PIL.Image as Image
|
||||
|
||||
resamplers = {
|
||||
"Lanczos": Image.Resampling.LANCZOS,
|
||||
"Nearest Neighbor": Image.Resampling.NEAREST,
|
||||
"Bilinear": Image.Resampling.BILINEAR,
|
||||
"Bicubic": Image.Resampling.BICUBIC,
|
||||
"Hamming": Image.Resampling.HAMMING,
|
||||
"Box": Image.Resampling.BOX,
|
||||
}
|
||||
|
||||
resampler_list = resamplers.keys()
|
||||
37
apps/stable_diffusion/src/utils/resources.py
Normal file
37
apps/stable_diffusion/src/utils/resources.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import os
|
||||
import json
|
||||
import sys
|
||||
|
||||
|
||||
def resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
base_path = getattr(
|
||||
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
|
||||
)
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
|
||||
def get_json_file(path):
|
||||
json_var = []
|
||||
loc_json = resource_path(path)
|
||||
if os.path.exists(loc_json):
|
||||
with open(loc_json, encoding="utf-8") as fopen:
|
||||
json_var = json.load(fopen)
|
||||
|
||||
if not json_var:
|
||||
print(f"Unable to fetch {path}")
|
||||
|
||||
return json_var
|
||||
|
||||
|
||||
# TODO: This shouldn't be called from here, every time the file imports
|
||||
# it will run all the global vars.
|
||||
prompt_examples = get_json_file("resources/prompts.json")
|
||||
models_db = get_json_file("resources/model_db.json")
|
||||
|
||||
# The base_model contains the input configuration for the different
|
||||
# models and also helps in providing information for the variants.
|
||||
base_models = get_json_file("resources/base_model.json")
|
||||
|
||||
# Contains optimization flags for different models.
|
||||
opt_flags = get_json_file("resources/opt_flags.json")
|
||||
495
apps/stable_diffusion/src/utils/resources/base_model.json
Normal file
495
apps/stable_diffusion/src/utils/resources/base_model.json
Normal file
@@ -0,0 +1,495 @@
|
||||
{
|
||||
"clip": {
|
||||
"token" : {
|
||||
"shape" : [
|
||||
"2*batch_size",
|
||||
"max_len"
|
||||
],
|
||||
"dtype":"i64"
|
||||
}
|
||||
},
|
||||
"sdxl_clip": {
|
||||
"token" : {
|
||||
"shape" : [
|
||||
"1*batch_size",
|
||||
"max_len"
|
||||
],
|
||||
"dtype":"i64"
|
||||
}
|
||||
},
|
||||
"vae_encode": {
|
||||
"image" : {
|
||||
"shape" : [
|
||||
"1*batch_size",3,"8*height","8*width"
|
||||
],
|
||||
"dtype":"f32"
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"vae": {
|
||||
"latents" : {
|
||||
"shape" : [
|
||||
"1*batch_size",4,"height","width"
|
||||
],
|
||||
"dtype":"f32"
|
||||
}
|
||||
},
|
||||
"vae_upscaler": {
|
||||
"latents" : {
|
||||
"shape" : [
|
||||
"1*batch_size",4,"8*height","8*width"
|
||||
],
|
||||
"dtype":"f32"
|
||||
}
|
||||
}
|
||||
},
|
||||
"unet": {
|
||||
"stabilityai/stable-diffusion-2-1": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"1*batch_size",
|
||||
4,
|
||||
"height",
|
||||
"width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"embedding": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
1024
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"guidance_scale": {
|
||||
"shape": 2,
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"CompVis/stable-diffusion-v1-4": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"1*batch_size",
|
||||
4,
|
||||
"height",
|
||||
"width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"embedding": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
768
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"guidance_scale": {
|
||||
"shape": 2,
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"stabilityai/stable-diffusion-2-inpainting": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"1*batch_size",
|
||||
9,
|
||||
"height",
|
||||
"width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"embedding": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
1024
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"guidance_scale": {
|
||||
"shape": 2,
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"runwayml/stable-diffusion-inpainting": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"1*batch_size",
|
||||
9,
|
||||
"height",
|
||||
"width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"embedding": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
768
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"guidance_scale": {
|
||||
"shape": 2,
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"stabilityai/stable-diffusion-x4-upscaler": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
7,
|
||||
"8*height",
|
||||
"8*width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"embedding": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
1024
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"noise_level": {
|
||||
"shape": [2],
|
||||
"dtype": "i64"
|
||||
}
|
||||
},
|
||||
"stabilityai/sdxl-turbo": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
4,
|
||||
"height",
|
||||
"width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"prompt_embeds": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
2048
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"text_embeds": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
1280
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"time_ids": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
6
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"guidance_scale": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"stabilityai/stable-diffusion-xl-base-1.0": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
4,
|
||||
"height",
|
||||
"width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"prompt_embeds": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
2048
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"text_embeds": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
1280
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"time_ids": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
6
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"guidance_scale": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
}
|
||||
}
|
||||
},
|
||||
"stencil_adapter": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"1*batch_size",
|
||||
4,
|
||||
"height",
|
||||
"width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"embedding": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
768
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"controlnet_hint": {
|
||||
"shape": [1, 3, "8*height", "8*width"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc1": {
|
||||
"shape": [2, 320, "height", "width"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc2": {
|
||||
"shape": [2, 320, "height", "width"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc3": {
|
||||
"shape": [2, 320, "height", "width"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc4": {
|
||||
"shape": [2, 320, "height/2", "width/2"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc5": {
|
||||
"shape": [2, 640, "height/2", "width/2"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc6": {
|
||||
"shape": [2, 640, "height/2", "width/2"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc7": {
|
||||
"shape": [2, 640, "height/4", "width/4"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc8": {
|
||||
"shape": [2, 1280, "height/4", "width/4"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc9": {
|
||||
"shape": [2, 1280, "height/4", "width/4"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc10": {
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc11": {
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc12": {
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc13": {
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"stencil_unet": {
|
||||
"CompVis/stable-diffusion-v1-4": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"1*batch_size",
|
||||
4,
|
||||
"height",
|
||||
"width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"embedding": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
768
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"guidance_scale": {
|
||||
"shape": 2,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control1": {
|
||||
"shape": [2, 320, "height", "width"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control2": {
|
||||
"shape": [2, 320, "height", "width"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control3": {
|
||||
"shape": [2, 320, "height", "width"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control4": {
|
||||
"shape": [2, 320, "height/2", "width/2"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control5": {
|
||||
"shape": [2, 640, "height/2", "width/2"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control6": {
|
||||
"shape": [2, 640, "height/2", "width/2"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control7": {
|
||||
"shape": [2, 640, "height/4", "width/4"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control8": {
|
||||
"shape": [2, 1280, "height/4", "width/4"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control9": {
|
||||
"shape": [2, 1280, "height/4", "width/4"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control10": {
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control11": {
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control12": {
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"control13": {
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale1": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale2": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale3": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale4": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale5": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale6": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale7": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale8": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale9": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale10": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale11": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale12": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale13": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
23
apps/stable_diffusion/src/utils/resources/model_config.json
Normal file
23
apps/stable_diffusion/src/utils/resources/model_config.json
Normal file
@@ -0,0 +1,23 @@
|
||||
[
|
||||
{
|
||||
"stablediffusion/v1_4":"CompVis/stable-diffusion-v1-4",
|
||||
"stablediffusion/v2_1base":"stabilityai/stable-diffusion-2-1-base",
|
||||
"stablediffusion/v2_1":"stabilityai/stable-diffusion-2-1",
|
||||
"stablediffusion/inpaint_v1":"runwayml/stable-diffusion-inpainting",
|
||||
"stablediffusion/inpaint_v2":"stabilityai/stable-diffusion-2-inpainting",
|
||||
"anythingv3/v1_4":"Linaqruf/anything-v3.0",
|
||||
"analogdiffusion/v1_4":"wavymulder/Analog-Diffusion",
|
||||
"openjourney/v1_4":"prompthero/openjourney",
|
||||
"dreamlike/v1_4":"dreamlike-art/dreamlike-diffusion-1.0"
|
||||
},
|
||||
{
|
||||
"stablediffusion/fp16":"fp16",
|
||||
"stablediffusion/fp32":"main",
|
||||
"anythingv3/fp16":"diffusers",
|
||||
"anythingv3/fp32":"diffusers",
|
||||
"analogdiffusion/fp16":"main",
|
||||
"analogdiffusion/fp32":"main",
|
||||
"openjourney/fp16":"main",
|
||||
"openjourney/fp32":"main"
|
||||
}
|
||||
]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user